• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * This file is part of the openHiTLS project.
3  *
4  * openHiTLS is licensed under the Mulan PSL v2.
5  * You can use this software according to the terms and conditions of the Mulan PSL v2.
6  * You may obtain a copy of Mulan PSL v2 at:
7  *
8  *     http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  * See the Mulan PSL v2 for more details.
14  */
15 
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_SLH_DSA
18 
19 #include <stdint.h>
20 #include <stddef.h>
21 #include "securec.h"
22 #include "bsl_err_internal.h"
23 #include "bsl_sal.h"
24 #include "crypt_errno.h"
25 #include "slh_dsa_local.h"
26 #include "slh_dsa_fors.h"
27 
ForsSign(const uint8_t * md,uint32_t mdLen,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * sig,uint32_t * sigLen)28 int32_t ForsSign(const uint8_t *md, uint32_t mdLen, SlhDsaAdrs *adrs, const CryptSlhDsaCtx *ctx, uint8_t *sig,
29                  uint32_t *sigLen)
30 {
31     int32_t ret = CRYPT_SLHDSA_ERR_INVALID_SIG_LEN;
32     uint32_t n = ctx->para.n;
33     uint32_t a = ctx->para.a;
34     uint32_t k = ctx->para.k;
35 
36     if (*sigLen < (a + 1) * n * k) {
37         return CRYPT_SLHDSA_ERR_SIG_LEN_NOT_ENOUGH;
38     }
39 
40     uint32_t *indices = (uint32_t *)BSL_SAL_Malloc(k * sizeof(uint32_t));
41     if (indices == NULL) {
42         return BSL_MALLOC_FAIL;
43     }
44 
45     BaseB(md, mdLen, a, indices, k);
46     uint32_t offset = 0;
47     for (uint32_t i = 0; i < k; i++) {
48         ret = ForsGenPrvKey(adrs, indices[i] + (i << a), ctx, sig + offset);
49         if (ret != 0) {
50             goto ERR;
51         }
52         offset += n;
53         for (uint32_t j = 0; j < a; j++) {
54             uint32_t s = (indices[i] >> j) ^ 1;
55             ret = ForsNode((i << (a - j)) + s, j, adrs, ctx, sig + offset);
56             if (ret != 0) {
57                 goto ERR;
58             }
59             offset += n;
60         }
61     }
62     *sigLen = offset;
63 ERR:
64     BSL_SAL_Free(indices);
65     return ret;
66 }
67 
ForsPkFromSig(const uint8_t * sig,uint32_t sigLen,const uint8_t * md,uint32_t mdLen,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * pk)68 int32_t ForsPkFromSig(const uint8_t *sig, uint32_t sigLen, const uint8_t *md, uint32_t mdLen, SlhDsaAdrs *adrs,
69                       const CryptSlhDsaCtx *ctx, uint8_t *pk)
70 {
71     int32_t ret;
72     uint32_t *indices = NULL;
73     uint8_t *root = NULL;
74     uint32_t n = ctx->para.n;
75     uint32_t a = ctx->para.a;
76     uint32_t k = ctx->para.k;
77 
78     if (sigLen < (a + 1) * n * k) {
79         return CRYPT_SLHDSA_ERR_SIG_LEN_NOT_ENOUGH;
80     }
81 
82     indices = (uint32_t *)BSL_SAL_Malloc(k * sizeof(uint32_t));
83     if (indices == NULL) {
84         ret = BSL_MALLOC_FAIL;
85         goto ERR;
86     }
87     root = (uint8_t *)BSL_SAL_Malloc(n * k);
88     if (root == NULL) {
89         ret = BSL_MALLOC_FAIL;
90         goto ERR;
91     }
92 
93     BaseB(md, mdLen, a, indices, k);
94 
95     uint8_t node0[SLH_DSA_MAX_N] = {0};
96     uint8_t node1[SLH_DSA_MAX_N] = {0};
97 
98     for (uint32_t i = 0; i < k; i++) {
99         ctx->adrsOps.setTreeHeight(adrs, 0);
100         ctx->adrsOps.setTreeIndex(adrs, (i << a) + indices[i]);
101 
102         ret = ctx->hashFuncs.f(ctx, adrs, sig + (a + 1) * n * i, n, node0);
103         if (ret != 0) {
104             goto ERR;
105         }
106         const uint8_t *auth = sig + (a + 1) * n * i + n;
107         for (uint32_t j = 0; j < a; j++) {
108             uint8_t tmp[SLH_DSA_MAX_N * 2];
109             ctx->adrsOps.setTreeHeight(adrs, j + 1);
110             if (((indices[i] >> j) & 1) == 1) {
111                 ctx->adrsOps.setTreeIndex(adrs, (ctx->adrsOps.getTreeIndex(adrs) - 1) >> 1);
112                 (void)memcpy_s(tmp, sizeof(tmp), auth + j * n, n);
113                 (void)memcpy_s(tmp + n, sizeof(tmp) - n, node0, n);
114             } else {
115                 ctx->adrsOps.setTreeIndex(adrs, ctx->adrsOps.getTreeIndex(adrs) >> 1);
116                 (void)memcpy_s(tmp, sizeof(tmp), node0, n);
117                 (void)memcpy_s(tmp + n, sizeof(tmp) - n, auth + j * n, n);
118             }
119 
120             ret = ctx->hashFuncs.h(ctx, adrs, tmp, 2 * n, node1);
121             if (ret != 0) {
122                 goto ERR;
123             }
124             (void)memcpy_s(node0, sizeof(node0), node1, sizeof(node1));
125         }
126         (void)memcpy_s(root + i * n, (k - i) * n, node0, n);
127     }
128 
129     SlhDsaAdrs forspkAdrs = *adrs;
130     ctx->adrsOps.setType(&forspkAdrs, FORS_ROOTS);
131     ctx->adrsOps.copyKeyPairAddr(&forspkAdrs, adrs);
132 
133     ret = ctx->hashFuncs.tl(ctx, &forspkAdrs, root, n * k, pk);
134     if (ret != 0) {
135         goto ERR;
136     }
137 
138 ERR:
139     BSL_SAL_Free(indices);
140     BSL_SAL_Free(root);
141     return ret;
142 }
143 
ForsGenPrvKey(const SlhDsaAdrs * adrs,uint32_t idx,const CryptSlhDsaCtx * ctx,uint8_t * sk)144 int32_t ForsGenPrvKey(const SlhDsaAdrs *adrs, uint32_t idx, const CryptSlhDsaCtx *ctx, uint8_t *sk)
145 {
146     SlhDsaAdrs skadrs = *adrs;
147     ctx->adrsOps.setType(&skadrs, FORS_PRF);
148     ctx->adrsOps.copyKeyPairAddr(&skadrs, adrs);
149     ctx->adrsOps.setTreeIndex(&skadrs, idx);
150 
151     return ctx->hashFuncs.prf(ctx, &skadrs, sk);
152 }
153 
ForsNode(uint32_t idx,uint32_t height,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * node)154 int32_t ForsNode(uint32_t idx, uint32_t height, SlhDsaAdrs *adrs, const CryptSlhDsaCtx *ctx, uint8_t *node)
155 {
156     int32_t ret;
157     uint32_t n = ctx->para.n;
158 
159     if (height == 0) {
160         uint8_t sk[SLH_DSA_MAX_N] = {0};
161         ret = ForsGenPrvKey(adrs, idx, ctx, sk);
162         if (ret != 0) {
163             return ret;
164         }
165         ctx->adrsOps.setTreeHeight(adrs, height);
166         ctx->adrsOps.setTreeIndex(adrs, idx);
167         return ctx->hashFuncs.f(ctx, adrs, sk, n, node);
168     }
169 
170     uint8_t dnode[SLH_DSA_MAX_N * 2];
171     ret = ForsNode(idx * 2, height - 1, adrs, ctx, dnode);
172     if (ret != 0) {
173         return ret;
174     }
175     ret = ForsNode(idx * 2 + 1, height - 1, adrs, ctx, dnode + n);
176     if (ret != 0) {
177         return ret;
178     }
179     ctx->adrsOps.setTreeHeight(adrs, height);
180     ctx->adrsOps.setTreeIndex(adrs, idx);
181     return ctx->hashFuncs.h(ctx, adrs, dnode, 2 * n, node);
182 }
183 #endif // HITLS_CRYPTO_SLH_DSA
184