• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * This file is part of the openHiTLS project.
3  *
4  * openHiTLS is licensed under the Mulan PSL v2.
5  * You can use this software according to the terms and conditions of the Mulan PSL v2.
6  * You may obtain a copy of Mulan PSL v2 at:
7  *
8  *     http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  * See the Mulan PSL v2 for more details.
14  */
15 
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_MLKEM
18 #include "securec.h"
19 #include "crypt_errno.h"
20 #include "bsl_sal.h"
21 #include "bsl_err_internal.h"
22 #include "eal_pkey_local.h"
23 #include "crypt_util_rand.h"
24 #include "crypt_utils.h"
25 #include "ml_kem_local.h"
26 
27 static const CRYPT_MlKemInfo ML_KEM_INFO[] = {
28     {2, 3, 2, 10, 4, 128, 800, 1632, 768, 32, 512},
29     {3, 2, 2, 10, 4, 192, 1184, 2400, 1088, 32, 768},
30     {4, 2, 2, 11, 5, 256, 1568, 3168, 1568, 32, 1024}
31 };
32 
MlKemGetInfo(uint32_t bits)33 static const CRYPT_MlKemInfo *MlKemGetInfo(uint32_t bits)
34 {
35     for (uint32_t i = 0; i < sizeof(ML_KEM_INFO) / sizeof(ML_KEM_INFO[0]); i++) {
36         if (ML_KEM_INFO[i].bits == bits) {
37             return &ML_KEM_INFO[i];
38         }
39     }
40     return NULL;
41 }
42 
CRYPT_ML_KEM_NewCtx(void)43 CRYPT_ML_KEM_Ctx *CRYPT_ML_KEM_NewCtx(void)
44 {
45     CRYPT_ML_KEM_Ctx *keyCtx = BSL_SAL_Malloc(sizeof(CRYPT_ML_KEM_Ctx));
46     if (keyCtx == NULL) {
47         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
48         return NULL;
49     }
50     (void)memset_s(keyCtx, sizeof(CRYPT_ML_KEM_Ctx), 0, sizeof(CRYPT_ML_KEM_Ctx));
51     BSL_SAL_ReferencesInit(&(keyCtx->references));
52     return keyCtx;
53 }
54 
CRYPT_ML_KEM_NewCtxEx(void * libCtx)55 CRYPT_ML_KEM_Ctx *CRYPT_ML_KEM_NewCtxEx(void *libCtx)
56 {
57     CRYPT_ML_KEM_Ctx *ctx = CRYPT_ML_KEM_NewCtx();
58     if (ctx == NULL) {
59         return NULL;
60     }
61     ctx->libCtx = libCtx;
62     return ctx;
63 }
64 
CRYPT_ML_KEM_FreeCtx(CRYPT_ML_KEM_Ctx * ctx)65 void CRYPT_ML_KEM_FreeCtx(CRYPT_ML_KEM_Ctx *ctx)
66 {
67     if (ctx == NULL) {
68         return;
69     }
70     int ret = 0;
71     BSL_SAL_AtomicDownReferences(&(ctx->references), &ret);
72     if (ret > 0) {
73         return;
74     }
75     BSL_SAL_CleanseData(ctx->dk, ctx->dkLen);
76     BSL_SAL_FREE(ctx->dk);
77     BSL_SAL_FREE(ctx->ek);
78     BSL_SAL_ReferencesFree(&(ctx->references));
79     BSL_SAL_FREE(ctx);
80 }
81 
MlKemSetAlgInfo(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)82 static int32_t MlKemSetAlgInfo(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
83 {
84     if (len != sizeof(uint32_t)) {
85         BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
86         return CRYPT_INVALID_ARG;
87     }
88     if (ctx->info != NULL) {
89         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_CTRL_INIT_REPEATED);
90         return CRYPT_MLKEM_CTRL_INIT_REPEATED;
91     }
92     uint32_t bits = 0;
93     int32_t keyType = *(int32_t*)val;
94     if (keyType == CRYPT_KEM_TYPE_MLKEM_512) {
95         bits = 512;  // MLKEM512
96     } else if (keyType == CRYPT_KEM_TYPE_MLKEM_768) {
97         bits = 768;  // MLKEM768
98     } else if (keyType == CRYPT_KEM_TYPE_MLKEM_1024) {
99         bits = 1024;  // MLKEM1024
100     }
101     const CRYPT_MlKemInfo *info = MlKemGetInfo(bits);
102     if (info == NULL) {
103         BSL_ERR_PUSH_ERROR(CRYPT_NOT_SUPPORT);
104         return CRYPT_NOT_SUPPORT;
105     }
106     ctx->info = info;
107     return CRYPT_SUCCESS;
108 }
109 
CRYPT_ML_KEM_DupCtx(CRYPT_ML_KEM_Ctx * ctx)110 CRYPT_ML_KEM_Ctx *CRYPT_ML_KEM_DupCtx(CRYPT_ML_KEM_Ctx *ctx)
111 {
112     if (ctx == NULL) {
113         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
114         return NULL;
115     }
116     CRYPT_ML_KEM_Ctx *newCtx = CRYPT_ML_KEM_NewCtx();
117     if (newCtx == NULL) {
118         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
119         return NULL;
120     }
121     if (ctx->info != NULL) {
122         newCtx->info = ctx->info;
123     }
124     if (ctx->ek != NULL) {
125         newCtx->ek = BSL_SAL_Dump(ctx->ek, ctx->ekLen);
126         if (newCtx->ek == NULL) {
127             CRYPT_ML_KEM_FreeCtx(newCtx);
128             return NULL;
129         }
130         newCtx->ekLen = ctx->ekLen;
131     }
132     if (ctx->dk != NULL) {
133         newCtx->dk = BSL_SAL_Dump(ctx->dk, ctx->dkLen);
134         if (newCtx->dk == NULL) {
135             CRYPT_ML_KEM_FreeCtx(newCtx);
136             return NULL;
137         }
138         newCtx->dkLen = ctx->dkLen;
139     }
140     return newCtx;
141 }
142 
MlKemGetEncapsKeyLen(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)143 static int32_t MlKemGetEncapsKeyLen(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
144 {
145     if (ctx->info == NULL) {
146         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
147         return CRYPT_MLKEM_KEYINFO_NOT_SET;
148     }
149     if (len != sizeof(uint32_t)) {
150         BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
151         return CRYPT_INVALID_ARG;
152     }
153     *(uint32_t*)val = ctx->info->encapsKeyLen;
154     return CRYPT_SUCCESS;
155 }
156 
MlKemGetDecapsKeyLen(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)157 static int32_t MlKemGetDecapsKeyLen(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
158 {
159     if (ctx->info == NULL) {
160         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
161         return CRYPT_MLKEM_KEYINFO_NOT_SET;
162     }
163     if (len != sizeof(uint32_t)) {
164         BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
165         return CRYPT_INVALID_ARG;
166     }
167     *(uint32_t*)val = ctx->info->decapsKeyLen;
168     return CRYPT_SUCCESS;
169 }
170 
MlKemGetCipherTextLen(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)171 static int32_t MlKemGetCipherTextLen(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
172 {
173     if (ctx->info == NULL) {
174         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
175         return CRYPT_MLKEM_KEYINFO_NOT_SET;
176     }
177     if (len != sizeof(uint32_t)) {
178         BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
179         return CRYPT_INVALID_ARG;
180     }
181     *(uint32_t*)val = ctx->info->cipherLen;
182     return CRYPT_SUCCESS;
183 }
184 
MlKemGetSharedLen(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)185 static int32_t MlKemGetSharedLen(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
186 {
187     if (ctx->info == NULL) {
188         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
189         return CRYPT_MLKEM_KEYINFO_NOT_SET;
190     }
191     if (len != sizeof(uint32_t)) {
192         BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
193         return CRYPT_INVALID_ARG;
194     }
195     *(uint32_t*)val = ctx->info->sharedLen;
196     return CRYPT_SUCCESS;
197 }
198 
CRYPT_ML_KEM_SetEncapsKey(CRYPT_ML_KEM_Ctx * ctx,const BSL_Param * param)199 int32_t CRYPT_ML_KEM_SetEncapsKey(CRYPT_ML_KEM_Ctx *ctx, const BSL_Param *param)
200 {
201     if (ctx == NULL || ctx->info == NULL || param == NULL) {
202         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
203         return CRYPT_NULL_INPUT;
204     }
205 
206     const BSL_Param *ek = BSL_PARAM_FindConstParam(param, CRYPT_PARAM_ML_KEM_PUBKEY);
207     if (ek == NULL || ek->value == NULL) {
208         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
209         return CRYPT_NULL_INPUT;
210     }
211     if (ek->valueLen != ctx->info->encapsKeyLen) {
212         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
213         return CRYPT_MLKEM_KEYLEN_ERROR;
214     }
215     uint8_t *data = BSL_SAL_Dump(ek->value, ek->valueLen);
216     if (data == NULL) {
217         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
218         return CRYPT_MEM_ALLOC_FAIL;
219     }
220     if (ctx->ek != NULL) {
221         BSL_SAL_Free(ctx->ek);
222     }
223     ctx->ek = data;
224     ctx->ekLen = ek->valueLen;
225     return CRYPT_SUCCESS;
226 }
227 
CRYPT_ML_KEM_GetEncapsKey(const CRYPT_ML_KEM_Ctx * ctx,BSL_Param * param)228 int32_t CRYPT_ML_KEM_GetEncapsKey(const CRYPT_ML_KEM_Ctx *ctx, BSL_Param *param)
229 {
230     if (ctx == NULL || ctx->info == NULL || param == NULL) {
231         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
232         return CRYPT_NULL_INPUT;
233     }
234     BSL_Param *ek = BSL_PARAM_FindParam(param, CRYPT_PARAM_ML_KEM_PUBKEY);
235     if (ek == NULL || ek->value == NULL) {
236         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
237         return CRYPT_NULL_INPUT;
238     }
239     if (ek->valueLen < ctx->info->encapsKeyLen) {
240         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
241         return CRYPT_MLKEM_KEYLEN_ERROR;
242     }
243     if (ctx->ek == NULL) {
244         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_SET);
245         return CRYPT_MLKEM_KEY_NOT_SET;
246     }
247 
248     if (memcpy_s(ek->value, ek->valueLen, ctx->ek, ctx->ekLen) != EOK) {
249         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
250         return CRYPT_MLKEM_KEYLEN_ERROR;
251     }
252     ek->useLen = ctx->ekLen;
253     return CRYPT_SUCCESS;
254 }
255 
CRYPT_ML_KEM_SetDecapsKey(CRYPT_ML_KEM_Ctx * ctx,const BSL_Param * param)256 int32_t CRYPT_ML_KEM_SetDecapsKey(CRYPT_ML_KEM_Ctx *ctx, const BSL_Param *param)
257 {
258     if (ctx == NULL || ctx->info == NULL || param == NULL) {
259         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
260         return CRYPT_NULL_INPUT;
261     }
262     const BSL_Param *dk = BSL_PARAM_FindConstParam(param, CRYPT_PARAM_ML_KEM_PRVKEY);
263     if (dk == NULL || dk->value == NULL) {
264         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
265         return CRYPT_NULL_INPUT;
266     }
267     if (dk->valueLen != ctx->info->decapsKeyLen) {
268         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
269         return CRYPT_MLKEM_KEYLEN_ERROR;
270     }
271 
272     uint8_t *data = BSL_SAL_Dump(dk->value, dk->valueLen);
273     if (data == NULL) {
274         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
275         return CRYPT_MEM_ALLOC_FAIL;
276     }
277     if (ctx->dk != NULL) {
278         BSL_SAL_CleanseData(ctx->dk, ctx->dkLen);
279         BSL_SAL_Free(ctx->dk);
280     }
281     ctx->dk = data;
282     ctx->dkLen = dk->valueLen;
283     return CRYPT_SUCCESS;
284 }
285 
CRYPT_ML_KEM_GetDecapsKey(const CRYPT_ML_KEM_Ctx * ctx,BSL_Param * param)286 int32_t CRYPT_ML_KEM_GetDecapsKey(const CRYPT_ML_KEM_Ctx *ctx, BSL_Param *param)
287 {
288     if (ctx == NULL || ctx->info == NULL || param == NULL) {
289         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
290         return CRYPT_NULL_INPUT;
291     }
292     BSL_Param *dk = BSL_PARAM_FindParam(param, CRYPT_PARAM_ML_KEM_PRVKEY);
293     if (dk == NULL || dk->value == NULL) {
294         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
295         return CRYPT_NULL_INPUT;
296     }
297     if (dk->valueLen < ctx->info->decapsKeyLen) {
298         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
299         return CRYPT_MLKEM_KEYLEN_ERROR;
300     }
301     if (ctx->dk == NULL) {
302         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_SET);
303         return CRYPT_MLKEM_KEY_NOT_SET;
304     }
305 
306     if (memcpy_s(dk->value, dk->valueLen, ctx->dk, ctx->dkLen) != EOK) {
307         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
308         return CRYPT_MLKEM_KEYLEN_ERROR;
309     }
310     dk->useLen = ctx->dkLen;
311     return CRYPT_SUCCESS;
312 }
313 
MlKemCmpKey(uint8_t * a,uint32_t aLen,uint8_t * b,uint32_t bLen)314 static int32_t MlKemCmpKey(uint8_t *a, uint32_t aLen, uint8_t *b, uint32_t bLen)
315 {
316     if (aLen != bLen) {
317         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_EQUAL);
318         return CRYPT_MLKEM_KEY_NOT_EQUAL;
319     }
320     if (a != NULL && b != NULL) {
321         if (memcmp(a, b, aLen) != 0) {
322             BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_EQUAL);
323             return CRYPT_MLKEM_KEY_NOT_EQUAL;
324         }
325     }
326     if ((a != NULL) != (b != NULL)) {
327         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_EQUAL);
328         return CRYPT_MLKEM_KEY_NOT_EQUAL;
329     }
330     return CRYPT_SUCCESS;
331 }
332 
CRYPT_ML_KEM_Cmp(const CRYPT_ML_KEM_Ctx * a,const CRYPT_ML_KEM_Ctx * b)333 int32_t CRYPT_ML_KEM_Cmp(const CRYPT_ML_KEM_Ctx *a, const CRYPT_ML_KEM_Ctx *b)
334 {
335     if (a == NULL || b == NULL) {
336         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
337         return CRYPT_NULL_INPUT;
338     }
339     if (a->info != b->info) {  // The value of info must be one of the ML_KEM_INFO arrays.
340         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_EQUAL);
341         return CRYPT_MLKEM_KEY_NOT_EQUAL;
342     }
343 
344     if (MlKemCmpKey(a->ek, a->ekLen, b->ek, b->ekLen) != CRYPT_SUCCESS) {
345         return CRYPT_MLKEM_KEY_NOT_EQUAL;
346     }
347     if (MlKemCmpKey(a->dk, a->dkLen, b->dk, b->dkLen) != CRYPT_SUCCESS) {
348         return CRYPT_MLKEM_KEY_NOT_EQUAL;
349     }
350     return CRYPT_SUCCESS;
351 }
352 
CRYPT_ML_KEM_GetSecBits(const CRYPT_ML_KEM_Ctx * ctx)353 int32_t CRYPT_ML_KEM_GetSecBits(const CRYPT_ML_KEM_Ctx *ctx)
354 {
355     if (ctx == NULL || ctx->info == NULL) {
356         return 0;
357     }
358     return (int32_t)ctx->info->secBits;
359 }
360 
CRYPT_ML_KEM_GetLen(const CRYPT_ML_KEM_Ctx * ctx,GetLenFunc func,void * val,uint32_t len)361 static int32_t CRYPT_ML_KEM_GetLen(const CRYPT_ML_KEM_Ctx *ctx, GetLenFunc func, void *val, uint32_t len)
362 {
363     if (val == NULL || len != sizeof(int32_t)) {
364         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
365         return CRYPT_NULL_INPUT;
366     }
367 
368     *(int32_t *)val = func(ctx);
369     return CRYPT_SUCCESS;
370 }
371 
CRYPT_ML_KEM_Ctrl(CRYPT_ML_KEM_Ctx * ctx,int32_t opt,void * val,uint32_t len)372 int32_t CRYPT_ML_KEM_Ctrl(CRYPT_ML_KEM_Ctx *ctx, int32_t opt, void *val, uint32_t len)
373 {
374     if (ctx == NULL) {
375         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
376         return CRYPT_NULL_INPUT;
377     }
378     if (opt == CRYPT_CTRL_GET_SECBITS) {
379         return CRYPT_ML_KEM_GetLen(ctx, (GetLenFunc)CRYPT_ML_KEM_GetSecBits, val, len);
380     }
381     if (val == NULL) {
382         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
383         return CRYPT_NULL_INPUT;
384     }
385     switch (opt) {
386         case CRYPT_CTRL_SET_PARA_BY_ID:
387             return MlKemSetAlgInfo(ctx, val, len);
388         case CRYPT_CTRL_GET_PUBKEY_LEN:
389             return MlKemGetEncapsKeyLen(ctx, val, len);
390         case CRYPT_CTRL_GET_PRVKEY_LEN:
391             return MlKemGetDecapsKeyLen(ctx, val, len);
392         case CRYPT_CTRL_GET_CIPHERTEXT_LEN:
393             return MlKemGetCipherTextLen(ctx, val, len);
394         case CRYPT_CTRL_GET_SHARED_KEY_LEN:
395             return MlKemGetSharedLen(ctx, val, len);
396         default:
397             BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_CTRL_NOT_SUPPORT);
398             return CRYPT_MLKEM_CTRL_NOT_SUPPORT;
399     }
400 }
401 
MlKemCreateKeyBuf(CRYPT_ML_KEM_Ctx * ctx)402 static int32_t MlKemCreateKeyBuf(CRYPT_ML_KEM_Ctx *ctx)
403 {
404     if (ctx->dk == NULL) {
405         uint8_t *dk = BSL_SAL_Malloc(ctx->info->decapsKeyLen);
406         if (dk == NULL) {
407             BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
408             return CRYPT_MEM_ALLOC_FAIL;
409         }
410         ctx->dk = dk;
411         ctx->dkLen = ctx->info->decapsKeyLen;
412     }
413     if (ctx->ek == NULL) {
414         uint8_t *ek = BSL_SAL_Malloc(ctx->info->encapsKeyLen);
415         if (ek == NULL) {
416             BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
417             return CRYPT_MEM_ALLOC_FAIL;
418         }
419         ctx->ek = ek;
420         ctx->ekLen = ctx->info->encapsKeyLen;
421     }
422     return CRYPT_SUCCESS;
423 }
424 
CRYPT_ML_KEM_GenKey(CRYPT_ML_KEM_Ctx * ctx)425 int32_t CRYPT_ML_KEM_GenKey(CRYPT_ML_KEM_Ctx *ctx)
426 {
427     if (ctx == NULL) {
428         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
429         return CRYPT_NULL_INPUT;
430     }
431     if (ctx->info == NULL) {
432         BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
433         return CRYPT_MLKEM_KEYINFO_NOT_SET;
434     }
435     if (MlKemCreateKeyBuf(ctx) != CRYPT_SUCCESS) {
436         return CRYPT_MEM_ALLOC_FAIL;
437     }
438     uint8_t d[MLKEM_SEED_LEN];
439     uint8_t z[MLKEM_SEED_LEN];
440     int32_t ret = CRYPT_RandEx(ctx->libCtx, d, MLKEM_SEED_LEN);
441     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
442     ret = CRYPT_RandEx(ctx->libCtx, z, MLKEM_SEED_LEN);
443     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
444 
445     ret = MLKEM_KeyGenInternal(ctx, d, z);
446     BSL_SAL_CleanseData(d, MLKEM_SEED_LEN);
447     BSL_SAL_CleanseData(z, MLKEM_SEED_LEN);
448     return ret;
449 }
450 
EncCapsInputCheck(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,uint32_t * ctLen,uint8_t * sk,uint32_t * skLen)451 static int32_t EncCapsInputCheck(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t *ctLen,
452     uint8_t *sk, uint32_t *skLen)
453 {
454     if (ctx == NULL || ctx->ek == NULL || ct == NULL || ctLen == NULL ||
455         sk == NULL || skLen == NULL) {
456         return CRYPT_NULL_INPUT;
457     }
458     if (ctx->info == NULL) {
459         return CRYPT_MLKEM_KEYINFO_NOT_SET;
460     }
461     if (*ctLen < ctx->info->cipherLen || *skLen < MLKEM_SHARED_KEY_LEN) {
462         return CRYPT_MLKEM_LEN_NOT_ENOUGH;
463     }
464     return CRYPT_SUCCESS;
465 }
466 
CRYPT_ML_KEM_Encaps(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * cipher,uint32_t * cipherLen,uint8_t * share,uint32_t * shareLen)467 int32_t CRYPT_ML_KEM_Encaps(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *cipher, uint32_t *cipherLen,
468     uint8_t *share, uint32_t *shareLen)
469 {
470     int32_t ret = EncCapsInputCheck(ctx, cipher, cipherLen, share, shareLen);
471     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
472 
473     uint8_t m[MLKEM_SEED_LEN];
474     ret = CRYPT_RandEx(ctx->libCtx, m, MLKEM_SEED_LEN);
475     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
476 
477     ret = MLKEM_EncapsInternal(ctx, cipher, cipherLen, share, shareLen, m);
478     BSL_SAL_CleanseData(m, MLKEM_SEED_LEN);
479     return ret;
480 }
481 
DecCapsInputCheck(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,uint32_t ctLen,uint8_t * sk,uint32_t * skLen)482 static int32_t DecCapsInputCheck(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t ctLen,
483     uint8_t *sk, uint32_t *skLen)
484 {
485     if (ctx == NULL || ctx->dk == NULL || ct == NULL || sk == NULL || skLen == NULL) {
486         return CRYPT_NULL_INPUT;
487     }
488     if (ctx->info == NULL) {
489         return CRYPT_MLKEM_KEYINFO_NOT_SET;
490     }
491     if (ctLen != ctx->info->cipherLen || *skLen < MLKEM_SHARED_KEY_LEN) {
492         return CRYPT_MLKEM_LEN_NOT_ENOUGH;
493     }
494     return CRYPT_SUCCESS;
495 }
496 
CRYPT_ML_KEM_Decaps(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * cipher,uint32_t cipherLen,uint8_t * share,uint32_t * shareLen)497 int32_t CRYPT_ML_KEM_Decaps(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *cipher, uint32_t cipherLen,
498     uint8_t *share, uint32_t *shareLen)
499 {
500     int32_t ret = DecCapsInputCheck(ctx, cipher, cipherLen, share, shareLen);
501     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
502 
503     return MLKEM_DecapsInternal(ctx, cipher, cipherLen, share, shareLen);
504 }
505 
506 #endif