• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * This file is part of the openHiTLS project.
3  *
4  * openHiTLS is licensed under the Mulan PSL v2.
5  * You can use this software according to the terms and conditions of the Mulan PSL v2.
6  * You may obtain a copy of Mulan PSL v2 at:
7  *
8  *     http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  * See the Mulan PSL v2 for more details.
14  */
15 
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_MLDSA
18 #include "securec.h"
19 #include "bsl_errno.h"
20 #include "bsl_sal.h"
21 #include "crypt_utils.h"
22 #include "crypt_sha3.h"
23 #include "crypt_errno.h"
24 #include "crypt_util_rand.h"
25 #include "bsl_err_internal.h"
26 #include "ml_dsa_local.h"
27 #include "eal_md_local.h"
28 
29 #define BITS_OF_BYTE 8
30 #define MLDSA_SET_VECTOR_MEM(ptr, buf) {ptr = buf; buf += MLDSA_N;}
31 
HashFuncH(const uint8_t * inPutA,uint32_t lenA,const uint8_t * inPutB,uint32_t lenB,uint8_t * out,uint32_t outLen)32 static int32_t HashFuncH(const uint8_t *inPutA, uint32_t lenA, const uint8_t *inPutB, uint32_t lenB,
33     uint8_t *out, uint32_t outLen)
34 {
35     uint32_t len = outLen;
36     int32_t ret = 0;
37     const EAL_MdMethod *hashMethod = EAL_MdFindMethod(CRYPT_MD_SHAKE256);
38     if (hashMethod == NULL) {
39         BSL_ERR_PUSH_ERROR(CRYPT_EAL_ALG_NOT_SUPPORT);
40         return CRYPT_EAL_ALG_NOT_SUPPORT;
41     }
42     void *mdCtx = hashMethod->newCtx();
43     if (mdCtx == NULL) {
44         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
45         return CRYPT_MEM_ALLOC_FAIL;
46     }
47 
48     GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
49     GOTO_ERR_IF(hashMethod->update(mdCtx, inPutA, lenA), ret);
50     if (inPutB != NULL) {
51         GOTO_ERR_IF(hashMethod->update(mdCtx, inPutB, lenB), ret);
52     }
53     GOTO_ERR_IF(hashMethod->final(mdCtx, out, &len), ret);
54 ERR:
55     hashMethod->freeCtx(mdCtx);
56     return ret;
57 }
58 
59 typedef struct {
60     int32_t *bufAddr;
61     uint32_t bufSize;
62     int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX];
63     int32_t *s2[MLDSA_K_MAX];
64     int32_t *t0[MLDSA_K_MAX];
65     int32_t *t1[MLDSA_K_MAX];
66     int32_t *s1[MLDSA_L_MAX];
67     int32_t *s1Ntt[MLDSA_L_MAX];
68 } MLDSA_KeyGenMatrixSt;
69 
MLDSASetMatrixMem(uint8_t k,uint8_t l,int32_t * matrix[MLDSA_K_MAX][MLDSA_L_MAX],int32_t * buf)70 static void MLDSASetMatrixMem(uint8_t k, uint8_t l, int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX], int32_t *buf)
71 {
72     for (uint8_t i = 0; i < k; i++) {
73         for (uint8_t j = 0; j < l; j++) {
74             matrix[i][j] = buf;
75             buf += MLDSA_N;
76         }
77     }
78 }
79 
MLDSAKeyGenCreateMatrix(uint8_t k,uint8_t l,MLDSA_KeyGenMatrixSt * st)80 static int32_t MLDSAKeyGenCreateMatrix(uint8_t k, uint8_t l, MLDSA_KeyGenMatrixSt *st)
81 {
82     // Key generation requires 3 two-dimensional arrays of length k and 2 of length l.
83     st->bufSize = (k * l + 3 * k + 2 * l) * MLDSA_N * sizeof(int32_t);
84     int32_t *buf = BSL_SAL_Malloc(st->bufSize);
85     if (buf == NULL) {
86         return BSL_MALLOC_FAIL;
87     }
88     st->bufAddr = buf;  // Used to free memory.
89     MLDSASetMatrixMem(k, l, st->matrix, buf);
90     buf += k * l * MLDSA_N;
91     for (uint8_t i = 0; i < k; i++) {
92         MLDSA_SET_VECTOR_MEM(st->t0[i], buf);
93         MLDSA_SET_VECTOR_MEM(st->t1[i], buf);
94         MLDSA_SET_VECTOR_MEM(st->s2[i], buf);
95     }
96     for (uint8_t i = 0; i < l; i++) {
97         MLDSA_SET_VECTOR_MEM(st->s1[i], buf);
98         MLDSA_SET_VECTOR_MEM(st->s1Ntt[i], buf);
99     }
100     return CRYPT_SUCCESS;
101 }
102 
103 typedef struct {
104     int32_t *bufAddr;
105     uint32_t bufSize;
106     int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX];
107     int32_t *t0[MLDSA_K_MAX];
108     int32_t *r0[MLDSA_K_MAX];
109     int32_t *s2[MLDSA_K_MAX];
110     int32_t *cs2[MLDSA_K_MAX];
111     int32_t *ct0[MLDSA_K_MAX];
112     int32_t *h[MLDSA_K_MAX];
113     int32_t *w[MLDSA_K_MAX];
114     int32_t *w1[MLDSA_K_MAX];
115     int32_t *s1[MLDSA_L_MAX];
116     int32_t *y[MLDSA_L_MAX];
117     int32_t *z[MLDSA_L_MAX];
118 } MLDSA_SignMatrixSt;
119 
MLDSASignCreateMatrix(uint8_t k,uint8_t l,MLDSA_SignMatrixSt * st)120 static int32_t MLDSASignCreateMatrix(uint8_t k, uint8_t l, MLDSA_SignMatrixSt *st)
121 {
122     // The signature requires 8 two-dimensional arrays of length k and 3 of length l.
123     st->bufSize = (k * l + 8 * k + 3 * l) * MLDSA_N * sizeof(int32_t);
124     int32_t *buf = BSL_SAL_Malloc(st->bufSize);
125     if (buf == NULL) {
126         return BSL_MALLOC_FAIL;
127     }
128     st->bufAddr = buf;  // Used to free memory.
129     MLDSASetMatrixMem(k, l, st->matrix, buf);
130     buf += k * l * MLDSA_N;
131     for (uint8_t i = 0; i < k; i++) {
132         MLDSA_SET_VECTOR_MEM(st->r0[i], buf);
133         MLDSA_SET_VECTOR_MEM(st->t0[i], buf);
134         MLDSA_SET_VECTOR_MEM(st->s2[i], buf);
135         MLDSA_SET_VECTOR_MEM(st->cs2[i], buf);
136         MLDSA_SET_VECTOR_MEM(st->ct0[i], buf);
137         MLDSA_SET_VECTOR_MEM(st->h[i], buf);
138         MLDSA_SET_VECTOR_MEM(st->w[i], buf);
139         MLDSA_SET_VECTOR_MEM(st->w1[i], buf);
140     }
141     for (uint8_t i = 0; i < l; i++) {
142         MLDSA_SET_VECTOR_MEM(st->s1[i], buf);
143         MLDSA_SET_VECTOR_MEM(st->y[i], buf);
144         MLDSA_SET_VECTOR_MEM(st->z[i], buf);
145     }
146     return CRYPT_SUCCESS;
147 }
148 
149 typedef struct {
150     int32_t *bufAddr;
151     uint32_t bufSize;
152     int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX];
153     int32_t *t1[MLDSA_K_MAX];
154     int32_t *h[MLDSA_K_MAX];
155     int32_t *w[MLDSA_K_MAX];
156     int32_t *z[MLDSA_L_MAX];
157 } MLDSA_VerifyMatrixSt;
158 
MLDSAVerifyCreateMatrix(uint8_t k,uint8_t l,MLDSA_VerifyMatrixSt * st)159 static int32_t MLDSAVerifyCreateMatrix(uint8_t k, uint8_t l, MLDSA_VerifyMatrixSt *st)
160 {
161     // Signature verification requires 3 two-dimensional arrays of length k and 1 of length l.
162     st->bufSize = (k * l + 3 * k + l) * MLDSA_N * sizeof(int32_t);
163     int32_t *buf = BSL_SAL_Malloc(st->bufSize);
164     if (buf == NULL) {
165         return BSL_MALLOC_FAIL;
166     }
167     st->bufAddr = buf;  // Used to free memory.
168     MLDSASetMatrixMem(k, l, st->matrix, buf);
169     buf += k * l * MLDSA_N;
170 
171     for (uint8_t i = 0; i < k; i++) {
172         MLDSA_SET_VECTOR_MEM(st->t1[i], buf);
173         MLDSA_SET_VECTOR_MEM(st->h[i], buf);
174         MLDSA_SET_VECTOR_MEM(st->w[i], buf);
175     }
176     for (uint8_t i = 0; i < l; i++) {
177         MLDSA_SET_VECTOR_MEM(st->z[i], buf);
178     }
179     return CRYPT_SUCCESS;
180 }
181 
182 // NIST.FIPS.204 Algorithm 14 CoeffFromThreeBytes(b0, b1, b2)
CoeffFromThreeBytes(uint8_t b0,uint8_t b1,uint8_t b2)183 static int32_t CoeffFromThreeBytes(uint8_t b0, uint8_t b1, uint8_t b2)
184 {
185     uint8_t b = b2;
186     if (b > 0x7f) {
187         b = b - 0x80;
188     }
189     // �� ← 2^16 ⋅ b2′ + 2^8 ⋅ b1 + b0
190     return (((int32_t)b << 16) | ((int32_t)b1 << 8)) | b0;
191 }
192 
193 // NIST.FIPS.204 Algorithm 30 RejNTTPoly(ρ)
RejNTTPoly(int32_t a[MLDSA_N],uint8_t seed[MLDSA_SEED_EXTEND_BYTES_LEN])194 static int32_t RejNTTPoly(int32_t a[MLDSA_N], uint8_t seed[MLDSA_SEED_EXTEND_BYTES_LEN])
195 {
196     int32_t ret;
197     unsigned int buflen = CRYPT_SHAKE128_BLOCKSIZE;
198     uint8_t buf[CRYPT_SHAKE128_BLOCKSIZE];
199 
200     const EAL_MdMethod *hashMethod = EAL_MdFindMethod(CRYPT_MD_SHAKE128);
201     if (hashMethod == NULL) {
202         BSL_ERR_PUSH_ERROR(CRYPT_EAL_ALG_NOT_SUPPORT);
203         return CRYPT_EAL_ALG_NOT_SUPPORT;
204     }
205     void *mdCtx = hashMethod->newCtx();
206     if (mdCtx == NULL) {
207         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
208         return CRYPT_MEM_ALLOC_FAIL;
209     }
210     GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
211     GOTO_ERR_IF(hashMethod->update(mdCtx, seed, MLDSA_SEED_EXTEND_BYTES_LEN), ret);
212     GOTO_ERR_IF(hashMethod->squeeze(mdCtx, buf, buflen), ret);
213     uint32_t j = 0;
214     for (uint32_t i = 0; i < MLDSA_N;) {
215         a[i] = CoeffFromThreeBytes(buf[j], buf[j + 1], buf[j + 2]); // Data from 3 uint8_t to int32_t.
216         j += 3;
217         if (a[i] < MLDSA_Q) {  // a[i] is less than MLDSA_Q is an invalid value.
218             i++;
219         }
220         if (j >= CRYPT_SHAKE128_BLOCKSIZE) {
221             GOTO_ERR_IF(hashMethod->squeeze(mdCtx, buf, buflen), ret);
222             j = 0;
223         }
224     }
225 ERR:
226     hashMethod->freeCtx(mdCtx);
227     return ret;
228 }
229 
230 // NIST.FIPS.204 Algorithm 32 ExpandA(ρ)
ExpandA(const CRYPT_ML_DSA_Ctx * ctx,const uint8_t * pubSeed,int32_t * matrix[MLDSA_K_MAX][MLDSA_L_MAX])231 static int32_t ExpandA(const CRYPT_ML_DSA_Ctx *ctx, const uint8_t *pubSeed, int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX])
232 {
233     uint8_t k = ctx->info->k;
234     uint8_t l = ctx->info->l;
235     uint8_t seed[MLDSA_SEED_EXTEND_BYTES_LEN];
236     (void)memcpy_s(seed, sizeof(seed), pubSeed, MLDSA_PUBLIC_SEED_LEN);
237     for (uint8_t i = 0; i < k; i++) {
238         for (uint8_t j = 0; j < l; j++) {
239             seed[MLDSA_PUBLIC_SEED_LEN] = j;
240             seed[MLDSA_PUBLIC_SEED_LEN + 1] = i;
241             int32_t ret = RejNTTPoly(matrix[i][j], seed);
242             RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
243         }
244     }
245     return CRYPT_SUCCESS;
246 }
247 
248 // NIST.FIPS.204 Algorithm 31 RejBoundedPoly(ρ)
RejBoundedPoly(const CRYPT_ML_DSA_Ctx * ctx,int32_t * a,uint8_t * s)249 static int32_t RejBoundedPoly(const CRYPT_ML_DSA_Ctx *ctx, int32_t *a, uint8_t *s)
250 {
251     uint8_t buf[CRYPT_SHAKE256_BLOCKSIZE];
252     uint32_t bufLen = CRYPT_SHAKE256_BLOCKSIZE;
253     int32_t ret = CRYPT_SUCCESS;
254     const EAL_MdMethod *hashMethod = EAL_MdFindMethod(CRYPT_MD_SHAKE256);
255     if (hashMethod == NULL) {
256         BSL_ERR_PUSH_ERROR(CRYPT_EAL_ALG_NOT_SUPPORT);
257         return CRYPT_EAL_ALG_NOT_SUPPORT;
258     }
259     void *mdCtx = hashMethod->newCtx();
260     if (mdCtx == NULL) {
261         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
262         return CRYPT_MEM_ALLOC_FAIL;
263     }
264     GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
265     GOTO_ERR_IF(hashMethod->update(mdCtx, s, MLDSA_PRIVATE_SEED_LEN + 2), ret);  // k and l used 2 bytes.
266     GOTO_ERR_IF(hashMethod->squeeze(mdCtx, buf, bufLen), ret);
267     for (uint32_t i = 0, j = 0; i < MLDSA_N; j++) {
268         if (j == CRYPT_SHAKE256_BLOCKSIZE) {
269             GOTO_ERR_IF(hashMethod->squeeze(mdCtx, buf, CRYPT_SHAKE256_BLOCKSIZE), ret);
270             j = 0;
271         }
272         int32_t z0 = (int32_t)(buf[j] & 0x0F);
273         int32_t z1 = (int32_t)(buf[j] >> 4u);
274         // Algorithm 15 CoeffFromHalfByte(b)
275         // if �� = 2 and b < 15 then return 2 − (b mod 5)
276         if (ctx->info->eta == 2) {
277             if (z0 < 0x0F) {
278                 // This is Barrett Modular Multiplication, 205 == 2^10 / 5
279                 z0 = z0 - ((205 * z0) >> 10) * 5;  // 2 − (b mod 5)
280                 a[i] = 2 - z0;
281                 i++;
282             }
283             if (z1 < 0x0F && i < MLDSA_N) {
284                 // Barrett Modular Multiplication, 205 == 2^10 / 5
285                 z1 = z1 - ((205 * z1) >> 10) * 5;
286                 a[i] = 2 - z1;  // 2 − (b mod 5)
287                 i++;
288             }
289         } else {
290             if (z0 < 9) { // if �� = 4 and b < 9 then a[i] = 4 − b
291                 a[i] = 4 - z0;
292                 i++;
293             }
294             if (z1 < 9 && i < MLDSA_N) { // if �� = 4 and b < 9 then a[i + 1] = 4 − b
295                 a[i] = 4 - z1;
296                 i++;
297             }
298         }
299     }
300 ERR:
301     hashMethod->freeCtx(mdCtx);
302     return ret;
303 }
304 
305 // Algorithm 33 ExpandS(ρ)
ExpandS(const CRYPT_ML_DSA_Ctx * ctx,const uint8_t * prvSeed,int32_t * s1[MLDSA_L_MAX],int32_t * s2[MLDSA_K_MAX])306 static int32_t ExpandS(const CRYPT_ML_DSA_Ctx *ctx, const uint8_t *prvSeed,
307     int32_t *s1[MLDSA_L_MAX], int32_t *s2[MLDSA_K_MAX])
308 {
309     int32_t ret;
310     uint8_t k = ctx->info->k;
311     uint8_t l = ctx->info->l;
312     uint8_t seed[MLDSA_PRIVATE_SEED_LEN + 2]; // 2 bytes are reserved.
313     (void)memcpy_s(seed, sizeof(seed), prvSeed, MLDSA_PRIVATE_SEED_LEN);
314     seed[MLDSA_PRIVATE_SEED_LEN + 1] = 0;
315     for (uint8_t i = 0; i < l; i++) {
316         seed[MLDSA_PRIVATE_SEED_LEN] = i;
317         ret = RejBoundedPoly(ctx, s1[i], seed);
318         RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
319     }
320     for (uint8_t i = 0; i < k; i++) {
321         seed[MLDSA_PRIVATE_SEED_LEN] = l + i;
322         ret = RejBoundedPoly(ctx, s2[i], seed);
323         RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
324     }
325     return CRYPT_SUCCESS;
326 }
327 
ComputesNTT(const CRYPT_ML_DSA_Ctx * ctx,int32_t * s[MLDSA_L_MAX],int32_t * sOut[MLDSA_L_MAX])328 static void ComputesNTT(const CRYPT_ML_DSA_Ctx *ctx, int32_t *s[MLDSA_L_MAX], int32_t *sOut[MLDSA_L_MAX])
329 {
330     for (uint8_t i = 0; i < ctx->info->l; i++) {
331         (void)memcpy_s(sOut[i], sizeof(int32_t) * MLDSA_N, s[i], sizeof(int32_t) * MLDSA_N);
332         MLDSA_ComputesNTT(sOut[i]);
333     }
334     return;
335 }
336 
VectorsMul(int32_t * t,int32_t * matrix,int32_t * s)337 static void VectorsMul(int32_t *t, int32_t *matrix, int32_t *s)
338 {
339     for (uint32_t i = 0; i < MLDSA_N; i++) {
340         t[i] = MLDSA_MontgomeryReduce((int64_t)matrix[i] * s[i]);
341     }
342 }
343 
MatrixMul(const CRYPT_ML_DSA_Ctx * ctx,int32_t * t,int32_t * matrix[MLDSA_L_MAX],int32_t * s[MLDSA_L_MAX])344 static void MatrixMul(const CRYPT_ML_DSA_Ctx *ctx, int32_t *t, int32_t *matrix[MLDSA_L_MAX], int32_t *s[MLDSA_L_MAX])
345 {
346     int32_t tmp[MLDSA_N] = { 0 };
347     VectorsMul(t, matrix[0], s[0]);
348     for (uint32_t i = 1; i < ctx->info->l; i++) {
349         VectorsMul(tmp, matrix[i], s[i]);
350         for (uint32_t j = 0; j < MLDSA_N; j++) {
351             t[j] = t[j] + tmp[j];
352         }
353     }
354     for (uint32_t j = 0; j < MLDSA_N; j++) {
355         MLDSA_MOD_Q(t[j]);
356     }
357 }
358 
ComputesT(const CRYPT_ML_DSA_Ctx * ctx,int32_t * t[MLDSA_K_MAX],int32_t * matrix[MLDSA_K_MAX][MLDSA_L_MAX],int32_t * s1[MLDSA_L_MAX],int32_t * s2[MLDSA_K_MAX])359 static void ComputesT(const CRYPT_ML_DSA_Ctx *ctx, int32_t *t[MLDSA_K_MAX], int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX],
360     int32_t *s1[MLDSA_L_MAX], int32_t *s2[MLDSA_K_MAX])
361 {
362     for (uint8_t i = 0; i < ctx->info->k; i++) {
363         MatrixMul(ctx, t[i], matrix[i], s1);
364         MLDSA_ComputesINVNTT(t[i]);
365         for (int32_t j = 0; j < MLDSA_N; j++) {
366             t[i][j] = t[i][j] + s2[i][j];
367             t[i][j] = t[i][j] < 0 ? (t[i][j] + MLDSA_Q) : t[i][j];
368         }
369     }
370 }
371 
ComputesPower2Round(const CRYPT_ML_DSA_Ctx * ctx,int32_t * t0[MLDSA_K_MAX],int32_t * t1[MLDSA_K_MAX])372 static void ComputesPower2Round(const CRYPT_ML_DSA_Ctx *ctx, int32_t *t0[MLDSA_K_MAX], int32_t *t1[MLDSA_K_MAX])
373 {
374     for (uint32_t i = 0; i < ctx->info->k; i++) {
375         for (int32_t j = 0; j < MLDSA_N; j++) {
376             int32_t t = (t1[i][j] + (1 << (MLDSA_D - 1)) - 1) >> MLDSA_D;
377             t0[i][j] = t1[i][j] - (t << MLDSA_D);
378             t1[i][j] = t;
379         }
380     }
381 }
382 
383 // The following encoding function encodes MLDSA_N int32_t data into the uint8_t array.
ByteEncode(uint8_t * buf,uint32_t * t,uint32_t bits)384 static void ByteEncode(uint8_t *buf, uint32_t *t, uint32_t bits)
385 {
386     if (bits == 10u) {
387         for (uint32_t i = 0; i < MLDSA_N / 4; i++) {
388             buf[5 * i + 0] = (uint8_t)(t[4 * i + 0] >> 0);
389             buf[5 * i + 1u] = (uint8_t)((t[4 * i + 0] >> 8u) | (t[4 * i + 1u] << 2u));
390             buf[5 * i + 2u] = (uint8_t)((t[4 * i + 1u] >> 6u) | (t[4 * i + 2u] << 4u));
391             buf[5 * i + 3u] = (uint8_t)((t[4 * i + 2u] >> 4u) | (t[4 * i + 3u] << 6u));
392             buf[5 * i + 4u] = (uint8_t)(t[4 * i + 3u] >> 2u);
393         }
394     } else if (bits == 6u) {
395         for (uint32_t i = 0; i < MLDSA_N / 4; i++) {
396             buf[3 * i + 0] = (uint8_t)(t[4 * i] | (t[4 * i + 1] << 6u));
397             buf[3 * i + 1u] = (uint8_t)(t[4 * i + 1u] >> 2 | (t[4 * i + 2u] << 4u));
398             buf[3 * i + 2u] = (uint8_t)(t[4 * i + 2u] >> 4 | (t[4 * i + 3u] << 2u));
399         }
400     } else if (bits == 4u) {
401         for (uint32_t i = 0; i < MLDSA_N / 2; i++) {
402             buf[i] = (uint8_t)(t[2 * i] | (t[2 * i + 1] << 4u));
403         }
404     }
405 }
406 
ByteDecode(uint8_t * buf,uint32_t * t,uint32_t bits)407 static void ByteDecode(uint8_t *buf, uint32_t *t, uint32_t bits)
408 {
409     if (bits == 10u) {
410         for (uint32_t i = 0; i < MLDSA_N / 4; i++) {
411             t[4 * i + 0] = (buf[5 * i + 0] | ((uint32_t)buf[5 * i + 1] << 8)) & 0x03ff;
412             t[4 * i + 1u] = ((buf[5 * i + 1u] >> 2u) | ((uint32_t)buf[5 * i + 2u] << 6u)) & 0x03ff;
413             t[4 * i + 2u] = ((buf[5 * i + 2u] >> 4u) | ((uint32_t)buf[5 * i + 3u] << 4u)) & 0x03ff;
414             t[4 * i + 3u] = ((buf[5 * i + 3u] >> 6u) | ((uint32_t)buf[5 * i + 4u] << 2u)) & 0x03ff;
415         }
416     }
417 }
418 
BitPack(uint8_t * buf,uint32_t w[MLDSA_N],uint32_t bits,uint32_t b)419 static void BitPack(uint8_t *buf, uint32_t w[MLDSA_N], uint32_t bits, uint32_t b)
420 {
421     uint32_t t[8] = {0};
422     uint32_t i;
423     uint32_t n;
424     if (bits == 3u) {
425         for (i = 0; i < MLDSA_N / 8; i++) {
426             for (uint32_t j = 0; j < 8; j++) {
427                 t[j] = b - (uint32_t)w[i * 8 + j];
428             }
429             n = bits * i;
430             buf[n + 0] = (uint8_t)((t[0]) | (t[1] << 3u) | (t[2] << 6u));
431             buf[n + 1u] = (uint8_t)((t[2] >> 2u) | (t[3] << 1u) | (t[4] << 4u) | (t[5] << 7u));
432             buf[n + 2u] = (uint8_t)((t[5] >> 1u) | (t[6] << 2u) | (t[7] << 5u));
433         }
434     } else if (bits == 4u) {
435         for (i = 0; i < MLDSA_N / 2; i++) {
436             t[0] = (int32_t)b - w[i * 2];
437             t[1] = (int32_t)b - w[i * 2 + 1];
438             buf[i] = (uint8_t)(t[0] | (t[1] << 4u));
439         }
440     } else if (bits == MLDSA_D) {
441         for (i = 0; i < MLDSA_N / 8; i++) {
442             for (uint32_t j = 0; j < 8; j++) {
443                 t[j] = b - w[i * 8 + j];
444             }
445             n = bits * i;
446             buf[n + 0] = (uint8_t)t[0];
447             buf[n + 1] = (uint8_t)(t[0] >> 8u);
448             buf[n + 1] |= (uint8_t)(t[1] << 5u);
449             buf[n + 2] = (uint8_t)(t[1] >> 3u);
450             buf[n + 3] = (uint8_t)(t[1] >> 11u);
451             buf[n + 3] |= (uint8_t)(t[2] << 2u);
452             buf[n + 4] = (uint8_t)(t[2] >> 6u);
453             buf[n + 4] |= (uint8_t)(t[3] << 7u);
454             buf[n + 5] = (uint8_t)(t[3] >> 1u);
455             buf[n + 6] = (uint8_t)(t[3] >> 9u);
456             buf[n + 6] |= (uint8_t)(t[4] << 4u);
457             buf[n + 7] = (uint8_t)(t[4] >> 4u);
458             buf[n + 8] = (uint8_t)(t[4] >> 12u);
459             buf[n + 8] |= (uint8_t)(t[5] << 1u);
460             buf[n + 9] = (uint8_t)(t[5] >> 7u);
461             buf[n + 9] |= (uint8_t)(t[6] << 6u);
462             buf[n + 10] = (uint8_t)(t[6] >> 2u);
463             buf[n + 11] = (uint8_t)(t[6] >> 10u);
464             buf[n + 11] |= (uint8_t)(t[7] << 3u);
465             buf[n + 12] = (uint8_t)(t[7] >> 5u);
466         }
467     }
468     // bits has only this three values.
469     return;
470 }
471 
BitUnPake(const uint8_t * v,uint32_t w[MLDSA_N],uint32_t bits,uint32_t b)472 static void BitUnPake(const uint8_t *v, uint32_t w[MLDSA_N], uint32_t bits, uint32_t b)
473 {
474     uint32_t t[8] = {0};
475     uint32_t i;
476     uint32_t n;
477     if (bits == 3u) {
478         for (i = 0; i < MLDSA_N / 8; i++) {
479             n = bits * i;
480             t[0] = (v[n + 0]) & 0x07;
481             t[1] = (v[n + 0] >> 3u) & 0x07;
482             t[2] = ((v[n + 0] >> 6u) | (v[n + 1] << 2u)) & 0x07;
483             t[3] = (v[n + 1u] >> 1u) & 0x07;
484             t[4] = (v[n + 1u] >> 4u) & 0x07;
485             t[5] = ((v[n + 1u] >> 7u) | (v[n + 2] << 1u)) & 0x07;
486             t[6] = (v[n + 2u] >> 2u) & 0x07;
487             t[7] = (v[n + 2u] >> 5u) & 0x07;
488 
489             for (uint32_t j = 0; j < 8; j++) {
490                 w[i * 8 + j] = b - t[j];
491             }
492         }
493     } else if (bits == 4u) {
494         for (i = 0; i < MLDSA_N / 2; i++) {
495             t[0] = v[i] & 0x0f;
496             t[1] = (v[i] >> 4u) & 0x0f;
497             w[i * 2] = b - t[0];
498             w[i * 2 + 1] = b - t[1];
499         }
500     } else if (bits == MLDSA_D) {
501         for (i = 0; i < MLDSA_N / 8; i++) {
502             n = bits * i;
503             t[0] = (v[n + 0] | ((uint32_t)v[n + 1] << 8u)) & 0x1fff;
504             t[1] = (v[n + 1] >> 5u | ((uint32_t)v[n + 2u] << 3u) |
505                 ((uint32_t)v[n + 3u] << 11u)) & 0x1fff;
506             t[2] = (v[n + 3u] >> 2u | ((uint32_t)v[n + 4u] << 6u)) & 0x1fff;
507             t[3] = (v[n + 4u] >> 7u | ((uint32_t)v[n + 5u] << 1u) |
508                 ((uint32_t)v[n + 6u] << 9u)) & 0x1fff;
509 
510             t[4] = (v[n + 6u] >> 4u | ((uint32_t)v[n + 7u] << 4u) |
511                 ((uint32_t)v[n + 8u] << 12u)) & 0x1fff;
512             t[5] = (v[n + 8u] >> 1u | ((uint32_t)v[n + 9u] << 7u)) & 0x1fff;
513             t[6] = (v[n + 9u] >> 6u | ((uint32_t)v[n + 10u] << 2u) |
514                 ((uint32_t)v[n + 11u] << 10u)) & 0x1fff;
515             t[7] = (v[n + 11u] >> 3u | ((uint32_t)v[n + 12u] << 5u)) & 0x1fff;
516 
517             for (uint32_t j = 0; j < 8; j++) {
518                 w[i * 8 + j] = b - t[j];
519             }
520         }
521     }
522     // bits has only this three values.
523     return;
524 }
525 
SignBitPack(uint8_t * buf,uint32_t w[MLDSA_N],uint32_t bits,uint32_t b)526 static void SignBitPack(uint8_t *buf, uint32_t w[MLDSA_N], uint32_t bits, uint32_t b)
527 {
528     uint32_t t[4] = {0};
529     uint32_t i;
530     uint32_t n;
531     if (bits == GAMMA_BITS_OF_MLDSA_44) {
532         for (i = 0; i < MLDSA_N / 4; i++) {
533             for (uint32_t j = 0; j < 4; j++) {
534                 t[j] = b - w[i * 4 + j];
535             }
536             n = 9 * i;
537             buf[n + 0] = (uint8_t)t[0];
538             buf[n + 1u] = (uint8_t)(t[0] >> 8u);
539             buf[n + 2u] = (uint8_t)(t[0] >> 16u | t[1] << 2u);
540             buf[n + 3u] = (uint8_t)(t[1] >> 6u);
541             buf[n + 4u] = (uint8_t)(t[1] >> 14u | t[2] << 4u);
542             buf[n + 5u] = (uint8_t)(t[2] >> 4u);
543             buf[n + 6u] = (uint8_t)(t[2] >> 12u | t[3] << 6u);
544             buf[n + 7u] = (uint8_t)(t[3] >> 2u);
545             buf[n + 8u] = (uint8_t)(t[3] >> 10u);
546         }
547     } else if (bits == GAMMA_BITS_OF_MLDSA_65_87) {
548         for (i = 0; i < MLDSA_N / 2; i++) {
549             t[0] = b - w[i * 2];
550             t[1] = b - w[i * 2 + 1u];
551             n = 5 * i;
552             buf[n + 0] = (uint8_t)t[0];
553             buf[n + 1u] = (uint8_t)(t[0] >> 8u);
554             buf[n + 2u] = (uint8_t)(t[0] >> 16u | t[1] << 4u);
555             buf[n + 3u] = (uint8_t)(t[1] >> 4u);
556             buf[n + 4u] = (uint8_t)(t[1] >> 12u);
557         }
558     }
559     // bits has only this two values.
560     return;
561 }
562 
SignBitUnPake(const uint8_t * v,uint32_t w[MLDSA_N],uint32_t bits,uint32_t b)563 static void SignBitUnPake(const uint8_t *v, uint32_t w[MLDSA_N], uint32_t bits, uint32_t b)
564 {
565     uint32_t t[4] = {0};
566     uint32_t i;
567     uint32_t n;
568     if (bits == GAMMA_BITS_OF_MLDSA_44) {
569         for (i = 0; i < MLDSA_N / 4; i++) {
570             n = 9 * i;
571             t[0] = (v[n + 0] | ((uint32_t)v[n + 1] << 8) | ((uint32_t)v[n + 2] << 16)) & 0x3ffff;
572             t[1] = (v[n + 2u] >> 2u | ((uint32_t)v[n + 3u] << 6u) | ((uint32_t)v[n + 4u] << 14u)) & 0x3ffff;
573             t[2] = (v[n + 4u] >> 4u | ((uint32_t)v[n + 5u] << 4u) | ((uint32_t)v[n + 6u] << 12u)) & 0x3ffff;
574             t[3] = (v[n + 6u] >> 6u | ((uint32_t)v[n + 7u] << 2u) | ((uint32_t)v[n + 8u] << 10u)) & 0x3ffff;
575 
576             n = 4 * i;
577             w[n] = b - t[0];
578             w[n + 1u] = b - t[1];
579             w[n + 2u] = b - t[2];
580             w[n + 3u] = b - t[3];
581         }
582     } else if (bits == GAMMA_BITS_OF_MLDSA_65_87) {
583         for (i = 0; i < MLDSA_N / 2; i++) {
584             n = 5 * i;
585             t[0] = (v[n + 0] | ((uint32_t)v[n + 1] << 8u) | ((uint32_t)v[n + 2u] << 16u)) & 0xfffff;
586             t[1] = (v[n + 2u] >> 4u | ((uint32_t)v[n + 3u] << 4u) | ((uint32_t)v[n + 4u] << 12u)) & 0xfffff;
587 
588             w[i * 2] = b - t[0];
589             w[i * 2 + 1u] = b - t[1];
590         }
591     }
592     // bits has only this two values.
593     return;
594 }
595 
596 // Algorithm 22 pkEncode(ρ, t1)
PkEncode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * seed,int32_t * t[MLDSA_K_MAX])597 static void PkEncode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *seed, int32_t *t[MLDSA_K_MAX])
598 {
599     (void)memcpy_s(ctx->pubKey, ctx->pubLen, seed, MLDSA_PUBLIC_SEED_LEN);
600     for (int32_t i = 0; i < ctx->info->k; i++) {
601         // 10 is bitlen(��−1) − d
602         ByteEncode(ctx->pubKey + MLDSA_PUBLIC_SEED_LEN + i * MLDSA_PUBKEY_POLYT_PACKEDBYTES, (uint32_t *)t[i], 10);
603     }
604 }
605 
606 // Algorithm 23 pkDecode(pk)
PkDecode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * seed,int32_t * t[MLDSA_K_MAX])607 static void PkDecode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *seed, int32_t *t[MLDSA_K_MAX])
608 {
609     (void)memcpy_s(seed, MLDSA_PUBLIC_SEED_LEN, ctx->pubKey, MLDSA_PUBLIC_SEED_LEN);
610     for (int32_t i = 0; i < ctx->info->k; i++) {
611         // 10 is bitlen(��−1) − d
612         ByteDecode(ctx->pubKey + MLDSA_PUBLIC_SEED_LEN + i * MLDSA_PUBKEY_POLYT_PACKEDBYTES, (uint32_t *)t[i], 10);
613     }
614 }
615 
616 // Algorithm 24 skEncode(ρ, K,tr, ��1, ��2, t0)
SkEncode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * pubSeed,uint8_t * signSeed,uint8_t * tr,MLDSA_KeyGenMatrixSt * st)617 static void SkEncode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *pubSeed, uint8_t *signSeed, uint8_t *tr,
618     MLDSA_KeyGenMatrixSt *st)
619 {
620     uint32_t i;
621     uint32_t bitLen = ctx->info->eta == 2 ? 3 : 4;  // 3 and 4 is bitlen(2��)
622     uint32_t index = MLDSA_PUBLIC_SEED_LEN;
623     (void)memcpy_s(ctx->prvKey, ctx->prvLen, pubSeed, MLDSA_PUBLIC_SEED_LEN);
624     (void)memcpy_s(ctx->prvKey + index, ctx->prvLen - index, signSeed, MLDSA_SIGNING_SEED_LEN);
625     index += MLDSA_SIGNING_SEED_LEN;
626     (void)memcpy_s(ctx->prvKey + index, ctx->prvLen - index, tr, MLDSA_PRIVATE_SEED_LEN);
627     index += MLDSA_PRIVATE_SEED_LEN;
628     for (i = 0; i < ctx->info->l; i++) {
629         BitPack(ctx->prvKey + index, (uint32_t *)st->s1[i], bitLen, ctx->info->eta);
630         index += MLDSA_N_BYTE * bitLen;
631     }
632     for (i = 0; i < ctx->info->k; i++) {
633         BitPack(ctx->prvKey + index, (uint32_t *)st->s2[i], bitLen, ctx->info->eta);
634         index += MLDSA_N_BYTE * bitLen;
635     }
636     for (i = 0; i < ctx->info->k; i++) {
637         BitPack(ctx->prvKey + index, (uint32_t *)st->t0[i], MLDSA_D, 4096);  // 2^(��−1) == 4096
638         index += MLDSA_N_BYTE * MLDSA_D;
639     }
640 }
641 
642 // Algorithm 25 skDecode(sk)
SkDecode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * pubSeed,uint8_t * signSeed,uint8_t * tr,MLDSA_SignMatrixSt * st)643 static void SkDecode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *pubSeed, uint8_t *signSeed, uint8_t *tr,
644     MLDSA_SignMatrixSt *st)
645 {
646     uint32_t i;
647     uint32_t bitLen = ctx->info->eta == 2 ? 3 : 4;  // 3 and 4 is bitlen(2��)
648     uint32_t index = MLDSA_PUBLIC_SEED_LEN;
649     (void)memcpy_s(pubSeed, MLDSA_PUBLIC_SEED_LEN, ctx->prvKey, MLDSA_PUBLIC_SEED_LEN);
650     (void)memcpy_s(signSeed, MLDSA_SIGNING_SEED_LEN, ctx->prvKey + index, MLDSA_SIGNING_SEED_LEN);
651 
652     index += MLDSA_SIGNING_SEED_LEN;
653     (void)memcpy_s(tr, MLDSA_PRIVATE_SEED_LEN, ctx->prvKey + index, MLDSA_PRIVATE_SEED_LEN);
654     index += MLDSA_PRIVATE_SEED_LEN;
655 
656     for (i = 0; i < ctx->info->l; i++) {
657         BitUnPake(ctx->prvKey + index, (uint32_t *)st->s1[i], bitLen, ctx->info->eta);
658         MLDSA_ComputesNTT(st->s1[i]);
659         index += MLDSA_N_BYTE * bitLen;
660     }
661     for (i = 0; i < ctx->info->k; i++) {
662         BitUnPake(ctx->prvKey + index, (uint32_t *)st->s2[i], bitLen, ctx->info->eta);
663         MLDSA_ComputesNTT(st->s2[i]);
664         index += MLDSA_N_BYTE * bitLen;
665     }
666     for (i = 0; i < ctx->info->k; i++) {
667         BitUnPake(ctx->prvKey + index, (uint32_t *)st->t0[i], MLDSA_D, 4096);  // 2^(��−1) == 4096
668         MLDSA_ComputesNTT(st->t0[i]);
669         index += MLDSA_N_BYTE * MLDSA_D;
670     }
671 }
672 
673 // Algorithm 34 ExpandMask(ρ, μ)
ExpandMask(const CRYPT_ML_DSA_Ctx * ctx,int32_t * y[MLDSA_L_MAX],uint8_t * p,uint16_t u)674 static int32_t ExpandMask(const CRYPT_ML_DSA_Ctx *ctx, int32_t *y[MLDSA_L_MAX], uint8_t *p, uint16_t u)
675 {
676     uint16_t n = 0;
677     uint8_t v[640];  // The maximum length is 20 * 32 == 640 byte.
678     uint32_t bits = (ctx->info->k == K_VALUE_OF_MLDSA_44) ? GAMMA_BITS_OF_MLDSA_44 : GAMMA_BITS_OF_MLDSA_65_87;
679     for (uint16_t i = 0; i < ctx->info->l; i++) {
680         n = u + i;
681         p[MLDSA_PRIVATE_SEED_LEN] = (uint8_t)n;
682         p[MLDSA_PRIVATE_SEED_LEN + 1] = (uint8_t)(n >> BITS_OF_BYTE);
683         // �� ← H(ρ′, 32��)
684         int32_t ret = HashFuncH(p, MLDSA_PRIVATE_SEED_LEN + 2, NULL, 0, v, 32 * bits);
685         if (ret != CRYPT_SUCCESS) {
686             return ret;
687         }
688         SignBitUnPake(v, (uint32_t *)y[i], bits, ctx->info->gamma1);
689     }
690     return CRYPT_SUCCESS;
691 }
692 
693 // Algorithm 36 Decompose(r)
Decompose(const CRYPT_ML_DSA_Ctx * ctx,int32_t r,int32_t * r1,int32_t * r0)694 static void Decompose(const CRYPT_ML_DSA_Ctx *ctx, int32_t r, int32_t *r1, int32_t *r0)
695 {
696     int32_t t = (int32_t)(((uint32_t)r + 0x7f) >> 7u);
697     if (ctx->info->k == K_VALUE_OF_MLDSA_44) {  // If is MLDSA44
698         // This is Barrett Modular Multiplication, mod is 2��2.
699         t = (t * 11275u + (1 << 23u)) >> 24u;
700         t ^= ((43 - t) >> 31u) & t;
701     } else {
702         t = (t * 1025u + (1 << 21u)) >> 22u;
703         t &= 0x0f;
704     }
705 
706     *r0 = r - t * 2 * ctx->info->gamma2;  // r1 ← (r+ − r0)/(2��2)
707     *r0 -= (((MLDSA_Q - 1) / 2 - *r0) >> 31u) & MLDSA_Q;
708     *r1 = t;  // high bits.
709     return;
710 }
711 
ComputesW(const CRYPT_ML_DSA_Ctx * ctx,int32_t * w[MLDSA_L_MAX],int32_t * w1[MLDSA_L_MAX],int32_t * matrix[MLDSA_K_MAX][MLDSA_L_MAX],int32_t * y[MLDSA_L_MAX])712 static void ComputesW(const CRYPT_ML_DSA_Ctx *ctx, int32_t *w[MLDSA_L_MAX], int32_t *w1[MLDSA_L_MAX],
713     int32_t *matrix[MLDSA_K_MAX][MLDSA_L_MAX], int32_t *y[MLDSA_L_MAX])
714 {
715     for (uint8_t i = 0; i < ctx->info->k; i++) {
716         MatrixMul(ctx, w[i], matrix[i], y);
717         MLDSA_ComputesINVNTT(w[i]);
718         for (int32_t j = 0; j < MLDSA_N; j++) {
719             w[i][j] = w[i][j] < 0 ? (w[i][j] + MLDSA_Q) : w[i][j];
720             Decompose(ctx, w[i][j], &w1[i][j], &w[i][j]);
721         }
722     }
723 }
724 
725 // Algorithm 28 w1Encode(w1)
W1Encode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * buf,int32_t * w[MLDSA_K_MAX])726 static void W1Encode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *buf, int32_t *w[MLDSA_K_MAX])
727 {
728     uint32_t bitLen = ctx->info->k == K_VALUE_OF_MLDSA_44 ? 6 : 4;  // Only the bitLen value of MLDSA44 is 6.
729     uint32_t blockSize = ctx->info->k == K_VALUE_OF_MLDSA_44 ? 192 : 128;  // MLDSA44 blockSize is 192, other is 128.
730     for (uint32_t i = 0; i < ctx->info->k; i++) {
731         ByteEncode(buf + i * blockSize, (uint32_t *)w[i], bitLen);
732     }
733 }
734 
735 // Algorithm 29 SampleInBall(ρ)
SampleInBall(const CRYPT_ML_DSA_Ctx * ctx,const uint8_t * p,uint32_t pLen,int32_t c[MLDSA_N])736 static int32_t SampleInBall(const CRYPT_ML_DSA_Ctx *ctx, const uint8_t *p, uint32_t pLen, int32_t c[MLDSA_N])
737 {
738     uint8_t s[CRYPT_SHAKE256_BLOCKSIZE] = {0};
739     uint32_t sLen = CRYPT_SHAKE256_BLOCKSIZE;
740     uint64_t h = 0;
741     uint32_t index = 0;
742     uint8_t j = 0;
743     int32_t ret;
744     const EAL_MdMethod *hashMethod = EAL_MdFindMethod(CRYPT_MD_SHAKE256);
745     if (hashMethod == NULL) {
746         BSL_ERR_PUSH_ERROR(CRYPT_EAL_ALG_NOT_SUPPORT);
747         return CRYPT_EAL_ALG_NOT_SUPPORT;
748     }
749     void *mdCtx = hashMethod->newCtx();
750     if (mdCtx == NULL) {
751         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
752         return CRYPT_MEM_ALLOC_FAIL;
753     }
754     GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
755     GOTO_ERR_IF(hashMethod->update(mdCtx, p, pLen), ret);
756     GOTO_ERR_IF(hashMethod->squeeze(mdCtx, s, sLen), ret);
757     for (index = 0; index < 8; index++) {    //  �� ← H.Squeeze(ctx, 8)
758         h = h | ((uint64_t)s[index] << (8 * index));
759     }
760     for (uint32_t i = MLDSA_N - ctx->info->tau; i < MLDSA_N; i++) {
761         do {
762             if (index == CRYPT_SHAKE256_BLOCKSIZE) {
763                 GOTO_ERR_IF(hashMethod->squeeze(mdCtx, s, sLen), ret);
764                 index = 0;
765             }
766             j = s[index];
767             index++;
768         } while (j > i);
769 
770         c[i] = c[j];
771         c[j] = 1 - ((h & 1) << 1);
772         h >>= 1;
773     }
774 ERR:
775     hashMethod->freeCtx(mdCtx);
776     return ret;
777 }
778 
MLDSA_VectorsAdd(int32_t * t,int32_t * a,int32_t * b)779 static void MLDSA_VectorsAdd(int32_t *t, int32_t *a, int32_t *b)
780 {
781     for (uint32_t i = 0; i < MLDSA_N; i++) {
782         t[i] = a[i] + b[i];
783         MLDSA_MOD_Q(t[i]);
784     }
785 }
786 
MLDSA_VectorsSub(int32_t * t,int32_t * a,int32_t * b)787 static void MLDSA_VectorsSub(int32_t *t, int32_t *a, int32_t *b)
788 {
789     for (uint32_t i = 0; i < MLDSA_N; i++) {
790         t[i] = a[i] - b[i];
791         MLDSA_MOD_Q(t[i]);
792     }
793 }
794 
ComputesZ(const CRYPT_ML_DSA_Ctx * ctx,int32_t * y[MLDSA_L_MAX],int32_t * c,int32_t * s[MLDSA_L_MAX],int32_t * z[MLDSA_L_MAX])795 static void ComputesZ(const CRYPT_ML_DSA_Ctx *ctx, int32_t *y[MLDSA_L_MAX], int32_t *c, int32_t *s[MLDSA_L_MAX],
796     int32_t *z[MLDSA_L_MAX])
797 {
798     for (uint8_t i = 0; i < ctx->info->l; i++) {
799         VectorsMul(z[i], c, s[i]);
800         MLDSA_ComputesINVNTT(z[i]);
801         MLDSA_VectorsAdd(z[i], y[i], z[i]);
802     }
803 }
804 
ValidityChecks(int32_t * z,uint32_t t)805 static bool ValidityChecks(int32_t *z, uint32_t t)
806 {
807     uint32_t n;
808     for (uint32_t j = 0; j < MLDSA_N; j++) {
809         n = z[j] >> 31;    // Shift rightwards by 31 bits.
810         n = z[j] - (n & ((uint32_t)z[j] << 1));
811         if (n >= t) {
812             return false;
813         }
814     }
815     return true;
816 }
817 
ValidityChecksL(const CRYPT_ML_DSA_Ctx * ctx,int32_t * z[MLDSA_L_MAX],uint32_t t)818 static bool ValidityChecksL(const CRYPT_ML_DSA_Ctx *ctx, int32_t *z[MLDSA_L_MAX], uint32_t t)
819 {
820     for (uint8_t i = 0; i < ctx->info->l; i++) {
821         if (ValidityChecks(z[i], t) == false) {
822             return false;
823         }
824     }
825     return true;
826 }
827 
ValidityChecksK(const CRYPT_ML_DSA_Ctx * ctx,int32_t * z[MLDSA_K_MAX],uint32_t t)828 static bool ValidityChecksK(const CRYPT_ML_DSA_Ctx *ctx, int32_t *z[MLDSA_K_MAX], uint32_t t)
829 {
830     for (uint8_t i = 0; i < ctx->info->k; i++) {
831         if (ValidityChecks(z[i], t) == false) {
832             return false;
833         }
834     }
835     return true;
836 }
837 
ComputesR(const CRYPT_ML_DSA_Ctx * ctx,int32_t * c,MLDSA_SignMatrixSt * st)838 static void ComputesR(const CRYPT_ML_DSA_Ctx *ctx, int32_t *c, MLDSA_SignMatrixSt *st)
839 {
840     for (uint8_t i = 0; i < ctx->info->k; i++) {
841         VectorsMul(st->cs2[i], c, st->s2[i]);
842         MLDSA_ComputesINVNTT(st->cs2[i]);
843         MLDSA_VectorsSub(st->r0[i], st->w[i], st->cs2[i]);
844     }
845 }
846 
ComputesCT(const CRYPT_ML_DSA_Ctx * ctx,int32_t * c,int32_t * t[MLDSA_K_MAX],int32_t * ct[MLDSA_K_MAX])847 static void ComputesCT(const CRYPT_ML_DSA_Ctx *ctx, int32_t *c, int32_t *t[MLDSA_K_MAX], int32_t *ct[MLDSA_K_MAX])
848 {
849     for (uint8_t i = 0; i < ctx->info->k; i++) {
850         VectorsMul(ct[i], c, t[i]);
851         MLDSA_ComputesINVNTT(ct[i]);
852         for (uint32_t j = 0; j < MLDSA_N; j++) {
853             int32_t m = (int32_t)(((uint32_t)ct[i][j] + (1 << 22)) >> 23);  // m = (ct + 2^22) / 2^23
854             ct[i][j] = ct[i][j] - m * MLDSA_Q;
855         }
856     }
857 }
858 
MakeHint(const CRYPT_ML_DSA_Ctx * ctx,MLDSA_SignMatrixSt * st)859 static uint32_t MakeHint(const CRYPT_ML_DSA_Ctx *ctx, MLDSA_SignMatrixSt *st)
860 {
861     uint32_t num = 0;
862     for (uint32_t i = 0; i < ctx->info->k; i++) {
863         MLDSA_VectorsAdd(st->w[i], st->w[i], st->ct0[i]);
864         MLDSA_VectorsSub(st->w[i], st->w[i], st->cs2[i]);
865         for (uint32_t j = 0; j < MLDSA_N; j++) {
866             if (st->w[i][j] > (int32_t)ctx->info->gamma2 || st->w[i][j] < (0 - (int32_t)ctx->info->gamma2) ||
867                 (st->w[i][j] == (0 - (int32_t)ctx->info->gamma2) && st->w1[i][j] != 0)) {
868                 st->h[i][j] = 1;
869                 num++;
870             } else {
871                 st->h[i][j] = 0;
872             }
873         }
874     }
875     return num;
876 }
877 
SigEncode(const CRYPT_ML_DSA_Ctx * ctx,uint8_t * out,uint32_t outLen,int32_t * z[MLDSA_L_MAX],int32_t * h[MLDSA_K_MAX])878 static void SigEncode(const CRYPT_ML_DSA_Ctx *ctx, uint8_t *out, uint32_t outLen, int32_t *z[MLDSA_L_MAX],
879     int32_t *h[MLDSA_K_MAX])
880 {
881     // // ��1 bits of MLDSA44 is 18,��1 bits of MLDSA65 and MLDSA87 is 20.
882     uint32_t bits = (ctx->info->k == K_VALUE_OF_MLDSA_44) ? GAMMA_BITS_OF_MLDSA_44 : GAMMA_BITS_OF_MLDSA_65_87;
883     uint32_t blockSize = MLDSA_N / BITS_OF_BYTE * bits;
884     uint8_t *ptr = out;
885     uint32_t index = 0;
886     for (uint32_t i = 0; i < ctx->info->l; i++) {
887         SignBitPack(ptr, (uint32_t *)z[i], bits, ctx->info->gamma1);
888         ptr += blockSize;
889     }
890 
891     (void)memset_s(ptr, outLen - blockSize * ctx->info->l, 0, outLen - blockSize * ctx->info->l);
892     for (uint32_t i = 0; i < ctx->info->k; i++) {
893         for (uint32_t j = 0; j < MLDSA_N; j++) {
894             if (h[i][j] != 0) {
895                 ptr[index] = j;
896                 index++;
897             }
898         }
899         ptr[ctx->info->omega + i] = index;
900     }
901 }
902 
SigDecode(const CRYPT_ML_DSA_Ctx * ctx,const uint8_t * in,int32_t * z[MLDSA_L_MAX],int32_t * h[MLDSA_K_MAX])903 static int32_t SigDecode(const CRYPT_ML_DSA_Ctx *ctx, const uint8_t *in, int32_t *z[MLDSA_L_MAX],
904     int32_t *h[MLDSA_K_MAX])
905 {
906     uint32_t bits = (ctx->info->k == K_VALUE_OF_MLDSA_44) ? GAMMA_BITS_OF_MLDSA_44 : GAMMA_BITS_OF_MLDSA_65_87;
907     uint32_t blockSize = MLDSA_N / BITS_OF_BYTE * bits;
908     const uint8_t *ptr = in;
909     uint32_t index = 0;
910 
911     for (int32_t i = 0; i < ctx->info->l; i++) {
912         SignBitUnPake(ptr, (uint32_t *)z[i], bits, ctx->info->gamma1);
913         ptr += blockSize;
914     }
915 
916     for (int32_t i = 0; i < ctx->info->k; i++) {
917         if (ptr[ctx->info->omega + i] < index || ptr[ctx->info->omega + i] > ctx->info->omega) {
918             BSL_ERR_PUSH_ERROR(CRYPT_MLDSA_SIGN_DATA_ERROR);
919             return CRYPT_MLDSA_SIGN_DATA_ERROR;
920         }
921         uint32_t first = index;
922         (void)memset_s(h[i], sizeof(int32_t) * MLDSA_N, 0, sizeof(int32_t) * MLDSA_N);
923         while (index < ptr[ctx->info->omega + i]) {
924             if (index > first && (ptr[index - 1] >= ptr[index])) {
925                 BSL_ERR_PUSH_ERROR(CRYPT_MLDSA_SIGN_DATA_ERROR);
926                 return CRYPT_MLDSA_SIGN_DATA_ERROR;
927             }
928             h[i][ptr[index]] = 1;
929             index++;
930         }
931     }
932     for (int32_t i = index; i < (ctx->info->omega - 1); i++) {
933         RETURN_RET_IF(ptr[i] != 0, CRYPT_MLDSA_SIGN_DATA_ERROR);
934     }
935     return CRYPT_SUCCESS;
936 }
937 
ComputesApproxW(const CRYPT_ML_DSA_Ctx * ctx,MLDSA_VerifyMatrixSt * st,int32_t * c,int32_t * w[MLDSA_K_MAX])938 static void ComputesApproxW(const CRYPT_ML_DSA_Ctx *ctx, MLDSA_VerifyMatrixSt *st, int32_t *c, int32_t *w[MLDSA_K_MAX])
939 {
940     MLDSA_ComputesNTT(c);
941     for (uint8_t i = 0; i < ctx->info->l; i++) {
942         MLDSA_ComputesNTT(st->z[i]);
943     }
944     for (uint8_t i = 0; i < ctx->info->k; i++) {
945         for (int32_t j = 0; j < MLDSA_N; j++) {
946             // t1 ⋅ 2^��
947             st->t1[i][j] = (int32_t)((uint32_t)st->t1[i][j] << MLDSA_D);
948         }
949         // NTT(t1 ⋅ 2^��)
950         MLDSA_ComputesNTT(st->t1[i]);
951         // NTT(��) ∘ NTT(t1 ⋅ 2^��)
952         VectorsMul(st->t1[i], st->t1[i], c);
953         // A ∘ NTT(z)
954         MatrixMul(ctx, w[i], st->matrix[i], st->z);
955 
956         MLDSA_VectorsSub(w[i], w[i], st->t1[i]);
957         MLDSA_ComputesINVNTT(w[i]);
958     }
959 }
960 
UseHint(const CRYPT_ML_DSA_Ctx * ctx,int32_t * h[MLDSA_K_MAX],int32_t * w[MLDSA_K_MAX])961 static void UseHint(const CRYPT_ML_DSA_Ctx *ctx, int32_t *h[MLDSA_K_MAX], int32_t *w[MLDSA_K_MAX])
962 {
963     int32_t r1;
964     int32_t r0;
965     for (uint8_t i = 0; i < ctx->info->k; i++) {
966         for (uint32_t j = 0; j < MLDSA_N; j++) {
967             if (w[i][j] < 0) {
968                 w[i][j] += MLDSA_Q;
969             }
970             Decompose(ctx, w[i][j], &r1, &r0);
971             if (h[i][j] == 0) {
972                 w[i][j] = r1;
973                 continue;
974             }
975             if (ctx->info->gamma2 == 95232) {  // 95232 is (MLDSA_Q-1) / 88;
976                 // �� ← (�� − 1)/(2��2) = 44
977                 // If r0 > 0 return (r1 + 1) mod m else return (r1 − 1) mod m
978                 w[i][j] = (r0 > 0) ? ((r1 == 43) ? 0 : (r1 + 1)) : ((r1 == 0) ? 43 : (r1 - 1)); // 43 is (m - 1)
979                 continue;
980             }
981             w[i][j] = ((r0 > 0) ? (r1 + 1) : (r1 - 1)) & 0x0f;
982         }
983     }
984 }
985 
986 // Referenced from NIST.FIPS.204 Algorithm 6 ML-DSA.KeyGen_internal(��)
MLDSA_KeyGenInternal(CRYPT_ML_DSA_Ctx * ctx,uint8_t * d)987 int32_t MLDSA_KeyGenInternal(CRYPT_ML_DSA_Ctx *ctx, uint8_t *d)
988 {
989     uint8_t k = ctx->info->k;
990     uint8_t l = ctx->info->l;
991     uint8_t seed[MLDSA_SEED_EXTEND_BYTES_LEN] = { 0 };
992     uint8_t digest[MLDSA_EXPANDED_SEED_BYTES_LEN] = { 0 };
993     uint8_t tr[MLDSA_TR_MSG_LEN] = { 0 };
994     MLDSA_KeyGenMatrixSt st = { 0 };
995     int32_t ret;
996 
997     GOTO_ERR_IF(MLDSAKeyGenCreateMatrix(k, l, &st), ret);
998     // 32-byte random seed + 1 byte 'k' + 1 byte 'l'
999     (void)memcpy_s(seed, sizeof(seed), d, MLDSA_SEED_BYTES_LEN);
1000     seed[MLDSA_SEED_BYTES_LEN] = k;
1001     seed[MLDSA_SEED_BYTES_LEN + 1] = l;
1002     // (ρ, ρ′, K) ∈ B32 × B64 × B32 ← H(��||IntegerToBytes(k, 1)||IntegerToBytes(ℓ, 1), 128)
1003     GOTO_ERR_IF(HashFuncH(seed, sizeof(seed), NULL, 0, digest, MLDSA_EXPANDED_SEED_BYTES_LEN), ret);
1004     uint8_t *pubSeed = digest;
1005     uint8_t *prvSeed = digest + MLDSA_PUBLIC_SEED_LEN;
1006     uint8_t *signSeed = digest + MLDSA_PUBLIC_SEED_LEN + MLDSA_PRIVATE_SEED_LEN;
1007 
1008     // A ← ExpandA(ρ)
1009     GOTO_ERR_IF(ExpandA(ctx, pubSeed, st.matrix), ret);
1010     // (��1, ��2) ← ExpandS(ρ′)
1011     GOTO_ERR_IF(ExpandS(ctx, prvSeed, st.s1, st.s2), ret);
1012 
1013     // t ← NTT^−1(A ∘ NTT(��1)) + ��2
1014     ComputesNTT(ctx, st.s1, st.s1Ntt);
1015     ComputesT(ctx, st.t1, st.matrix, st.s1Ntt, st.s2);  // t = As1 + s2
1016 
1017     // (t1, t0) ← Power2Round(t)
1018     ComputesPower2Round(ctx, st.t0, st.t1);
1019     // pk ← pkEncode(ρ, t1)
1020     PkEncode(ctx, pubSeed, st.t1);
1021 
1022     // tr ← H(pk, 64)
1023     GOTO_ERR_IF(HashFuncH(ctx->pubKey, ctx->pubLen, NULL, 0, tr, MLDSA_TR_MSG_LEN), ret);  // Step 9
1024 
1025     // sk ← skEncode(ρ, K, tr, ��1, ��2, t0)
1026     SkEncode(ctx, pubSeed, signSeed, tr, &st); // Step 10
1027 ERR:
1028     BSL_SAL_ClearFree(st.bufAddr, st.bufSize);
1029     BSL_SAL_CleanseData(seed, sizeof(seed));
1030     BSL_SAL_CleanseData(digest, sizeof(digest));
1031     return ret;
1032 }
1033 
1034 // Referenced from NIST.FIPS.204 Algorithm 7 ML-DSA.Sign_internal(sk, ��′, r����)
MLDSA_SignInternal(const CRYPT_ML_DSA_Ctx * ctx,CRYPT_Data * msg,uint8_t * out,uint32_t * outLen,uint8_t * rand)1035 int32_t MLDSA_SignInternal(const CRYPT_ML_DSA_Ctx *ctx, CRYPT_Data *msg, uint8_t *out, uint32_t *outLen, uint8_t *rand)
1036 {
1037     int32_t ret = CRYPT_SUCCESS;
1038     uint8_t pubSeed[MLDSA_PUBLIC_SEED_LEN];
1039     uint8_t uBuf[MLDSA_XOF_MSG_LEN];
1040     uint8_t tr[MLDSA_TR_MSG_LEN];
1041     uint8_t signSeed[MLDSA_SIGNING_SEED_LEN + MLDSA_SEED_BYTES_LEN];
1042     (void)memcpy_s(signSeed + MLDSA_SIGNING_SEED_LEN, MLDSA_SEED_BYTES_LEN, rand, MLDSA_SEED_BYTES_LEN);
1043 
1044     // The w1Len length of MLDSA44 and MLDSA65 is 768, and the w1Len length of MLDSA87 is 1024.
1045     uint32_t w1Len = (ctx->info->k == 4 || ctx->info->k == 6) ? 768 : 1024;
1046     uint8_t *w1Buf = BSL_SAL_Malloc(w1Len);
1047     RETURN_RET_IF(w1Buf == NULL, CRYPT_MEM_ALLOC_FAIL);
1048 
1049     MLDSA_SignMatrixSt st = { 0 };
1050     GOTO_ERR_IF(MLDSASignCreateMatrix(ctx->info->k, ctx->info->l, &st), ret);
1051 
1052     // (ρ, K, tr, ��1, ��2, t0) ← skDecode(sk)
1053     SkDecode(ctx, pubSeed, signSeed, tr, &st);
1054     // A ← ExpandA(ρ)
1055     GOTO_ERR_IF(ExpandA(ctx, pubSeed, st.matrix), ret);
1056     if (ctx->isMuMsg) {
1057         (void)memcpy_s(uBuf, MLDSA_XOF_MSG_LEN, msg->data, msg->len);
1058     } else {
1059         // μ ← H(BytesToBits(tr)||��′, 64)
1060         GOTO_ERR_IF(HashFuncH(tr, MLDSA_TR_MSG_LEN, msg->data, msg->len, uBuf, MLDSA_XOF_MSG_LEN), ret);
1061     }
1062     // ρ″ ← H(K||r����||μ, 64)
1063     uint8_t p[MLDSA_XOF_MSG_LEN + 2]; // The counter used 2 bytes.
1064     GOTO_ERR_IF(HashFuncH(signSeed, sizeof(signSeed), uBuf, MLDSA_XOF_MSG_LEN, p, MLDSA_XOF_MSG_LEN), ret);
1065 
1066     uint16_t u = 0;
1067     // The length of c is λ/4.
1068     uint32_t cBufLen = ctx->info->secBits / 4;
1069     int32_t c[MLDSA_N];
1070     do {
1071         // y ← ExpandMask(ρ″, ��)
1072         GOTO_ERR_IF(ExpandMask(ctx, st.y, p, u), ret);
1073         u = u + ctx->info->l;
1074         ComputesNTT(ctx, st.y, st.z);
1075         // w ← NTT−1(A ∘ NTT(y)); w1 ← HighBits(w)
1076         ComputesW(ctx, st.w, st.w1, st.matrix, st.z);
1077 
1078         // �� ← H(μ||w1Encode(w1), ��/4)
1079         W1Encode(ctx, w1Buf, st.w1);
1080         GOTO_ERR_IF(HashFuncH(uBuf, MLDSA_XOF_MSG_LEN, w1Buf, w1Len, out, cBufLen), ret);
1081         (void)memset_s(c, sizeof(c), 0, sizeof(c));
1082         // �� ∈ ���� ← SampleInBall(c)
1083         SampleInBall(ctx, out, cBufLen, c);
1084         // �� ← NTT(��)
1085         MLDSA_ComputesNTT(c);
1086 
1087         // ⟨⟨����1⟩⟩ ← NTT^−1(�� ∘ ��1); z ← y + ⟨⟨����1⟩⟩
1088         ComputesZ(ctx, st.y, c, st.s1, st.z);
1089         // if ||z||∞ ≥ ��1 − β
1090         if (ValidityChecksL(ctx, st.z, ctx->info->gamma1 - ctx->info->beta) == false) {
1091             continue;
1092         }
1093         // ⟨⟨����2⟩⟩ ← NTT^−1(�� ∘ ��2); ��0 ← LowBits(w − ⟨⟨����2⟩⟩)
1094         ComputesR(ctx, c, &st);
1095         // if ||��0||∞ ≥ ��2 − β
1096         if (ValidityChecksK(ctx, st.r0, ctx->info->gamma2 - ctx->info->beta) == false) {
1097             continue;
1098         }
1099         // ⟨⟨��t0⟩⟩ ← NTT^−1(�� ∘ t0)
1100         ComputesCT(ctx, c, st.t0, st.ct0);
1101         // if ||⟨⟨��t0⟩⟩||∞ ≥ ��2
1102         if (ValidityChecksK(ctx, st.ct0, ctx->info->gamma2) == false) {
1103             continue;
1104         }
1105         // h ← MakeHint(−⟨⟨��t0⟩⟩, w − ⟨⟨����2⟩⟩ + ⟨⟨��t0⟩⟩)
1106         if (MakeHint(ctx, &st) > ctx->info->omega) {
1107             continue;
1108         }
1109         break;
1110     } while (true);
1111 
1112     *outLen = ctx->info->signatureLen;
1113     // σ ← sigEncode(��, z̃ mod±��, h)
1114     SigEncode(ctx, out + cBufLen, *outLen - cBufLen, st.z, st.h);
1115 ERR:
1116     BSL_SAL_ClearFree(st.bufAddr, st.bufSize);
1117     BSL_SAL_ClearFree(w1Buf, w1Len);
1118     BSL_SAL_CleanseData(signSeed, sizeof(signSeed));
1119     return ret;
1120 }
1121 
1122 // Referenced from NIST.FIPS.204 Algorithm 8 ML-DSA.Verify_internal(pk, ��′, σ)
MLDSA_VerifyInternal(const CRYPT_ML_DSA_Ctx * ctx,CRYPT_Data * msg,const uint8_t * sign,uint32_t signLen)1123 int32_t MLDSA_VerifyInternal(const CRYPT_ML_DSA_Ctx *ctx, CRYPT_Data *msg, const uint8_t *sign, uint32_t signLen)
1124 {
1125     (void)signLen;
1126     uint8_t k = ctx->info->k;
1127     uint8_t l = ctx->info->l;
1128     uint8_t pubSeed[MLDSA_PUBLIC_SEED_LEN];
1129     uint8_t uBuf[MLDSA_XOF_MSG_LEN];
1130     uint8_t cBuf[MLDSA_XOF_MSG_LEN];
1131     uint8_t tr[MLDSA_TR_MSG_LEN];
1132     uint32_t cBufLen = ctx->info->secBits / 4;
1133     MLDSA_VerifyMatrixSt st = { 0 };
1134     int32_t c[MLDSA_N] = { 0 };
1135     int32_t ret;
1136 
1137     // The w1Len length of MLDSA44 and MLDSA65 is 768, and the w1Len length of MLDSA87 is 1024.
1138     uint32_t w1Len = (k == 4 || k == 6) ? 768 : 1024;
1139     uint8_t *w1Buf = BSL_SAL_Malloc(w1Len);
1140     RETURN_RET_IF(w1Buf == NULL, CRYPT_MEM_ALLOC_FAIL);
1141 
1142     GOTO_ERR_IF(MLDSAVerifyCreateMatrix(k, l, &st), ret);
1143 
1144     // (ρ, t1) ← pkDecode(pk)
1145     PkDecode(ctx, pubSeed, st.t1);
1146     // (c,z,h) ← sigDecode(σ)
1147     GOTO_ERR_IF(SigDecode(ctx, sign + cBufLen, st.z, st.h), ret);
1148 
1149     // if ||z||∞ < ��1 − β
1150     if (ValidityChecksL(ctx, st.z, ctx->info->gamma1 - ctx->info->beta) == false) {
1151         ret = CRYPT_MLDSA_SIGN_DATA_ERROR;
1152         goto ERR;
1153     }
1154 
1155     // A ← ExpandA(ρ)
1156     GOTO_ERR_IF(ExpandA(ctx, pubSeed, st.matrix), ret);
1157     if (ctx->isMuMsg) {
1158         (void)memcpy_s(uBuf, MLDSA_XOF_MSG_LEN, msg->data, msg->len);
1159     } else {
1160         // tr ← H(pk, 64)
1161         GOTO_ERR_IF(HashFuncH(ctx->pubKey, ctx->pubLen, NULL, 0, tr, MLDSA_TR_MSG_LEN), ret);
1162         // μ ← (H(BytesToBits(tr)||��′, 64))
1163         GOTO_ERR_IF(HashFuncH(tr, MLDSA_TR_MSG_LEN, msg->data, msg->len, uBuf, MLDSA_XOF_MSG_LEN), ret);
1164     }
1165 
1166     // �� ∈ ���� ← SampleInBall(��)
1167     SampleInBall(ctx, sign, cBufLen, c);
1168     // w′ ← NTT−1(A ∘ NTT(z) − NTT(��) ∘ NTT(t1 ⋅ 2��))
1169     ComputesApproxW(ctx, &st, c, st.w);
1170     // w1′ ← UseHint(h, w′)
1171     UseHint(ctx, st.h, st.w);
1172     // c′← H(μ||w1Encode(w1′), ��/4)
1173     W1Encode(ctx, w1Buf, st.w);
1174     GOTO_ERR_IF(HashFuncH(uBuf, MLDSA_XOF_MSG_LEN, w1Buf, w1Len, cBuf, cBufLen), ret);
1175 
1176     // If c and c' are not equal, verify failed.
1177     if (memcmp(sign, cBuf, cBufLen) != 0) {
1178         BSL_ERR_PUSH_ERROR(CRYPT_MLDSA_VERIFY_FAIL);
1179         ret = CRYPT_MLDSA_VERIFY_FAIL;
1180         goto ERR;
1181     }
1182 ERR:
1183     BSL_SAL_Free(st.bufAddr);
1184     BSL_SAL_Free(w1Buf);
1185     return ret;
1186 }
1187 
1188 #endif