1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15 #include "hitls_build.h"
16 #ifdef HITLS_CRYPTO_SLH_DSA
17
18 #include <stddef.h>
19 #include "securec.h"
20 #include "bsl_err_internal.h"
21 #include "bsl_sal.h"
22 #include "bsl_obj_internal.h"
23 #include "bsl_asn1.h"
24 #include "crypt_errno.h"
25 #include "crypt_util_rand.h"
26 #include "eal_md_local.h"
27 #include "crypt_slh_dsa.h"
28 #include "slh_dsa_local.h"
29 #include "slh_dsa_hash.h"
30 #include "slh_dsa_fors.h"
31 #include "slh_dsa_xmss.h"
32 #include "slh_dsa_hypertree.h"
33
34 #define MAX_DIGEST_SIZE 64
35 #define BYTE_BITS 8
36 #define SLH_DSA_PREFIX_LEN 2
37 #define ASN1_HEADER_LEN 2
38 #define SPLIT_CEIL(a, b) (((a) + (b) - 1) / (b))
39 #define SPLIT_BYTES(a) SPLIT_CEIL(a, BYTE_BITS)
40
41 typedef struct {
42 BSL_Param *pubSeed;
43 BSL_Param *pubRoot;
44 } SlhDsaPubKeyParam;
45
46 typedef struct {
47 BSL_Param *prvSeed;
48 BSL_Param *prvPrf;
49 BSL_Param *pubSeed;
50 BSL_Param *pubRoot;
51 } SlhDsaPrvKeyParam;
52
53 // reference to FIPS-205, table 2
54 static uint32_t g_slhDsaN[CRYPT_SLH_DSA_ALG_ID_MAX] = {16, 16, 16, 16, 24, 24, 24, 24, 32, 32, 32, 32};
55 static uint32_t g_slhDsaH[CRYPT_SLH_DSA_ALG_ID_MAX] = {63, 63, 66, 66, 63, 63, 66, 66, 64, 64, 68, 68};
56 static uint32_t g_slhDsaD[CRYPT_SLH_DSA_ALG_ID_MAX] = {7, 7, 22, 22, 7, 7, 22, 22, 8, 8, 17, 17};
57 static uint32_t g_slhDsaHp[CRYPT_SLH_DSA_ALG_ID_MAX] = {9, 9, 3, 3, 9, 9, 3, 3, 8, 8, 4, 4}; // xmss height
58 static uint32_t g_slhDsaA[CRYPT_SLH_DSA_ALG_ID_MAX] = {12, 12, 6, 6, 14, 14, 8, 8, 14, 14, 9, 9};
59 static uint32_t g_slhDsaK[CRYPT_SLH_DSA_ALG_ID_MAX] = {14, 14, 33, 33, 17, 17, 33, 33, 22, 22, 35, 35};
60 static uint32_t g_slhDsaM[CRYPT_SLH_DSA_ALG_ID_MAX] = {30, 30, 34, 34, 39, 39, 42, 42, 47, 47, 49, 49};
61 static uint32_t g_slhDsaPkBytes[CRYPT_SLH_DSA_ALG_ID_MAX] = {32, 32, 32, 32, 48, 48, 48, 48, 64, 64, 64, 64};
62 static uint32_t g_slhDsaSigBytes[CRYPT_SLH_DSA_ALG_ID_MAX] = {7856, 7856, 17088, 17088, 16224, 16224,
63 35664, 35664, 29792, 29792, 49856, 49856};
64 static uint8_t g_secCategory[] = {1, 1, 1, 1, 3, 3, 3, 3, 5, 5, 5, 5};
65
66 // "UC" means uncompressed
UCAdrsSetLayerAddr(SlhDsaAdrs * adrs,uint32_t layer)67 static void UCAdrsSetLayerAddr(SlhDsaAdrs *adrs, uint32_t layer)
68 {
69 PUT_UINT32_BE(layer, adrs->uc.layerAddr, 0);
70 }
71
UCAdrsSetTreeAddr(SlhDsaAdrs * adrs,uint64_t tree)72 static void UCAdrsSetTreeAddr(SlhDsaAdrs *adrs, uint64_t tree)
73 {
74 // Write 8-byte tree address starting from offset 4 in 12-byte treeAddr field
75 PUT_UINT64_BE(tree, adrs->uc.treeAddr, 4);
76 }
77
UCAdrsSetType(SlhDsaAdrs * adrs,AdrsType type)78 static void UCAdrsSetType(SlhDsaAdrs *adrs, AdrsType type)
79 {
80 PUT_UINT32_BE(type, adrs->uc.type, 0);
81 (void)memset_s(adrs->uc.padding, sizeof(adrs->uc.padding), 0, sizeof(adrs->uc.padding));
82 }
83
UCAdrsSetKeyPairAddr(SlhDsaAdrs * adrs,uint32_t keyPair)84 static void UCAdrsSetKeyPairAddr(SlhDsaAdrs *adrs, uint32_t keyPair)
85 {
86 PUT_UINT32_BE(keyPair, adrs->uc.padding, 0);
87 }
88
UCAdrsSetChainAddr(SlhDsaAdrs * adrs,uint32_t chain)89 static void UCAdrsSetChainAddr(SlhDsaAdrs *adrs, uint32_t chain)
90 {
91 PUT_UINT32_BE(chain, adrs->uc.padding, 4); // chain address is 4 bytes, start from 4-th byte
92 }
93
UCAdrsSetTreeHeight(SlhDsaAdrs * adrs,uint32_t height)94 static void UCAdrsSetTreeHeight(SlhDsaAdrs *adrs, uint32_t height)
95 {
96 PUT_UINT32_BE(height, adrs->uc.padding, 4); // tree height is 4 bytes, start from 4-th byte
97 }
98
UCAdrsSetHashAddr(SlhDsaAdrs * adrs,uint32_t hash)99 static void UCAdrsSetHashAddr(SlhDsaAdrs *adrs, uint32_t hash)
100 {
101 PUT_UINT32_BE(hash, adrs->uc.padding, 8); // hash address is 4 bytes, start from 8-th byte
102 }
103
UCAdrsSetTreeIndex(SlhDsaAdrs * adrs,uint32_t index)104 static void UCAdrsSetTreeIndex(SlhDsaAdrs *adrs, uint32_t index)
105 {
106 PUT_UINT32_BE(index, adrs->uc.padding, 8); // tree index is 4 bytes, start from 8-th byte
107 }
108
UCAdrsGetTreeHeight(const SlhDsaAdrs * adrs)109 static uint32_t UCAdrsGetTreeHeight(const SlhDsaAdrs *adrs)
110 {
111 return GET_UINT32_BE(adrs->uc.padding, 0);
112 }
113
UCAdrsGetTreeIndex(const SlhDsaAdrs * adrs)114 static uint32_t UCAdrsGetTreeIndex(const SlhDsaAdrs *adrs)
115 {
116 return GET_UINT32_BE(adrs->uc.padding, 8); // tree index is 4 bytes, start from 8-th byte
117 }
118
UCAdrsCopyKeyPairAddr(SlhDsaAdrs * adrs,const SlhDsaAdrs * adrs2)119 static void UCAdrsCopyKeyPairAddr(SlhDsaAdrs *adrs, const SlhDsaAdrs *adrs2)
120 {
121 (void)memcpy_s(adrs->uc.padding, sizeof(adrs->uc.padding), adrs2->uc.padding,
122 4); // key pair address is 4 bytes, start from 4-th byte
123 }
124
UCAdrsGetAdrsLen(void)125 static uint32_t UCAdrsGetAdrsLen(void)
126 {
127 return SLH_DSA_ADRS_LEN;
128 }
129
130 // "C" means compressed
CAdrsSetLayerAddr(SlhDsaAdrs * adrs,uint32_t layer)131 static void CAdrsSetLayerAddr(SlhDsaAdrs *adrs, uint32_t layer)
132 {
133 adrs->c.layerAddr = (uint8_t)layer;
134 }
135
CAdrsSetTreeAddr(SlhDsaAdrs * adrs,uint64_t tree)136 static void CAdrsSetTreeAddr(SlhDsaAdrs *adrs, uint64_t tree)
137 {
138 // Write 8-byte tree address starting from offset 0 in 8-byte treeAddr field
139 PUT_UINT64_BE(tree, adrs->c.treeAddr, 0);
140 }
141
CAdrsSetType(SlhDsaAdrs * adrs,AdrsType type)142 static void CAdrsSetType(SlhDsaAdrs *adrs, AdrsType type)
143 {
144 adrs->c.type = type;
145 (void)memset_s(adrs->c.padding, sizeof(adrs->c.padding), 0, sizeof(adrs->c.padding));
146 }
147
CAdrsSetKeyPairAddr(SlhDsaAdrs * adrs,uint32_t keyPair)148 static void CAdrsSetKeyPairAddr(SlhDsaAdrs *adrs, uint32_t keyPair)
149 {
150 PUT_UINT32_BE(keyPair, adrs->c.padding, 0);
151 }
152
CAdrsSetChainAddr(SlhDsaAdrs * adrs,uint32_t chain)153 static void CAdrsSetChainAddr(SlhDsaAdrs *adrs, uint32_t chain)
154 {
155 PUT_UINT32_BE(chain, adrs->c.padding, 4); // chain address is 4 bytes, start from 4-th byte
156 }
157
CAdrsSetTreeHeight(SlhDsaAdrs * adrs,uint32_t height)158 static void CAdrsSetTreeHeight(SlhDsaAdrs *adrs, uint32_t height)
159 {
160 PUT_UINT32_BE(height, adrs->c.padding, 4); // tree height is 4 bytes, start from 4-th byte
161 }
162
CAdrsSetHashAddr(SlhDsaAdrs * adrs,uint32_t hash)163 static void CAdrsSetHashAddr(SlhDsaAdrs *adrs, uint32_t hash)
164 {
165 PUT_UINT32_BE(hash, adrs->c.padding, 8); // hash address is 4 bytes, start from 8-th byte
166 }
167
CAdrsSetTreeIndex(SlhDsaAdrs * adrs,uint32_t index)168 static void CAdrsSetTreeIndex(SlhDsaAdrs *adrs, uint32_t index)
169 {
170 PUT_UINT32_BE(index, adrs->c.padding, 8); // tree index is 4 bytes, start from 8-th byte
171 }
172
CAdrsGetTreeHeight(const SlhDsaAdrs * adrs)173 static uint32_t CAdrsGetTreeHeight(const SlhDsaAdrs *adrs)
174 {
175 return GET_UINT32_BE(adrs->c.padding, 0); // tree height is 4 bytes, start from 0-th byte
176 }
177
CAdrsGetTreeIndex(const SlhDsaAdrs * adrs)178 static uint32_t CAdrsGetTreeIndex(const SlhDsaAdrs *adrs)
179 {
180 return GET_UINT32_BE(adrs->c.padding, 8); // tree index is 4 bytes, start from 8-th byte
181 }
182
CAdrsCopyKeyPairAddr(SlhDsaAdrs * adrs,const SlhDsaAdrs * adrs2)183 static void CAdrsCopyKeyPairAddr(SlhDsaAdrs *adrs, const SlhDsaAdrs *adrs2)
184 {
185 (void)memcpy_s(adrs->c.padding, sizeof(adrs->c.padding), adrs2->c.padding,
186 4); // key pair address is 4 bytes, start from 4-th byte
187 }
188
CAdrsGetAdrsLen(void)189 static uint32_t CAdrsGetAdrsLen(void)
190 {
191 return SLH_DSA_ADRS_COMPRESSED_LEN;
192 }
193
194 static AdrsOps g_adrsOps[2] = {{
195 .setLayerAddr = UCAdrsSetLayerAddr,
196 .setTreeAddr = UCAdrsSetTreeAddr,
197 .setType = UCAdrsSetType,
198 .setKeyPairAddr = UCAdrsSetKeyPairAddr,
199 .setChainAddr = UCAdrsSetChainAddr,
200 .setTreeHeight = UCAdrsSetTreeHeight,
201 .setHashAddr = UCAdrsSetHashAddr,
202 .setTreeIndex = UCAdrsSetTreeIndex,
203 .getTreeHeight = UCAdrsGetTreeHeight,
204 .getTreeIndex = UCAdrsGetTreeIndex,
205 .copyKeyPairAddr = UCAdrsCopyKeyPairAddr,
206 .getAdrsLen = UCAdrsGetAdrsLen,
207 },
208 {
209 .setLayerAddr = CAdrsSetLayerAddr,
210 .setTreeAddr = CAdrsSetTreeAddr,
211 .setType = CAdrsSetType,
212 .setKeyPairAddr = CAdrsSetKeyPairAddr,
213 .setChainAddr = CAdrsSetChainAddr,
214 .setTreeHeight = CAdrsSetTreeHeight,
215 .setHashAddr = CAdrsSetHashAddr,
216 .setTreeIndex = CAdrsSetTreeIndex,
217 .getTreeHeight = CAdrsGetTreeHeight,
218 .getTreeIndex = CAdrsGetTreeIndex,
219 .copyKeyPairAddr = CAdrsCopyKeyPairAddr,
220 .getAdrsLen = CAdrsGetAdrsLen,
221 }};
222
BaseB(const uint8_t * x,uint32_t xLen,uint32_t b,uint32_t * out,uint32_t outLen)223 void BaseB(const uint8_t *x, uint32_t xLen, uint32_t b, uint32_t *out, uint32_t outLen)
224 {
225 uint32_t bit = 0;
226 uint32_t o = 0;
227 uint32_t xi = 0;
228 for (uint32_t i = 0; i < outLen; i++) {
229 while (bit < b && xi < xLen) {
230 o = (o << BYTE_BITS) + x[xi];
231 bit += 8;
232 xi++;
233 }
234 bit -= b;
235 out[i] = o >> bit;
236 // keep the remaining bits
237 o &= (1 << bit) - 1;
238 }
239 }
240
241 // ToInt(b[0:l]) mod 2^m
ToIntMod(const uint8_t * b,uint32_t l,uint32_t m)242 static uint64_t ToIntMod(const uint8_t *b, uint32_t l, uint32_t m)
243 {
244 uint64_t ret = 0;
245 for (uint32_t i = 0; i < l; i++) {
246 ret = (ret << BYTE_BITS) + b[i];
247 }
248
249 return ret & (~(uint64_t)0 >> (64 - m)); // mod 2^m is same to ~(uint64_t)0 >> (64 - m)
250 }
251
CRYPT_SLH_DSA_NewCtx(void)252 CryptSlhDsaCtx *CRYPT_SLH_DSA_NewCtx(void)
253 {
254 CryptSlhDsaCtx *ctx = (CryptSlhDsaCtx *)BSL_SAL_Calloc(sizeof(CryptSlhDsaCtx), 1);
255 if (ctx == NULL) {
256 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
257 return NULL;
258 }
259 ctx->para.algId = CRYPT_SLH_DSA_ALG_ID_MAX;
260 ctx->isPrehash = false;
261 ctx->isDeterministic = false;
262 return ctx;
263 }
264
CRYPT_SLH_DSA_NewCtxEx(void * libCtx)265 CryptSlhDsaCtx *CRYPT_SLH_DSA_NewCtxEx(void *libCtx)
266 {
267 CryptSlhDsaCtx *ctx = CRYPT_SLH_DSA_NewCtx();
268 if (ctx == NULL) {
269 return NULL;
270 }
271 ctx->libCtx = libCtx;
272 return ctx;
273 }
274
CRYPT_SLH_DSA_FreeCtx(CryptSlhDsaCtx * ctx)275 void CRYPT_SLH_DSA_FreeCtx(CryptSlhDsaCtx *ctx)
276 {
277 if (ctx == NULL) {
278 return;
279 }
280 BSL_SAL_Free(ctx->context);
281 BSL_SAL_ClearFree(ctx->addrand, ctx->addrandLen);
282 BSL_SAL_CleanseData(ctx->prvKey.seed, sizeof(ctx->prvKey.seed));
283 BSL_SAL_CleanseData(ctx->prvKey.prf, sizeof(ctx->prvKey.prf));
284 BSL_SAL_Free(ctx);
285 }
286
CRYPT_SLH_DSA_Gen(CryptSlhDsaCtx * ctx)287 int32_t CRYPT_SLH_DSA_Gen(CryptSlhDsaCtx *ctx)
288 {
289 int32_t ret;
290 if (ctx == NULL) {
291 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
292 return CRYPT_NULL_INPUT;
293 }
294 if (ctx->para.algId >= CRYPT_SLH_DSA_ALG_ID_MAX) {
295 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_ALGID);
296 return CRYPT_SLHDSA_ERR_INVALID_ALGID;
297 }
298 uint32_t n = ctx->para.n;
299 uint32_t d = ctx->para.d;
300 uint32_t hp = ctx->para.hp;
301 ret = CRYPT_RandEx(ctx->libCtx, ctx->prvKey.seed, n);
302 if (ret != CRYPT_SUCCESS) {
303 BSL_ERR_PUSH_ERROR(ret);
304 return ret;
305 }
306
307 ret = CRYPT_RandEx(ctx->libCtx, ctx->prvKey.prf, n);
308 if (ret != CRYPT_SUCCESS) {
309 BSL_ERR_PUSH_ERROR(ret);
310 return ret;
311 }
312
313 ret = CRYPT_RandEx(ctx->libCtx, ctx->prvKey.pub.seed, n);
314 if (ret != CRYPT_SUCCESS) {
315 BSL_ERR_PUSH_ERROR(ret);
316 return ret;
317 }
318 SlhDsaAdrs adrs = {0};
319 ctx->adrsOps.setLayerAddr(&adrs, d - 1);
320 uint8_t node[SLH_DSA_MAX_N] = {0};
321 ret = XmssNode(node, 0, hp, &adrs, ctx);
322 if (ret != CRYPT_SUCCESS) {
323 BSL_ERR_PUSH_ERROR(ret);
324 return ret;
325 }
326 (void)memcpy_s(ctx->prvKey.pub.root, n, node, n);
327 return CRYPT_SUCCESS;
328 }
329
GetAddRand(CryptSlhDsaCtx * ctx)330 static int32_t GetAddRand(CryptSlhDsaCtx *ctx)
331 {
332 if (ctx->addrand != NULL) {
333 // the additional rand is set.
334 return CRYPT_SUCCESS;
335 }
336 if (!ctx->isDeterministic) {
337 ctx->addrand = (uint8_t *)BSL_SAL_Malloc(ctx->para.n);
338 if (ctx->addrand == NULL) {
339 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
340 return CRYPT_MEM_ALLOC_FAIL;
341 }
342 int32_t ret = CRYPT_RandEx(ctx->libCtx, ctx->addrand, ctx->para.n);
343 if (ret != CRYPT_SUCCESS) {
344 return ret;
345 }
346 } else {
347 // FIPS-204, Algorithm 19, line 2.
348 // if is deterministic, use the public key seed as the random number.
349 uint8_t *rand = (uint8_t *)BSL_SAL_Malloc(ctx->para.n);
350 if (rand == NULL) {
351 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
352 return CRYPT_MEM_ALLOC_FAIL;
353 }
354 (void)memcpy_s(rand, ctx->para.n, ctx->prvKey.pub.seed, ctx->para.n);
355 ctx->addrand = rand;
356 }
357 ctx->addrandLen = ctx->para.n;
358 return CRYPT_SUCCESS;
359 }
360
GetTreeAndLeafIdx(const uint8_t * digest,const CryptSlhDsaCtx * ctx,uint64_t * treeIdx,uint32_t * leafIdx)361 static void GetTreeAndLeafIdx(const uint8_t *digest, const CryptSlhDsaCtx *ctx, uint64_t *treeIdx, uint32_t *leafIdx)
362 {
363 uint32_t a = ctx->para.a;
364 uint32_t k = ctx->para.k;
365 uint32_t h = ctx->para.h;
366 uint32_t d = ctx->para.d;
367
368 uint32_t mdIdx = SPLIT_BYTES(k * a);
369 uint32_t treeIdxLen = SPLIT_BYTES(h - h / d);
370 uint32_t leafIdxLen = SPLIT_BYTES(h / d);
371 *treeIdx = ToIntMod(digest + mdIdx, treeIdxLen, h - h / d);
372 *leafIdx = (uint32_t)ToIntMod(digest + mdIdx + treeIdxLen, leafIdxLen, h / d);
373 }
374
CRYPT_SLH_DSA_SignInternal(CryptSlhDsaCtx * ctx,const uint8_t * msg,uint32_t msgLen,uint8_t * sig,uint32_t * sigLen)375 static int32_t CRYPT_SLH_DSA_SignInternal(CryptSlhDsaCtx *ctx, const uint8_t *msg, uint32_t msgLen, uint8_t *sig,
376 uint32_t *sigLen)
377 {
378 int32_t ret;
379 uint32_t n = ctx->para.n;
380 uint32_t a = ctx->para.a;
381 uint32_t k = ctx->para.k;
382 uint32_t sigBytes = ctx->para.sigBytes;
383 uint32_t mdIdx = SPLIT_BYTES(k * a);
384 uint64_t treeIdx;
385 uint32_t leafIdx;
386
387 if (*sigLen < sigBytes) {
388 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_SIG_LEN);
389 return CRYPT_SLHDSA_ERR_INVALID_SIG_LEN;
390 }
391 SlhDsaAdrs adrs = {0};
392 uint32_t offset = 0;
393 uint32_t left = *sigLen;
394
395 ret = GetAddRand(ctx);
396 if (ret != CRYPT_SUCCESS) {
397 return ret;
398 }
399
400 ret = ctx->hashFuncs.prfmsg(ctx, ctx->addrand, msg, msgLen, sig);
401 if (ret != CRYPT_SUCCESS) {
402 BSL_ERR_PUSH_ERROR(ret);
403 return ret;
404 }
405 offset += n;
406 uint8_t digest[SLH_DSA_MAX_M] = {0};
407 ret = ctx->hashFuncs.hmsg(ctx, sig, msg, msgLen, digest);
408 if (ret != CRYPT_SUCCESS) {
409 BSL_ERR_PUSH_ERROR(ret);
410 return ret;
411 }
412
413 GetTreeAndLeafIdx(digest, ctx, &treeIdx, &leafIdx);
414 ctx->adrsOps.setTreeAddr(&adrs, treeIdx);
415 ctx->adrsOps.setType(&adrs, FORS_TREE);
416 ctx->adrsOps.setKeyPairAddr(&adrs, leafIdx);
417 ret = ForsSign(digest, mdIdx, &adrs, ctx, sig + offset, &left);
418 if (ret != CRYPT_SUCCESS) {
419 BSL_ERR_PUSH_ERROR(ret);
420 return ret;
421 }
422 uint8_t pk[SLH_DSA_MAX_N] = {0};
423 ret = ForsPkFromSig(sig + n, left, digest, mdIdx, &adrs, ctx, pk);
424 if (ret != CRYPT_SUCCESS) {
425 BSL_ERR_PUSH_ERROR(ret);
426 return ret;
427 }
428 offset += left;
429 left = *sigLen - offset;
430 ret = HypertreeSign(pk, n, treeIdx, leafIdx, ctx, sig + offset, &left);
431 if (ret != CRYPT_SUCCESS) {
432 BSL_ERR_PUSH_ERROR(ret);
433 return ret;
434 }
435 *sigLen = offset + left;
436 return CRYPT_SUCCESS;
437 }
438
CRYPT_SLH_DSA_VerifyInternal(const CryptSlhDsaCtx * ctx,const uint8_t * msg,uint32_t msgLen,const uint8_t * sig,uint32_t sigLen)439 static int32_t CRYPT_SLH_DSA_VerifyInternal(const CryptSlhDsaCtx *ctx, const uint8_t *msg, uint32_t msgLen,
440 const uint8_t *sig, uint32_t sigLen)
441 {
442 int32_t ret;
443 uint32_t n = ctx->para.n;
444 uint32_t a = ctx->para.a;
445 uint32_t k = ctx->para.k;
446 uint32_t sigBytes = ctx->para.sigBytes;
447 uint32_t mdIdx = SPLIT_BYTES(k * a);
448 uint64_t treeIdx;
449 uint32_t leafIdx;
450
451 if (sigLen != sigBytes) {
452 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_SIG_LEN);
453 return CRYPT_SLHDSA_ERR_INVALID_SIG_LEN;
454 }
455
456 SlhDsaAdrs adrs = {0};
457 uint32_t offset = 0;
458
459 uint8_t digest[SLH_DSA_MAX_M] = {0};
460 ret = ctx->hashFuncs.hmsg(ctx, sig, msg, msgLen, digest);
461 if (ret != CRYPT_SUCCESS) {
462 BSL_ERR_PUSH_ERROR(ret);
463 return ret;
464 }
465 offset += n;
466
467 GetTreeAndLeafIdx(digest, ctx, &treeIdx, &leafIdx);
468 ctx->adrsOps.setTreeAddr(&adrs, treeIdx);
469 ctx->adrsOps.setType(&adrs, FORS_TREE);
470 ctx->adrsOps.setKeyPairAddr(&adrs, leafIdx);
471 uint8_t pk[SLH_DSA_MAX_N] = {0};
472 ret = ForsPkFromSig(sig + offset, (1 + a) * k * n, digest, mdIdx, &adrs, ctx, pk);
473 if (ret != CRYPT_SUCCESS) {
474 BSL_ERR_PUSH_ERROR(ret);
475 return ret;
476 }
477 offset += (1 + a) * k * n;
478 ret = HypertreeVerify(pk, n, sig + offset, sigLen - offset, treeIdx, leafIdx, ctx);
479 if (ret != CRYPT_SUCCESS) {
480 BSL_ERR_PUSH_ERROR(ret);
481 return ret;
482 }
483 return CRYPT_SUCCESS;
484 }
485
GetMdSize(const EAL_MdMethod * hashMethod,int32_t hashId)486 static uint32_t GetMdSize(const EAL_MdMethod *hashMethod, int32_t hashId)
487 {
488 if (hashId == CRYPT_MD_SHAKE128) {
489 return 32; // To use SHAKE128, generate a 32-byte digest.
490 } else if (hashId == CRYPT_MD_SHAKE256) {
491 return 64; // To use SHAKE256, generate a 64-byte digest.
492 }
493 return hashMethod->mdSize;
494 }
495
MsgEncode(const CryptSlhDsaCtx * ctx,int32_t algId,const uint8_t * data,uint32_t dataLen,uint8_t ** mpOut,uint32_t * mpLenOut)496 static int32_t MsgEncode(const CryptSlhDsaCtx *ctx, int32_t algId, const uint8_t *data, uint32_t dataLen,
497 uint8_t **mpOut, uint32_t *mpLenOut)
498 {
499 int32_t ret;
500 BslOidString *oid = NULL;
501 uint32_t offset = 0;
502 uint8_t prehash[MAX_DIGEST_SIZE] = {0};
503 uint32_t prehashLen = sizeof(prehash);
504
505 uint32_t mpLen = SLH_DSA_PREFIX_LEN + ctx->contextLen;
506 if (ctx->isPrehash) {
507 oid = BSL_OBJ_GetOID((BslCid)algId);
508 if (oid == NULL) {
509 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_PREHASH_ID_NOT_SUPPORTED);
510 return CRYPT_SLHDSA_ERR_PREHASH_ID_NOT_SUPPORTED;
511 }
512 mpLen += 2 + oid->octetLen; // asn1 header length is 2
513 prehashLen = GetMdSize(EAL_MdFindMethod(algId), algId);
514 const CRYPT_ConstData constData = {data, dataLen};
515 ret = CRYPT_CalcHash(EAL_MdFindMethod(algId), &constData, 1, prehash, &prehashLen);
516 if (ret != CRYPT_SUCCESS) {
517 BSL_ERR_PUSH_ERROR(ret);
518 return ret;
519 }
520 mpLen += prehashLen;
521 } else {
522 mpLen += dataLen;
523 }
524
525 uint8_t *mp = (uint8_t *)BSL_SAL_Malloc(mpLen);
526 if (mp == NULL) {
527 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
528 return CRYPT_MEM_ALLOC_FAIL;
529 }
530 mp[0] = ctx->isPrehash ? 1 : 0;
531 mp[1] = (uint8_t)ctx->contextLen;
532 (void)memcpy_s(mp + SLH_DSA_PREFIX_LEN, mpLen - SLH_DSA_PREFIX_LEN, ctx->context, ctx->contextLen);
533 offset += SLH_DSA_PREFIX_LEN + ctx->contextLen;
534
535 if (ctx->isPrehash) {
536 // asn1 encoding of hash oid
537 (mp + offset)[0] = BSL_ASN1_TAG_OBJECT_ID;
538 (mp + offset)[1] = (uint8_t)oid->octetLen;
539 offset += 2; // asn1 header length is 2
540 (void)memcpy_s(mp + offset, mpLen - offset, oid->octs, oid->octetLen);
541 offset += oid->octetLen;
542 (void)memcpy_s(mp + offset, mpLen - offset, prehash, prehashLen);
543 } else {
544 (void)memcpy_s(mp + offset, mpLen - offset, data, dataLen);
545 }
546 *mpOut = mp;
547 *mpLenOut = mpLen;
548 return CRYPT_SUCCESS;
549 }
550
CRYPT_SLH_DSA_Sign(CryptSlhDsaCtx * ctx,int32_t algId,const uint8_t * data,uint32_t dataLen,uint8_t * sign,uint32_t * signLen)551 int32_t CRYPT_SLH_DSA_Sign(CryptSlhDsaCtx *ctx, int32_t algId, const uint8_t *data, uint32_t dataLen, uint8_t *sign,
552 uint32_t *signLen)
553 {
554 int32_t ret;
555 uint8_t *mp = NULL;
556 uint32_t mpLen = 0;
557
558 if (ctx == NULL || data == NULL || dataLen == 0 || sign == NULL || signLen == NULL) {
559 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
560 return CRYPT_NULL_INPUT;
561 }
562 ret = MsgEncode(ctx, algId, data, dataLen, &mp, &mpLen);
563 if (ret != CRYPT_SUCCESS) {
564 return ret;
565 }
566 ret = CRYPT_SLH_DSA_SignInternal(ctx, mp, mpLen, sign, signLen);
567 if (ret != CRYPT_SUCCESS) {
568 BSL_SAL_Free(mp);
569 BSL_ERR_PUSH_ERROR(ret);
570 return ret;
571 }
572 BSL_SAL_Free(mp);
573 return CRYPT_SUCCESS;
574 }
575
CRYPT_SLH_DSA_Verify(const CryptSlhDsaCtx * ctx,int32_t algId,const uint8_t * data,uint32_t dataLen,const uint8_t * sign,uint32_t signLen)576 int32_t CRYPT_SLH_DSA_Verify(const CryptSlhDsaCtx *ctx, int32_t algId, const uint8_t *data, uint32_t dataLen,
577 const uint8_t *sign, uint32_t signLen)
578 {
579 (void)algId;
580 int32_t ret;
581 uint8_t *mp = NULL;
582 uint32_t mpLen = 0;
583
584 if (ctx == NULL || data == NULL || dataLen == 0 || sign == NULL || signLen == 0) {
585 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
586 return CRYPT_NULL_INPUT;
587 }
588
589 ret = MsgEncode(ctx, algId, data, dataLen, &mp, &mpLen);
590 if (ret != CRYPT_SUCCESS) {
591 return ret;
592 }
593 ret = CRYPT_SLH_DSA_VerifyInternal(ctx, mp, mpLen, sign, signLen);
594 BSL_SAL_Free(mp);
595 return ret;
596 }
597
SlhDsaSetAlgId(CryptSlhDsaCtx * ctx,CRYPT_SLH_DSA_AlgId algId)598 static void SlhDsaSetAlgId(CryptSlhDsaCtx *ctx, CRYPT_SLH_DSA_AlgId algId)
599 {
600 ctx->para.algId = algId;
601 ctx->para.n = g_slhDsaN[algId];
602 ctx->para.h = g_slhDsaH[algId];
603 ctx->para.d = g_slhDsaD[algId];
604 ctx->para.hp = g_slhDsaHp[algId];
605 ctx->para.a = g_slhDsaA[algId];
606 ctx->para.k = g_slhDsaK[algId];
607 ctx->para.m = g_slhDsaM[algId];
608 ctx->para.pkBytes = g_slhDsaPkBytes[algId];
609 ctx->para.sigBytes = g_slhDsaSigBytes[algId];
610 ctx->para.secCategory = g_secCategory[algId];
611 SlhDsaInitHashFuncs(ctx);
612 if (ctx->para.isCompressed) {
613 ctx->adrsOps = g_adrsOps[1];
614 } else {
615 ctx->adrsOps = g_adrsOps[0];
616 }
617 }
618
CRYPT_SLH_DSA_Ctrl(CryptSlhDsaCtx * ctx,int32_t opt,void * val,uint32_t len)619 int32_t CRYPT_SLH_DSA_Ctrl(CryptSlhDsaCtx *ctx, int32_t opt, void *val, uint32_t len)
620 {
621 if (ctx == NULL) {
622 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
623 return CRYPT_NULL_INPUT;
624 }
625 switch (opt) {
626 case CRYPT_CTRL_SET_PARA_BY_ID:
627 if (val == NULL || len != sizeof(CRYPT_SLH_DSA_AlgId)) {
628 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
629 return CRYPT_INVALID_ARG;
630 }
631 CRYPT_SLH_DSA_AlgId algId = *(CRYPT_SLH_DSA_AlgId *)val;
632 if (algId >= CRYPT_SLH_DSA_ALG_ID_MAX) {
633 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_ALGID);
634 return CRYPT_SLHDSA_ERR_INVALID_ALGID;
635 }
636 SlhDsaSetAlgId(ctx, algId);
637 return CRYPT_SUCCESS;
638 case CRYPT_CTRL_SET_PREHASH_FLAG:
639 if (val == NULL || len != sizeof(int32_t)) {
640 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
641 return CRYPT_INVALID_ARG;
642 }
643 ctx->isPrehash = (*(int32_t *)val != 0);
644 return CRYPT_SUCCESS;
645 case CRYPT_CTRL_SET_CTX_INFO:
646 if (val == NULL) {
647 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
648 return CRYPT_INVALID_ARG;
649 }
650 if (len > 255) {
651 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_CONTEXT_LEN_OVERFLOW);
652 return CRYPT_SLHDSA_ERR_CONTEXT_LEN_OVERFLOW;
653 }
654 ctx->contextLen = len;
655 BSL_SAL_Free(ctx->context);
656 ctx->context = (uint8_t *)BSL_SAL_Malloc(len);
657 if (ctx->context == NULL) {
658 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
659 return CRYPT_MEM_ALLOC_FAIL;
660 }
661 (void)memcpy_s(ctx->context, len, val, len);
662 return CRYPT_SUCCESS;
663 case CRYPT_CTRL_GET_SLH_DSA_KEY_LEN:
664 if (val == NULL || len != sizeof(uint32_t)) {
665 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
666 return CRYPT_INVALID_ARG;
667 }
668 *(uint32_t *)val = ctx->para.n;
669 return CRYPT_SUCCESS;
670 case CRYPT_CTRL_SET_DETERMINISTIC_FLAG:
671 if (val == NULL || len != sizeof(int32_t)) {
672 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
673 return CRYPT_INVALID_ARG;
674 }
675 ctx->isDeterministic = (*(int32_t *)val != 0);
676 return CRYPT_SUCCESS;
677 case CRYPT_CTRL_SET_SLH_DSA_ADDRAND:
678 if (val == NULL || len != ctx->para.n) {
679 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
680 return CRYPT_INVALID_ARG;
681 }
682 if (ctx->addrand != NULL) {
683 BSL_SAL_Free(ctx->addrand);
684 }
685 uint8_t *rand = (uint8_t *)BSL_SAL_Malloc(len);
686 if (rand == NULL) {
687 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
688 return CRYPT_MEM_ALLOC_FAIL;
689 }
690 (void)memcpy_s(rand, len, val, len);
691 ctx->addrand = rand;
692 ctx->addrandLen = len;
693 return CRYPT_SUCCESS;
694 default:
695 BSL_ERR_PUSH_ERROR(CRYPT_NOT_SUPPORT);
696 return CRYPT_NOT_SUPPORT;
697 }
698 }
699
PubKeyParamCheck(const CryptSlhDsaCtx * ctx,BSL_Param * para,SlhDsaPubKeyParam * pub)700 static int32_t PubKeyParamCheck(const CryptSlhDsaCtx *ctx, BSL_Param *para, SlhDsaPubKeyParam *pub)
701 {
702 if (ctx == NULL || para == NULL) {
703 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
704 return CRYPT_NULL_INPUT;
705 }
706 pub->pubSeed = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PUB_SEED);
707 pub->pubRoot = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PUB_ROOT);
708 if (pub->pubSeed == NULL || pub->pubSeed->value == NULL || pub->pubRoot == NULL || pub->pubRoot->value == NULL) {
709 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
710 return CRYPT_NULL_INPUT;
711 }
712 if (pub->pubSeed->valueLen != ctx->para.n || pub->pubRoot->valueLen != ctx->para.n) {
713 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_KEYLEN);
714 return CRYPT_SLHDSA_ERR_INVALID_KEYLEN;
715 }
716 return CRYPT_SUCCESS;
717 }
718
PrvKeyParamCheck(const CryptSlhDsaCtx * ctx,BSL_Param * para,SlhDsaPrvKeyParam * prv)719 static int32_t PrvKeyParamCheck(const CryptSlhDsaCtx *ctx, BSL_Param *para, SlhDsaPrvKeyParam *prv)
720 {
721 if (ctx == NULL || para == NULL) {
722 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
723 return CRYPT_NULL_INPUT;
724 }
725 prv->prvSeed = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PRV_SEED);
726 prv->prvPrf = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PRV_PRF);
727 prv->pubSeed = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PUB_SEED);
728 prv->pubRoot = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PUB_ROOT);
729 if (prv->prvSeed == NULL || prv->prvSeed->value == NULL || prv->prvPrf == NULL || prv->prvPrf->value == NULL ||
730 prv->pubSeed == NULL || prv->pubSeed->value == NULL || prv->pubRoot == NULL || prv->pubRoot->value == NULL) {
731 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
732 return CRYPT_NULL_INPUT;
733 }
734 if (prv->prvSeed->valueLen != ctx->para.n || prv->prvPrf->valueLen != ctx->para.n ||
735 prv->pubSeed->valueLen != ctx->para.n || prv->pubRoot->valueLen != ctx->para.n) {
736 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_KEYLEN);
737 return CRYPT_SLHDSA_ERR_INVALID_KEYLEN;
738 }
739 return CRYPT_SUCCESS;
740 }
741
CRYPT_SLH_DSA_GetPubKey(const CryptSlhDsaCtx * ctx,BSL_Param * para)742 int32_t CRYPT_SLH_DSA_GetPubKey(const CryptSlhDsaCtx *ctx, BSL_Param *para)
743 {
744 SlhDsaPubKeyParam pub;
745 int32_t ret = PubKeyParamCheck(ctx, para, &pub);
746 if (ret != CRYPT_SUCCESS) {
747 BSL_ERR_PUSH_ERROR(ret);
748 return ret;
749 }
750 pub.pubSeed->useLen = pub.pubRoot->useLen = ctx->para.n;
751 (void)memcpy_s(pub.pubSeed->value, pub.pubSeed->valueLen, ctx->prvKey.pub.seed, ctx->para.n);
752 (void)memcpy_s(pub.pubRoot->value, pub.pubRoot->valueLen, ctx->prvKey.pub.root, ctx->para.n);
753
754 return CRYPT_SUCCESS;
755 }
756
CRYPT_SLH_DSA_GetPrvKey(const CryptSlhDsaCtx * ctx,BSL_Param * para)757 int32_t CRYPT_SLH_DSA_GetPrvKey(const CryptSlhDsaCtx *ctx, BSL_Param *para)
758 {
759 SlhDsaPrvKeyParam prv;
760 int32_t ret = PrvKeyParamCheck(ctx, para, &prv);
761 if (ret != CRYPT_SUCCESS) {
762 BSL_ERR_PUSH_ERROR(ret);
763 return ret;
764 }
765
766 prv.prvSeed->useLen = ctx->para.n;
767 prv.prvPrf->useLen = ctx->para.n;
768 prv.pubSeed->useLen = ctx->para.n;
769 prv.pubRoot->useLen = ctx->para.n;
770 (void)memcpy_s(prv.prvSeed->value, prv.prvSeed->valueLen, ctx->prvKey.seed, ctx->para.n);
771 (void)memcpy_s(prv.prvPrf->value, prv.prvPrf->valueLen, ctx->prvKey.prf, ctx->para.n);
772 (void)memcpy_s(prv.pubSeed->value, prv.pubSeed->valueLen, ctx->prvKey.pub.seed, ctx->para.n);
773 (void)memcpy_s(prv.pubRoot->value, prv.pubRoot->valueLen, ctx->prvKey.pub.root, ctx->para.n);
774
775 return CRYPT_SUCCESS;
776 }
777
CRYPT_SLH_DSA_SetPubKey(CryptSlhDsaCtx * ctx,const BSL_Param * para)778 int32_t CRYPT_SLH_DSA_SetPubKey(CryptSlhDsaCtx *ctx, const BSL_Param *para)
779 {
780 SlhDsaPubKeyParam pub;
781 int32_t ret = PubKeyParamCheck(ctx, (BSL_Param *)(uintptr_t)para, &pub);
782 if (ret != CRYPT_SUCCESS) {
783 BSL_ERR_PUSH_ERROR(ret);
784 return ret;
785 }
786 (void)memcpy_s(ctx->prvKey.pub.seed, ctx->para.n, pub.pubSeed->value, ctx->para.n);
787 (void)memcpy_s(ctx->prvKey.pub.root, ctx->para.n, pub.pubRoot->value, ctx->para.n);
788
789 return CRYPT_SUCCESS;
790 }
791
CRYPT_SLH_DSA_SetPrvKey(CryptSlhDsaCtx * ctx,const BSL_Param * para)792 int32_t CRYPT_SLH_DSA_SetPrvKey(CryptSlhDsaCtx *ctx, const BSL_Param *para)
793 {
794 SlhDsaPrvKeyParam prv;
795 int32_t ret = PrvKeyParamCheck(ctx, (BSL_Param *)(uintptr_t)para, &prv);
796 if (ret != CRYPT_SUCCESS) {
797 BSL_ERR_PUSH_ERROR(ret);
798 return ret;
799 }
800
801 (void)memcpy_s(ctx->prvKey.seed, sizeof(ctx->prvKey.seed), prv.prvSeed->value, ctx->para.n);
802 (void)memcpy_s(ctx->prvKey.prf, sizeof(ctx->prvKey.prf), prv.prvPrf->value, ctx->para.n);
803 (void)memcpy_s(ctx->prvKey.pub.seed, sizeof(ctx->prvKey.pub.seed), prv.pubSeed->value, ctx->para.n);
804 (void)memcpy_s(ctx->prvKey.pub.root, sizeof(ctx->prvKey.pub.root), prv.pubRoot->value, ctx->para.n);
805
806 return CRYPT_SUCCESS;
807 }
808
809 #endif // HITLS_CRYPTO_SLH_DSA