• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * This file is part of the openHiTLS project.
3  *
4  * openHiTLS is licensed under the Mulan PSL v2.
5  * You can use this software according to the terms and conditions of the Mulan PSL v2.
6  * You may obtain a copy of Mulan PSL v2 at:
7  *
8  *     http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  * See the Mulan PSL v2 for more details.
14  */
15 #include "hitls_build.h"
16 #ifdef HITLS_CRYPTO_SLH_DSA
17 
18 #include <stddef.h>
19 #include "securec.h"
20 #include "bsl_err_internal.h"
21 #include "bsl_sal.h"
22 #include "bsl_obj_internal.h"
23 #include "bsl_asn1.h"
24 #include "crypt_errno.h"
25 #include "crypt_util_rand.h"
26 #include "eal_md_local.h"
27 #include "crypt_slh_dsa.h"
28 #include "slh_dsa_local.h"
29 #include "slh_dsa_hash.h"
30 #include "slh_dsa_fors.h"
31 #include "slh_dsa_xmss.h"
32 #include "slh_dsa_hypertree.h"
33 
34 #define MAX_DIGEST_SIZE 64
35 #define BYTE_BITS          8
36 #define SLH_DSA_PREFIX_LEN 2
37 #define ASN1_HEADER_LEN    2
38 #define SPLIT_CEIL(a, b)   (((a) + (b) - 1) / (b))
39 #define SPLIT_BYTES(a)     SPLIT_CEIL(a, BYTE_BITS)
40 
41 typedef struct {
42     BSL_Param *pubSeed;
43     BSL_Param *pubRoot;
44 } SlhDsaPubKeyParam;
45 
46 typedef struct {
47     BSL_Param *prvSeed;
48     BSL_Param *prvPrf;
49     BSL_Param *pubSeed;
50     BSL_Param *pubRoot;
51 } SlhDsaPrvKeyParam;
52 
53 // reference to FIPS-205, table 2
54 static uint32_t g_slhDsaN[CRYPT_SLH_DSA_ALG_ID_MAX] = {16, 16, 16, 16, 24, 24, 24, 24, 32, 32, 32, 32};
55 static uint32_t g_slhDsaH[CRYPT_SLH_DSA_ALG_ID_MAX] = {63, 63, 66, 66, 63, 63, 66, 66, 64, 64, 68, 68};
56 static uint32_t g_slhDsaD[CRYPT_SLH_DSA_ALG_ID_MAX] = {7, 7, 22, 22, 7, 7, 22, 22, 8, 8, 17, 17};
57 static uint32_t g_slhDsaHp[CRYPT_SLH_DSA_ALG_ID_MAX] = {9, 9, 3, 3, 9, 9, 3, 3, 8, 8, 4, 4}; // xmss height
58 static uint32_t g_slhDsaA[CRYPT_SLH_DSA_ALG_ID_MAX] = {12, 12, 6, 6, 14, 14, 8, 8, 14, 14, 9, 9};
59 static uint32_t g_slhDsaK[CRYPT_SLH_DSA_ALG_ID_MAX] = {14, 14, 33, 33, 17, 17, 33, 33, 22, 22, 35, 35};
60 static uint32_t g_slhDsaM[CRYPT_SLH_DSA_ALG_ID_MAX] = {30, 30, 34, 34, 39, 39, 42, 42, 47, 47, 49, 49};
61 static uint32_t g_slhDsaPkBytes[CRYPT_SLH_DSA_ALG_ID_MAX] = {32, 32, 32, 32, 48, 48, 48, 48, 64, 64, 64, 64};
62 static uint32_t g_slhDsaSigBytes[CRYPT_SLH_DSA_ALG_ID_MAX] = {7856,  7856,  17088, 17088, 16224, 16224,
63                                                               35664, 35664, 29792, 29792, 49856, 49856};
64 static uint8_t g_secCategory[] = {1, 1, 1, 1, 3, 3, 3, 3, 5, 5, 5, 5};
65 
66 // "UC" means uncompressed
UCAdrsSetLayerAddr(SlhDsaAdrs * adrs,uint32_t layer)67 static void UCAdrsSetLayerAddr(SlhDsaAdrs *adrs, uint32_t layer)
68 {
69     PUT_UINT32_BE(layer, adrs->uc.layerAddr, 0);
70 }
71 
UCAdrsSetTreeAddr(SlhDsaAdrs * adrs,uint64_t tree)72 static void UCAdrsSetTreeAddr(SlhDsaAdrs *adrs, uint64_t tree)
73 {
74     // Write 8-byte tree address starting from offset 4 in 12-byte treeAddr field
75     PUT_UINT64_BE(tree, adrs->uc.treeAddr, 4);
76 }
77 
UCAdrsSetType(SlhDsaAdrs * adrs,AdrsType type)78 static void UCAdrsSetType(SlhDsaAdrs *adrs, AdrsType type)
79 {
80     PUT_UINT32_BE(type, adrs->uc.type, 0);
81     (void)memset_s(adrs->uc.padding, sizeof(adrs->uc.padding), 0, sizeof(adrs->uc.padding));
82 }
83 
UCAdrsSetKeyPairAddr(SlhDsaAdrs * adrs,uint32_t keyPair)84 static void UCAdrsSetKeyPairAddr(SlhDsaAdrs *adrs, uint32_t keyPair)
85 {
86     PUT_UINT32_BE(keyPair, adrs->uc.padding, 0);
87 }
88 
UCAdrsSetChainAddr(SlhDsaAdrs * adrs,uint32_t chain)89 static void UCAdrsSetChainAddr(SlhDsaAdrs *adrs, uint32_t chain)
90 {
91     PUT_UINT32_BE(chain, adrs->uc.padding, 4); // chain address is 4 bytes, start from 4-th byte
92 }
93 
UCAdrsSetTreeHeight(SlhDsaAdrs * adrs,uint32_t height)94 static void UCAdrsSetTreeHeight(SlhDsaAdrs *adrs, uint32_t height)
95 {
96     PUT_UINT32_BE(height, adrs->uc.padding, 4); // tree height is 4 bytes, start from 4-th byte
97 }
98 
UCAdrsSetHashAddr(SlhDsaAdrs * adrs,uint32_t hash)99 static void UCAdrsSetHashAddr(SlhDsaAdrs *adrs, uint32_t hash)
100 {
101     PUT_UINT32_BE(hash, adrs->uc.padding, 8); // hash address is 4 bytes, start from 8-th byte
102 }
103 
UCAdrsSetTreeIndex(SlhDsaAdrs * adrs,uint32_t index)104 static void UCAdrsSetTreeIndex(SlhDsaAdrs *adrs, uint32_t index)
105 {
106     PUT_UINT32_BE(index, adrs->uc.padding, 8); // tree index is 4 bytes, start from 8-th byte
107 }
108 
UCAdrsGetTreeHeight(const SlhDsaAdrs * adrs)109 static uint32_t UCAdrsGetTreeHeight(const SlhDsaAdrs *adrs)
110 {
111     return GET_UINT32_BE(adrs->uc.padding, 0);
112 }
113 
UCAdrsGetTreeIndex(const SlhDsaAdrs * adrs)114 static uint32_t UCAdrsGetTreeIndex(const SlhDsaAdrs *adrs)
115 {
116     return GET_UINT32_BE(adrs->uc.padding, 8); // tree index is 4 bytes, start from 8-th byte
117 }
118 
UCAdrsCopyKeyPairAddr(SlhDsaAdrs * adrs,const SlhDsaAdrs * adrs2)119 static void UCAdrsCopyKeyPairAddr(SlhDsaAdrs *adrs, const SlhDsaAdrs *adrs2)
120 {
121     (void)memcpy_s(adrs->uc.padding, sizeof(adrs->uc.padding), adrs2->uc.padding,
122                    4); // key pair address is 4 bytes, start from 4-th byte
123 }
124 
UCAdrsGetAdrsLen(void)125 static uint32_t UCAdrsGetAdrsLen(void)
126 {
127     return SLH_DSA_ADRS_LEN;
128 }
129 
130 // "C" means compressed
CAdrsSetLayerAddr(SlhDsaAdrs * adrs,uint32_t layer)131 static void CAdrsSetLayerAddr(SlhDsaAdrs *adrs, uint32_t layer)
132 {
133     adrs->c.layerAddr = (uint8_t)layer;
134 }
135 
CAdrsSetTreeAddr(SlhDsaAdrs * adrs,uint64_t tree)136 static void CAdrsSetTreeAddr(SlhDsaAdrs *adrs, uint64_t tree)
137 {
138     // Write 8-byte tree address starting from offset 0 in 8-byte treeAddr field
139     PUT_UINT64_BE(tree, adrs->c.treeAddr, 0);
140 }
141 
CAdrsSetType(SlhDsaAdrs * adrs,AdrsType type)142 static void CAdrsSetType(SlhDsaAdrs *adrs, AdrsType type)
143 {
144     adrs->c.type = type;
145     (void)memset_s(adrs->c.padding, sizeof(adrs->c.padding), 0, sizeof(adrs->c.padding));
146 }
147 
CAdrsSetKeyPairAddr(SlhDsaAdrs * adrs,uint32_t keyPair)148 static void CAdrsSetKeyPairAddr(SlhDsaAdrs *adrs, uint32_t keyPair)
149 {
150     PUT_UINT32_BE(keyPair, adrs->c.padding, 0);
151 }
152 
CAdrsSetChainAddr(SlhDsaAdrs * adrs,uint32_t chain)153 static void CAdrsSetChainAddr(SlhDsaAdrs *adrs, uint32_t chain)
154 {
155     PUT_UINT32_BE(chain, adrs->c.padding, 4); // chain address is 4 bytes, start from 4-th byte
156 }
157 
CAdrsSetTreeHeight(SlhDsaAdrs * adrs,uint32_t height)158 static void CAdrsSetTreeHeight(SlhDsaAdrs *adrs, uint32_t height)
159 {
160     PUT_UINT32_BE(height, adrs->c.padding, 4); // tree height is 4 bytes, start from 4-th byte
161 }
162 
CAdrsSetHashAddr(SlhDsaAdrs * adrs,uint32_t hash)163 static void CAdrsSetHashAddr(SlhDsaAdrs *adrs, uint32_t hash)
164 {
165     PUT_UINT32_BE(hash, adrs->c.padding, 8); // hash address is 4 bytes, start from 8-th byte
166 }
167 
CAdrsSetTreeIndex(SlhDsaAdrs * adrs,uint32_t index)168 static void CAdrsSetTreeIndex(SlhDsaAdrs *adrs, uint32_t index)
169 {
170     PUT_UINT32_BE(index, adrs->c.padding, 8); // tree index is 4 bytes, start from 8-th byte
171 }
172 
CAdrsGetTreeHeight(const SlhDsaAdrs * adrs)173 static uint32_t CAdrsGetTreeHeight(const SlhDsaAdrs *adrs)
174 {
175     return GET_UINT32_BE(adrs->c.padding, 0); // tree height is 4 bytes, start from 0-th byte
176 }
177 
CAdrsGetTreeIndex(const SlhDsaAdrs * adrs)178 static uint32_t CAdrsGetTreeIndex(const SlhDsaAdrs *adrs)
179 {
180     return GET_UINT32_BE(adrs->c.padding, 8); // tree index is 4 bytes, start from 8-th byte
181 }
182 
CAdrsCopyKeyPairAddr(SlhDsaAdrs * adrs,const SlhDsaAdrs * adrs2)183 static void CAdrsCopyKeyPairAddr(SlhDsaAdrs *adrs, const SlhDsaAdrs *adrs2)
184 {
185     (void)memcpy_s(adrs->c.padding, sizeof(adrs->c.padding), adrs2->c.padding,
186                    4); // key pair address is 4 bytes, start from 4-th byte
187 }
188 
CAdrsGetAdrsLen(void)189 static uint32_t CAdrsGetAdrsLen(void)
190 {
191     return SLH_DSA_ADRS_COMPRESSED_LEN;
192 }
193 
194 static AdrsOps g_adrsOps[2] = {{
195     .setLayerAddr = UCAdrsSetLayerAddr,
196     .setTreeAddr = UCAdrsSetTreeAddr,
197     .setType = UCAdrsSetType,
198     .setKeyPairAddr = UCAdrsSetKeyPairAddr,
199     .setChainAddr = UCAdrsSetChainAddr,
200     .setTreeHeight = UCAdrsSetTreeHeight,
201     .setHashAddr = UCAdrsSetHashAddr,
202     .setTreeIndex = UCAdrsSetTreeIndex,
203     .getTreeHeight = UCAdrsGetTreeHeight,
204     .getTreeIndex = UCAdrsGetTreeIndex,
205     .copyKeyPairAddr = UCAdrsCopyKeyPairAddr,
206     .getAdrsLen = UCAdrsGetAdrsLen,
207 },
208 {
209     .setLayerAddr = CAdrsSetLayerAddr,
210     .setTreeAddr = CAdrsSetTreeAddr,
211     .setType = CAdrsSetType,
212     .setKeyPairAddr = CAdrsSetKeyPairAddr,
213     .setChainAddr = CAdrsSetChainAddr,
214     .setTreeHeight = CAdrsSetTreeHeight,
215     .setHashAddr = CAdrsSetHashAddr,
216     .setTreeIndex = CAdrsSetTreeIndex,
217     .getTreeHeight = CAdrsGetTreeHeight,
218     .getTreeIndex = CAdrsGetTreeIndex,
219     .copyKeyPairAddr = CAdrsCopyKeyPairAddr,
220     .getAdrsLen = CAdrsGetAdrsLen,
221 }};
222 
BaseB(const uint8_t * x,uint32_t xLen,uint32_t b,uint32_t * out,uint32_t outLen)223 void BaseB(const uint8_t *x, uint32_t xLen, uint32_t b, uint32_t *out, uint32_t outLen)
224 {
225     uint32_t bit = 0;
226     uint32_t o = 0;
227     uint32_t xi = 0;
228     for (uint32_t i = 0; i < outLen; i++) {
229         while (bit < b && xi < xLen) {
230             o = (o << BYTE_BITS) + x[xi];
231             bit += 8;
232             xi++;
233         }
234         bit -= b;
235         out[i] = o >> bit;
236         // keep the remaining bits
237         o &= (1 << bit) - 1;
238     }
239 }
240 
241 // ToInt(b[0:l]) mod 2^m
ToIntMod(const uint8_t * b,uint32_t l,uint32_t m)242 static uint64_t ToIntMod(const uint8_t *b, uint32_t l, uint32_t m)
243 {
244     uint64_t ret = 0;
245     for (uint32_t i = 0; i < l; i++) {
246         ret = (ret << BYTE_BITS) + b[i];
247     }
248 
249     return ret & (~(uint64_t)0 >> (64 - m)); // mod 2^m is same to ~(uint64_t)0 >> (64 - m)
250 }
251 
CRYPT_SLH_DSA_NewCtx(void)252 CryptSlhDsaCtx *CRYPT_SLH_DSA_NewCtx(void)
253 {
254     CryptSlhDsaCtx *ctx = (CryptSlhDsaCtx *)BSL_SAL_Calloc(sizeof(CryptSlhDsaCtx), 1);
255     if (ctx == NULL) {
256         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
257         return NULL;
258     }
259     ctx->para.algId = CRYPT_SLH_DSA_ALG_ID_MAX;
260     ctx->isPrehash = false;
261     ctx->isDeterministic = false;
262     return ctx;
263 }
264 
CRYPT_SLH_DSA_NewCtxEx(void * libCtx)265 CryptSlhDsaCtx *CRYPT_SLH_DSA_NewCtxEx(void *libCtx)
266 {
267     CryptSlhDsaCtx *ctx = CRYPT_SLH_DSA_NewCtx();
268     if (ctx == NULL) {
269         return NULL;
270     }
271     ctx->libCtx = libCtx;
272     return ctx;
273 }
274 
CRYPT_SLH_DSA_FreeCtx(CryptSlhDsaCtx * ctx)275 void CRYPT_SLH_DSA_FreeCtx(CryptSlhDsaCtx *ctx)
276 {
277     if (ctx == NULL) {
278         return;
279     }
280     BSL_SAL_Free(ctx->context);
281     BSL_SAL_ClearFree(ctx->addrand, ctx->addrandLen);
282     BSL_SAL_CleanseData(ctx->prvKey.seed, sizeof(ctx->prvKey.seed));
283     BSL_SAL_CleanseData(ctx->prvKey.prf, sizeof(ctx->prvKey.prf));
284     BSL_SAL_Free(ctx);
285 }
286 
CRYPT_SLH_DSA_Gen(CryptSlhDsaCtx * ctx)287 int32_t CRYPT_SLH_DSA_Gen(CryptSlhDsaCtx *ctx)
288 {
289     int32_t ret;
290     if (ctx == NULL) {
291         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
292         return CRYPT_NULL_INPUT;
293     }
294     if (ctx->para.algId >= CRYPT_SLH_DSA_ALG_ID_MAX) {
295         BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_ALGID);
296         return CRYPT_SLHDSA_ERR_INVALID_ALGID;
297     }
298     uint32_t n = ctx->para.n;
299     uint32_t d = ctx->para.d;
300     uint32_t hp = ctx->para.hp;
301     ret = CRYPT_RandEx(ctx->libCtx, ctx->prvKey.seed, n);
302     if (ret != CRYPT_SUCCESS) {
303         BSL_ERR_PUSH_ERROR(ret);
304         return ret;
305     }
306 
307     ret = CRYPT_RandEx(ctx->libCtx, ctx->prvKey.prf, n);
308     if (ret != CRYPT_SUCCESS) {
309         BSL_ERR_PUSH_ERROR(ret);
310         return ret;
311     }
312 
313     ret = CRYPT_RandEx(ctx->libCtx, ctx->prvKey.pub.seed, n);
314     if (ret != CRYPT_SUCCESS) {
315         BSL_ERR_PUSH_ERROR(ret);
316         return ret;
317     }
318     SlhDsaAdrs adrs = {0};
319     ctx->adrsOps.setLayerAddr(&adrs, d - 1);
320     uint8_t node[SLH_DSA_MAX_N] = {0};
321     ret = XmssNode(node, 0, hp, &adrs, ctx);
322     if (ret != CRYPT_SUCCESS) {
323         BSL_ERR_PUSH_ERROR(ret);
324         return ret;
325     }
326     (void)memcpy_s(ctx->prvKey.pub.root, n, node, n);
327     return CRYPT_SUCCESS;
328 }
329 
GetAddRand(CryptSlhDsaCtx * ctx)330 static int32_t GetAddRand(CryptSlhDsaCtx *ctx)
331 {
332     if (ctx->addrand != NULL) {
333         // the additional rand is set.
334         return CRYPT_SUCCESS;
335     }
336     if (!ctx->isDeterministic) {
337         ctx->addrand = (uint8_t *)BSL_SAL_Malloc(ctx->para.n);
338         if (ctx->addrand == NULL) {
339             BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
340             return CRYPT_MEM_ALLOC_FAIL;
341         }
342         int32_t ret = CRYPT_RandEx(ctx->libCtx, ctx->addrand, ctx->para.n);
343         if (ret != CRYPT_SUCCESS) {
344             return ret;
345         }
346     } else {
347         // FIPS-204, Algorithm 19, line 2.
348         // if is deterministic, use the public key seed as the random number.
349         uint8_t *rand = (uint8_t *)BSL_SAL_Malloc(ctx->para.n);
350         if (rand == NULL) {
351             BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
352             return CRYPT_MEM_ALLOC_FAIL;
353         }
354         (void)memcpy_s(rand, ctx->para.n, ctx->prvKey.pub.seed, ctx->para.n);
355         ctx->addrand = rand;
356     }
357     ctx->addrandLen = ctx->para.n;
358     return CRYPT_SUCCESS;
359 }
360 
GetTreeAndLeafIdx(const uint8_t * digest,const CryptSlhDsaCtx * ctx,uint64_t * treeIdx,uint32_t * leafIdx)361 static void GetTreeAndLeafIdx(const uint8_t *digest, const CryptSlhDsaCtx *ctx, uint64_t *treeIdx, uint32_t *leafIdx)
362 {
363     uint32_t a = ctx->para.a;
364     uint32_t k = ctx->para.k;
365     uint32_t h = ctx->para.h;
366     uint32_t d = ctx->para.d;
367 
368     uint32_t mdIdx = SPLIT_BYTES(k * a);
369     uint32_t treeIdxLen = SPLIT_BYTES(h - h / d);
370     uint32_t leafIdxLen = SPLIT_BYTES(h / d);
371     *treeIdx = ToIntMod(digest + mdIdx, treeIdxLen, h - h / d);
372     *leafIdx = (uint32_t)ToIntMod(digest + mdIdx + treeIdxLen, leafIdxLen, h / d);
373 }
374 
CRYPT_SLH_DSA_SignInternal(CryptSlhDsaCtx * ctx,const uint8_t * msg,uint32_t msgLen,uint8_t * sig,uint32_t * sigLen)375 static int32_t CRYPT_SLH_DSA_SignInternal(CryptSlhDsaCtx *ctx, const uint8_t *msg, uint32_t msgLen, uint8_t *sig,
376                                           uint32_t *sigLen)
377 {
378     int32_t ret;
379     uint32_t n = ctx->para.n;
380     uint32_t a = ctx->para.a;
381     uint32_t k = ctx->para.k;
382     uint32_t sigBytes = ctx->para.sigBytes;
383     uint32_t mdIdx = SPLIT_BYTES(k * a);
384     uint64_t treeIdx;
385     uint32_t leafIdx;
386 
387     if (*sigLen < sigBytes) {
388         BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_SIG_LEN);
389         return CRYPT_SLHDSA_ERR_INVALID_SIG_LEN;
390     }
391     SlhDsaAdrs adrs = {0};
392     uint32_t offset = 0;
393     uint32_t left = *sigLen;
394 
395     ret = GetAddRand(ctx);
396     if (ret != CRYPT_SUCCESS) {
397         return ret;
398     }
399 
400     ret = ctx->hashFuncs.prfmsg(ctx, ctx->addrand, msg, msgLen, sig);
401     if (ret != CRYPT_SUCCESS) {
402         BSL_ERR_PUSH_ERROR(ret);
403         return ret;
404     }
405     offset += n;
406     uint8_t digest[SLH_DSA_MAX_M] = {0};
407     ret = ctx->hashFuncs.hmsg(ctx, sig, msg, msgLen, digest);
408     if (ret != CRYPT_SUCCESS) {
409         BSL_ERR_PUSH_ERROR(ret);
410         return ret;
411     }
412 
413     GetTreeAndLeafIdx(digest, ctx, &treeIdx, &leafIdx);
414     ctx->adrsOps.setTreeAddr(&adrs, treeIdx);
415     ctx->adrsOps.setType(&adrs, FORS_TREE);
416     ctx->adrsOps.setKeyPairAddr(&adrs, leafIdx);
417     ret = ForsSign(digest, mdIdx, &adrs, ctx, sig + offset, &left);
418     if (ret != CRYPT_SUCCESS) {
419         BSL_ERR_PUSH_ERROR(ret);
420         return ret;
421     }
422     uint8_t pk[SLH_DSA_MAX_N] = {0};
423     ret = ForsPkFromSig(sig + n, left, digest, mdIdx, &adrs, ctx, pk);
424     if (ret != CRYPT_SUCCESS) {
425         BSL_ERR_PUSH_ERROR(ret);
426         return ret;
427     }
428     offset += left;
429     left = *sigLen - offset;
430     ret = HypertreeSign(pk, n, treeIdx, leafIdx, ctx, sig + offset, &left);
431     if (ret != CRYPT_SUCCESS) {
432         BSL_ERR_PUSH_ERROR(ret);
433         return ret;
434     }
435     *sigLen = offset + left;
436     return CRYPT_SUCCESS;
437 }
438 
CRYPT_SLH_DSA_VerifyInternal(const CryptSlhDsaCtx * ctx,const uint8_t * msg,uint32_t msgLen,const uint8_t * sig,uint32_t sigLen)439 static int32_t CRYPT_SLH_DSA_VerifyInternal(const CryptSlhDsaCtx *ctx, const uint8_t *msg, uint32_t msgLen,
440                                             const uint8_t *sig, uint32_t sigLen)
441 {
442     int32_t ret;
443     uint32_t n = ctx->para.n;
444     uint32_t a = ctx->para.a;
445     uint32_t k = ctx->para.k;
446     uint32_t sigBytes = ctx->para.sigBytes;
447     uint32_t mdIdx = SPLIT_BYTES(k * a);
448     uint64_t treeIdx;
449     uint32_t leafIdx;
450 
451     if (sigLen != sigBytes) {
452         BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_SIG_LEN);
453         return CRYPT_SLHDSA_ERR_INVALID_SIG_LEN;
454     }
455 
456     SlhDsaAdrs adrs = {0};
457     uint32_t offset = 0;
458 
459     uint8_t digest[SLH_DSA_MAX_M] = {0};
460     ret = ctx->hashFuncs.hmsg(ctx, sig, msg, msgLen, digest);
461     if (ret != CRYPT_SUCCESS) {
462         BSL_ERR_PUSH_ERROR(ret);
463         return ret;
464     }
465     offset += n;
466 
467     GetTreeAndLeafIdx(digest, ctx, &treeIdx, &leafIdx);
468     ctx->adrsOps.setTreeAddr(&adrs, treeIdx);
469     ctx->adrsOps.setType(&adrs, FORS_TREE);
470     ctx->adrsOps.setKeyPairAddr(&adrs, leafIdx);
471     uint8_t pk[SLH_DSA_MAX_N] = {0};
472     ret = ForsPkFromSig(sig + offset, (1 + a) * k * n, digest, mdIdx, &adrs, ctx, pk);
473     if (ret != CRYPT_SUCCESS) {
474         BSL_ERR_PUSH_ERROR(ret);
475         return ret;
476     }
477     offset += (1 + a) * k * n;
478     ret = HypertreeVerify(pk, n, sig + offset, sigLen - offset, treeIdx, leafIdx, ctx);
479     if (ret != CRYPT_SUCCESS) {
480         BSL_ERR_PUSH_ERROR(ret);
481         return ret;
482     }
483     return CRYPT_SUCCESS;
484 }
485 
GetMdSize(const EAL_MdMethod * hashMethod,int32_t hashId)486 static uint32_t GetMdSize(const EAL_MdMethod *hashMethod, int32_t hashId)
487 {
488     if (hashId == CRYPT_MD_SHAKE128) {
489         return 32;  // To use SHAKE128, generate a 32-byte digest.
490     } else if (hashId == CRYPT_MD_SHAKE256) {
491         return 64;  // To use SHAKE256, generate a 64-byte digest.
492     }
493     return hashMethod->mdSize;
494 }
495 
MsgEncode(const CryptSlhDsaCtx * ctx,int32_t algId,const uint8_t * data,uint32_t dataLen,uint8_t ** mpOut,uint32_t * mpLenOut)496 static int32_t MsgEncode(const CryptSlhDsaCtx *ctx, int32_t algId, const uint8_t *data, uint32_t dataLen,
497     uint8_t **mpOut, uint32_t *mpLenOut)
498 {
499     int32_t ret;
500     BslOidString *oid = NULL;
501     uint32_t offset = 0;
502     uint8_t prehash[MAX_DIGEST_SIZE] = {0};
503     uint32_t prehashLen = sizeof(prehash);
504 
505     uint32_t mpLen = SLH_DSA_PREFIX_LEN + ctx->contextLen;
506     if (ctx->isPrehash) {
507         oid = BSL_OBJ_GetOID((BslCid)algId);
508         if (oid == NULL) {
509             BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_PREHASH_ID_NOT_SUPPORTED);
510             return CRYPT_SLHDSA_ERR_PREHASH_ID_NOT_SUPPORTED;
511         }
512         mpLen += 2 + oid->octetLen; // asn1 header length is 2
513         prehashLen = GetMdSize(EAL_MdFindMethod(algId), algId);
514         const CRYPT_ConstData constData = {data, dataLen};
515         ret = CRYPT_CalcHash(EAL_MdFindMethod(algId), &constData, 1, prehash, &prehashLen);
516         if (ret != CRYPT_SUCCESS) {
517             BSL_ERR_PUSH_ERROR(ret);
518             return ret;
519         }
520         mpLen += prehashLen;
521     } else {
522         mpLen += dataLen;
523     }
524 
525     uint8_t *mp = (uint8_t *)BSL_SAL_Malloc(mpLen);
526     if (mp == NULL) {
527         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
528         return CRYPT_MEM_ALLOC_FAIL;
529     }
530     mp[0] = ctx->isPrehash ? 1 : 0;
531     mp[1] = (uint8_t)ctx->contextLen;
532     (void)memcpy_s(mp + SLH_DSA_PREFIX_LEN, mpLen - SLH_DSA_PREFIX_LEN, ctx->context, ctx->contextLen);
533     offset += SLH_DSA_PREFIX_LEN + ctx->contextLen;
534 
535     if (ctx->isPrehash) {
536         // asn1 encoding of hash oid
537         (mp + offset)[0] = BSL_ASN1_TAG_OBJECT_ID;
538         (mp + offset)[1] = (uint8_t)oid->octetLen;
539         offset += 2; // asn1 header length is 2
540         (void)memcpy_s(mp + offset, mpLen - offset, oid->octs, oid->octetLen);
541         offset += oid->octetLen;
542         (void)memcpy_s(mp + offset, mpLen - offset, prehash, prehashLen);
543     } else {
544         (void)memcpy_s(mp + offset, mpLen - offset, data, dataLen);
545     }
546     *mpOut = mp;
547     *mpLenOut = mpLen;
548     return CRYPT_SUCCESS;
549 }
550 
CRYPT_SLH_DSA_Sign(CryptSlhDsaCtx * ctx,int32_t algId,const uint8_t * data,uint32_t dataLen,uint8_t * sign,uint32_t * signLen)551 int32_t CRYPT_SLH_DSA_Sign(CryptSlhDsaCtx *ctx, int32_t algId, const uint8_t *data, uint32_t dataLen, uint8_t *sign,
552                            uint32_t *signLen)
553 {
554     int32_t ret;
555     uint8_t *mp = NULL;
556     uint32_t mpLen = 0;
557 
558     if (ctx == NULL || data == NULL || dataLen == 0 || sign == NULL || signLen == NULL) {
559         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
560         return CRYPT_NULL_INPUT;
561     }
562     ret = MsgEncode(ctx, algId, data, dataLen, &mp, &mpLen);
563     if (ret != CRYPT_SUCCESS) {
564         return ret;
565     }
566     ret = CRYPT_SLH_DSA_SignInternal(ctx, mp, mpLen, sign, signLen);
567     if (ret != CRYPT_SUCCESS) {
568         BSL_SAL_Free(mp);
569         BSL_ERR_PUSH_ERROR(ret);
570         return ret;
571     }
572     BSL_SAL_Free(mp);
573     return CRYPT_SUCCESS;
574 }
575 
CRYPT_SLH_DSA_Verify(const CryptSlhDsaCtx * ctx,int32_t algId,const uint8_t * data,uint32_t dataLen,const uint8_t * sign,uint32_t signLen)576 int32_t CRYPT_SLH_DSA_Verify(const CryptSlhDsaCtx *ctx, int32_t algId, const uint8_t *data, uint32_t dataLen,
577                              const uint8_t *sign, uint32_t signLen)
578 {
579     (void)algId;
580     int32_t ret;
581     uint8_t *mp = NULL;
582     uint32_t mpLen = 0;
583 
584     if (ctx == NULL || data == NULL || dataLen == 0 || sign == NULL || signLen == 0) {
585         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
586         return CRYPT_NULL_INPUT;
587     }
588 
589     ret = MsgEncode(ctx, algId, data, dataLen, &mp, &mpLen);
590     if (ret != CRYPT_SUCCESS) {
591         return ret;
592     }
593     ret = CRYPT_SLH_DSA_VerifyInternal(ctx, mp, mpLen, sign, signLen);
594     BSL_SAL_Free(mp);
595     return ret;
596 }
597 
SlhDsaSetAlgId(CryptSlhDsaCtx * ctx,CRYPT_SLH_DSA_AlgId algId)598 static void SlhDsaSetAlgId(CryptSlhDsaCtx *ctx, CRYPT_SLH_DSA_AlgId algId)
599 {
600     ctx->para.algId = algId;
601     ctx->para.n = g_slhDsaN[algId];
602     ctx->para.h = g_slhDsaH[algId];
603     ctx->para.d = g_slhDsaD[algId];
604     ctx->para.hp = g_slhDsaHp[algId];
605     ctx->para.a = g_slhDsaA[algId];
606     ctx->para.k = g_slhDsaK[algId];
607     ctx->para.m = g_slhDsaM[algId];
608     ctx->para.pkBytes = g_slhDsaPkBytes[algId];
609     ctx->para.sigBytes = g_slhDsaSigBytes[algId];
610     ctx->para.secCategory = g_secCategory[algId];
611     SlhDsaInitHashFuncs(ctx);
612     if (ctx->para.isCompressed) {
613         ctx->adrsOps = g_adrsOps[1];
614     } else {
615         ctx->adrsOps = g_adrsOps[0];
616     }
617 }
618 
CRYPT_SLH_DSA_Ctrl(CryptSlhDsaCtx * ctx,int32_t opt,void * val,uint32_t len)619 int32_t CRYPT_SLH_DSA_Ctrl(CryptSlhDsaCtx *ctx, int32_t opt, void *val, uint32_t len)
620 {
621     if (ctx == NULL) {
622         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
623         return CRYPT_NULL_INPUT;
624     }
625     switch (opt) {
626         case CRYPT_CTRL_SET_PARA_BY_ID:
627             if (val == NULL || len != sizeof(CRYPT_SLH_DSA_AlgId)) {
628                 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
629                 return CRYPT_INVALID_ARG;
630             }
631             CRYPT_SLH_DSA_AlgId algId = *(CRYPT_SLH_DSA_AlgId *)val;
632             if (algId >= CRYPT_SLH_DSA_ALG_ID_MAX) {
633                 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_ALGID);
634                 return CRYPT_SLHDSA_ERR_INVALID_ALGID;
635             }
636             SlhDsaSetAlgId(ctx, algId);
637             return CRYPT_SUCCESS;
638         case CRYPT_CTRL_SET_PREHASH_FLAG:
639             if (val == NULL || len != sizeof(int32_t)) {
640                 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
641                 return CRYPT_INVALID_ARG;
642             }
643             ctx->isPrehash = (*(int32_t *)val != 0);
644             return CRYPT_SUCCESS;
645         case CRYPT_CTRL_SET_CTX_INFO:
646             if (val == NULL) {
647                 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
648                 return CRYPT_INVALID_ARG;
649             }
650             if (len > 255) {
651                 BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_CONTEXT_LEN_OVERFLOW);
652                 return CRYPT_SLHDSA_ERR_CONTEXT_LEN_OVERFLOW;
653             }
654             ctx->contextLen = len;
655             BSL_SAL_Free(ctx->context);
656             ctx->context = (uint8_t *)BSL_SAL_Malloc(len);
657             if (ctx->context == NULL) {
658                 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
659                 return CRYPT_MEM_ALLOC_FAIL;
660             }
661             (void)memcpy_s(ctx->context, len, val, len);
662             return CRYPT_SUCCESS;
663         case CRYPT_CTRL_GET_SLH_DSA_KEY_LEN:
664             if (val == NULL || len != sizeof(uint32_t)) {
665                 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
666                 return CRYPT_INVALID_ARG;
667             }
668             *(uint32_t *)val = ctx->para.n;
669             return CRYPT_SUCCESS;
670         case CRYPT_CTRL_SET_DETERMINISTIC_FLAG:
671             if (val == NULL || len != sizeof(int32_t)) {
672                 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
673                 return CRYPT_INVALID_ARG;
674             }
675             ctx->isDeterministic = (*(int32_t *)val != 0);
676             return CRYPT_SUCCESS;
677         case CRYPT_CTRL_SET_SLH_DSA_ADDRAND:
678             if (val == NULL || len != ctx->para.n) {
679                 BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
680                 return CRYPT_INVALID_ARG;
681             }
682             if (ctx->addrand != NULL) {
683                 BSL_SAL_Free(ctx->addrand);
684             }
685             uint8_t *rand = (uint8_t *)BSL_SAL_Malloc(len);
686             if (rand == NULL) {
687                 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
688                 return CRYPT_MEM_ALLOC_FAIL;
689             }
690             (void)memcpy_s(rand, len, val, len);
691             ctx->addrand = rand;
692             ctx->addrandLen = len;
693             return CRYPT_SUCCESS;
694         default:
695             BSL_ERR_PUSH_ERROR(CRYPT_NOT_SUPPORT);
696             return CRYPT_NOT_SUPPORT;
697     }
698 }
699 
PubKeyParamCheck(const CryptSlhDsaCtx * ctx,BSL_Param * para,SlhDsaPubKeyParam * pub)700 static int32_t PubKeyParamCheck(const CryptSlhDsaCtx *ctx, BSL_Param *para, SlhDsaPubKeyParam *pub)
701 {
702     if (ctx == NULL || para == NULL) {
703         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
704         return CRYPT_NULL_INPUT;
705     }
706     pub->pubSeed = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PUB_SEED);
707     pub->pubRoot = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PUB_ROOT);
708     if (pub->pubSeed == NULL || pub->pubSeed->value == NULL || pub->pubRoot == NULL || pub->pubRoot->value == NULL) {
709         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
710         return CRYPT_NULL_INPUT;
711     }
712     if (pub->pubSeed->valueLen != ctx->para.n || pub->pubRoot->valueLen != ctx->para.n) {
713         BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_KEYLEN);
714         return CRYPT_SLHDSA_ERR_INVALID_KEYLEN;
715     }
716     return CRYPT_SUCCESS;
717 }
718 
PrvKeyParamCheck(const CryptSlhDsaCtx * ctx,BSL_Param * para,SlhDsaPrvKeyParam * prv)719 static int32_t PrvKeyParamCheck(const CryptSlhDsaCtx *ctx, BSL_Param *para, SlhDsaPrvKeyParam *prv)
720 {
721     if (ctx == NULL || para == NULL) {
722         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
723         return CRYPT_NULL_INPUT;
724     }
725     prv->prvSeed = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PRV_SEED);
726     prv->prvPrf = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PRV_PRF);
727     prv->pubSeed = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PUB_SEED);
728     prv->pubRoot = BSL_PARAM_FindParam(para, CRYPT_PARAM_SLH_DSA_PUB_ROOT);
729     if (prv->prvSeed == NULL || prv->prvSeed->value == NULL || prv->prvPrf == NULL || prv->prvPrf->value == NULL ||
730         prv->pubSeed == NULL || prv->pubSeed->value == NULL || prv->pubRoot == NULL || prv->pubRoot->value == NULL) {
731         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
732         return CRYPT_NULL_INPUT;
733     }
734     if (prv->prvSeed->valueLen != ctx->para.n || prv->prvPrf->valueLen != ctx->para.n ||
735         prv->pubSeed->valueLen != ctx->para.n || prv->pubRoot->valueLen != ctx->para.n) {
736         BSL_ERR_PUSH_ERROR(CRYPT_SLHDSA_ERR_INVALID_KEYLEN);
737         return CRYPT_SLHDSA_ERR_INVALID_KEYLEN;
738     }
739     return CRYPT_SUCCESS;
740 }
741 
CRYPT_SLH_DSA_GetPubKey(const CryptSlhDsaCtx * ctx,BSL_Param * para)742 int32_t CRYPT_SLH_DSA_GetPubKey(const CryptSlhDsaCtx *ctx, BSL_Param *para)
743 {
744     SlhDsaPubKeyParam pub;
745     int32_t ret = PubKeyParamCheck(ctx, para, &pub);
746     if (ret != CRYPT_SUCCESS) {
747         BSL_ERR_PUSH_ERROR(ret);
748         return ret;
749     }
750     pub.pubSeed->useLen = pub.pubRoot->useLen = ctx->para.n;
751     (void)memcpy_s(pub.pubSeed->value, pub.pubSeed->valueLen, ctx->prvKey.pub.seed, ctx->para.n);
752     (void)memcpy_s(pub.pubRoot->value, pub.pubRoot->valueLen, ctx->prvKey.pub.root, ctx->para.n);
753 
754     return CRYPT_SUCCESS;
755 }
756 
CRYPT_SLH_DSA_GetPrvKey(const CryptSlhDsaCtx * ctx,BSL_Param * para)757 int32_t CRYPT_SLH_DSA_GetPrvKey(const CryptSlhDsaCtx *ctx, BSL_Param *para)
758 {
759     SlhDsaPrvKeyParam prv;
760     int32_t ret = PrvKeyParamCheck(ctx, para, &prv);
761     if (ret != CRYPT_SUCCESS) {
762         BSL_ERR_PUSH_ERROR(ret);
763         return ret;
764     }
765 
766     prv.prvSeed->useLen = ctx->para.n;
767     prv.prvPrf->useLen = ctx->para.n;
768     prv.pubSeed->useLen = ctx->para.n;
769     prv.pubRoot->useLen = ctx->para.n;
770     (void)memcpy_s(prv.prvSeed->value, prv.prvSeed->valueLen, ctx->prvKey.seed, ctx->para.n);
771     (void)memcpy_s(prv.prvPrf->value, prv.prvPrf->valueLen, ctx->prvKey.prf, ctx->para.n);
772     (void)memcpy_s(prv.pubSeed->value, prv.pubSeed->valueLen, ctx->prvKey.pub.seed, ctx->para.n);
773     (void)memcpy_s(prv.pubRoot->value, prv.pubRoot->valueLen, ctx->prvKey.pub.root, ctx->para.n);
774 
775     return CRYPT_SUCCESS;
776 }
777 
CRYPT_SLH_DSA_SetPubKey(CryptSlhDsaCtx * ctx,const BSL_Param * para)778 int32_t CRYPT_SLH_DSA_SetPubKey(CryptSlhDsaCtx *ctx, const BSL_Param *para)
779 {
780     SlhDsaPubKeyParam pub;
781     int32_t ret = PubKeyParamCheck(ctx, (BSL_Param *)(uintptr_t)para, &pub);
782     if (ret != CRYPT_SUCCESS) {
783         BSL_ERR_PUSH_ERROR(ret);
784         return ret;
785     }
786     (void)memcpy_s(ctx->prvKey.pub.seed, ctx->para.n, pub.pubSeed->value, ctx->para.n);
787     (void)memcpy_s(ctx->prvKey.pub.root, ctx->para.n, pub.pubRoot->value, ctx->para.n);
788 
789     return CRYPT_SUCCESS;
790 }
791 
CRYPT_SLH_DSA_SetPrvKey(CryptSlhDsaCtx * ctx,const BSL_Param * para)792 int32_t CRYPT_SLH_DSA_SetPrvKey(CryptSlhDsaCtx *ctx, const BSL_Param *para)
793 {
794     SlhDsaPrvKeyParam prv;
795     int32_t ret = PrvKeyParamCheck(ctx, (BSL_Param *)(uintptr_t)para, &prv);
796     if (ret != CRYPT_SUCCESS) {
797         BSL_ERR_PUSH_ERROR(ret);
798         return ret;
799     }
800 
801     (void)memcpy_s(ctx->prvKey.seed, sizeof(ctx->prvKey.seed), prv.prvSeed->value, ctx->para.n);
802     (void)memcpy_s(ctx->prvKey.prf, sizeof(ctx->prvKey.prf), prv.prvPrf->value, ctx->para.n);
803     (void)memcpy_s(ctx->prvKey.pub.seed, sizeof(ctx->prvKey.pub.seed), prv.pubSeed->value, ctx->para.n);
804     (void)memcpy_s(ctx->prvKey.pub.root, sizeof(ctx->prvKey.pub.root), prv.pubRoot->value, ctx->para.n);
805 
806     return CRYPT_SUCCESS;
807 }
808 
809 #endif // HITLS_CRYPTO_SLH_DSA