• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * This file is part of the openHiTLS project.
3  *
4  * openHiTLS is licensed under the Mulan PSL v2.
5  * You can use this software according to the terms and conditions of the Mulan PSL v2.
6  * You may obtain a copy of Mulan PSL v2 at:
7  *
8  *     http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  * See the Mulan PSL v2 for more details.
14  */
15 #include "hitls_build.h"
16 #ifdef HITLS_CRYPTO_SLH_DSA
17 
18 #include <stdint.h>
19 #include <string.h>
20 #include "securec.h"
21 #include "bsl_errno.h"
22 #include "crypt_errno.h"
23 #include "bsl_sal.h"
24 #include "slh_dsa_local.h"
25 #include "slh_dsa_wots.h"
26 
MsgToBaseW(const CryptSlhDsaCtx * ctx,const uint8_t * msg,uint32_t msgLen,uint32_t * out)27 static int32_t MsgToBaseW(const CryptSlhDsaCtx *ctx, const uint8_t *msg, uint32_t msgLen, uint32_t *out)
28 {
29     uint32_t n = ctx->para.n;
30     uint32_t len1 = 2 * n;
31     uint32_t len2 = 3;
32 
33     BaseB(msg, msgLen, SLH_DSA_LGW, out, len1);
34 
35     // todo: check if csum overflow
36     uint64_t csum = 0;
37     for (uint32_t i = 0; i < len1; i++) {
38         csum += SLH_DSA_W - 1 - out[i];
39     }
40     csum <<= SLH_DSA_LGW;
41     uint8_t csumBytes[2];
42     csumBytes[0] = (uint8_t)(csum >> 8);
43     csumBytes[1] = (uint8_t)csum;
44 
45     BaseB(csumBytes, 2, SLH_DSA_LGW, out + len1, len2);
46     return 0;
47 }
48 
WotsChain(const uint8_t * x,uint32_t xLen,uint32_t start,uint32_t end,const uint8_t * seed,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * output)49 int32_t WotsChain(const uint8_t *x, uint32_t xLen, uint32_t start, uint32_t end, const uint8_t *seed, SlhDsaAdrs *adrs,
50                   const CryptSlhDsaCtx *ctx, uint8_t *output)
51 {
52     (void)seed;
53     int32_t ret;
54     uint8_t tmp[SLH_DSA_MAX_N];
55     (void)memcpy_s(tmp, sizeof(tmp), x, xLen);
56     uint32_t tmpLen = xLen;
57 
58     for (uint32_t i = start; i < start + end; i++) {
59         ctx->adrsOps.setHashAddr(adrs, i);
60         ret = ctx->hashFuncs.f(ctx, adrs, tmp, tmpLen, tmp);
61         if (ret != 0) {
62             return ret;
63         }
64     }
65 
66     (void)memcpy_s(output, tmpLen, tmp, tmpLen);
67     return 0;
68 }
69 
WotsGeneratePublicKey(uint8_t * pub,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx)70 int WotsGeneratePublicKey(uint8_t *pub, SlhDsaAdrs *adrs, const CryptSlhDsaCtx *ctx)
71 {
72     int32_t ret;
73 
74     uint32_t n = ctx->para.n;
75     uint32_t len = 2 * n + 3;
76     SlhDsaAdrs skAdrs = *adrs;
77     ctx->adrsOps.setType(&skAdrs, WOTS_PRF);
78     ctx->adrsOps.copyKeyPairAddr(&skAdrs, adrs);
79 
80     uint8_t *tmp = (uint8_t *)BSL_SAL_Malloc(len * n);
81     if (tmp == NULL) {
82         return BSL_MALLOC_FAIL;
83     }
84 
85     for (uint32_t i = 0; i < len; i++) {
86         ctx->adrsOps.setChainAddr(&skAdrs, i);
87         uint8_t sk[SLH_DSA_MAX_N] = {0};
88         ret = ctx->hashFuncs.prf(ctx, &skAdrs, sk);
89         if (ret != 0) {
90             goto ERR;
91         }
92         ctx->adrsOps.setChainAddr(adrs, i);
93         ret = WotsChain(sk, n, 0, SLH_DSA_W - 1, ctx->prvKey.pub.seed, adrs, ctx, (tmp + i * n));
94         if (ret != 0) {
95             goto ERR;
96         }
97     }
98 
99     // compress public key
100     SlhDsaAdrs wotspk = *adrs;
101     ctx->adrsOps.setType(&wotspk, WOTS_PK);
102     ctx->adrsOps.copyKeyPairAddr(&wotspk, adrs);
103 
104     ret = ctx->hashFuncs.tl(ctx, &wotspk, tmp, len * n, pub);
105 
106 ERR:
107     BSL_SAL_Free(tmp);
108     return ret;
109 }
110 
WotsSign(uint8_t * sig,uint32_t * sigLen,const uint8_t * msg,uint32_t msgLen,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx)111 int32_t WotsSign(uint8_t *sig, uint32_t *sigLen, const uint8_t *msg, uint32_t msgLen, SlhDsaAdrs *adrs,
112                  const CryptSlhDsaCtx *ctx)
113 {
114     int32_t ret;
115     uint32_t n = ctx->para.n;
116     uint32_t len = 2 * n + 3;
117 
118     if (*sigLen < len * n) {
119         return CRYPT_BN_BUFF_LEN_NOT_ENOUGH;
120     }
121 
122     uint32_t *msgw = (uint32_t *)BSL_SAL_Malloc(len * sizeof(uint32_t));
123     if (msgw == NULL) {
124         return BSL_MALLOC_FAIL;
125     }
126     ret = MsgToBaseW(ctx, msg, msgLen, msgw);
127     if (ret != 0) {
128         goto ERR;
129     }
130 
131     SlhDsaAdrs skAdrs = *adrs;
132     ctx->adrsOps.setType(&skAdrs, WOTS_PRF);
133     ctx->adrsOps.copyKeyPairAddr(&skAdrs, adrs);
134     for (uint32_t i = 0; i < len; i++) {
135         ctx->adrsOps.setChainAddr(&skAdrs, i);
136         uint8_t sk[SLH_DSA_MAX_N] = {0};
137         ret = ctx->hashFuncs.prf(ctx, &skAdrs, sk);
138         if (ret != 0) {
139             goto ERR;
140         }
141         ctx->adrsOps.setChainAddr(adrs, i);
142         ret = WotsChain(sk, n, 0, msgw[i], ctx->prvKey.pub.seed, adrs, ctx, sig + i * n);
143         if (ret != 0) {
144             goto ERR;
145         }
146     }
147 ERR:
148     BSL_SAL_Free(msgw);
149     *sigLen = len * n;
150     return ret;
151 }
152 
WotsPubKeyFromSig(const uint8_t * msg,uint32_t msgLen,const uint8_t * sig,uint32_t sigLen,SlhDsaAdrs * adrs,const CryptSlhDsaCtx * ctx,uint8_t * pub)153 int32_t WotsPubKeyFromSig(const uint8_t *msg, uint32_t msgLen, const uint8_t *sig, uint32_t sigLen, SlhDsaAdrs *adrs,
154                       const CryptSlhDsaCtx *ctx, uint8_t *pub)
155 {
156     int32_t ret;
157     uint32_t n = ctx->para.n;
158     uint32_t len = 2 * n + 3;
159     uint32_t *msgw = NULL;
160     uint8_t *tmp = NULL;
161 
162     if (sigLen < len * n) {
163         return CRYPT_SLHDSA_ERR_SIG_LEN_NOT_ENOUGH;
164     }
165 
166     msgw = (uint32_t *)BSL_SAL_Malloc(len * sizeof(uint32_t));
167     if (msgw == NULL) {
168         return BSL_MALLOC_FAIL;
169     }
170     ret = MsgToBaseW(ctx, msg, msgLen, msgw);
171     if (ret != 0) {
172         goto ERR;
173     }
174     tmp = (uint8_t *)BSL_SAL_Malloc(len * n);
175     if (tmp == NULL) {
176         ret = BSL_MALLOC_FAIL;
177         goto ERR;
178     }
179 
180     for (uint32_t i = 0; i < len; i++) {
181         ctx->adrsOps.setChainAddr(adrs, i);
182         ret = WotsChain(sig + i * n, n, msgw[i], SLH_DSA_W - 1 - msgw[i], ctx->prvKey.pub.seed, adrs, ctx, tmp + i * n);
183         if (ret != 0) {
184             goto ERR;
185         }
186     }
187     SlhDsaAdrs wotspk = *adrs;
188     ctx->adrsOps.setType(&wotspk, WOTS_PK);
189     ctx->adrsOps.copyKeyPairAddr(&wotspk, adrs);
190     ret = ctx->hashFuncs.tl(ctx, &wotspk, tmp, len * n, pub);
191 
192 ERR:
193     BSL_SAL_Free(msgw);
194     if (tmp != NULL) {
195         BSL_SAL_Free(tmp);
196     }
197     return ret;
198 }
199 
200 #endif // HITLS_CRYPTO_SLH_DSA
201