1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15
16 #include "hitls_build.h"
17 #if defined(HITLS_CRYPTO_SM2_EXCH) || defined(HITLS_CRYPTO_SM2_CRYPT)
18
19 #include <stdbool.h>
20 #include "crypt_errno.h"
21 #include "crypt_types.h"
22 #include "crypt_utils.h"
23 #include "securec.h"
24 #include "bsl_sal.h"
25 #include "bsl_err_internal.h"
26 #include "crypt_bn.h"
27 #include "crypt_ecc.h"
28 #include "crypt_ecc_pkey.h"
29 #include "crypt_local_types.h"
30 #include "crypt_sm2.h"
31 #include "sm2_local.h"
32
33 /* GM/T003_2012 Defined Key Derive Function */
KdfGmt0032012(uint8_t * out,const uint32_t * outlen,const uint8_t * z,uint32_t zlen,const EAL_MdMethod * hashMethod)34 int32_t KdfGmt0032012(uint8_t *out, const uint32_t *outlen, const uint8_t *z, uint32_t zlen,
35 const EAL_MdMethod *hashMethod)
36 {
37 if (out == NULL || outlen == NULL || *outlen == 0 || (z == NULL && zlen != 0) || hashMethod == NULL) {
38 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
39 return CRYPT_NULL_INPUT;
40 }
41 uint32_t counter;
42 uint8_t ctr[4];
43 uint32_t mdlen;
44 int32_t ret;
45 uint32_t len = MAX_MD_SIZE;
46 void *mdCtx = hashMethod->newCtx();
47 uint8_t dgst[MAX_MD_SIZE];
48 uint8_t *tmp = out;
49 uint32_t tmplen = *outlen;
50 if (mdCtx == NULL) {
51 ret = CRYPT_MEM_ALLOC_FAIL;
52 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
53 goto ERR;
54 }
55 mdlen = (uint32_t)hashMethod->mdSize;
56 for (counter = 1;; counter++) {
57 GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
58 PUT_UINT32_BE(counter, ctr, 0);
59 GOTO_ERR_IF(hashMethod->update(mdCtx, z, zlen), ret);
60 GOTO_ERR_IF(hashMethod->update(mdCtx, ctr, sizeof(ctr)), ret);
61 GOTO_ERR_IF(hashMethod->final(mdCtx, dgst, &len), ret);
62 if (tmplen > mdlen) {
63 (void)memcpy_s(tmp, tmplen, dgst, mdlen);
64 tmp += mdlen;
65 tmplen -= mdlen;
66 } else {
67 (void)memcpy_s(tmp, tmplen, dgst, tmplen);
68 (void)memset_s(dgst, mdlen, 0, mdlen);
69 break;
70 }
71 }
72 ERR:
73 hashMethod->freeCtx(mdCtx);
74 return ret;
75 }
76
Sm2CleanR(CRYPT_SM2_Ctx * ctx)77 void Sm2CleanR(CRYPT_SM2_Ctx *ctx)
78 {
79 BN_Destroy(ctx->r);
80 ctx->r = NULL;
81 ECC_FreePoint(ctx->pointR);
82 ctx->pointR = NULL;
83 return;
84 }
85
Sm2CalculateKey(const CRYPT_SM2_Ctx * selfCtx,const CRYPT_SM2_Ctx * peerCtx,ECC_Point * uorv,uint8_t * out,uint32_t * outlen)86 static int32_t Sm2CalculateKey(const CRYPT_SM2_Ctx *selfCtx, const CRYPT_SM2_Ctx *peerCtx, ECC_Point *uorv,
87 uint8_t *out, uint32_t *outlen)
88 {
89 uint32_t keyBits = CRYPT_SM2_GetBits(selfCtx);
90 uint32_t elementLen = (keyBits + 7) / 8; // Multiply keyBits by 8. Add 7 to round up the result.
91 int32_t ret;
92 uint32_t bufLen = elementLen * 2 + SM3_MD_SIZE * 2 + 1; /* add 1 byte tag; 2: 2 coordinates x and y, 2 z values */
93 uint32_t dataLen = 0; // length of actual data;
94 uint32_t curLen = 0; // length of buffer reserved for the current operation.
95 uint8_t *buf = (uint8_t *)BSL_SAL_Calloc(bufLen, sizeof(uint8_t));
96 if (buf == NULL) {
97 ret = CRYPT_MEM_ALLOC_FAIL;
98 BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
99 goto ERR;
100 }
101 /* 1 : Get public key for uorv, Notice: the first byte is a tag, not a valid char */
102 curLen = elementLen * 2 + 1; // add 1 byte tag; 2: 2 coordinates x and y
103 GOTO_ERR_IF(ECC_EncodePoint(selfCtx->pkey->para, uorv, buf, &curLen, CRYPT_POINT_UNCOMPRESSED), ret);
104 dataLen += curLen;
105 if (selfCtx->server == 1) {
106 /* SIDE A, Z_A || Z_B, server is initiator(Z_A), client is responder(Z_B) */
107 curLen = SM3_MD_SIZE;
108 GOTO_ERR_IF_EX(Sm2ComputeZDigest(selfCtx, buf + dataLen, &curLen), ret);
109 dataLen += curLen;
110 }
111 /* Caculate Peer z */
112 curLen = SM3_MD_SIZE;
113 GOTO_ERR_IF_EX(Sm2ComputeZDigest(peerCtx, buf + dataLen, &curLen), ret);
114 dataLen += curLen;
115 if (selfCtx->server == 0) {
116 /* SIDE B */
117 curLen = SM3_MD_SIZE;
118 GOTO_ERR_IF_EX(Sm2ComputeZDigest(selfCtx, buf + dataLen, &curLen), ret);
119 dataLen += curLen;
120 }
121 GOTO_ERR_IF(KdfGmt0032012(out, outlen, (const uint8_t *)(buf + 1), dataLen - 1, selfCtx->hashMethod), ret);
122 ERR:
123 BSL_SAL_FREE(buf);
124 return ret;
125 }
126
IsParamValid(const CRYPT_SM2_Ctx * selfCtx,const CRYPT_SM2_Ctx * peerCtx)127 static int32_t IsParamValid(const CRYPT_SM2_Ctx *selfCtx, const CRYPT_SM2_Ctx *peerCtx)
128 {
129 if (selfCtx->pkey->prvkey == NULL || peerCtx->pkey->pubkey == NULL) {
130 BSL_ERR_PUSH_ERROR(CRYPT_SM2_ERR_EMPTY_KEY);
131 return CRYPT_SM2_ERR_EMPTY_KEY;
132 }
133
134 if (selfCtx->hashMethod == NULL || peerCtx->hashMethod == NULL) {
135 BSL_ERR_PUSH_ERROR(CRYPT_SM2_ERR_NO_HASH_METHOD);
136 return CRYPT_SM2_ERR_NO_HASH_METHOD;
137 }
138
139 if (peerCtx->pointR == NULL || selfCtx->r == NULL) {
140 BSL_ERR_PUSH_ERROR(CRYPT_SM2_R_NOT_SET);
141 return CRYPT_SM2_R_NOT_SET;
142 }
143
144 if (selfCtx->pkey->pubkey == NULL) {
145 int32_t ret = ECC_GenPublicKey(selfCtx->pkey);
146 if (ret != CRYPT_SUCCESS) {
147 BSL_ERR_PUSH_ERROR(ret);
148 return ret;
149 }
150 }
151
152 return CRYPT_SUCCESS;
153 }
154
BnMemDestroy(BN_BigNum * xs,BN_BigNum * xp,BN_BigNum * t,BN_BigNum * twoPowerW,BN_BigNum * order)155 void BnMemDestroy(BN_BigNum *xs, BN_BigNum *xp, BN_BigNum *t,
156 BN_BigNum *twoPowerW, BN_BigNum *order)
157 {
158 BN_Destroy(xs);
159 BN_Destroy(xp);
160 BN_Destroy(t);
161 BN_Destroy(twoPowerW);
162 BN_Destroy(order);
163 }
164
Sm3MsgHash(const EAL_MdMethod * hashMethod,const uint8_t * yBuf,const uint8_t * hashBuf,uint8_t * out,uint32_t * outlen,uint8_t tag)165 static int32_t Sm3MsgHash(const EAL_MdMethod *hashMethod, const uint8_t *yBuf, const uint8_t *hashBuf,
166 uint8_t *out, uint32_t *outlen, uint8_t tag)
167 {
168 int32_t ret;
169 void *mdCtx = hashMethod->newCtx();
170 if (mdCtx == NULL) {
171 ret = CRYPT_MEM_ALLOC_FAIL;
172 BSL_ERR_PUSH_ERROR(ret);
173 return ret;
174 }
175 GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
176 GOTO_ERR_IF(hashMethod->update(mdCtx, &tag, 1), ret);
177 GOTO_ERR_IF(hashMethod->update(mdCtx, yBuf, SM3_MD_SIZE), ret);
178 GOTO_ERR_IF(hashMethod->update(mdCtx, hashBuf, SM3_MD_SIZE), ret);
179 GOTO_ERR_IF(hashMethod->final(mdCtx, out, outlen), ret);
180 ERR:
181 hashMethod->freeCtx(mdCtx);
182 return ret;
183 }
184
Sm3InnerHash(const EAL_MdMethod * hashMethod,const uint8_t * coordinate,const uint8_t * zBuf,uint32_t zlen,const uint8_t * rBuf,uint8_t * out,uint32_t * outlen)185 static int32_t Sm3InnerHash(const EAL_MdMethod *hashMethod, const uint8_t *coordinate, const uint8_t *zBuf,
186 uint32_t zlen, const uint8_t *rBuf, uint8_t *out, uint32_t *outlen)
187 {
188 int32_t ret;
189 void *mdCtx = hashMethod->newCtx();
190 if (mdCtx == NULL) {
191 ret = CRYPT_MEM_ALLOC_FAIL;
192 BSL_ERR_PUSH_ERROR(ret);
193 return ret;
194 }
195 GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
196 GOTO_ERR_IF(hashMethod->update(mdCtx, coordinate, SM2_X_LEN), ret);
197 GOTO_ERR_IF(hashMethod->update(mdCtx, zBuf, zlen), ret);
198 GOTO_ERR_IF(hashMethod->update(mdCtx, rBuf, SM2_TWO_POINT_COORDINATE_LEN), ret);
199 GOTO_ERR_IF(hashMethod->final(mdCtx, out, outlen), ret);
200 ERR:
201 hashMethod->freeCtx(mdCtx);
202 return ret;
203 }
204
Sm2KapFinalCheck(CRYPT_SM2_Ctx * sCtx,CRYPT_SM2_Ctx * pCtx,ECC_Point * uorv)205 int32_t Sm2KapFinalCheck(CRYPT_SM2_Ctx *sCtx, CRYPT_SM2_Ctx *pCtx, ECC_Point *uorv)
206 {
207 int32_t ret;
208 uint32_t len = SM3_MD_SIZE;
209 uint8_t r1Buf[SM2_POINT_COORDINATE_LEN];
210 uint8_t r2Buf[SM2_POINT_COORDINATE_LEN];
211 uint8_t rBuf[SM2_TWO_POINT_COORDINATE_LEN];
212 uint8_t xBuf[SM2_X_LEN];
213 uint8_t yBuf[SM2_X_LEN];
214 uint8_t zBuf[SM2_POINT_COORDINATE_LEN - 1];
215 uint8_t stmpBuf[SM3_MD_SIZE];
216 uint32_t buflen = SM2_POINT_COORDINATE_LEN;
217 uint32_t zlen = 0;
218 uint8_t tag1 = 0x03;
219 uint8_t tag2 = 0x02;
220 // Xv
221 GOTO_ERR_IF(ECC_EncodePoint(sCtx->pkey->para, uorv, r1Buf, &buflen, CRYPT_POINT_UNCOMPRESSED), ret);
222 (void)memcpy_s(xBuf, SM2_X_LEN, r1Buf + 1, SM2_X_LEN);
223 (void)memcpy_s(yBuf, SM2_X_LEN, r1Buf + 1 + SM2_X_LEN, SM2_X_LEN);
224 // Calculate ZA || ZB
225 if (sCtx->server == 1) {
226 /* SIDE A, Z_A || Z_B, server is initiator(Z_A), client is responder(Z_B) */
227 GOTO_ERR_IF_EX(Sm2ComputeZDigest(sCtx, zBuf, &len), ret);
228 zlen += len;
229 GOTO_ERR_IF(ECC_EncodePoint(sCtx->pkey->para, sCtx->pointR, r1Buf, &buflen, CRYPT_POINT_UNCOMPRESSED), ret);
230 GOTO_ERR_IF(ECC_EncodePoint(sCtx->pkey->para, pCtx->pointR, r2Buf, &buflen, CRYPT_POINT_UNCOMPRESSED), ret);
231 }
232 /* Calculate Peer z */
233 GOTO_ERR_IF_EX(Sm2ComputeZDigest(pCtx, zBuf + zlen, &len), ret);
234 zlen += len;
235 if (sCtx->server == 0) {
236 /* SIDE B */
237 GOTO_ERR_IF_EX(Sm2ComputeZDigest(sCtx, zBuf + zlen, &len), ret);
238 zlen += len;
239 GOTO_ERR_IF(ECC_EncodePoint(sCtx->pkey->para, pCtx->pointR, r1Buf, &buflen, CRYPT_POINT_UNCOMPRESSED), ret);
240 GOTO_ERR_IF(ECC_EncodePoint(sCtx->pkey->para, sCtx->pointR, r2Buf, &buflen, CRYPT_POINT_UNCOMPRESSED), ret);
241 tag1 = 0x02;
242 tag2 = 0x03;
243 }
244 (void)memcpy_s(rBuf, SM2_TWO_POINT_COORDINATE_LEN, r1Buf + 1, SM2_POINT_COORDINATE_LEN - 1);
245 (void)memcpy_s(rBuf + SM2_POINT_COORDINATE_LEN - 1, SM2_TWO_POINT_COORDINATE_LEN - SM2_POINT_COORDINATE_LEN + 1,
246 r2Buf + 1, SM2_POINT_COORDINATE_LEN - 1);
247 // Calculate the hash value.
248 GOTO_ERR_IF_EX(Sm3InnerHash(sCtx->hashMethod, xBuf, zBuf, zlen, rBuf, stmpBuf, &len), ret);
249 // Calculate the hash value sent to the peer end.
250 GOTO_ERR_IF_EX(Sm3MsgHash(sCtx->hashMethod, yBuf, stmpBuf, sCtx->sumSend, &len, tag1), ret);
251 // Computes the hash value for validation
252 GOTO_ERR_IF_EX(Sm3MsgHash(sCtx->hashMethod, yBuf, stmpBuf, sCtx->sumCheck, &len, tag2), ret);
253 sCtx->isSumValid = 1;
254 return ret;
255 ERR:
256 sCtx->isSumValid = 0; // Reset checksum validity flag
257 return ret;
258 }
259
SM2_PKG_Kdf(const CRYPT_SM2_Ctx * ctx,uint8_t * in,const uint32_t inLen,uint8_t * out,uint32_t * outLen)260 static int SM2_PKG_Kdf(const CRYPT_SM2_Ctx *ctx, uint8_t *in, const uint32_t inLen, uint8_t *out, uint32_t *outLen)
261 {
262 int32_t ret;
263 const uint32_t shareKeyLen = 16;
264 const EAL_MdMethod *hashMethod = ctx->hashMethod;
265 uint8_t *tmp = BSL_SAL_Malloc(hashMethod->mdSize);
266 uint32_t tmpLen = hashMethod->mdSize;
267 void *mdCtx = hashMethod->newCtx();
268 if (mdCtx == NULL || tmp == NULL) {
269 ret = CRYPT_MEM_ALLOC_FAIL;
270 BSL_ERR_PUSH_ERROR(ret);
271 goto ERR;
272 }
273 GOTO_ERR_IF(hashMethod->init(mdCtx, NULL), ret);
274 GOTO_ERR_IF(hashMethod->update(mdCtx, in, inLen), ret);
275 GOTO_ERR_IF(hashMethod->final(mdCtx, tmp, &tmpLen), ret);
276 if (memcpy_s(out, *outLen, tmp, shareKeyLen) != EOK) {
277 ret = CRYPT_SECUREC_FAIL;
278 BSL_ERR_PUSH_ERROR(ret);
279 goto ERR;
280 }
281 *outLen = shareKeyLen;
282 ERR:
283 hashMethod->freeCtx(mdCtx);
284 BSL_SAL_ClearFree(tmp, hashMethod->mdSize);
285 return ret;
286 }
287
SM2_PKGComputeKey(const CRYPT_SM2_Ctx * selfCtx,const CRYPT_SM2_Ctx * peerCtx,uint8_t * out,uint32_t * outlen)288 static int32_t SM2_PKGComputeKey(const CRYPT_SM2_Ctx *selfCtx, const CRYPT_SM2_Ctx *peerCtx,
289 uint8_t *out, uint32_t *outlen)
290 {
291 if (selfCtx->pkey == NULL || peerCtx->pkey == NULL) {
292 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
293 return CRYPT_NULL_INPUT;
294 }
295 if (selfCtx->hashMethod == NULL) {
296 BSL_ERR_PUSH_ERROR(CRYPT_SM2_ERR_NO_HASH_METHOD);
297 return CRYPT_SM2_ERR_NO_HASH_METHOD;
298 }
299 int32_t ret;
300 uint8_t sharePointCode[65] = {0};
301 uint32_t codeLen = sizeof(sharePointCode);
302 const ECC_Pkey *eccPkey = selfCtx->pkey;
303 BN_BigNum *tmpPrvkey = BN_Dup(eccPkey->prvkey);
304 ECC_Point *sharePoint = ECC_NewPoint(eccPkey->para);
305 if ((tmpPrvkey == NULL) || (sharePoint == NULL)) {
306 ret = CRYPT_MEM_ALLOC_FAIL;
307 BSL_ERR_PUSH_ERROR(ret);
308 goto ERR;
309 }
310 GOTO_ERR_IF(ECC_PointMul(eccPkey->para, sharePoint, eccPkey->prvkey, peerCtx->pkey->pubkey), ret);
311 GOTO_ERR_IF(ECC_PointCheck(sharePoint), ret);
312 GOTO_ERR_IF_EX(ECC_EncodePoint(eccPkey->para, sharePoint, sharePointCode, &codeLen, CRYPT_POINT_UNCOMPRESSED), ret);
313 GOTO_ERR_IF_EX(SM2_PKG_Kdf(selfCtx, sharePointCode + 1, codeLen - 1, out, outlen), ret);
314 ERR:
315 BN_Destroy(tmpPrvkey);
316 ECC_FreePoint(sharePoint);
317 return ret;
318 }
319
CRYPT_SM2_KapComputeKey(const CRYPT_SM2_Ctx * selfCtx,const CRYPT_SM2_Ctx * peerCtx,uint8_t * out,uint32_t * outlen)320 int32_t CRYPT_SM2_KapComputeKey(const CRYPT_SM2_Ctx *selfCtx, const CRYPT_SM2_Ctx *peerCtx,
321 uint8_t *out, uint32_t *outlen)
322 {
323 if (selfCtx == NULL || peerCtx == NULL || out == NULL || outlen == NULL || *outlen == 0) {
324 BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
325 return CRYPT_NULL_INPUT;
326 }
327 if (selfCtx->pkgImpl != 0) {
328 return SM2_PKGComputeKey(selfCtx, peerCtx, out, outlen);
329 }
330 ECC_Point *uorv = ECC_NewPoint(selfCtx->pkey->para);
331 uint32_t keyBits = CRYPT_SM2_GetBits(selfCtx);
332 BN_BigNum *xs = BN_Create(keyBits);
333 BN_BigNum *xp = BN_Create(keyBits);
334 BN_BigNum *t = BN_Create(keyBits);
335 BN_BigNum *twoPowerW = BN_Create(keyBits);
336 BN_BigNum *order = ECC_GetParaN(selfCtx->pkey->para);
337 uint32_t w;
338 int32_t ret;
339 BN_Optimizer *opt = BN_OptimizerCreate();
340 if (uorv == NULL || xs == NULL || xp == NULL || t == NULL || twoPowerW == NULL ||
341 order == NULL || opt == NULL) {
342 ret = CRYPT_MEM_ALLOC_FAIL;
343 BSL_ERR_PUSH_ERROR(ret);
344 goto ERR;
345 }
346 GOTO_ERR_IF(IsParamValid(selfCtx, peerCtx), ret);
347 /* Second: Caculate -- w */
348 // w is equal to the number of digits of n rounded up, divided by 2, and then subtracted by 1.
349 w = (BN_Bits(order) + 1) / 2 - 1;
350 GOTO_ERR_IF(BN_Zeroize(twoPowerW), ret);
351 GOTO_ERR_IF(BN_SetBit(twoPowerW, w), ret);
352 /* Third: Caculate -- X = 2 ^ w + (x & (2 ^ w - 1)) = 2 ^ w + (x mod 2 ^ w) */
353 /* Get x */
354 GOTO_ERR_IF(ECC_GetPointDataX(selfCtx->pkey->para, selfCtx->pointR, xs), ret);
355 GOTO_ERR_IF(ECC_GetPointDataX(peerCtx->pkey->para, peerCtx->pointR, xp), ret);
356 /* x mod 2 ^ w */
357 /* Caculate Self x */
358 GOTO_ERR_IF(BN_Mod(xs, xs, twoPowerW, opt), ret);
359 GOTO_ERR_IF(BN_Add(xs, xs, twoPowerW), ret);
360 /* Caculate Peer x */
361 GOTO_ERR_IF(BN_Mod(xp, xp, twoPowerW, opt), ret);
362 GOTO_ERR_IF(BN_Add(xp, xp, twoPowerW), ret);
363 /* Forth: Caculate t */
364 GOTO_ERR_IF(BN_ModMul(t, xs, selfCtx->r, order, opt), ret);
365 GOTO_ERR_IF(BN_ModAddQuick(t, t, selfCtx->pkey->prvkey, order, opt), ret);
366 /* Fifth: Caculate V or U */
367 GOTO_ERR_IF(ECC_PointMul(peerCtx->pkey->para, uorv, xp, peerCtx->pointR), ret);
368 /* P + [x]R */
369 GOTO_ERR_IF(ECC_PointAddAffine(selfCtx->pkey->para, uorv, uorv, peerCtx->pkey->pubkey), ret);
370 GOTO_ERR_IF(ECC_PointMul(selfCtx->pkey->para, uorv, t, uorv), ret);
371 /* Detect uorv is in */
372 GOTO_ERR_IF(ECC_PointCheck(uorv), ret);
373 /* Sixth: Caculate Key -- Need Xuorv, Yuorv, Zc, Zs, klen */
374 GOTO_ERR_IF_EX(Sm2CalculateKey(selfCtx, peerCtx, uorv, out, outlen), ret);
375 GOTO_ERR_IF_EX(Sm2KapFinalCheck((CRYPT_SM2_Ctx *)(uintptr_t)selfCtx, (CRYPT_SM2_Ctx *)(uintptr_t)peerCtx, uorv),
376 ret);
377 ERR:
378 BnMemDestroy(xs, xp, t, twoPowerW, order);
379 ECC_FreePoint(uorv);
380 Sm2CleanR((CRYPT_SM2_Ctx *)(uintptr_t)selfCtx);
381 BN_OptimizerDestroy(opt);
382 return ret;
383 }
384 #endif
385