1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15 #include "hitls_build.h"
16 #ifdef HITLS_CRYPTO_SLH_DSA
17
18 #include <stdint.h>
19 #include <string.h>
20 #include "securec.h"
21 #include "bsl_errno.h"
22 #include "crypt_errno.h"
23 #include "bsl_sal.h"
24 #include "slh_dsa_local.h"
25 #include "slh_dsa_wots.h"
26
MsgToBaseW(const CryptSlhDsaCtx * ctx,const uint8_t * msg,uint32_t msgLen,uint32_t * out)27 static int32_t MsgToBaseW(const CryptSlhDsaCtx *ctx, const uint8_t *msg, uint32_t msgLen, uint32_t *out)
28 {
29 uint32_t n = ctx->para.n;
30 uint32_t len1 = 2 * n;
31 uint32_t len2 = 3;
32
33 BaseB(msg, msgLen, SLH_DSA_LGW, out, len1);
34
35 // todo: check if csum overflow
36 uint64_t csum = 0;
37 for (uint32_t i = 0; i < len1; i++) {
38 csum += SLH_DSA_W - 1 - out[i];
39 }
40 csum <<= SLH_DSA_LGW;
41 uint8_t csumBytes[2];
42 csumBytes[0] = (uint8_t)(csum >> 8);
43 csumBytes[1] = (uint8_t)csum;
44
45 BaseB(csumBytes, 2, SLH_DSA_LGW, out + len1, len2);
46 return 0;
47 }
48
WotsChain(const uint8_t * x,uint32_t xLen,uint32_t start,uint32_t end,const uint8_t * seed,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * output)49 int32_t WotsChain(const uint8_t *x, uint32_t xLen, uint32_t start, uint32_t end, const uint8_t *seed, SlhDsaAdrs *adrs,
50 const CryptSlhDsaCtx *ctx, uint8_t *output)
51 {
52 (void)seed;
53 int32_t ret;
54 uint8_t tmp[SLH_DSA_MAX_N];
55 (void)memcpy_s(tmp, sizeof(tmp), x, xLen);
56 uint32_t tmpLen = xLen;
57
58 for (uint32_t i = start; i < start + end; i++) {
59 ctx->adrsOps.setHashAddr(adrs, i);
60 ret = ctx->hashFuncs.f(ctx, adrs, tmp, tmpLen, tmp);
61 if (ret != 0) {
62 return ret;
63 }
64 }
65
66 (void)memcpy_s(output, tmpLen, tmp, tmpLen);
67 return 0;
68 }
69
WotsGeneratePublicKey(uint8_t * pub,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx)70 int WotsGeneratePublicKey(uint8_t *pub, SlhDsaAdrs *adrs, const CryptSlhDsaCtx *ctx)
71 {
72 int32_t ret;
73
74 uint32_t n = ctx->para.n;
75 uint32_t len = 2 * n + 3;
76 SlhDsaAdrs skAdrs = *adrs;
77 ctx->adrsOps.setType(&skAdrs, WOTS_PRF);
78 ctx->adrsOps.copyKeyPairAddr(&skAdrs, adrs);
79
80 uint8_t *tmp = (uint8_t *)BSL_SAL_Malloc(len * n);
81 if (tmp == NULL) {
82 return BSL_MALLOC_FAIL;
83 }
84
85 for (uint32_t i = 0; i < len; i++) {
86 ctx->adrsOps.setChainAddr(&skAdrs, i);
87 uint8_t sk[SLH_DSA_MAX_N] = {0};
88 ret = ctx->hashFuncs.prf(ctx, &skAdrs, sk);
89 if (ret != 0) {
90 goto ERR;
91 }
92 ctx->adrsOps.setChainAddr(adrs, i);
93 ret = WotsChain(sk, n, 0, SLH_DSA_W - 1, ctx->prvKey.pub.seed, adrs, ctx, (tmp + i * n));
94 if (ret != 0) {
95 goto ERR;
96 }
97 }
98
99 // compress public key
100 SlhDsaAdrs wotspk = *adrs;
101 ctx->adrsOps.setType(&wotspk, WOTS_PK);
102 ctx->adrsOps.copyKeyPairAddr(&wotspk, adrs);
103
104 ret = ctx->hashFuncs.tl(ctx, &wotspk, tmp, len * n, pub);
105
106 ERR:
107 BSL_SAL_Free(tmp);
108 return ret;
109 }
110
WotsSign(uint8_t * sig,uint32_t * sigLen,const uint8_t * msg,uint32_t msgLen,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx)111 int32_t WotsSign(uint8_t *sig, uint32_t *sigLen, const uint8_t *msg, uint32_t msgLen, SlhDsaAdrs *adrs,
112 const CryptSlhDsaCtx *ctx)
113 {
114 int32_t ret;
115 uint32_t n = ctx->para.n;
116 uint32_t len = 2 * n + 3;
117
118 if (*sigLen < len * n) {
119 return CRYPT_BN_BUFF_LEN_NOT_ENOUGH;
120 }
121
122 uint32_t *msgw = (uint32_t *)BSL_SAL_Malloc(len * sizeof(uint32_t));
123 if (msgw == NULL) {
124 return BSL_MALLOC_FAIL;
125 }
126 ret = MsgToBaseW(ctx, msg, msgLen, msgw);
127 if (ret != 0) {
128 goto ERR;
129 }
130
131 SlhDsaAdrs skAdrs = *adrs;
132 ctx->adrsOps.setType(&skAdrs, WOTS_PRF);
133 ctx->adrsOps.copyKeyPairAddr(&skAdrs, adrs);
134 for (uint32_t i = 0; i < len; i++) {
135 ctx->adrsOps.setChainAddr(&skAdrs, i);
136 uint8_t sk[SLH_DSA_MAX_N] = {0};
137 ret = ctx->hashFuncs.prf(ctx, &skAdrs, sk);
138 if (ret != 0) {
139 goto ERR;
140 }
141 ctx->adrsOps.setChainAddr(adrs, i);
142 ret = WotsChain(sk, n, 0, msgw[i], ctx->prvKey.pub.seed, adrs, ctx, sig + i * n);
143 if (ret != 0) {
144 goto ERR;
145 }
146 }
147 ERR:
148 BSL_SAL_Free(msgw);
149 *sigLen = len * n;
150 return ret;
151 }
152
WotsPubKeyFromSig(const uint8_t * msg,uint32_t msgLen,const uint8_t * sig,uint32_t sigLen,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * pub)153 int32_t WotsPubKeyFromSig(const uint8_t *msg, uint32_t msgLen, const uint8_t *sig, uint32_t sigLen, SlhDsaAdrs *adrs,
154 const CryptSlhDsaCtx *ctx, uint8_t *pub)
155 {
156 int32_t ret;
157 uint32_t n = ctx->para.n;
158 uint32_t len = 2 * n + 3;
159 uint32_t *msgw = NULL;
160 uint8_t *tmp = NULL;
161
162 if (sigLen < len * n) {
163 return CRYPT_SLHDSA_ERR_SIG_LEN_NOT_ENOUGH;
164 }
165
166 msgw = (uint32_t *)BSL_SAL_Malloc(len * sizeof(uint32_t));
167 if (msgw == NULL) {
168 return BSL_MALLOC_FAIL;
169 }
170 ret = MsgToBaseW(ctx, msg, msgLen, msgw);
171 if (ret != 0) {
172 goto ERR;
173 }
174 tmp = (uint8_t *)BSL_SAL_Malloc(len * n);
175 if (tmp == NULL) {
176 ret = BSL_MALLOC_FAIL;
177 goto ERR;
178 }
179
180 for (uint32_t i = 0; i < len; i++) {
181 ctx->adrsOps.setChainAddr(adrs, i);
182 ret = WotsChain(sig + i * n, n, msgw[i], SLH_DSA_W - 1 - msgw[i], ctx->prvKey.pub.seed, adrs, ctx, tmp + i * n);
183 if (ret != 0) {
184 goto ERR;
185 }
186 }
187 SlhDsaAdrs wotspk = *adrs;
188 ctx->adrsOps.setType(&wotspk, WOTS_PK);
189 ctx->adrsOps.copyKeyPairAddr(&wotspk, adrs);
190 ret = ctx->hashFuncs.tl(ctx, &wotspk, tmp, len * n, pub);
191
192 ERR:
193 BSL_SAL_Free(msgw);
194 if (tmp != NULL) {
195 BSL_SAL_Free(tmp);
196 }
197 return ret;
198 }
199
200 #endif // HITLS_CRYPTO_SLH_DSA
201