1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_MLDSA
18 #include "securec.h"
19 #include "bsl_errno.h"
20 #include "bsl_sal.h"
21 #include "crypt_utils.h"
22 #include "crypt_sha3.h"
23 #include "crypt_errno.h"
24 #include "crypt_util_rand.h"
25 #include "bsl_err_internal.h"
26 #include "ml_dsa_local.h"
27 #include "eal_md_local.h"
28
29 #define BITS_OF_BYTE 8
30 #define MLDSA_SET_VECTOR_MEM(ptr, buf) {ptr = buf; buf += MLDSA_N;}
31
HashFuncH(const uint8_t * inPutA,uint32_t lenA,const uint8_t * inPutB,uint32_t lenB,uint8_t * out,uint32_t outLen)32 static int32_t HashFuncH(const uint8_t *inPutA, uint32_t lenA, const uint8_t *inPutB, uint32_t lenB,
33 uint8_t *out, uint32_t outLen)
34 {
35 uint32_t len = outLen;
36 int32_t ret = 0;
37 const EAL_MdMethod *hashMethod = EAL_MdFindMethod(CRYPT_MD_SHAKE256);
38 if (hashMethod == NULL) {
39 BSL_ERR_PUSH_ERROR(CRYPT_EAL_ALG_NOT_SUPPORT);
40 return CRYPT_EAL_ALG_NOT_SUPPORT;
41 }
42 void *mdCtx = hashMethod->newCtx();
43 if (mdCtx == NULL) {
44 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
45 return CRYPT_MEM_ALLOC_FAIL;
46 }
47
48 GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
49 GOTO_ERR_IF(hashMethod->update(mdCtx, inPutA, lenA), ret);
50 if (inPutB != NULL) {
51 GOTO_ERR_IF(hashMethod->update(mdCtx, inPutB, lenB), ret);
52 }
53 GOTO_ERR_IF(hashMethod->final(mdCtx, out, &len), ret);
54 ERR:
55 hashMethod->freeCtx(mdCtx);
56 return ret;
57 }
58
59 typedef struct {
60 int32_t *bufAddr;
61 uint32_t bufSize;
62 int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX];
63 int32_t *s2[MLDSA_K_MAX];
64 int32_t *t0[MLDSA_K_MAX];
65 int32_t *t1[MLDSA_K_MAX];
66 int32_t *s1[MLDSA_L_MAX];
67 int32_t *s1Ntt[MLDSA_L_MAX];
68 } MLDSA_KeyGenMatrixSt;
69
MLDSASetMatrixMem(uint8_t k,uint8_t l,int32_t * matrix[MLDSA_K_MAX][MLDSA_L_MAX],int32_t * buf)70 static void MLDSASetMatrixMem(uint8_t k, uint8_t l, int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX], int32_t *buf)
71 {
72 for (uint8_t i = 0; i < k; i++) {
73 for (uint8_t j = 0; j < l; j++) {
74 matrix[i][j] = buf;
75 buf += MLDSA_N;
76 }
77 }
78 }
79
MLDSAKeyGenCreateMatrix(uint8_t k,uint8_t l,MLDSA_KeyGenMatrixSt * st)80 static int32_t MLDSAKeyGenCreateMatrix(uint8_t k, uint8_t l, MLDSA_KeyGenMatrixSt *st)
81 {
82 // Key generation requires 3 two-dimensional arrays of length k and 2 of length l.
83 st->bufSize = (k * l + 3 * k + 2 * l) * MLDSA_N * sizeof(int32_t);
84 int32_t *buf = BSL_SAL_Malloc(st->bufSize);
85 if (buf == NULL) {
86 return BSL_MALLOC_FAIL;
87 }
88 st->bufAddr = buf; // Used to free memory.
89 MLDSASetMatrixMem(k, l, st->matrix, buf);
90 buf += k * l * MLDSA_N;
91 for (uint8_t i = 0; i < k; i++) {
92 MLDSA_SET_VECTOR_MEM(st->t0[i], buf);
93 MLDSA_SET_VECTOR_MEM(st->t1[i], buf);
94 MLDSA_SET_VECTOR_MEM(st->s2[i], buf);
95 }
96 for (uint8_t i = 0; i < l; i++) {
97 MLDSA_SET_VECTOR_MEM(st->s1[i], buf);
98 MLDSA_SET_VECTOR_MEM(st->s1Ntt[i], buf);
99 }
100 return CRYPT_SUCCESS;
101 }
102
103 typedef struct {
104 int32_t *bufAddr;
105 uint32_t bufSize;
106 int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX];
107 int32_t *t0[MLDSA_K_MAX];
108 int32_t *r0[MLDSA_K_MAX];
109 int32_t *s2[MLDSA_K_MAX];
110 int32_t *cs2[MLDSA_K_MAX];
111 int32_t *ct0[MLDSA_K_MAX];
112 int32_t *h[MLDSA_K_MAX];
113 int32_t *w[MLDSA_K_MAX];
114 int32_t *w1[MLDSA_K_MAX];
115 int32_t *s1[MLDSA_L_MAX];
116 int32_t *y[MLDSA_L_MAX];
117 int32_t *z[MLDSA_L_MAX];
118 } MLDSA_SignMatrixSt;
119
MLDSASignCreateMatrix(uint8_t k,uint8_t l,MLDSA_SignMatrixSt * st)120 static int32_t MLDSASignCreateMatrix(uint8_t k, uint8_t l, MLDSA_SignMatrixSt *st)
121 {
122 // The signature requires 8 two-dimensional arrays of length k and 3 of length l.
123 st->bufSize = (k * l + 8 * k + 3 * l) * MLDSA_N * sizeof(int32_t);
124 int32_t *buf = BSL_SAL_Malloc(st->bufSize);
125 if (buf == NULL) {
126 return BSL_MALLOC_FAIL;
127 }
128 st->bufAddr = buf; // Used to free memory.
129 MLDSASetMatrixMem(k, l, st->matrix, buf);
130 buf += k * l * MLDSA_N;
131 for (uint8_t i = 0; i < k; i++) {
132 MLDSA_SET_VECTOR_MEM(st->r0[i], buf);
133 MLDSA_SET_VECTOR_MEM(st->t0[i], buf);
134 MLDSA_SET_VECTOR_MEM(st->s2[i], buf);
135 MLDSA_SET_VECTOR_MEM(st->cs2[i], buf);
136 MLDSA_SET_VECTOR_MEM(st->ct0[i], buf);
137 MLDSA_SET_VECTOR_MEM(st->h[i], buf);
138 MLDSA_SET_VECTOR_MEM(st->w[i], buf);
139 MLDSA_SET_VECTOR_MEM(st->w1[i], buf);
140 }
141 for (uint8_t i = 0; i < l; i++) {
142 MLDSA_SET_VECTOR_MEM(st->s1[i], buf);
143 MLDSA_SET_VECTOR_MEM(st->y[i], buf);
144 MLDSA_SET_VECTOR_MEM(st->z[i], buf);
145 }
146 return CRYPT_SUCCESS;
147 }
148
149 typedef struct {
150 int32_t *bufAddr;
151 uint32_t bufSize;
152 int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX];
153 int32_t *t1[MLDSA_K_MAX];
154 int32_t *h[MLDSA_K_MAX];
155 int32_t *w[MLDSA_K_MAX];
156 int32_t *z[MLDSA_L_MAX];
157 } MLDSA_VerifyMatrixSt;
158
MLDSAVerifyCreateMatrix(uint8_t k,uint8_t l,MLDSA_VerifyMatrixSt * st)159 static int32_t MLDSAVerifyCreateMatrix(uint8_t k, uint8_t l, MLDSA_VerifyMatrixSt *st)
160 {
161 // Signature verification requires 3 two-dimensional arrays of length k and 1 of length l.
162 st->bufSize = (k * l + 3 * k + l) * MLDSA_N * sizeof(int32_t);
163 int32_t *buf = BSL_SAL_Malloc(st->bufSize);
164 if (buf == NULL) {
165 return BSL_MALLOC_FAIL;
166 }
167 st->bufAddr = buf; // Used to free memory.
168 MLDSASetMatrixMem(k, l, st->matrix, buf);
169 buf += k * l * MLDSA_N;
170
171 for (uint8_t i = 0; i < k; i++) {
172 MLDSA_SET_VECTOR_MEM(st->t1[i], buf);
173 MLDSA_SET_VECTOR_MEM(st->h[i], buf);
174 MLDSA_SET_VECTOR_MEM(st->w[i], buf);
175 }
176 for (uint8_t i = 0; i < l; i++) {
177 MLDSA_SET_VECTOR_MEM(st->z[i], buf);
178 }
179 return CRYPT_SUCCESS;
180 }
181
182 // NIST.FIPS.204 Algorithm 14 CoeffFromThreeBytes(b0, b1, b2)
CoeffFromThreeBytes(uint8_t b0,uint8_t b1,uint8_t b2)183 static int32_t CoeffFromThreeBytes(uint8_t b0, uint8_t b1, uint8_t b2)
184 {
185 uint8_t b = b2;
186 if (b > 0x7f) {
187 b = b - 0x80;
188 }
189 // ← 2^16 ⋅ b2′ + 2^8 ⋅ b1 + b0
190 return (((int32_t)b << 16) | ((int32_t)b1 << 8)) | b0;
191 }
192
193 // NIST.FIPS.204 Algorithm 30 RejNTTPoly(ρ)
RejNTTPoly(int32_t a[MLDSA_N],uint8_t seed[MLDSA_SEED_EXTEND_BYTES_LEN])194 static int32_t RejNTTPoly(int32_t a[MLDSA_N], uint8_t seed[MLDSA_SEED_EXTEND_BYTES_LEN])
195 {
196 int32_t ret;
197 unsigned int buflen = CRYPT_SHAKE128_BLOCKSIZE;
198 uint8_t buf[CRYPT_SHAKE128_BLOCKSIZE];
199
200 const EAL_MdMethod *hashMethod = EAL_MdFindMethod(CRYPT_MD_SHAKE128);
201 if (hashMethod == NULL) {
202 BSL_ERR_PUSH_ERROR(CRYPT_EAL_ALG_NOT_SUPPORT);
203 return CRYPT_EAL_ALG_NOT_SUPPORT;
204 }
205 void *mdCtx = hashMethod->newCtx();
206 if (mdCtx == NULL) {
207 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
208 return CRYPT_MEM_ALLOC_FAIL;
209 }
210 GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
211 GOTO_ERR_IF(hashMethod->update(mdCtx, seed, MLDSA_SEED_EXTEND_BYTES_LEN), ret);
212 GOTO_ERR_IF(hashMethod->squeeze(mdCtx, buf, buflen), ret);
213 uint32_t j = 0;
214 for (uint32_t i = 0; i < MLDSA_N;) {
215 a[i] = CoeffFromThreeBytes(buf[j], buf[j + 1], buf[j + 2]); // Data from 3 uint8_t to int32_t.
216 j += 3;
217 if (a[i] < MLDSA_Q) { // a[i] is less than MLDSA_Q is an invalid value.
218 i++;
219 }
220 if (j >= CRYPT_SHAKE128_BLOCKSIZE) {
221 GOTO_ERR_IF(hashMethod->squeeze(mdCtx, buf, buflen), ret);
222 j = 0;
223 }
224 }
225 ERR:
226 hashMethod->freeCtx(mdCtx);
227 return ret;
228 }
229
230 // NIST.FIPS.204 Algorithm 32 ExpandA(ρ)
ExpandA(const CRYPT_ML_DSA_Ctx * ctx,const uint8_t * pubSeed,int32_t * matrix[MLDSA_K_MAX][MLDSA_L_MAX])231 static int32_t ExpandA(const CRYPT_ML_DSA_Ctx *ctx, const uint8_t *pubSeed, int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX])
232 {
233 uint8_t k = ctx->info->k;
234 uint8_t l = ctx->info->l;
235 uint8_t seed[MLDSA_SEED_EXTEND_BYTES_LEN];
236 (void)memcpy_s(seed, sizeof(seed), pubSeed, MLDSA_PUBLIC_SEED_LEN);
237 for (uint8_t i = 0; i < k; i++) {
238 for (uint8_t j = 0; j < l; j++) {
239 seed[MLDSA_PUBLIC_SEED_LEN] = j;
240 seed[MLDSA_PUBLIC_SEED_LEN + 1] = i;
241 int32_t ret = RejNTTPoly(matrix[i][j], seed);
242 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
243 }
244 }
245 return CRYPT_SUCCESS;
246 }
247
248 // NIST.FIPS.204 Algorithm 31 RejBoundedPoly(ρ)
RejBoundedPoly(const CRYPT_ML_DSA_Ctx * ctx,int32_t * a,uint8_t * s)249 static int32_t RejBoundedPoly(const CRYPT_ML_DSA_Ctx *ctx, int32_t *a, uint8_t *s)
250 {
251 uint8_t buf[CRYPT_SHAKE256_BLOCKSIZE];
252 uint32_t bufLen = CRYPT_SHAKE256_BLOCKSIZE;
253 int32_t ret = CRYPT_SUCCESS;
254 const EAL_MdMethod *hashMethod = EAL_MdFindMethod(CRYPT_MD_SHAKE256);
255 if (hashMethod == NULL) {
256 BSL_ERR_PUSH_ERROR(CRYPT_EAL_ALG_NOT_SUPPORT);
257 return CRYPT_EAL_ALG_NOT_SUPPORT;
258 }
259 void *mdCtx = hashMethod->newCtx();
260 if (mdCtx == NULL) {
261 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
262 return CRYPT_MEM_ALLOC_FAIL;
263 }
264 GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
265 GOTO_ERR_IF(hashMethod->update(mdCtx, s, MLDSA_PRIVATE_SEED_LEN + 2), ret); // k and l used 2 bytes.
266 GOTO_ERR_IF(hashMethod->squeeze(mdCtx, buf, bufLen), ret);
267 for (uint32_t i = 0, j = 0; i < MLDSA_N; j++) {
268 if (j == CRYPT_SHAKE256_BLOCKSIZE) {
269 GOTO_ERR_IF(hashMethod->squeeze(mdCtx, buf, CRYPT_SHAKE256_BLOCKSIZE), ret);
270 j = 0;
271 }
272 int32_t z0 = (int32_t)(buf[j] & 0x0F);
273 int32_t z1 = (int32_t)(buf[j] >> 4u);
274 // Algorithm 15 CoeffFromHalfByte(b)
275 // if = 2 and b < 15 then return 2 − (b mod 5)
276 if (ctx->info->eta == 2) {
277 if (z0 < 0x0F) {
278 // This is Barrett Modular Multiplication, 205 == 2^10 / 5
279 z0 = z0 - ((205 * z0) >> 10) * 5; // 2 − (b mod 5)
280 a[i] = 2 - z0;
281 i++;
282 }
283 if (z1 < 0x0F && i < MLDSA_N) {
284 // Barrett Modular Multiplication, 205 == 2^10 / 5
285 z1 = z1 - ((205 * z1) >> 10) * 5;
286 a[i] = 2 - z1; // 2 − (b mod 5)
287 i++;
288 }
289 } else {
290 if (z0 < 9) { // if = 4 and b < 9 then a[i] = 4 − b
291 a[i] = 4 - z0;
292 i++;
293 }
294 if (z1 < 9 && i < MLDSA_N) { // if = 4 and b < 9 then a[i + 1] = 4 − b
295 a[i] = 4 - z1;
296 i++;
297 }
298 }
299 }
300 ERR:
301 hashMethod->freeCtx(mdCtx);
302 return ret;
303 }
304
305 // Algorithm 33 ExpandS(ρ)
ExpandS(const CRYPT_ML_DSA_Ctx * ctx,const uint8_t * prvSeed,int32_t * s1[MLDSA_L_MAX],int32_t * s2[MLDSA_K_MAX])306 static int32_t ExpandS(const CRYPT_ML_DSA_Ctx *ctx, const uint8_t *prvSeed,
307 int32_t *s1[MLDSA_L_MAX], int32_t *s2[MLDSA_K_MAX])
308 {
309 int32_t ret;
310 uint8_t k = ctx->info->k;
311 uint8_t l = ctx->info->l;
312 uint8_t seed[MLDSA_PRIVATE_SEED_LEN + 2]; // 2 bytes are reserved.
313 (void)memcpy_s(seed, sizeof(seed), prvSeed, MLDSA_PRIVATE_SEED_LEN);
314 seed[MLDSA_PRIVATE_SEED_LEN + 1] = 0;
315 for (uint8_t i = 0; i < l; i++) {
316 seed[MLDSA_PRIVATE_SEED_LEN] = i;
317 ret = RejBoundedPoly(ctx, s1[i], seed);
318 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
319 }
320 for (uint8_t i = 0; i < k; i++) {
321 seed[MLDSA_PRIVATE_SEED_LEN] = l + i;
322 ret = RejBoundedPoly(ctx, s2[i], seed);
323 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
324 }
325 return CRYPT_SUCCESS;
326 }
327
ComputesNTT(const CRYPT_ML_DSA_Ctx * ctx,int32_t * s[MLDSA_L_MAX],int32_t * sOut[MLDSA_L_MAX])328 static void ComputesNTT(const CRYPT_ML_DSA_Ctx *ctx, int32_t *s[MLDSA_L_MAX], int32_t *sOut[MLDSA_L_MAX])
329 {
330 for (uint8_t i = 0; i < ctx->info->l; i++) {
331 (void)memcpy_s(sOut[i], sizeof(int32_t) * MLDSA_N, s[i], sizeof(int32_t) * MLDSA_N);
332 MLDSA_ComputesNTT(sOut[i]);
333 }
334 return;
335 }
336
VectorsMul(int32_t * t,int32_t * matrix,int32_t * s)337 static void VectorsMul(int32_t *t, int32_t *matrix, int32_t *s)
338 {
339 for (uint32_t i = 0; i < MLDSA_N; i++) {
340 t[i] = MLDSA_MontgomeryReduce((int64_t)matrix[i] * s[i]);
341 }
342 }
343
MatrixMul(const CRYPT_ML_DSA_Ctx * ctx,int32_t * t,int32_t * matrix[MLDSA_L_MAX],int32_t * s[MLDSA_L_MAX])344 static void MatrixMul(const CRYPT_ML_DSA_Ctx *ctx, int32_t *t, int32_t *matrix[MLDSA_L_MAX], int32_t *s[MLDSA_L_MAX])
345 {
346 int32_t tmp[MLDSA_N] = { 0 };
347 VectorsMul(t, matrix[0], s[0]);
348 for (uint32_t i = 1; i < ctx->info->l; i++) {
349 VectorsMul(tmp, matrix[i], s[i]);
350 for (uint32_t j = 0; j < MLDSA_N; j++) {
351 t[j] = t[j] + tmp[j];
352 }
353 }
354 for (uint32_t j = 0; j < MLDSA_N; j++) {
355 MLDSA_MOD_Q(t[j]);
356 }
357 }
358
ComputesT(const CRYPT_ML_DSA_Ctx * ctx,int32_t * t[MLDSA_K_MAX],int32_t * matrix[MLDSA_K_MAX][MLDSA_L_MAX],int32_t * s1[MLDSA_L_MAX],int32_t * s2[MLDSA_K_MAX])359 static void ComputesT(const CRYPT_ML_DSA_Ctx *ctx, int32_t *t[MLDSA_K_MAX], int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX],
360 int32_t *s1[MLDSA_L_MAX], int32_t *s2[MLDSA_K_MAX])
361 {
362 for (uint8_t i = 0; i < ctx->info->k; i++) {
363 MatrixMul(ctx, t[i], matrix[i], s1);
364 MLDSA_ComputesINVNTT(t[i]);
365 for (int32_t j = 0; j < MLDSA_N; j++) {
366 t[i][j] = t[i][j] + s2[i][j];
367 t[i][j] = t[i][j] < 0 ? (t[i][j] + MLDSA_Q) : t[i][j];
368 }
369 }
370 }
371
ComputesPower2Round(const CRYPT_ML_DSA_Ctx * ctx,int32_t * t0[MLDSA_K_MAX],int32_t * t1[MLDSA_K_MAX])372 static void ComputesPower2Round(const CRYPT_ML_DSA_Ctx *ctx, int32_t *t0[MLDSA_K_MAX], int32_t *t1[MLDSA_K_MAX])
373 {
374 for (uint32_t i = 0; i < ctx->info->k; i++) {
375 for (int32_t j = 0; j < MLDSA_N; j++) {
376 int32_t t = (t1[i][j] + (1 << (MLDSA_D - 1)) - 1) >> MLDSA_D;
377 t0[i][j] = t1[i][j] - (t << MLDSA_D);
378 t1[i][j] = t;
379 }
380 }
381 }
382
383 // The following encoding function encodes MLDSA_N int32_t data into the uint8_t array.
ByteEncode(uint8_t * buf,uint32_t * t,uint32_t bits)384 static void ByteEncode(uint8_t *buf, uint32_t *t, uint32_t bits)
385 {
386 if (bits == 10u) {
387 for (uint32_t i = 0; i < MLDSA_N / 4; i++) {
388 buf[5 * i + 0] = (uint8_t)(t[4 * i + 0] >> 0);
389 buf[5 * i + 1u] = (uint8_t)((t[4 * i + 0] >> 8u) | (t[4 * i + 1u] << 2u));
390 buf[5 * i + 2u] = (uint8_t)((t[4 * i + 1u] >> 6u) | (t[4 * i + 2u] << 4u));
391 buf[5 * i + 3u] = (uint8_t)((t[4 * i + 2u] >> 4u) | (t[4 * i + 3u] << 6u));
392 buf[5 * i + 4u] = (uint8_t)(t[4 * i + 3u] >> 2u);
393 }
394 } else if (bits == 6u) {
395 for (uint32_t i = 0; i < MLDSA_N / 4; i++) {
396 buf[3 * i + 0] = (uint8_t)(t[4 * i] | (t[4 * i + 1] << 6u));
397 buf[3 * i + 1u] = (uint8_t)(t[4 * i + 1u] >> 2 | (t[4 * i + 2u] << 4u));
398 buf[3 * i + 2u] = (uint8_t)(t[4 * i + 2u] >> 4 | (t[4 * i + 3u] << 2u));
399 }
400 } else if (bits == 4u) {
401 for (uint32_t i = 0; i < MLDSA_N / 2; i++) {
402 buf[i] = (uint8_t)(t[2 * i] | (t[2 * i + 1] << 4u));
403 }
404 }
405 }
406
ByteDecode(uint8_t * buf,uint32_t * t,uint32_t bits)407 static void ByteDecode(uint8_t *buf, uint32_t *t, uint32_t bits)
408 {
409 if (bits == 10u) {
410 for (uint32_t i = 0; i < MLDSA_N / 4; i++) {
411 t[4 * i + 0] = (buf[5 * i + 0] | ((uint32_t)buf[5 * i + 1] << 8)) & 0x03ff;
412 t[4 * i + 1u] = ((buf[5 * i + 1u] >> 2u) | ((uint32_t)buf[5 * i + 2u] << 6u)) & 0x03ff;
413 t[4 * i + 2u] = ((buf[5 * i + 2u] >> 4u) | ((uint32_t)buf[5 * i + 3u] << 4u)) & 0x03ff;
414 t[4 * i + 3u] = ((buf[5 * i + 3u] >> 6u) | ((uint32_t)buf[5 * i + 4u] << 2u)) & 0x03ff;
415 }
416 }
417 }
418
BitPack(uint8_t * buf,uint32_t w[MLDSA_N],uint32_t bits,uint32_t b)419 static void BitPack(uint8_t *buf, uint32_t w[MLDSA_N], uint32_t bits, uint32_t b)
420 {
421 uint32_t t[8] = {0};
422 uint32_t i;
423 uint32_t n;
424 if (bits == 3u) {
425 for (i = 0; i < MLDSA_N / 8; i++) {
426 for (uint32_t j = 0; j < 8; j++) {
427 t[j] = b - (uint32_t)w[i * 8 + j];
428 }
429 n = bits * i;
430 buf[n + 0] = (uint8_t)((t[0]) | (t[1] << 3u) | (t[2] << 6u));
431 buf[n + 1u] = (uint8_t)((t[2] >> 2u) | (t[3] << 1u) | (t[4] << 4u) | (t[5] << 7u));
432 buf[n + 2u] = (uint8_t)((t[5] >> 1u) | (t[6] << 2u) | (t[7] << 5u));
433 }
434 } else if (bits == 4u) {
435 for (i = 0; i < MLDSA_N / 2; i++) {
436 t[0] = (int32_t)b - w[i * 2];
437 t[1] = (int32_t)b - w[i * 2 + 1];
438 buf[i] = (uint8_t)(t[0] | (t[1] << 4u));
439 }
440 } else if (bits == MLDSA_D) {
441 for (i = 0; i < MLDSA_N / 8; i++) {
442 for (uint32_t j = 0; j < 8; j++) {
443 t[j] = b - w[i * 8 + j];
444 }
445 n = bits * i;
446 buf[n + 0] = (uint8_t)t[0];
447 buf[n + 1] = (uint8_t)(t[0] >> 8u);
448 buf[n + 1] |= (uint8_t)(t[1] << 5u);
449 buf[n + 2] = (uint8_t)(t[1] >> 3u);
450 buf[n + 3] = (uint8_t)(t[1] >> 11u);
451 buf[n + 3] |= (uint8_t)(t[2] << 2u);
452 buf[n + 4] = (uint8_t)(t[2] >> 6u);
453 buf[n + 4] |= (uint8_t)(t[3] << 7u);
454 buf[n + 5] = (uint8_t)(t[3] >> 1u);
455 buf[n + 6] = (uint8_t)(t[3] >> 9u);
456 buf[n + 6] |= (uint8_t)(t[4] << 4u);
457 buf[n + 7] = (uint8_t)(t[4] >> 4u);
458 buf[n + 8] = (uint8_t)(t[4] >> 12u);
459 buf[n + 8] |= (uint8_t)(t[5] << 1u);
460 buf[n + 9] = (uint8_t)(t[5] >> 7u);
461 buf[n + 9] |= (uint8_t)(t[6] << 6u);
462 buf[n + 10] = (uint8_t)(t[6] >> 2u);
463 buf[n + 11] = (uint8_t)(t[6] >> 10u);
464 buf[n + 11] |= (uint8_t)(t[7] << 3u);
465 buf[n + 12] = (uint8_t)(t[7] >> 5u);
466 }
467 }
468 // bits has only this three values.
469 return;
470 }
471
BitUnPake(const uint8_t * v,uint32_t w[MLDSA_N],uint32_t bits,uint32_t b)472 static void BitUnPake(const uint8_t *v, uint32_t w[MLDSA_N], uint32_t bits, uint32_t b)
473 {
474 uint32_t t[8] = {0};
475 uint32_t i;
476 uint32_t n;
477 if (bits == 3u) {
478 for (i = 0; i < MLDSA_N / 8; i++) {
479 n = bits * i;
480 t[0] = (v[n + 0]) & 0x07;
481 t[1] = (v[n + 0] >> 3u) & 0x07;
482 t[2] = ((v[n + 0] >> 6u) | (v[n + 1] << 2u)) & 0x07;
483 t[3] = (v[n + 1u] >> 1u) & 0x07;
484 t[4] = (v[n + 1u] >> 4u) & 0x07;
485 t[5] = ((v[n + 1u] >> 7u) | (v[n + 2] << 1u)) & 0x07;
486 t[6] = (v[n + 2u] >> 2u) & 0x07;
487 t[7] = (v[n + 2u] >> 5u) & 0x07;
488
489 for (uint32_t j = 0; j < 8; j++) {
490 w[i * 8 + j] = b - t[j];
491 }
492 }
493 } else if (bits == 4u) {
494 for (i = 0; i < MLDSA_N / 2; i++) {
495 t[0] = v[i] & 0x0f;
496 t[1] = (v[i] >> 4u) & 0x0f;
497 w[i * 2] = b - t[0];
498 w[i * 2 + 1] = b - t[1];
499 }
500 } else if (bits == MLDSA_D) {
501 for (i = 0; i < MLDSA_N / 8; i++) {
502 n = bits * i;
503 t[0] = (v[n + 0] | ((uint32_t)v[n + 1] << 8u)) & 0x1fff;
504 t[1] = (v[n + 1] >> 5u | ((uint32_t)v[n + 2u] << 3u) |
505 ((uint32_t)v[n + 3u] << 11u)) & 0x1fff;
506 t[2] = (v[n + 3u] >> 2u | ((uint32_t)v[n + 4u] << 6u)) & 0x1fff;
507 t[3] = (v[n + 4u] >> 7u | ((uint32_t)v[n + 5u] << 1u) |
508 ((uint32_t)v[n + 6u] << 9u)) & 0x1fff;
509
510 t[4] = (v[n + 6u] >> 4u | ((uint32_t)v[n + 7u] << 4u) |
511 ((uint32_t)v[n + 8u] << 12u)) & 0x1fff;
512 t[5] = (v[n + 8u] >> 1u | ((uint32_t)v[n + 9u] << 7u)) & 0x1fff;
513 t[6] = (v[n + 9u] >> 6u | ((uint32_t)v[n + 10u] << 2u) |
514 ((uint32_t)v[n + 11u] << 10u)) & 0x1fff;
515 t[7] = (v[n + 11u] >> 3u | ((uint32_t)v[n + 12u] << 5u)) & 0x1fff;
516
517 for (uint32_t j = 0; j < 8; j++) {
518 w[i * 8 + j] = b - t[j];
519 }
520 }
521 }
522 // bits has only this three values.
523 return;
524 }
525
SignBitPack(uint8_t * buf,uint32_t w[MLDSA_N],uint32_t bits,uint32_t b)526 static void SignBitPack(uint8_t *buf, uint32_t w[MLDSA_N], uint32_t bits, uint32_t b)
527 {
528 uint32_t t[4] = {0};
529 uint32_t i;
530 uint32_t n;
531 if (bits == GAMMA_BITS_OF_MLDSA_44) {
532 for (i = 0; i < MLDSA_N / 4; i++) {
533 for (uint32_t j = 0; j < 4; j++) {
534 t[j] = b - w[i * 4 + j];
535 }
536 n = 9 * i;
537 buf[n + 0] = (uint8_t)t[0];
538 buf[n + 1u] = (uint8_t)(t[0] >> 8u);
539 buf[n + 2u] = (uint8_t)(t[0] >> 16u | t[1] << 2u);
540 buf[n + 3u] = (uint8_t)(t[1] >> 6u);
541 buf[n + 4u] = (uint8_t)(t[1] >> 14u | t[2] << 4u);
542 buf[n + 5u] = (uint8_t)(t[2] >> 4u);
543 buf[n + 6u] = (uint8_t)(t[2] >> 12u | t[3] << 6u);
544 buf[n + 7u] = (uint8_t)(t[3] >> 2u);
545 buf[n + 8u] = (uint8_t)(t[3] >> 10u);
546 }
547 } else if (bits == GAMMA_BITS_OF_MLDSA_65_87) {
548 for (i = 0; i < MLDSA_N / 2; i++) {
549 t[0] = b - w[i * 2];
550 t[1] = b - w[i * 2 + 1u];
551 n = 5 * i;
552 buf[n + 0] = (uint8_t)t[0];
553 buf[n + 1u] = (uint8_t)(t[0] >> 8u);
554 buf[n + 2u] = (uint8_t)(t[0] >> 16u | t[1] << 4u);
555 buf[n + 3u] = (uint8_t)(t[1] >> 4u);
556 buf[n + 4u] = (uint8_t)(t[1] >> 12u);
557 }
558 }
559 // bits has only this two values.
560 return;
561 }
562
SignBitUnPake(const uint8_t * v,uint32_t w[MLDSA_N],uint32_t bits,uint32_t b)563 static void SignBitUnPake(const uint8_t *v, uint32_t w[MLDSA_N], uint32_t bits, uint32_t b)
564 {
565 uint32_t t[4] = {0};
566 uint32_t i;
567 uint32_t n;
568 if (bits == GAMMA_BITS_OF_MLDSA_44) {
569 for (i = 0; i < MLDSA_N / 4; i++) {
570 n = 9 * i;
571 t[0] = (v[n + 0] | ((uint32_t)v[n + 1] << 8) | ((uint32_t)v[n + 2] << 16)) & 0x3ffff;
572 t[1] = (v[n + 2u] >> 2u | ((uint32_t)v[n + 3u] << 6u) | ((uint32_t)v[n + 4u] << 14u)) & 0x3ffff;
573 t[2] = (v[n + 4u] >> 4u | ((uint32_t)v[n + 5u] << 4u) | ((uint32_t)v[n + 6u] << 12u)) & 0x3ffff;
574 t[3] = (v[n + 6u] >> 6u | ((uint32_t)v[n + 7u] << 2u) | ((uint32_t)v[n + 8u] << 10u)) & 0x3ffff;
575
576 n = 4 * i;
577 w[n] = b - t[0];
578 w[n + 1u] = b - t[1];
579 w[n + 2u] = b - t[2];
580 w[n + 3u] = b - t[3];
581 }
582 } else if (bits == GAMMA_BITS_OF_MLDSA_65_87) {
583 for (i = 0; i < MLDSA_N / 2; i++) {
584 n = 5 * i;
585 t[0] = (v[n + 0] | ((uint32_t)v[n + 1] << 8u) | ((uint32_t)v[n + 2u] << 16u)) & 0xfffff;
586 t[1] = (v[n + 2u] >> 4u | ((uint32_t)v[n + 3u] << 4u) | ((uint32_t)v[n + 4u] << 12u)) & 0xfffff;
587
588 w[i * 2] = b - t[0];
589 w[i * 2 + 1u] = b - t[1];
590 }
591 }
592 // bits has only this two values.
593 return;
594 }
595
596 // Algorithm 22 pkEncode(ρ, t1)
PkEncode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * seed,int32_t * t[MLDSA_K_MAX])597 static void PkEncode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *seed, int32_t *t[MLDSA_K_MAX])
598 {
599 (void)memcpy_s(ctx->pubKey, ctx->pubLen, seed, MLDSA_PUBLIC_SEED_LEN);
600 for (int32_t i = 0; i < ctx->info->k; i++) {
601 // 10 is bitlen(−1) − d
602 ByteEncode(ctx->pubKey + MLDSA_PUBLIC_SEED_LEN + i * MLDSA_PUBKEY_POLYT_PACKEDBYTES, (uint32_t *)t[i], 10);
603 }
604 }
605
606 // Algorithm 23 pkDecode(pk)
PkDecode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * seed,int32_t * t[MLDSA_K_MAX])607 static void PkDecode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *seed, int32_t *t[MLDSA_K_MAX])
608 {
609 (void)memcpy_s(seed, MLDSA_PUBLIC_SEED_LEN, ctx->pubKey, MLDSA_PUBLIC_SEED_LEN);
610 for (int32_t i = 0; i < ctx->info->k; i++) {
611 // 10 is bitlen(−1) − d
612 ByteDecode(ctx->pubKey + MLDSA_PUBLIC_SEED_LEN + i * MLDSA_PUBKEY_POLYT_PACKEDBYTES, (uint32_t *)t[i], 10);
613 }
614 }
615
616 // Algorithm 24 skEncode(ρ, K,tr, 1, 2, t0)
SkEncode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * pubSeed,uint8_t * signSeed,uint8_t * tr,MLDSA_KeyGenMatrixSt * st)617 static void SkEncode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *pubSeed, uint8_t *signSeed, uint8_t *tr,
618 MLDSA_KeyGenMatrixSt *st)
619 {
620 uint32_t i;
621 uint32_t bitLen = ctx->info->eta == 2 ? 3 : 4; // 3 and 4 is bitlen(2)
622 uint32_t index = MLDSA_PUBLIC_SEED_LEN;
623 (void)memcpy_s(ctx->prvKey, ctx->prvLen, pubSeed, MLDSA_PUBLIC_SEED_LEN);
624 (void)memcpy_s(ctx->prvKey + index, ctx->prvLen - index, signSeed, MLDSA_SIGNING_SEED_LEN);
625 index += MLDSA_SIGNING_SEED_LEN;
626 (void)memcpy_s(ctx->prvKey + index, ctx->prvLen - index, tr, MLDSA_PRIVATE_SEED_LEN);
627 index += MLDSA_PRIVATE_SEED_LEN;
628 for (i = 0; i < ctx->info->l; i++) {
629 BitPack(ctx->prvKey + index, (uint32_t *)st->s1[i], bitLen, ctx->info->eta);
630 index += MLDSA_N_BYTE * bitLen;
631 }
632 for (i = 0; i < ctx->info->k; i++) {
633 BitPack(ctx->prvKey + index, (uint32_t *)st->s2[i], bitLen, ctx->info->eta);
634 index += MLDSA_N_BYTE * bitLen;
635 }
636 for (i = 0; i < ctx->info->k; i++) {
637 BitPack(ctx->prvKey + index, (uint32_t *)st->t0[i], MLDSA_D, 4096); // 2^(−1) == 4096
638 index += MLDSA_N_BYTE * MLDSA_D;
639 }
640 }
641
642 // Algorithm 25 skDecode(sk)
SkDecode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * pubSeed,uint8_t * signSeed,uint8_t * tr,MLDSA_SignMatrixSt * st)643 static void SkDecode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *pubSeed, uint8_t *signSeed, uint8_t *tr,
644 MLDSA_SignMatrixSt *st)
645 {
646 uint32_t i;
647 uint32_t bitLen = ctx->info->eta == 2 ? 3 : 4; // 3 and 4 is bitlen(2)
648 uint32_t index = MLDSA_PUBLIC_SEED_LEN;
649 (void)memcpy_s(pubSeed, MLDSA_PUBLIC_SEED_LEN, ctx->prvKey, MLDSA_PUBLIC_SEED_LEN);
650 (void)memcpy_s(signSeed, MLDSA_SIGNING_SEED_LEN, ctx->prvKey + index, MLDSA_SIGNING_SEED_LEN);
651
652 index += MLDSA_SIGNING_SEED_LEN;
653 (void)memcpy_s(tr, MLDSA_PRIVATE_SEED_LEN, ctx->prvKey + index, MLDSA_PRIVATE_SEED_LEN);
654 index += MLDSA_PRIVATE_SEED_LEN;
655
656 for (i = 0; i < ctx->info->l; i++) {
657 BitUnPake(ctx->prvKey + index, (uint32_t *)st->s1[i], bitLen, ctx->info->eta);
658 MLDSA_ComputesNTT(st->s1[i]);
659 index += MLDSA_N_BYTE * bitLen;
660 }
661 for (i = 0; i < ctx->info->k; i++) {
662 BitUnPake(ctx->prvKey + index, (uint32_t *)st->s2[i], bitLen, ctx->info->eta);
663 MLDSA_ComputesNTT(st->s2[i]);
664 index += MLDSA_N_BYTE * bitLen;
665 }
666 for (i = 0; i < ctx->info->k; i++) {
667 BitUnPake(ctx->prvKey + index, (uint32_t *)st->t0[i], MLDSA_D, 4096); // 2^(−1) == 4096
668 MLDSA_ComputesNTT(st->t0[i]);
669 index += MLDSA_N_BYTE * MLDSA_D;
670 }
671 }
672
673 // Algorithm 34 ExpandMask(ρ, μ)
ExpandMask(const CRYPT_ML_DSA_Ctx * ctx,int32_t * y[MLDSA_L_MAX],uint8_t * p,uint16_t u)674 static int32_t ExpandMask(const CRYPT_ML_DSA_Ctx *ctx, int32_t *y[MLDSA_L_MAX], uint8_t *p, uint16_t u)
675 {
676 uint16_t n = 0;
677 uint8_t v[640]; // The maximum length is 20 * 32 == 640 byte.
678 uint32_t bits = (ctx->info->k == K_VALUE_OF_MLDSA_44) ? GAMMA_BITS_OF_MLDSA_44 : GAMMA_BITS_OF_MLDSA_65_87;
679 for (uint16_t i = 0; i < ctx->info->l; i++) {
680 n = u + i;
681 p[MLDSA_PRIVATE_SEED_LEN] = (uint8_t)n;
682 p[MLDSA_PRIVATE_SEED_LEN + 1] = (uint8_t)(n >> BITS_OF_BYTE);
683 // ← H(ρ′, 32)
684 int32_t ret = HashFuncH(p, MLDSA_PRIVATE_SEED_LEN + 2, NULL, 0, v, 32 * bits);
685 if (ret != CRYPT_SUCCESS) {
686 return ret;
687 }
688 SignBitUnPake(v, (uint32_t *)y[i], bits, ctx->info->gamma1);
689 }
690 return CRYPT_SUCCESS;
691 }
692
693 // Algorithm 36 Decompose(r)
Decompose(const CRYPT_ML_DSA_Ctx * ctx,int32_t r,int32_t * r1,int32_t * r0)694 static void Decompose(const CRYPT_ML_DSA_Ctx *ctx, int32_t r, int32_t *r1, int32_t *r0)
695 {
696 int32_t t = (int32_t)(((uint32_t)r + 0x7f) >> 7u);
697 if (ctx->info->k == K_VALUE_OF_MLDSA_44) { // If is MLDSA44
698 // This is Barrett Modular Multiplication, mod is 22.
699 t = (t * 11275u + (1 << 23u)) >> 24u;
700 t ^= ((43 - t) >> 31u) & t;
701 } else {
702 t = (t * 1025u + (1 << 21u)) >> 22u;
703 t &= 0x0f;
704 }
705
706 *r0 = r - t * 2 * ctx->info->gamma2; // r1 ← (r+ − r0)/(22)
707 *r0 -= (((MLDSA_Q - 1) / 2 - *r0) >> 31u) & MLDSA_Q;
708 *r1 = t; // high bits.
709 return;
710 }
711
ComputesW(const CRYPT_ML_DSA_Ctx * ctx,int32_t * w[MLDSA_L_MAX],int32_t * w1[MLDSA_L_MAX],int32_t * matrix[MLDSA_K_MAX][MLDSA_L_MAX],int32_t * y[MLDSA_L_MAX])712 static void ComputesW(const CRYPT_ML_DSA_Ctx *ctx, int32_t *w[MLDSA_L_MAX], int32_t *w1[MLDSA_L_MAX],
713 int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX], int32_t *y[MLDSA_L_MAX])
714 {
715 for (uint8_t i = 0; i < ctx->info->k; i++) {
716 MatrixMul(ctx, w[i], matrix[i], y);
717 MLDSA_ComputesINVNTT(w[i]);
718 for (int32_t j = 0; j < MLDSA_N; j++) {
719 w[i][j] = w[i][j] < 0 ? (w[i][j] + MLDSA_Q) : w[i][j];
720 Decompose(ctx, w[i][j], &w1[i][j], &w[i][j]);
721 }
722 }
723 }
724
725 // Algorithm 28 w1Encode(w1)
W1Encode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * buf,int32_t * w[MLDSA_K_MAX])726 static void W1Encode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *buf, int32_t *w[MLDSA_K_MAX])
727 {
728 uint32_t bitLen = ctx->info->k == K_VALUE_OF_MLDSA_44 ? 6 : 4; // Only the bitLen value of MLDSA44 is 6.
729 uint32_t blockSize = ctx->info->k == K_VALUE_OF_MLDSA_44 ? 192 : 128; // MLDSA44 blockSize is 192, other is 128.
730 for (uint32_t i = 0; i < ctx->info->k; i++) {
731 ByteEncode(buf + i * blockSize, (uint32_t *)w[i], bitLen);
732 }
733 }
734
735 // Algorithm 29 SampleInBall(ρ)
SampleInBall(const CRYPT_ML_DSA_Ctx * ctx,const uint8_t * p,uint32_t pLen,int32_t c[MLDSA_N])736 static int32_t SampleInBall(const CRYPT_ML_DSA_Ctx *ctx, const uint8_t *p, uint32_t pLen, int32_t c[MLDSA_N])
737 {
738 uint8_t s[CRYPT_SHAKE256_BLOCKSIZE] = {0};
739 uint32_t sLen = CRYPT_SHAKE256_BLOCKSIZE;
740 uint64_t h = 0;
741 uint32_t index = 0;
742 uint8_t j = 0;
743 int32_t ret;
744 const EAL_MdMethod *hashMethod = EAL_MdFindMethod(CRYPT_MD_SHAKE256);
745 if (hashMethod == NULL) {
746 BSL_ERR_PUSH_ERROR(CRYPT_EAL_ALG_NOT_SUPPORT);
747 return CRYPT_EAL_ALG_NOT_SUPPORT;
748 }
749 void *mdCtx = hashMethod->newCtx();
750 if (mdCtx == NULL) {
751 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
752 return CRYPT_MEM_ALLOC_FAIL;
753 }
754 GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
755 GOTO_ERR_IF(hashMethod->update(mdCtx, p, pLen), ret);
756 GOTO_ERR_IF(hashMethod->squeeze(mdCtx, s, sLen), ret);
757 for (index = 0; index < 8; index++) { // ← H.Squeeze(ctx, 8)
758 h = h | ((uint64_t)s[index] << (8 * index));
759 }
760 for (uint32_t i = MLDSA_N - ctx->info->tau; i < MLDSA_N; i++) {
761 do {
762 if (index == CRYPT_SHAKE256_BLOCKSIZE) {
763 GOTO_ERR_IF(hashMethod->squeeze(mdCtx, s, sLen), ret);
764 index = 0;
765 }
766 j = s[index];
767 index++;
768 } while (j > i);
769
770 c[i] = c[j];
771 c[j] = 1 - ((h & 1) << 1);
772 h >>= 1;
773 }
774 ERR:
775 hashMethod->freeCtx(mdCtx);
776 return ret;
777 }
778
MLDSA_VectorsAdd(int32_t * t,int32_t * a,int32_t * b)779 static void MLDSA_VectorsAdd(int32_t *t, int32_t *a, int32_t *b)
780 {
781 for (uint32_t i = 0; i < MLDSA_N; i++) {
782 t[i] = a[i] + b[i];
783 MLDSA_MOD_Q(t[i]);
784 }
785 }
786
MLDSA_VectorsSub(int32_t * t,int32_t * a,int32_t * b)787 static void MLDSA_VectorsSub(int32_t *t, int32_t *a, int32_t *b)
788 {
789 for (uint32_t i = 0; i < MLDSA_N; i++) {
790 t[i] = a[i] - b[i];
791 MLDSA_MOD_Q(t[i]);
792 }
793 }
794
ComputesZ(const CRYPT_ML_DSA_Ctx * ctx,int32_t * y[MLDSA_L_MAX],int32_t * c,int32_t * s[MLDSA_L_MAX],int32_t * z[MLDSA_L_MAX])795 static void ComputesZ(const CRYPT_ML_DSA_Ctx *ctx, int32_t *y[MLDSA_L_MAX], int32_t *c, int32_t *s[MLDSA_L_MAX],
796 int32_t *z[MLDSA_L_MAX])
797 {
798 for (uint8_t i = 0; i < ctx->info->l; i++) {
799 VectorsMul(z[i], c, s[i]);
800 MLDSA_ComputesINVNTT(z[i]);
801 MLDSA_VectorsAdd(z[i], y[i], z[i]);
802 }
803 }
804
ValidityChecks(int32_t * z,uint32_t t)805 static bool ValidityChecks(int32_t *z, uint32_t t)
806 {
807 uint32_t n;
808 for (uint32_t j = 0; j < MLDSA_N; j++) {
809 n = z[j] >> 31; // Shift rightwards by 31 bits.
810 n = z[j] - (n & ((uint32_t)z[j] << 1));
811 if (n >= t) {
812 return false;
813 }
814 }
815 return true;
816 }
817
ValidityChecksL(const CRYPT_ML_DSA_Ctx * ctx,int32_t * z[MLDSA_L_MAX],uint32_t t)818 static bool ValidityChecksL(const CRYPT_ML_DSA_Ctx *ctx, int32_t *z[MLDSA_L_MAX], uint32_t t)
819 {
820 for (uint8_t i = 0; i < ctx->info->l; i++) {
821 if (ValidityChecks(z[i], t) == false) {
822 return false;
823 }
824 }
825 return true;
826 }
827
ValidityChecksK(const CRYPT_ML_DSA_Ctx * ctx,int32_t * z[MLDSA_K_MAX],uint32_t t)828 static bool ValidityChecksK(const CRYPT_ML_DSA_Ctx *ctx, int32_t *z[MLDSA_K_MAX], uint32_t t)
829 {
830 for (uint8_t i = 0; i < ctx->info->k; i++) {
831 if (ValidityChecks(z[i], t) == false) {
832 return false;
833 }
834 }
835 return true;
836 }
837
ComputesR(const CRYPT_ML_DSA_Ctx * ctx,int32_t * c,MLDSA_SignMatrixSt * st)838 static void ComputesR(const CRYPT_ML_DSA_Ctx *ctx, int32_t *c, MLDSA_SignMatrixSt *st)
839 {
840 for (uint8_t i = 0; i < ctx->info->k; i++) {
841 VectorsMul(st->cs2[i], c, st->s2[i]);
842 MLDSA_ComputesINVNTT(st->cs2[i]);
843 MLDSA_VectorsSub(st->r0[i], st->w[i], st->cs2[i]);
844 }
845 }
846
ComputesCT(const CRYPT_ML_DSA_Ctx * ctx,int32_t * c,int32_t * t[MLDSA_K_MAX],int32_t * ct[MLDSA_K_MAX])847 static void ComputesCT(const CRYPT_ML_DSA_Ctx *ctx, int32_t *c, int32_t *t[MLDSA_K_MAX], int32_t *ct[MLDSA_K_MAX])
848 {
849 for (uint8_t i = 0; i < ctx->info->k; i++) {
850 VectorsMul(ct[i], c, t[i]);
851 MLDSA_ComputesINVNTT(ct[i]);
852 for (uint32_t j = 0; j < MLDSA_N; j++) {
853 int32_t m = (int32_t)(((uint32_t)ct[i][j] + (1 << 22)) >> 23); // m = (ct + 2^22) / 2^23
854 ct[i][j] = ct[i][j] - m * MLDSA_Q;
855 }
856 }
857 }
858
MakeHint(const CRYPT_ML_DSA_Ctx * ctx,MLDSA_SignMatrixSt * st)859 static uint32_t MakeHint(const CRYPT_ML_DSA_Ctx *ctx, MLDSA_SignMatrixSt *st)
860 {
861 uint32_t num = 0;
862 for (uint32_t i = 0; i < ctx->info->k; i++) {
863 MLDSA_VectorsAdd(st->w[i], st->w[i], st->ct0[i]);
864 MLDSA_VectorsSub(st->w[i], st->w[i], st->cs2[i]);
865 for (uint32_t j = 0; j < MLDSA_N; j++) {
866 if (st->w[i][j] > (int32_t)ctx->info->gamma2 || st->w[i][j] < (0 - (int32_t)ctx->info->gamma2) ||
867 (st->w[i][j] == (0 - (int32_t)ctx->info->gamma2) && st->w1[i][j] != 0)) {
868 st->h[i][j] = 1;
869 num++;
870 } else {
871 st->h[i][j] = 0;
872 }
873 }
874 }
875 return num;
876 }
877
SigEncode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * out,uint32_t outLen,int32_t * z[MLDSA_L_MAX],int32_t * h[MLDSA_K_MAX])878 static void SigEncode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *out, uint32_t outLen, int32_t *z[MLDSA_L_MAX],
879 int32_t *h[MLDSA_K_MAX])
880 {
881 // // 1 bits of MLDSA44 is 18,1 bits of MLDSA65 and MLDSA87 is 20.
882 uint32_t bits = (ctx->info->k == K_VALUE_OF_MLDSA_44) ? GAMMA_BITS_OF_MLDSA_44 : GAMMA_BITS_OF_MLDSA_65_87;
883 uint32_t blockSize = MLDSA_N / BITS_OF_BYTE * bits;
884 uint8_t *ptr = out;
885 uint32_t index = 0;
886 for (uint32_t i = 0; i < ctx->info->l; i++) {
887 SignBitPack(ptr, (uint32_t *)z[i], bits, ctx->info->gamma1);
888 ptr += blockSize;
889 }
890
891 (void)memset_s(ptr, outLen - blockSize * ctx->info->l, 0, outLen - blockSize * ctx->info->l);
892 for (uint32_t i = 0; i < ctx->info->k; i++) {
893 for (uint32_t j = 0; j < MLDSA_N; j++) {
894 if (h[i][j] != 0) {
895 ptr[index] = j;
896 index++;
897 }
898 }
899 ptr[ctx->info->omega + i] = index;
900 }
901 }
902
SigDecode(const CRYPT_ML_DSA_Ctx * ctx,const uint8_t * in,int32_t * z[MLDSA_L_MAX],int32_t * h[MLDSA_K_MAX])903 static int32_t SigDecode(const CRYPT_ML_DSA_Ctx *ctx, const uint8_t *in, int32_t *z[MLDSA_L_MAX],
904 int32_t *h[MLDSA_K_MAX])
905 {
906 uint32_t bits = (ctx->info->k == K_VALUE_OF_MLDSA_44) ? GAMMA_BITS_OF_MLDSA_44 : GAMMA_BITS_OF_MLDSA_65_87;
907 uint32_t blockSize = MLDSA_N / BITS_OF_BYTE * bits;
908 const uint8_t *ptr = in;
909 uint32_t index = 0;
910
911 for (int32_t i = 0; i < ctx->info->l; i++) {
912 SignBitUnPake(ptr, (uint32_t *)z[i], bits, ctx->info->gamma1);
913 ptr += blockSize;
914 }
915
916 for (int32_t i = 0; i < ctx->info->k; i++) {
917 if (ptr[ctx->info->omega + i] < index || ptr[ctx->info->omega + i] > ctx->info->omega) {
918 BSL_ERR_PUSH_ERROR(CRYPT_MLDSA_SIGN_DATA_ERROR);
919 return CRYPT_MLDSA_SIGN_DATA_ERROR;
920 }
921 uint32_t first = index;
922 (void)memset_s(h[i], sizeof(int32_t) * MLDSA_N, 0, sizeof(int32_t) * MLDSA_N);
923 while (index < ptr[ctx->info->omega + i]) {
924 if (index > first && (ptr[index - 1] >= ptr[index])) {
925 BSL_ERR_PUSH_ERROR(CRYPT_MLDSA_SIGN_DATA_ERROR);
926 return CRYPT_MLDSA_SIGN_DATA_ERROR;
927 }
928 h[i][ptr[index]] = 1;
929 index++;
930 }
931 }
932 for (int32_t i = index; i < (ctx->info->omega - 1); i++) {
933 RETURN_RET_IF(ptr[i] != 0, CRYPT_MLDSA_SIGN_DATA_ERROR);
934 }
935 return CRYPT_SUCCESS;
936 }
937
ComputesApproxW(const CRYPT_ML_DSA_Ctx * ctx,MLDSA_VerifyMatrixSt * st,int32_t * c,int32_t * w[MLDSA_K_MAX])938 static void ComputesApproxW(const CRYPT_ML_DSA_Ctx *ctx, MLDSA_VerifyMatrixSt *st, int32_t *c, int32_t *w[MLDSA_K_MAX])
939 {
940 MLDSA_ComputesNTT(c);
941 for (uint8_t i = 0; i < ctx->info->l; i++) {
942 MLDSA_ComputesNTT(st->z[i]);
943 }
944 for (uint8_t i = 0; i < ctx->info->k; i++) {
945 for (int32_t j = 0; j < MLDSA_N; j++) {
946 // t1 ⋅ 2^
947 st->t1[i][j] = (int32_t)((uint32_t)st->t1[i][j] << MLDSA_D);
948 }
949 // NTT(t1 ⋅ 2^)
950 MLDSA_ComputesNTT(st->t1[i]);
951 // NTT() ∘ NTT(t1 ⋅ 2^)
952 VectorsMul(st->t1[i], st->t1[i], c);
953 // A ∘ NTT(z)
954 MatrixMul(ctx, w[i], st->matrix[i], st->z);
955
956 MLDSA_VectorsSub(w[i], w[i], st->t1[i]);
957 MLDSA_ComputesINVNTT(w[i]);
958 }
959 }
960
UseHint(const CRYPT_ML_DSA_Ctx * ctx,int32_t * h[MLDSA_K_MAX],int32_t * w[MLDSA_K_MAX])961 static void UseHint(const CRYPT_ML_DSA_Ctx *ctx, int32_t *h[MLDSA_K_MAX], int32_t *w[MLDSA_K_MAX])
962 {
963 int32_t r1;
964 int32_t r0;
965 for (uint8_t i = 0; i < ctx->info->k; i++) {
966 for (uint32_t j = 0; j < MLDSA_N; j++) {
967 if (w[i][j] < 0) {
968 w[i][j] += MLDSA_Q;
969 }
970 Decompose(ctx, w[i][j], &r1, &r0);
971 if (h[i][j] == 0) {
972 w[i][j] = r1;
973 continue;
974 }
975 if (ctx->info->gamma2 == 95232) { // 95232 is (MLDSA_Q-1) / 88;
976 // ← ( − 1)/(22) = 44
977 // If r0 > 0 return (r1 + 1) mod m else return (r1 − 1) mod m
978 w[i][j] = (r0 > 0) ? ((r1 == 43) ? 0 : (r1 + 1)) : ((r1 == 0) ? 43 : (r1 - 1)); // 43 is (m - 1)
979 continue;
980 }
981 w[i][j] = ((r0 > 0) ? (r1 + 1) : (r1 - 1)) & 0x0f;
982 }
983 }
984 }
985
986 // Referenced from NIST.FIPS.204 Algorithm 6 ML-DSA.KeyGen_internal()
MLDSA_KeyGenInternal(CRYPT_ML_DSA_Ctx * ctx,uint8_t * d)987 int32_t MLDSA_KeyGenInternal(CRYPT_ML_DSA_Ctx *ctx, uint8_t *d)
988 {
989 uint8_t k = ctx->info->k;
990 uint8_t l = ctx->info->l;
991 uint8_t seed[MLDSA_SEED_EXTEND_BYTES_LEN] = { 0 };
992 uint8_t digest[MLDSA_EXPANDED_SEED_BYTES_LEN] = { 0 };
993 uint8_t tr[MLDSA_TR_MSG_LEN] = { 0 };
994 MLDSA_KeyGenMatrixSt st = { 0 };
995 int32_t ret;
996
997 GOTO_ERR_IF(MLDSAKeyGenCreateMatrix(k, l, &st), ret);
998 // 32-byte random seed + 1 byte 'k' + 1 byte 'l'
999 (void)memcpy_s(seed, sizeof(seed), d, MLDSA_SEED_BYTES_LEN);
1000 seed[MLDSA_SEED_BYTES_LEN] = k;
1001 seed[MLDSA_SEED_BYTES_LEN + 1] = l;
1002 // (ρ, ρ′, K) ∈ B32 × B64 × B32 ← H(||IntegerToBytes(k, 1)||IntegerToBytes(ℓ, 1), 128)
1003 GOTO_ERR_IF(HashFuncH(seed, sizeof(seed), NULL, 0, digest, MLDSA_EXPANDED_SEED_BYTES_LEN), ret);
1004 uint8_t *pubSeed = digest;
1005 uint8_t *prvSeed = digest + MLDSA_PUBLIC_SEED_LEN;
1006 uint8_t *signSeed = digest + MLDSA_PUBLIC_SEED_LEN + MLDSA_PRIVATE_SEED_LEN;
1007
1008 // A ← ExpandA(ρ)
1009 GOTO_ERR_IF(ExpandA(ctx, pubSeed, st.matrix), ret);
1010 // (1, 2) ← ExpandS(ρ′)
1011 GOTO_ERR_IF(ExpandS(ctx, prvSeed, st.s1, st.s2), ret);
1012
1013 // t ← NTT^−1(A ∘ NTT(1)) + 2
1014 ComputesNTT(ctx, st.s1, st.s1Ntt);
1015 ComputesT(ctx, st.t1, st.matrix, st.s1Ntt, st.s2); // t = As1 + s2
1016
1017 // (t1, t0) ← Power2Round(t)
1018 ComputesPower2Round(ctx, st.t0, st.t1);
1019 // pk ← pkEncode(ρ, t1)
1020 PkEncode(ctx, pubSeed, st.t1);
1021
1022 // tr ← H(pk, 64)
1023 GOTO_ERR_IF(HashFuncH(ctx->pubKey, ctx->pubLen, NULL, 0, tr, MLDSA_TR_MSG_LEN), ret); // Step 9
1024
1025 // sk ← skEncode(ρ, K, tr, 1, 2, t0)
1026 SkEncode(ctx, pubSeed, signSeed, tr, &st); // Step 10
1027 ERR:
1028 BSL_SAL_ClearFree(st.bufAddr, st.bufSize);
1029 BSL_SAL_CleanseData(seed, sizeof(seed));
1030 BSL_SAL_CleanseData(digest, sizeof(digest));
1031 return ret;
1032 }
1033
1034 // Referenced from NIST.FIPS.204 Algorithm 7 ML-DSA.Sign_internal(sk, ′, r)
MLDSA_SignInternal(const CRYPT_ML_DSA_Ctx * ctx,CRYPT_Data * msg,uint8_t * out,uint32_t * outLen,uint8_t * rand)1035 int32_t MLDSA_SignInternal(const CRYPT_ML_DSA_Ctx *ctx, CRYPT_Data *msg, uint8_t *out, uint32_t *outLen, uint8_t *rand)
1036 {
1037 int32_t ret = CRYPT_SUCCESS;
1038 uint8_t pubSeed[MLDSA_PUBLIC_SEED_LEN];
1039 uint8_t uBuf[MLDSA_XOF_MSG_LEN];
1040 uint8_t tr[MLDSA_TR_MSG_LEN];
1041 uint8_t signSeed[MLDSA_SIGNING_SEED_LEN + MLDSA_SEED_BYTES_LEN];
1042 (void)memcpy_s(signSeed + MLDSA_SIGNING_SEED_LEN, MLDSA_SEED_BYTES_LEN, rand, MLDSA_SEED_BYTES_LEN);
1043
1044 // The w1Len length of MLDSA44 and MLDSA65 is 768, and the w1Len length of MLDSA87 is 1024.
1045 uint32_t w1Len = (ctx->info->k == 4 || ctx->info->k == 6) ? 768 : 1024;
1046 uint8_t *w1Buf = BSL_SAL_Malloc(w1Len);
1047 RETURN_RET_IF(w1Buf == NULL, CRYPT_MEM_ALLOC_FAIL);
1048
1049 MLDSA_SignMatrixSt st = { 0 };
1050 GOTO_ERR_IF(MLDSASignCreateMatrix(ctx->info->k, ctx->info->l, &st), ret);
1051
1052 // (ρ, K, tr, 1, 2, t0) ← skDecode(sk)
1053 SkDecode(ctx, pubSeed, signSeed, tr, &st);
1054 // A ← ExpandA(ρ)
1055 GOTO_ERR_IF(ExpandA(ctx, pubSeed, st.matrix), ret);
1056 if (ctx->isMuMsg) {
1057 (void)memcpy_s(uBuf, MLDSA_XOF_MSG_LEN, msg->data, msg->len);
1058 } else {
1059 // μ ← H(BytesToBits(tr)||′, 64)
1060 GOTO_ERR_IF(HashFuncH(tr, MLDSA_TR_MSG_LEN, msg->data, msg->len, uBuf, MLDSA_XOF_MSG_LEN), ret);
1061 }
1062 // ρ″ ← H(K||r||μ, 64)
1063 uint8_t p[MLDSA_XOF_MSG_LEN + 2]; // The counter used 2 bytes.
1064 GOTO_ERR_IF(HashFuncH(signSeed, sizeof(signSeed), uBuf, MLDSA_XOF_MSG_LEN, p, MLDSA_XOF_MSG_LEN), ret);
1065
1066 uint16_t u = 0;
1067 // The length of c is λ/4.
1068 uint32_t cBufLen = ctx->info->secBits / 4;
1069 int32_t c[MLDSA_N];
1070 do {
1071 // y ← ExpandMask(ρ″, )
1072 GOTO_ERR_IF(ExpandMask(ctx, st.y, p, u), ret);
1073 u = u + ctx->info->l;
1074 ComputesNTT(ctx, st.y, st.z);
1075 // w ← NTT−1(A ∘ NTT(y)); w1 ← HighBits(w)
1076 ComputesW(ctx, st.w, st.w1, st.matrix, st.z);
1077
1078 // ← H(μ||w1Encode(w1), /4)
1079 W1Encode(ctx, w1Buf, st.w1);
1080 GOTO_ERR_IF(HashFuncH(uBuf, MLDSA_XOF_MSG_LEN, w1Buf, w1Len, out, cBufLen), ret);
1081 (void)memset_s(c, sizeof(c), 0, sizeof(c));
1082 // ∈ ← SampleInBall(c)
1083 SampleInBall(ctx, out, cBufLen, c);
1084 // ← NTT()
1085 MLDSA_ComputesNTT(c);
1086
1087 // ⟨⟨1⟩⟩ ← NTT^−1( ∘ 1); z ← y + ⟨⟨1⟩⟩
1088 ComputesZ(ctx, st.y, c, st.s1, st.z);
1089 // if ||z||∞ ≥ 1 − β
1090 if (ValidityChecksL(ctx, st.z, ctx->info->gamma1 - ctx->info->beta) == false) {
1091 continue;
1092 }
1093 // ⟨⟨2⟩⟩ ← NTT^−1( ∘ 2); 0 ← LowBits(w − ⟨⟨2⟩⟩)
1094 ComputesR(ctx, c, &st);
1095 // if ||0||∞ ≥ 2 − β
1096 if (ValidityChecksK(ctx, st.r0, ctx->info->gamma2 - ctx->info->beta) == false) {
1097 continue;
1098 }
1099 // ⟨⟨t0⟩⟩ ← NTT^−1( ∘ t0)
1100 ComputesCT(ctx, c, st.t0, st.ct0);
1101 // if ||⟨⟨t0⟩⟩||∞ ≥ 2
1102 if (ValidityChecksK(ctx, st.ct0, ctx->info->gamma2) == false) {
1103 continue;
1104 }
1105 // h ← MakeHint(−⟨⟨t0⟩⟩, w − ⟨⟨2⟩⟩ + ⟨⟨t0⟩⟩)
1106 if (MakeHint(ctx, &st) > ctx->info->omega) {
1107 continue;
1108 }
1109 break;
1110 } while (true);
1111
1112 *outLen = ctx->info->signatureLen;
1113 // σ ← sigEncode(, z̃ mod±, h)
1114 SigEncode(ctx, out + cBufLen, *outLen - cBufLen, st.z, st.h);
1115 ERR:
1116 BSL_SAL_ClearFree(st.bufAddr, st.bufSize);
1117 BSL_SAL_ClearFree(w1Buf, w1Len);
1118 BSL_SAL_CleanseData(signSeed, sizeof(signSeed));
1119 return ret;
1120 }
1121
1122 // Referenced from NIST.FIPS.204 Algorithm 8 ML-DSA.Verify_internal(pk, ′, σ)
MLDSA_VerifyInternal(const CRYPT_ML_DSA_Ctx * ctx,CRYPT_Data * msg,const uint8_t * sign,uint32_t signLen)1123 int32_t MLDSA_VerifyInternal(const CRYPT_ML_DSA_Ctx *ctx, CRYPT_Data *msg, const uint8_t *sign, uint32_t signLen)
1124 {
1125 (void)signLen;
1126 uint8_t k = ctx->info->k;
1127 uint8_t l = ctx->info->l;
1128 uint8_t pubSeed[MLDSA_PUBLIC_SEED_LEN];
1129 uint8_t uBuf[MLDSA_XOF_MSG_LEN];
1130 uint8_t cBuf[MLDSA_XOF_MSG_LEN];
1131 uint8_t tr[MLDSA_TR_MSG_LEN];
1132 uint32_t cBufLen = ctx->info->secBits / 4;
1133 MLDSA_VerifyMatrixSt st = { 0 };
1134 int32_t c[MLDSA_N] = { 0 };
1135 int32_t ret;
1136
1137 // The w1Len length of MLDSA44 and MLDSA65 is 768, and the w1Len length of MLDSA87 is 1024.
1138 uint32_t w1Len = (k == 4 || k == 6) ? 768 : 1024;
1139 uint8_t *w1Buf = BSL_SAL_Malloc(w1Len);
1140 RETURN_RET_IF(w1Buf == NULL, CRYPT_MEM_ALLOC_FAIL);
1141
1142 GOTO_ERR_IF(MLDSAVerifyCreateMatrix(k, l, &st), ret);
1143
1144 // (ρ, t1) ← pkDecode(pk)
1145 PkDecode(ctx, pubSeed, st.t1);
1146 // (c,z,h) ← sigDecode(σ)
1147 GOTO_ERR_IF(SigDecode(ctx, sign + cBufLen, st.z, st.h), ret);
1148
1149 // if ||z||∞ < 1 − β
1150 if (ValidityChecksL(ctx, st.z, ctx->info->gamma1 - ctx->info->beta) == false) {
1151 ret = CRYPT_MLDSA_SIGN_DATA_ERROR;
1152 goto ERR;
1153 }
1154
1155 // A ← ExpandA(ρ)
1156 GOTO_ERR_IF(ExpandA(ctx, pubSeed, st.matrix), ret);
1157 if (ctx->isMuMsg) {
1158 (void)memcpy_s(uBuf, MLDSA_XOF_MSG_LEN, msg->data, msg->len);
1159 } else {
1160 // tr ← H(pk, 64)
1161 GOTO_ERR_IF(HashFuncH(ctx->pubKey, ctx->pubLen, NULL, 0, tr, MLDSA_TR_MSG_LEN), ret);
1162 // μ ← (H(BytesToBits(tr)||′, 64))
1163 GOTO_ERR_IF(HashFuncH(tr, MLDSA_TR_MSG_LEN, msg->data, msg->len, uBuf, MLDSA_XOF_MSG_LEN), ret);
1164 }
1165
1166 // ∈ ← SampleInBall()
1167 SampleInBall(ctx, sign, cBufLen, c);
1168 // w′ ← NTT−1(A ∘ NTT(z) − NTT() ∘ NTT(t1 ⋅ 2))
1169 ComputesApproxW(ctx, &st, c, st.w);
1170 // w1′ ← UseHint(h, w′)
1171 UseHint(ctx, st.h, st.w);
1172 // c′← H(μ||w1Encode(w1′), /4)
1173 W1Encode(ctx, w1Buf, st.w);
1174 GOTO_ERR_IF(HashFuncH(uBuf, MLDSA_XOF_MSG_LEN, w1Buf, w1Len, cBuf, cBufLen), ret);
1175
1176 // If c and c' are not equal, verify failed.
1177 if (memcmp(sign, cBuf, cBufLen) != 0) {
1178 BSL_ERR_PUSH_ERROR(CRYPT_MLDSA_VERIFY_FAIL);
1179 ret = CRYPT_MLDSA_VERIFY_FAIL;
1180 goto ERR;
1181 }
1182 ERR:
1183 BSL_SAL_Free(st.bufAddr);
1184 BSL_SAL_Free(w1Buf);
1185 return ret;
1186 }
1187
1188 #endif