1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_MLKEM
18 #include "securec.h"
19 #include "crypt_errno.h"
20 #include "bsl_sal.h"
21 #include "bsl_err_internal.h"
22 #include "eal_pkey_local.h"
23 #include "crypt_util_rand.h"
24 #include "crypt_utils.h"
25 #include "ml_kem_local.h"
26
27 static const CRYPT_MlKemInfo ML_KEM_INFO[] = {
28 {2, 3, 2, 10, 4, 128, 800, 1632, 768, 32, 512},
29 {3, 2, 2, 10, 4, 192, 1184, 2400, 1088, 32, 768},
30 {4, 2, 2, 11, 5, 256, 1568, 3168, 1568, 32, 1024}
31 };
32
MlKemGetInfo(uint32_t bits)33 static const CRYPT_MlKemInfo *MlKemGetInfo(uint32_t bits)
34 {
35 for (uint32_t i = 0; i < sizeof(ML_KEM_INFO) / sizeof(ML_KEM_INFO[0]); i++) {
36 if (ML_KEM_INFO[i].bits == bits) {
37 return &ML_KEM_INFO[i];
38 }
39 }
40 return NULL;
41 }
42
CRYPT_ML_KEM_NewCtx(void)43 CRYPT_ML_KEM_Ctx *CRYPT_ML_KEM_NewCtx(void)
44 {
45 CRYPT_ML_KEM_Ctx *keyCtx = BSL_SAL_Malloc(sizeof(CRYPT_ML_KEM_Ctx));
46 if (keyCtx == NULL) {
47 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
48 return NULL;
49 }
50 (void)memset_s(keyCtx, sizeof(CRYPT_ML_KEM_Ctx), 0, sizeof(CRYPT_ML_KEM_Ctx));
51 BSL_SAL_ReferencesInit(&(keyCtx->references));
52 return keyCtx;
53 }
54
CRYPT_ML_KEM_NewCtxEx(void * libCtx)55 CRYPT_ML_KEM_Ctx *CRYPT_ML_KEM_NewCtxEx(void *libCtx)
56 {
57 CRYPT_ML_KEM_Ctx *ctx = CRYPT_ML_KEM_NewCtx();
58 if (ctx == NULL) {
59 return NULL;
60 }
61 ctx->libCtx = libCtx;
62 return ctx;
63 }
64
CRYPT_ML_KEM_FreeCtx(CRYPT_ML_KEM_Ctx * ctx)65 void CRYPT_ML_KEM_FreeCtx(CRYPT_ML_KEM_Ctx *ctx)
66 {
67 if (ctx == NULL) {
68 return;
69 }
70 int ret = 0;
71 BSL_SAL_AtomicDownReferences(&(ctx->references), &ret);
72 if (ret > 0) {
73 return;
74 }
75 BSL_SAL_CleanseData(ctx->dk, ctx->dkLen);
76 BSL_SAL_FREE(ctx->dk);
77 BSL_SAL_FREE(ctx->ek);
78 BSL_SAL_ReferencesFree(&(ctx->references));
79 BSL_SAL_FREE(ctx);
80 }
81
MlKemSetAlgInfo(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)82 static int32_t MlKemSetAlgInfo(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
83 {
84 if (len != sizeof(uint32_t)) {
85 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
86 return CRYPT_INVALID_ARG;
87 }
88 if (ctx->info != NULL) {
89 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_CTRL_INIT_REPEATED);
90 return CRYPT_MLKEM_CTRL_INIT_REPEATED;
91 }
92 uint32_t bits = 0;
93 int32_t keyType = *(int32_t*)val;
94 if (keyType == CRYPT_KEM_TYPE_MLKEM_512) {
95 bits = 512; // MLKEM512
96 } else if (keyType == CRYPT_KEM_TYPE_MLKEM_768) {
97 bits = 768; // MLKEM768
98 } else if (keyType == CRYPT_KEM_TYPE_MLKEM_1024) {
99 bits = 1024; // MLKEM1024
100 }
101 const CRYPT_MlKemInfo *info = MlKemGetInfo(bits);
102 if (info == NULL) {
103 BSL_ERR_PUSH_ERROR(CRYPT_NOT_SUPPORT);
104 return CRYPT_NOT_SUPPORT;
105 }
106 ctx->info = info;
107 return CRYPT_SUCCESS;
108 }
109
CRYPT_ML_KEM_DupCtx(CRYPT_ML_KEM_Ctx * ctx)110 CRYPT_ML_KEM_Ctx *CRYPT_ML_KEM_DupCtx(CRYPT_ML_KEM_Ctx *ctx)
111 {
112 if (ctx == NULL) {
113 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
114 return NULL;
115 }
116 CRYPT_ML_KEM_Ctx *newCtx = CRYPT_ML_KEM_NewCtx();
117 if (newCtx == NULL) {
118 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
119 return NULL;
120 }
121 if (ctx->info != NULL) {
122 newCtx->info = ctx->info;
123 }
124 if (ctx->ek != NULL) {
125 newCtx->ek = BSL_SAL_Dump(ctx->ek, ctx->ekLen);
126 if (newCtx->ek == NULL) {
127 CRYPT_ML_KEM_FreeCtx(newCtx);
128 return NULL;
129 }
130 newCtx->ekLen = ctx->ekLen;
131 }
132 if (ctx->dk != NULL) {
133 newCtx->dk = BSL_SAL_Dump(ctx->dk, ctx->dkLen);
134 if (newCtx->dk == NULL) {
135 CRYPT_ML_KEM_FreeCtx(newCtx);
136 return NULL;
137 }
138 newCtx->dkLen = ctx->dkLen;
139 }
140 return newCtx;
141 }
142
MlKemGetEncapsKeyLen(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)143 static int32_t MlKemGetEncapsKeyLen(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
144 {
145 if (ctx->info == NULL) {
146 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
147 return CRYPT_MLKEM_KEYINFO_NOT_SET;
148 }
149 if (len != sizeof(uint32_t)) {
150 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
151 return CRYPT_INVALID_ARG;
152 }
153 *(uint32_t*)val = ctx->info->encapsKeyLen;
154 return CRYPT_SUCCESS;
155 }
156
MlKemGetDecapsKeyLen(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)157 static int32_t MlKemGetDecapsKeyLen(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
158 {
159 if (ctx->info == NULL) {
160 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
161 return CRYPT_MLKEM_KEYINFO_NOT_SET;
162 }
163 if (len != sizeof(uint32_t)) {
164 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
165 return CRYPT_INVALID_ARG;
166 }
167 *(uint32_t*)val = ctx->info->decapsKeyLen;
168 return CRYPT_SUCCESS;
169 }
170
MlKemGetCipherTextLen(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)171 static int32_t MlKemGetCipherTextLen(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
172 {
173 if (ctx->info == NULL) {
174 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
175 return CRYPT_MLKEM_KEYINFO_NOT_SET;
176 }
177 if (len != sizeof(uint32_t)) {
178 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
179 return CRYPT_INVALID_ARG;
180 }
181 *(uint32_t*)val = ctx->info->cipherLen;
182 return CRYPT_SUCCESS;
183 }
184
MlKemGetSharedLen(CRYPT_ML_KEM_Ctx * ctx,void * val,uint32_t len)185 static int32_t MlKemGetSharedLen(CRYPT_ML_KEM_Ctx *ctx, void *val, uint32_t len)
186 {
187 if (ctx->info == NULL) {
188 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
189 return CRYPT_MLKEM_KEYINFO_NOT_SET;
190 }
191 if (len != sizeof(uint32_t)) {
192 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
193 return CRYPT_INVALID_ARG;
194 }
195 *(uint32_t*)val = ctx->info->sharedLen;
196 return CRYPT_SUCCESS;
197 }
198
CRYPT_ML_KEM_SetEncapsKey(CRYPT_ML_KEM_Ctx * ctx,const BSL_Param * param)199 int32_t CRYPT_ML_KEM_SetEncapsKey(CRYPT_ML_KEM_Ctx *ctx, const BSL_Param *param)
200 {
201 if (ctx == NULL || ctx->info == NULL || param == NULL) {
202 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
203 return CRYPT_NULL_INPUT;
204 }
205
206 const BSL_Param *ek = BSL_PARAM_FindConstParam(param, CRYPT_PARAM_ML_KEM_PUBKEY);
207 if (ek == NULL || ek->value == NULL) {
208 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
209 return CRYPT_NULL_INPUT;
210 }
211 if (ek->valueLen != ctx->info->encapsKeyLen) {
212 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
213 return CRYPT_MLKEM_KEYLEN_ERROR;
214 }
215 uint8_t *data = BSL_SAL_Dump(ek->value, ek->valueLen);
216 if (data == NULL) {
217 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
218 return CRYPT_MEM_ALLOC_FAIL;
219 }
220 if (ctx->ek != NULL) {
221 BSL_SAL_Free(ctx->ek);
222 }
223 ctx->ek = data;
224 ctx->ekLen = ek->valueLen;
225 return CRYPT_SUCCESS;
226 }
227
CRYPT_ML_KEM_GetEncapsKey(const CRYPT_ML_KEM_Ctx * ctx,BSL_Param * param)228 int32_t CRYPT_ML_KEM_GetEncapsKey(const CRYPT_ML_KEM_Ctx *ctx, BSL_Param *param)
229 {
230 if (ctx == NULL || ctx->info == NULL || param == NULL) {
231 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
232 return CRYPT_NULL_INPUT;
233 }
234 BSL_Param *ek = BSL_PARAM_FindParam(param, CRYPT_PARAM_ML_KEM_PUBKEY);
235 if (ek == NULL || ek->value == NULL) {
236 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
237 return CRYPT_NULL_INPUT;
238 }
239 if (ek->valueLen < ctx->info->encapsKeyLen) {
240 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
241 return CRYPT_MLKEM_KEYLEN_ERROR;
242 }
243 if (ctx->ek == NULL) {
244 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_SET);
245 return CRYPT_MLKEM_KEY_NOT_SET;
246 }
247
248 if (memcpy_s(ek->value, ek->valueLen, ctx->ek, ctx->ekLen) != EOK) {
249 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
250 return CRYPT_MLKEM_KEYLEN_ERROR;
251 }
252 ek->useLen = ctx->ekLen;
253 return CRYPT_SUCCESS;
254 }
255
CRYPT_ML_KEM_SetDecapsKey(CRYPT_ML_KEM_Ctx * ctx,const BSL_Param * param)256 int32_t CRYPT_ML_KEM_SetDecapsKey(CRYPT_ML_KEM_Ctx *ctx, const BSL_Param *param)
257 {
258 if (ctx == NULL || ctx->info == NULL || param == NULL) {
259 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
260 return CRYPT_NULL_INPUT;
261 }
262 const BSL_Param *dk = BSL_PARAM_FindConstParam(param, CRYPT_PARAM_ML_KEM_PRVKEY);
263 if (dk == NULL || dk->value == NULL) {
264 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
265 return CRYPT_NULL_INPUT;
266 }
267 if (dk->valueLen != ctx->info->decapsKeyLen) {
268 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
269 return CRYPT_MLKEM_KEYLEN_ERROR;
270 }
271
272 uint8_t *data = BSL_SAL_Dump(dk->value, dk->valueLen);
273 if (data == NULL) {
274 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
275 return CRYPT_MEM_ALLOC_FAIL;
276 }
277 if (ctx->dk != NULL) {
278 BSL_SAL_CleanseData(ctx->dk, ctx->dkLen);
279 BSL_SAL_Free(ctx->dk);
280 }
281 ctx->dk = data;
282 ctx->dkLen = dk->valueLen;
283 return CRYPT_SUCCESS;
284 }
285
CRYPT_ML_KEM_GetDecapsKey(const CRYPT_ML_KEM_Ctx * ctx,BSL_Param * param)286 int32_t CRYPT_ML_KEM_GetDecapsKey(const CRYPT_ML_KEM_Ctx *ctx, BSL_Param *param)
287 {
288 if (ctx == NULL || ctx->info == NULL || param == NULL) {
289 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
290 return CRYPT_NULL_INPUT;
291 }
292 BSL_Param *dk = BSL_PARAM_FindParam(param, CRYPT_PARAM_ML_KEM_PRVKEY);
293 if (dk == NULL || dk->value == NULL) {
294 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
295 return CRYPT_NULL_INPUT;
296 }
297 if (dk->valueLen < ctx->info->decapsKeyLen) {
298 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
299 return CRYPT_MLKEM_KEYLEN_ERROR;
300 }
301 if (ctx->dk == NULL) {
302 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_SET);
303 return CRYPT_MLKEM_KEY_NOT_SET;
304 }
305
306 if (memcpy_s(dk->value, dk->valueLen, ctx->dk, ctx->dkLen) != EOK) {
307 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
308 return CRYPT_MLKEM_KEYLEN_ERROR;
309 }
310 dk->useLen = ctx->dkLen;
311 return CRYPT_SUCCESS;
312 }
313
MlKemCmpKey(uint8_t * a,uint32_t aLen,uint8_t * b,uint32_t bLen)314 static int32_t MlKemCmpKey(uint8_t *a, uint32_t aLen, uint8_t *b, uint32_t bLen)
315 {
316 if (aLen != bLen) {
317 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_EQUAL);
318 return CRYPT_MLKEM_KEY_NOT_EQUAL;
319 }
320 if (a != NULL && b != NULL) {
321 if (memcmp(a, b, aLen) != 0) {
322 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_EQUAL);
323 return CRYPT_MLKEM_KEY_NOT_EQUAL;
324 }
325 }
326 if ((a != NULL) != (b != NULL)) {
327 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_EQUAL);
328 return CRYPT_MLKEM_KEY_NOT_EQUAL;
329 }
330 return CRYPT_SUCCESS;
331 }
332
CRYPT_ML_KEM_Cmp(const CRYPT_ML_KEM_Ctx * a,const CRYPT_ML_KEM_Ctx * b)333 int32_t CRYPT_ML_KEM_Cmp(const CRYPT_ML_KEM_Ctx *a, const CRYPT_ML_KEM_Ctx *b)
334 {
335 if (a == NULL || b == NULL) {
336 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
337 return CRYPT_NULL_INPUT;
338 }
339 if (a->info != b->info) { // The value of info must be one of the ML_KEM_INFO arrays.
340 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEY_NOT_EQUAL);
341 return CRYPT_MLKEM_KEY_NOT_EQUAL;
342 }
343
344 if (MlKemCmpKey(a->ek, a->ekLen, b->ek, b->ekLen) != CRYPT_SUCCESS) {
345 return CRYPT_MLKEM_KEY_NOT_EQUAL;
346 }
347 if (MlKemCmpKey(a->dk, a->dkLen, b->dk, b->dkLen) != CRYPT_SUCCESS) {
348 return CRYPT_MLKEM_KEY_NOT_EQUAL;
349 }
350 return CRYPT_SUCCESS;
351 }
352
CRYPT_ML_KEM_GetSecBits(const CRYPT_ML_KEM_Ctx * ctx)353 int32_t CRYPT_ML_KEM_GetSecBits(const CRYPT_ML_KEM_Ctx *ctx)
354 {
355 if (ctx == NULL || ctx->info == NULL) {
356 return 0;
357 }
358 return (int32_t)ctx->info->secBits;
359 }
360
CRYPT_ML_KEM_GetLen(const CRYPT_ML_KEM_Ctx * ctx,GetLenFunc func,void * val,uint32_t len)361 static int32_t CRYPT_ML_KEM_GetLen(const CRYPT_ML_KEM_Ctx *ctx, GetLenFunc func, void *val, uint32_t len)
362 {
363 if (val == NULL || len != sizeof(int32_t)) {
364 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
365 return CRYPT_NULL_INPUT;
366 }
367
368 *(int32_t *)val = func(ctx);
369 return CRYPT_SUCCESS;
370 }
371
CRYPT_ML_KEM_Ctrl(CRYPT_ML_KEM_Ctx * ctx,int32_t opt,void * val,uint32_t len)372 int32_t CRYPT_ML_KEM_Ctrl(CRYPT_ML_KEM_Ctx *ctx, int32_t opt, void *val, uint32_t len)
373 {
374 if (ctx == NULL) {
375 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
376 return CRYPT_NULL_INPUT;
377 }
378 if (opt == CRYPT_CTRL_GET_SECBITS) {
379 return CRYPT_ML_KEM_GetLen(ctx, (GetLenFunc)CRYPT_ML_KEM_GetSecBits, val, len);
380 }
381 if (val == NULL) {
382 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
383 return CRYPT_NULL_INPUT;
384 }
385 switch (opt) {
386 case CRYPT_CTRL_SET_PARA_BY_ID:
387 return MlKemSetAlgInfo(ctx, val, len);
388 case CRYPT_CTRL_GET_PUBKEY_LEN:
389 return MlKemGetEncapsKeyLen(ctx, val, len);
390 case CRYPT_CTRL_GET_PRVKEY_LEN:
391 return MlKemGetDecapsKeyLen(ctx, val, len);
392 case CRYPT_CTRL_GET_CIPHERTEXT_LEN:
393 return MlKemGetCipherTextLen(ctx, val, len);
394 case CRYPT_CTRL_GET_SHARED_KEY_LEN:
395 return MlKemGetSharedLen(ctx, val, len);
396 default:
397 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_CTRL_NOT_SUPPORT);
398 return CRYPT_MLKEM_CTRL_NOT_SUPPORT;
399 }
400 }
401
MlKemCreateKeyBuf(CRYPT_ML_KEM_Ctx * ctx)402 static int32_t MlKemCreateKeyBuf(CRYPT_ML_KEM_Ctx *ctx)
403 {
404 if (ctx->dk == NULL) {
405 uint8_t *dk = BSL_SAL_Malloc(ctx->info->decapsKeyLen);
406 if (dk == NULL) {
407 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
408 return CRYPT_MEM_ALLOC_FAIL;
409 }
410 ctx->dk = dk;
411 ctx->dkLen = ctx->info->decapsKeyLen;
412 }
413 if (ctx->ek == NULL) {
414 uint8_t *ek = BSL_SAL_Malloc(ctx->info->encapsKeyLen);
415 if (ek == NULL) {
416 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
417 return CRYPT_MEM_ALLOC_FAIL;
418 }
419 ctx->ek = ek;
420 ctx->ekLen = ctx->info->encapsKeyLen;
421 }
422 return CRYPT_SUCCESS;
423 }
424
CRYPT_ML_KEM_GenKey(CRYPT_ML_KEM_Ctx * ctx)425 int32_t CRYPT_ML_KEM_GenKey(CRYPT_ML_KEM_Ctx *ctx)
426 {
427 if (ctx == NULL) {
428 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
429 return CRYPT_NULL_INPUT;
430 }
431 if (ctx->info == NULL) {
432 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYINFO_NOT_SET);
433 return CRYPT_MLKEM_KEYINFO_NOT_SET;
434 }
435 if (MlKemCreateKeyBuf(ctx) != CRYPT_SUCCESS) {
436 return CRYPT_MEM_ALLOC_FAIL;
437 }
438 uint8_t d[MLKEM_SEED_LEN];
439 uint8_t z[MLKEM_SEED_LEN];
440 int32_t ret = CRYPT_RandEx(ctx->libCtx, d, MLKEM_SEED_LEN);
441 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
442 ret = CRYPT_RandEx(ctx->libCtx, z, MLKEM_SEED_LEN);
443 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
444
445 ret = MLKEM_KeyGenInternal(ctx, d, z);
446 BSL_SAL_CleanseData(d, MLKEM_SEED_LEN);
447 BSL_SAL_CleanseData(z, MLKEM_SEED_LEN);
448 return ret;
449 }
450
EncCapsInputCheck(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,uint32_t * ctLen,uint8_t * sk,uint32_t * skLen)451 static int32_t EncCapsInputCheck(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t *ctLen,
452 uint8_t *sk, uint32_t *skLen)
453 {
454 if (ctx == NULL || ctx->ek == NULL || ct == NULL || ctLen == NULL ||
455 sk == NULL || skLen == NULL) {
456 return CRYPT_NULL_INPUT;
457 }
458 if (ctx->info == NULL) {
459 return CRYPT_MLKEM_KEYINFO_NOT_SET;
460 }
461 if (*ctLen < ctx->info->cipherLen || *skLen < MLKEM_SHARED_KEY_LEN) {
462 return CRYPT_MLKEM_LEN_NOT_ENOUGH;
463 }
464 return CRYPT_SUCCESS;
465 }
466
CRYPT_ML_KEM_Encaps(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * cipher,uint32_t * cipherLen,uint8_t * share,uint32_t * shareLen)467 int32_t CRYPT_ML_KEM_Encaps(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *cipher, uint32_t *cipherLen,
468 uint8_t *share, uint32_t *shareLen)
469 {
470 int32_t ret = EncCapsInputCheck(ctx, cipher, cipherLen, share, shareLen);
471 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
472
473 uint8_t m[MLKEM_SEED_LEN];
474 ret = CRYPT_RandEx(ctx->libCtx, m, MLKEM_SEED_LEN);
475 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
476
477 ret = MLKEM_EncapsInternal(ctx, cipher, cipherLen, share, shareLen, m);
478 BSL_SAL_CleanseData(m, MLKEM_SEED_LEN);
479 return ret;
480 }
481
DecCapsInputCheck(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,uint32_t ctLen,uint8_t * sk,uint32_t * skLen)482 static int32_t DecCapsInputCheck(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t ctLen,
483 uint8_t *sk, uint32_t *skLen)
484 {
485 if (ctx == NULL || ctx->dk == NULL || ct == NULL || sk == NULL || skLen == NULL) {
486 return CRYPT_NULL_INPUT;
487 }
488 if (ctx->info == NULL) {
489 return CRYPT_MLKEM_KEYINFO_NOT_SET;
490 }
491 if (ctLen != ctx->info->cipherLen || *skLen < MLKEM_SHARED_KEY_LEN) {
492 return CRYPT_MLKEM_LEN_NOT_ENOUGH;
493 }
494 return CRYPT_SUCCESS;
495 }
496
CRYPT_ML_KEM_Decaps(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * cipher,uint32_t cipherLen,uint8_t * share,uint32_t * shareLen)497 int32_t CRYPT_ML_KEM_Decaps(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *cipher, uint32_t cipherLen,
498 uint8_t *share, uint32_t *shareLen)
499 {
500 int32_t ret = DecCapsInputCheck(ctx, cipher, cipherLen, share, shareLen);
501 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
502
503 return MLKEM_DecapsInternal(ctx, cipher, cipherLen, share, shareLen);
504 }
505
506 #endif