1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_SLH_DSA
18
19 #include <stdint.h>
20 #include <stddef.h>
21 #include "securec.h"
22 #include "bsl_err_internal.h"
23 #include "bsl_sal.h"
24 #include "crypt_errno.h"
25 #include "slh_dsa_local.h"
26 #include "slh_dsa_fors.h"
27
ForsSign(const uint8_t * md,uint32_t mdLen,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * sig,uint32_t * sigLen)28 int32_t ForsSign(const uint8_t *md, uint32_t mdLen, SlhDsaAdrs *adrs, const CryptSlhDsaCtx *ctx, uint8_t *sig,
29 uint32_t *sigLen)
30 {
31 int32_t ret = CRYPT_SLHDSA_ERR_INVALID_SIG_LEN;
32 uint32_t n = ctx->para.n;
33 uint32_t a = ctx->para.a;
34 uint32_t k = ctx->para.k;
35
36 if (*sigLen < (a + 1) * n * k) {
37 return CRYPT_SLHDSA_ERR_SIG_LEN_NOT_ENOUGH;
38 }
39
40 uint32_t *indices = (uint32_t *)BSL_SAL_Malloc(k * sizeof(uint32_t));
41 if (indices == NULL) {
42 return BSL_MALLOC_FAIL;
43 }
44
45 BaseB(md, mdLen, a, indices, k);
46 uint32_t offset = 0;
47 for (uint32_t i = 0; i < k; i++) {
48 ret = ForsGenPrvKey(adrs, indices[i] + (i << a), ctx, sig + offset);
49 if (ret != 0) {
50 goto ERR;
51 }
52 offset += n;
53 for (uint32_t j = 0; j < a; j++) {
54 uint32_t s = (indices[i] >> j) ^ 1;
55 ret = ForsNode((i << (a - j)) + s, j, adrs, ctx, sig + offset);
56 if (ret != 0) {
57 goto ERR;
58 }
59 offset += n;
60 }
61 }
62 *sigLen = offset;
63 ERR:
64 BSL_SAL_Free(indices);
65 return ret;
66 }
67
ForsPkFromSig(const uint8_t * sig,uint32_t sigLen,const uint8_t * md,uint32_t mdLen,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * pk)68 int32_t ForsPkFromSig(const uint8_t *sig, uint32_t sigLen, const uint8_t *md, uint32_t mdLen, SlhDsaAdrs *adrs,
69 const CryptSlhDsaCtx *ctx, uint8_t *pk)
70 {
71 int32_t ret;
72 uint32_t *indices = NULL;
73 uint8_t *root = NULL;
74 uint32_t n = ctx->para.n;
75 uint32_t a = ctx->para.a;
76 uint32_t k = ctx->para.k;
77
78 if (sigLen < (a + 1) * n * k) {
79 return CRYPT_SLHDSA_ERR_SIG_LEN_NOT_ENOUGH;
80 }
81
82 indices = (uint32_t *)BSL_SAL_Malloc(k * sizeof(uint32_t));
83 if (indices == NULL) {
84 ret = BSL_MALLOC_FAIL;
85 goto ERR;
86 }
87 root = (uint8_t *)BSL_SAL_Malloc(n * k);
88 if (root == NULL) {
89 ret = BSL_MALLOC_FAIL;
90 goto ERR;
91 }
92
93 BaseB(md, mdLen, a, indices, k);
94
95 uint8_t node0[SLH_DSA_MAX_N] = {0};
96 uint8_t node1[SLH_DSA_MAX_N] = {0};
97
98 for (uint32_t i = 0; i < k; i++) {
99 ctx->adrsOps.setTreeHeight(adrs, 0);
100 ctx->adrsOps.setTreeIndex(adrs, (i << a) + indices[i]);
101
102 ret = ctx->hashFuncs.f(ctx, adrs, sig + (a + 1) * n * i, n, node0);
103 if (ret != 0) {
104 goto ERR;
105 }
106 const uint8_t *auth = sig + (a + 1) * n * i + n;
107 for (uint32_t j = 0; j < a; j++) {
108 uint8_t tmp[SLH_DSA_MAX_N * 2];
109 ctx->adrsOps.setTreeHeight(adrs, j + 1);
110 if (((indices[i] >> j) & 1) == 1) {
111 ctx->adrsOps.setTreeIndex(adrs, (ctx->adrsOps.getTreeIndex(adrs) - 1) >> 1);
112 (void)memcpy_s(tmp, sizeof(tmp), auth + j * n, n);
113 (void)memcpy_s(tmp + n, sizeof(tmp) - n, node0, n);
114 } else {
115 ctx->adrsOps.setTreeIndex(adrs, ctx->adrsOps.getTreeIndex(adrs) >> 1);
116 (void)memcpy_s(tmp, sizeof(tmp), node0, n);
117 (void)memcpy_s(tmp + n, sizeof(tmp) - n, auth + j * n, n);
118 }
119
120 ret = ctx->hashFuncs.h(ctx, adrs, tmp, 2 * n, node1);
121 if (ret != 0) {
122 goto ERR;
123 }
124 (void)memcpy_s(node0, sizeof(node0), node1, sizeof(node1));
125 }
126 (void)memcpy_s(root + i * n, (k - i) * n, node0, n);
127 }
128
129 SlhDsaAdrs forspkAdrs = *adrs;
130 ctx->adrsOps.setType(&forspkAdrs, FORS_ROOTS);
131 ctx->adrsOps.copyKeyPairAddr(&forspkAdrs, adrs);
132
133 ret = ctx->hashFuncs.tl(ctx, &forspkAdrs, root, n * k, pk);
134 if (ret != 0) {
135 goto ERR;
136 }
137
138 ERR:
139 BSL_SAL_Free(indices);
140 BSL_SAL_Free(root);
141 return ret;
142 }
143
ForsGenPrvKey(const SlhDsaAdrs * adrs,uint32_t idx,const CryptSlhDsaCtx * ctx,uint8_t * sk)144 int32_t ForsGenPrvKey(const SlhDsaAdrs *adrs, uint32_t idx, const CryptSlhDsaCtx *ctx, uint8_t *sk)
145 {
146 SlhDsaAdrs skadrs = *adrs;
147 ctx->adrsOps.setType(&skadrs, FORS_PRF);
148 ctx->adrsOps.copyKeyPairAddr(&skadrs, adrs);
149 ctx->adrsOps.setTreeIndex(&skadrs, idx);
150
151 return ctx->hashFuncs.prf(ctx, &skadrs, sk);
152 }
153
ForsNode(uint32_t idx,uint32_t height,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * node)154 int32_t ForsNode(uint32_t idx, uint32_t height, SlhDsaAdrs *adrs, const CryptSlhDsaCtx *ctx, uint8_t *node)
155 {
156 int32_t ret;
157 uint32_t n = ctx->para.n;
158
159 if (height == 0) {
160 uint8_t sk[SLH_DSA_MAX_N] = {0};
161 ret = ForsGenPrvKey(adrs, idx, ctx, sk);
162 if (ret != 0) {
163 return ret;
164 }
165 ctx->adrsOps.setTreeHeight(adrs, height);
166 ctx->adrsOps.setTreeIndex(adrs, idx);
167 return ctx->hashFuncs.f(ctx, adrs, sk, n, node);
168 }
169
170 uint8_t dnode[SLH_DSA_MAX_N * 2];
171 ret = ForsNode(idx * 2, height - 1, adrs, ctx, dnode);
172 if (ret != 0) {
173 return ret;
174 }
175 ret = ForsNode(idx * 2 + 1, height - 1, adrs, ctx, dnode + n);
176 if (ret != 0) {
177 return ret;
178 }
179 ctx->adrsOps.setTreeHeight(adrs, height);
180 ctx->adrsOps.setTreeIndex(adrs, idx);
181 return ctx->hashFuncs.h(ctx, adrs, dnode, 2 * n, node);
182 }
183 #endif // HITLS_CRYPTO_SLH_DSA
184