1 /*
2 * Copyright (c) 2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include "hdc_ssl_ut.h"
16 #include "securec.h"
17 using namespace testing::ext;
18 using namespace testing;
19 namespace Hdc {
20 typedef size_t rsize_t;
21 class MockHdcSSLBase : public HdcSSLBase {
22 public:
23 MOCK_METHOD5(RsaPubkeyEncrypt, int(const unsigned char* in, int inLen,
24 unsigned char* out, int outBufSize, const std::string& pubkey));
25 MOCK_METHOD0(IsHandshakeFinish, bool());
26 MOCK_METHOD0(ShowSSLInfo, void());
27 public:
MockHdcSSLBase(SSLInfoPtr hSSLInfo)28 explicit MockHdcSSLBase(SSLInfoPtr hSSLInfo) : HdcSSLBase(hSSLInfo)
29 {
30 }
31
~MockHdcSSLBase()32 ~MockHdcSSLBase()
33 {
34 }
35
SetPskCallback()36 bool SetPskCallback() override
37 {
38 if (SSL_CTX_set_ex_data(sslCtx, 0, preSharedKey) != 1) {
39 return false;
40 }
41 SSL_CTX_set_psk_client_callback(sslCtx, PskClientCallback);
42 return true;
43 }
44
SetSSLState()45 void SetSSLState() override
46 {
47 SSL_set_connect_state(ssl);
48 }
49
SetSSLMethod()50 const SSL_METHOD *SetSSLMethod() override
51 {
52 return TLS_client_method();
53 }
54 };
55
HdcSSLTest()56 HdcSSLTest::HdcSSLTest() {}
57
~HdcSSLTest()58 HdcSSLTest::~HdcSSLTest() {}
59
SetUpTestCase()60 void HdcSSLTest::SetUpTestCase() {}
61
TearDownTestCase()62 void HdcSSLTest::TearDownTestCase() {}
63
SetUp()64 void HdcSSLTest::SetUp() {}
65
TearDown()66 void HdcSSLTest::TearDown() {}
67
GenerateRSAKeyPair(std::string & publicKey,std::string & privateKey)68 void GenerateRSAKeyPair(std::string& publicKey, std::string& privateKey)
69 {
70 EVP_PKEY *pkey = EVP_PKEY_new();
71 BIGNUM *exponent = BN_new();
72 RSA *rsa = RSA_new();
73 int bits = RSA_KEY_BITS;
74
75 BN_set_word(exponent, RSA_F4);
76 RSA_generate_key_ex(rsa, bits, exponent, nullptr);
77 EVP_PKEY_set1_RSA(pkey, rsa);
78 BIO *bio = BIO_new(BIO_s_mem());
79 ASSERT_TRUE(bio != nullptr);
80 ASSERT_TRUE(PEM_write_bio_PUBKEY(bio, pkey));
81 char *pubkeyStr;
82 long pubkeyLen = BIO_get_mem_data(bio, &pubkeyStr);
83 publicKey.assign(pubkeyStr, pubkeyLen);
84
85 BIO_free(bio);
86 bio = BIO_new(BIO_s_mem());
87 ASSERT_TRUE(bio != nullptr);
88 ASSERT_TRUE(PEM_write_bio_PrivateKey(bio, pkey,
89 nullptr, nullptr, 0, nullptr, nullptr));
90 char *privkeyStr;
91 long privkeyLen = BIO_get_mem_data(bio, &privkeyStr);
92 privateKey.assign(privkeyStr, privkeyLen);
93
94 BIO_free(bio);
95 EVP_PKEY_free(pkey);
96 BN_free(exponent);
97 }
98
ReadPrivateKeyFromString(const std::string & privateKeyPEM)99 EVP_PKEY* ReadPrivateKeyFromString(const std::string& privateKeyPEM)
100 {
101 BIO *bio = BIO_new_mem_buf(privateKeyPEM.c_str(), -1);
102 if (!bio) {
103 std::cerr << "Error: BIO_new_mem_buf failed" << std::endl;
104 return nullptr;
105 }
106
107 EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr);
108 if (!pkey) {
109 std::cerr << "Error: PEM_read_bio_PrivateKey failed" << std::endl;
110 BIO_free(bio);
111 return nullptr;
112 }
113
114 BIO_free(bio);
115 return pkey;
116 }
117
RsaPrikeyDecrypt(const unsigned char * in,int inLen,unsigned char * out,EVP_PKEY * priKey)118 int RsaPrikeyDecrypt(const unsigned char* in, int inLen, unsigned char* out, EVP_PKEY* priKey)
119 {
120 RSA *rsa = EVP_PKEY_get1_RSA(priKey);
121 if (!rsa) {
122 unsigned long err = ERR_get_error();
123 char errbuf[120];
124 ERR_error_string_n(err, errbuf, sizeof(errbuf));
125 WRITE_LOG(LOG_WARN, "Error: EVP_PKEY_get1_RSA failed %s", errbuf);
126 return 0;
127 }
128 int outLen = RSA_private_decrypt(inLen, in, out, rsa, RSA_PKCS1_OAEP_PADDING);
129 RSA_free(rsa);
130 return outLen;
131 }
132
SSLHandShakeEmulate(HdcSSLBase * sslClient,HdcSSLBase * sslServer)133 void SSLHandShakeEmulate(HdcSSLBase *sslClient, HdcSSLBase *sslServer)
134 {
135 vector<uint8_t> buf;
136 ASSERT_EQ(sslClient->PerformHandshake(buf), RET_SUCCESS);
137 ASSERT_EQ(sslServer->DoBIOWrite(buf.data(), buf.size()), buf.size()); // step 1
138 ASSERT_EQ(sslServer->PerformHandshake(buf), RET_SUCCESS);
139 ASSERT_EQ(sslClient->DoBIOWrite(buf.data(), buf.size()), buf.size()); // step 2
140 ASSERT_EQ(sslClient->PerformHandshake(buf), RET_SUCCESS);
141 ASSERT_EQ(sslServer->DoBIOWrite(buf.data(), buf.size()), buf.size()); // step 3
142 ASSERT_EQ(sslServer->PerformHandshake(buf), RET_SUCCESS);
143 ASSERT_EQ(sslClient->DoBIOWrite(buf.data(), buf.size()), buf.size()); // step 4
144 ASSERT_EQ(sslClient->PerformHandshake(buf), RET_SSL_HANDSHAKE_FINISHED);
145 ASSERT_EQ(sslClient->IsHandshakeFinish(), true);
146 ASSERT_EQ(sslServer->IsHandshakeFinish(), true);
147 }
148
149 /**
150 * @tc.name: SetSSLInfoTest001
151 * @tc.desc: test SetSSLInfo add info
152 * @tc.type: FUNC
153 */
154 HWTEST_F(HdcSSLTest, SetSSLInfoTest001, TestSize.Level0)
155 {
156 SSLInfoPtr hSSLInfo = new HdcSSLInfo();
157 HSession hSession = new HdcSession();
158 hSession->serverOrDaemon = false;
159 hSession->sessionId = 123;
160 HdcSSLBase::SetSSLInfo(hSSLInfo, hSession);
161 ASSERT_EQ(hSSLInfo->cipher, TLS_AES_128_GCM_SHA256);
162 ASSERT_EQ(hSSLInfo->isDaemon, true);
163 ASSERT_EQ(hSSLInfo->sessionId, hSession->sessionId);
164 delete hSSLInfo;
165 delete hSession;
166 }
167
168 /**
169 * @tc.name: InitSSLTest001
170 * @tc.desc: test InitSSL as daemon role
171 * @tc.type: FUNC
172 */
173 HWTEST_F(HdcSSLTest, InitSSLTest001, TestSize.Level0)
174 {
175 SSLInfoPtr hSSLInfo = new HdcSSLInfo();
176 HSession hSession = new HdcSession();
177 HdcSSLBase::SetSSLInfo(hSSLInfo, hSession);
178 HdcSSLBase *sslBase = new (std::nothrow) HdcDaemonSSL(hSSLInfo);
179 ASSERT_EQ(sslBase->InitSSL(), RET_SUCCESS);
180 ASSERT_EQ(sslBase->isInited, true);
181 ASSERT_NE(sslBase->ssl, nullptr);
182 ASSERT_NE(sslBase->sslCtx, nullptr);
183 ASSERT_NE(sslBase->inBIO, nullptr);
184 ASSERT_NE(sslBase->outBIO, nullptr);
185 delete sslBase;
186 delete hSSLInfo;
187 delete hSession;
188 }
189
190 /**
191 * @tc.name: InitSSLTest002
192 * @tc.desc: test InitSSL as host role
193 * @tc.type: FUNC
194 */
195 HWTEST_F(HdcSSLTest, InitSSLTest002, TestSize.Level0)
196 {
197 SSLInfoPtr hSSLInfo = new HdcSSLInfo();
198 HSession hSession = new HdcSession();
199 HdcSSLBase::SetSSLInfo(hSSLInfo, hSession);
200 HdcSSLBase *sslBase = new (std::nothrow) HdcHostSSL(hSSLInfo);
201 ASSERT_EQ(sslBase->InitSSL(), RET_SUCCESS);
202 ASSERT_EQ(sslBase->isInited, true);
203 ASSERT_NE(sslBase->ssl, nullptr);
204 ASSERT_NE(sslBase->sslCtx, nullptr);
205 ASSERT_NE(sslBase->inBIO, nullptr);
206 ASSERT_NE(sslBase->outBIO, nullptr);
207 delete sslBase;
208 delete hSSLInfo;
209 delete hSession;
210 }
211
212 /**
213 * @tc.name: ClearSSLTest001
214 * @tc.desc: test ~HdcSSLBase as host role
215 * @tc.type: FUNC
216 */
217 HWTEST_F(HdcSSLTest, ClearSSLTest001, TestSize.Level0)
218 {
219 SSLInfoPtr hSSLInfo = new HdcSSLInfo();
220 HSession hSession = new HdcSession();
221 HdcSSLBase::SetSSLInfo(hSSLInfo, hSession);
222 HdcSSLBase *sslBase = new (std::nothrow) HdcHostSSL(hSSLInfo);
223 ASSERT_EQ(sslBase->InitSSL(), RET_SUCCESS);
224 sslBase->~HdcSSLBase();
225 ASSERT_EQ(sslBase->isInited, false);
226 ASSERT_EQ(sslBase->ssl, nullptr);
227 ASSERT_EQ(sslBase->sslCtx, nullptr);
228 ASSERT_EQ(sslBase->inBIO, nullptr);
229 ASSERT_EQ(sslBase->outBIO, nullptr);
230 sslBase = nullptr;
231 ASSERT_EQ(sslBase, nullptr);
232 delete hSSLInfo;
233 delete hSession;
234 }
235
236 /**
237 * @tc.name: ClearSSLTest002
238 * @tc.desc: test ~HdcSSLBase as daemon role
239 * @tc.type: FUNC
240 */
241 HWTEST_F(HdcSSLTest, ClearSSLTest002, TestSize.Level0)
242 {
243 SSLInfoPtr hSSLInfo = new HdcSSLInfo();
244 HSession hSession = new HdcSession();
245 HdcSSLBase::SetSSLInfo(hSSLInfo, hSession);
246 HdcSSLBase *sslBase = new (std::nothrow) HdcDaemonSSL(hSSLInfo);
247 ASSERT_EQ(sslBase->InitSSL(), RET_SUCCESS);
248 sslBase->~HdcSSLBase();
249 ASSERT_EQ(sslBase->isInited, false);
250 ASSERT_EQ(sslBase->ssl, nullptr);
251 ASSERT_EQ(sslBase->sslCtx, nullptr);
252 ASSERT_EQ(sslBase->inBIO, nullptr);
253 ASSERT_EQ(sslBase->outBIO, nullptr);
254 sslBase = nullptr;
255 ASSERT_EQ(sslBase, nullptr);
256 delete hSSLInfo;
257 delete hSession;
258 }
259
260 /**
261 * @tc.name: DoSSLHandshakeTest001
262 * @tc.desc: test SSLHandshake with step 1~6
263 * @tc.type: FUNC
264 */
265 // host( ) ---(TLS handshake client hello )--> hdcd( ) step 1
266 // host( ) <--(TLS handshake server hello )--- hdcd( ) step 2
267 // host(ok) ---(TLS handshake change cipher)--> hdcd( ) step 3
268 // host(ok) <--(TLS handshake change cipher)--- hdcd(ok) step 4
269 HWTEST_F(HdcSSLTest, DoSSLHandshakeTest001, TestSize.Level0)
270 {
271 SSLInfoPtr hSSLInfoDaemon = new HdcSSLInfo();
272 SSLInfoPtr hSSLInfoHost = new HdcSSLInfo();
273 HSession hSessionDaemon = new HdcSession();
274 HSession hSessionHost = new HdcSession();
275 HdcSSLBase::SetSSLInfo(hSSLInfoDaemon, hSessionDaemon);
276 HdcSSLBase::SetSSLInfo(hSSLInfoHost, hSessionHost);
277 HdcSSLBase *sslServer = new (std::nothrow) HdcDaemonSSL(hSSLInfoDaemon);
278 HdcSSLBase *sslClient = new (std::nothrow) HdcHostSSL(hSSLInfoHost);
279 std::vector<unsigned char> psk(32);
280 fill(psk.begin(), psk.end(), 0);
281 sslClient->InputPsk(psk.data(), psk.size());
282 sslServer->InputPsk(psk.data(), psk.size());
283 int pskClientRet = memcmp(sslClient->preSharedKey, psk.data(), psk.size());
284 int pskServerRet = memcmp(sslServer->preSharedKey, psk.data(), psk.size());
285 ASSERT_EQ(pskClientRet, 0);
286 ASSERT_EQ(pskServerRet, 0);
287 ASSERT_EQ(sslServer->InitSSL(), RET_SUCCESS);
288 ASSERT_EQ(sslClient->InitSSL(), RET_SUCCESS);
289 SSLHandShakeEmulate(sslClient, sslServer);
290 std::vector<uint8_t> plainTextOriginal;
291 std::vector<uint8_t> plainTextAltered;
292 std::string str = "hello world";
293 int targetSize = HdcSSLBase::GetSSLBufLen(str.size());
294 plainTextOriginal.assign(str.begin(), str.end());
295 plainTextAltered.assign(str.begin(), str.end());
296 int sourceSize = plainTextAltered.size();
297 plainTextAltered.resize(targetSize);
298 ASSERT_EQ(sslClient->Encrypt(sourceSize, plainTextAltered.data()), targetSize);
299 int diffRet = memcmp(plainTextOriginal.data(), plainTextAltered.data(), plainTextAltered.size());
300 ASSERT_NE(diffRet, 0);
301 int index = 0;
302 ASSERT_EQ(sslServer->Decrypt(targetSize, BUF_SIZE_DEFAULT16, plainTextAltered.data(), index),
303 RET_SUCCESS);
304 int sameRet = memcmp(plainTextOriginal.data(), plainTextAltered.data(), str.size());
305 ASSERT_EQ(sameRet, 0);
306 delete sslClient;
307 delete sslServer;
308 delete hSSLInfoHost;
309 delete hSSLInfoDaemon;
310 delete hSessionHost;
311 delete hSessionDaemon;
312 }
313
314 /**
315 * @tc.name: InputPskTest001
316 * @tc.desc: test InputPsk function with huge size input and normal size input.
317 * @tc.type: FUNC
318 */
319 HWTEST_F(HdcSSLTest, InputPskTest001, TestSize.Level0)
320 {
321 SSLInfoPtr hSSLInfo = new HdcSSLInfo();
322 HSession hSession = new HdcSession();
323 HdcSSLBase::SetSSLInfo(hSSLInfo, hSession);
324 HdcSSLBase *sslClient = new (std::nothrow) HdcHostSSL(hSSLInfo);
325 std::vector<unsigned char> pskHuge(BUF_SIZE_PSK * 2); // 2 times of psk size
326 std::vector<unsigned char> psk(BUF_SIZE_PSK);
327 fill(psk.begin(), psk.end(), 0);
328 fill(pskHuge.begin(), pskHuge.end(), 0);
329 ASSERT_EQ(sslClient->InputPsk(pskHuge.data(), pskHuge.size()), false);
330 ASSERT_EQ(sslClient->InputPsk(psk.data(), psk.size()), true);
331 for (int i = 0; i < psk.size(); ++i) {
332 ASSERT_EQ(sslClient->preSharedKey[i], psk[i]);
333 }
334 ASSERT_EQ(sslClient->ClearPsk(), true);
335 for (int i = 0; i < psk.size(); ++i) {
336 ASSERT_EQ(sslClient->preSharedKey[i], 0);
337 }
338 delete sslClient;
339 delete hSSLInfo;
340 delete hSession;
341 }
342
343 /**
344 * @tc.name: PskServerCallbackTest001
345 * @tc.desc: test PskServerCallback function with normal and error input.
346 * @tc.type: FUNC
347 */
348 HWTEST_F(HdcSSLTest, PskServerCallbackTest001, TestSize.Level0)
349 {
350 SSL_library_init();
351 OpenSSL_add_all_algorithms();
352 SSL_load_error_strings();
353 SSL *ssl;
354 SSL_CTX *sslCtx;
355 const SSL_METHOD *method;
356 method = TLS_server_method();
357 sslCtx = SSL_CTX_new(method);
358 std::string pskInputStr = "01234567890123456789012345678912"; // set data
359 unsigned char pskInput[BUF_SIZE_PSK];
360 std::copy(pskInputStr.begin(), pskInputStr.end(), pskInput);
361 ASSERT_EQ(SSL_CTX_set_ex_data(sslCtx, 0, pskInput), true);
362 SSL_CTX_set_psk_server_callback(sslCtx, HdcSSLBase::PskServerCallback);
363 ssl = SSL_new(sslCtx);
364 SSL_set_accept_state(ssl);
365 unsigned char psk[BUF_SIZE_PSK];
366 char identityValid[BUF_SIZE_PSK];
367 unsigned int maxPskLen = BUF_SIZE_PSK;
368
369 unsigned int ret = HdcSSLBase::PskServerCallback(ssl, STR_PSK_IDENTITY.c_str(), psk, maxPskLen);
370 ASSERT_EQ(ret, BUF_SIZE_PSK);
371 for (int i = 0; i < BUF_SIZE_PSK; ++i) {
372 ASSERT_EQ(psk[i], pskInput[i]);
373 }
374 unsigned int validLen = 0; // 无效的keyLen
375 ASSERT_EQ(HdcSSLBase::PskServerCallback(ssl, STR_PSK_IDENTITY.c_str(), psk, validLen), 0);
376 ASSERT_EQ(HdcSSLBase::PskServerCallback(ssl, identityValid, psk, maxPskLen), 0);
377 SSL_shutdown(ssl);
378 SSL_free(ssl);
379 SSL_CTX_free(sslCtx);
380 }
381
382 /**
383 * @tc.name: PskServerCallbackTest002
384 * @tc.desc: test PskServerCallback function with no pskInput
385 * @tc.type: FUNC
386 */
387 HWTEST_F(HdcSSLTest, PskServerCallbackTest002, TestSize.Level0)
388 {
389 SSL_library_init();
390 OpenSSL_add_all_algorithms();
391 SSL_load_error_strings();
392 SSL *ssl;
393 SSL_CTX *sslCtx;
394 const SSL_METHOD *method;
395 method = TLS_server_method();
396 sslCtx = SSL_CTX_new(method);
397 ssl = SSL_new(sslCtx);
398 SSL_set_accept_state(ssl);
399 unsigned char psk[BUF_SIZE_PSK];
400 unsigned int maxPskLen = BUF_SIZE_PSK;
401 ASSERT_EQ(HdcSSLBase::PskServerCallback(ssl, STR_PSK_IDENTITY.c_str(), psk, maxPskLen), 0);
402 SSL_shutdown(ssl);
403 SSL_free(ssl);
404 SSL_CTX_free(sslCtx);
405 }
406
407 /**
408 * @tc.name: PskClientCallbackTest001
409 * @tc.desc: test PskClientCallback function with normal and error input.
410 * @tc.type: FUNC
411 */
412 HWTEST_F(HdcSSLTest, PskClientCallbackTest001, TestSize.Level0)
413 {
414 SSL_library_init();
415 OpenSSL_add_all_algorithms();
416 SSL_load_error_strings();
417 SSL *ssl;
418 SSL_CTX *sslCtx;
419 const SSL_METHOD *method;
420 method = TLS_client_method();
421 sslCtx = SSL_CTX_new(method);
422 std::string pskInputStr = "01234567890123456789012345678912";
423 unsigned char pskInput[BUF_SIZE_PSK];
424 std::copy(pskInputStr.begin(), pskInputStr.end(), pskInput);
425 ASSERT_EQ(SSL_CTX_set_ex_data(sslCtx, 0, pskInput), true);
426 SSL_CTX_set_psk_client_callback(sslCtx, HdcSSLBase::PskClientCallback);
427 ssl = SSL_new(sslCtx);
428 SSL_set_connect_state(ssl);
429 const char* hint = STR_PSK_IDENTITY.c_str();
430 char identity[BUF_SIZE_PSK];
431 unsigned int maxIdentityLen = BUF_SIZE_PSK;
432 unsigned char psk[BUF_SIZE_PSK];
433 unsigned char pskValid[BUF_SIZE_MICRO];
434 unsigned int maxPskLen = BUF_SIZE_PSK;
435
436 unsigned int ret = HdcSSLBase::PskClientCallback(ssl, hint, identity, maxIdentityLen, psk, maxPskLen);
437 ASSERT_EQ(ret, BUF_SIZE_PSK);
438 for (int i = 0; i < BUF_SIZE_PSK; ++i) {
439 ASSERT_EQ(psk[i], pskInput[i]);
440 }
441 unsigned int validLen = 0; // valid keyLen
442 ASSERT_EQ(HdcSSLBase::PskClientCallback(ssl, hint, identity, maxIdentityLen, psk, validLen), 0);
443 ASSERT_EQ(HdcSSLBase::PskClientCallback(ssl, hint, identity, validLen, psk, maxPskLen), 0);
444 ASSERT_EQ(HdcSSLBase::PskClientCallback(ssl, hint, identity, validLen, pskValid, maxPskLen), 0);
445 ASSERT_EQ(HdcSSLBase::PskClientCallback(ssl, hint, identity, STR_PSK_IDENTITY.size(), pskValid, maxPskLen), 0);
446 SSL_shutdown(ssl);
447 SSL_free(ssl);
448 SSL_CTX_free(sslCtx);
449 }
450
451 /**
452 * @tc.name: PskClientCallbackTest001
453 * @tc.desc: test PskClientCallback function with no pskInput.
454 * @tc.type: FUNC
455 */
456 HWTEST_F(HdcSSLTest, PskClientCallbackTest002, TestSize.Level0)
457 {
458 SSL_library_init();
459 OpenSSL_add_all_algorithms();
460 SSL_load_error_strings();
461 SSL *ssl;
462 SSL_CTX *sslCtx;
463 const SSL_METHOD *method;
464 method = TLS_client_method();
465 sslCtx = SSL_CTX_new(method);
466 ssl = SSL_new(sslCtx);
467 SSL_set_connect_state(ssl);
468 const char* hint = STR_PSK_IDENTITY.c_str();
469 char identity[BUF_SIZE_PSK];
470 unsigned int maxIdentityLen = BUF_SIZE_PSK;
471 unsigned char psk[BUF_SIZE_PSK];
472 unsigned int maxPskLen = BUF_SIZE_PSK;
473 unsigned int ret = HdcSSLBase::PskClientCallback(ssl, hint, identity, maxIdentityLen, psk, maxPskLen);
474 ASSERT_EQ(ret, 0);
475 SSL_shutdown(ssl);
476 SSL_free(ssl);
477 SSL_CTX_free(sslCtx);
478 }
479
480 /**
481 * @tc.name: RsaPrikeyDecryptTest001
482 * @tc.desc: test RsaPrikeyDecrypt function with normal input.
483 * @tc.type: FUNC
484 */
485 HWTEST_F(HdcSSLTest, RsaPrikeyDecryptTest001, TestSize.Level0)
486 {
487 SSLInfoPtr hSSLInfo = new HdcSSLInfo();
488 HSession hSession = new HdcSession();
489 HdcSSLBase::SetSSLInfo(hSSLInfo, hSession);
490 MockHdcSSLBase *sslBase = new (std::nothrow) MockHdcSSLBase(hSSLInfo);
491 unsigned char in[BUF_SIZE_DEFAULT2] = "test data";
492 int inLen = strlen((char*)in);
493 unsigned char out[BUF_SIZE_DEFAULT2];
494 int ret = sslBase->RsaPrikeyDecrypt(in, inLen, out, BUF_SIZE_DEFAULT2);
495 ASSERT_EQ(ret, ERR_GENERIC);
496 }
497
498 /**
499 * @tc.name: RsaPubkeyEncryptTest001
500 * @tc.desc: test RsaPubkeyEncrypt function with normal input.
501 * @tc.type: FUNC
502 */
503 HWTEST_F(HdcSSLTest, RsaPubkeyEncryptTest001, TestSize.Level0)
504 {
505 SSLInfoPtr hSSLInfo = new HdcSSLInfo();
506 HSession hSession = new HdcSession();
507 HdcSSLBase::SetSSLInfo(hSSLInfo, hSession);
508 MockHdcSSLBase *sslBase = new (std::nothrow) MockHdcSSLBase(hSSLInfo);
509 unsigned char in[BUF_SIZE_DEFAULT2] = "test data";
510 int inLen = strlen((char*)in);
511 unsigned char out[BUF_SIZE_DEFAULT2];
512 std::string pubkey = "public key";
513
514 EXPECT_CALL(*sslBase, RsaPubkeyEncrypt(testing::_, testing::_, testing::_, testing::_, testing::_))
515 .WillOnce(testing::Return(inLen));
516
517 int ret = sslBase->RsaPubkeyEncrypt(in, inLen, out, BUF_SIZE_DEFAULT2, pubkey);
518 ASSERT_EQ(ret, inLen);
519 delete sslBase;
520 delete hSSLInfo;
521 delete hSession;
522 }
523
524 /**
525 * @tc.name: SetHandshakeLabelTest001
526 * @tc.desc: test SetHandshakeLabel function when handshake ok.
527 * @tc.type: FUNC
528 */
529 HWTEST_F(HdcSSLTest, SetHandshakeLabelTest001, TestSize.Level0)
530 {
531 SSLInfoPtr hSSLInfoDaemon = new HdcSSLInfo();
532 SSLInfoPtr hSSLInfoHost = new HdcSSLInfo();
533 HSession hSessionDaemon = new HdcSession();
534 HSession hSessionHost = new HdcSession();
535 HdcSSLBase::SetSSLInfo(hSSLInfoDaemon, hSessionDaemon);
536 HdcSSLBase::SetSSLInfo(hSSLInfoHost, hSessionHost);
537 HdcSSLBase *sslServer = new (std::nothrow) HdcDaemonSSL(hSSLInfoDaemon);
538 HdcSSLBase *sslClient = new (std::nothrow) HdcHostSSL(hSSLInfoHost);
539 std::vector<unsigned char> psk(32);
540 fill(psk.begin(), psk.end(), 0);
541 sslClient->InputPsk(psk.data(), psk.size());
542 sslServer->InputPsk(psk.data(), psk.size());
543 sslServer->InitSSL();
544 sslClient->InitSSL();
545 SSLHandShakeEmulate(sslClient, sslServer);
546
547 ASSERT_EQ(sslServer->SetHandshakeLabel(hSessionDaemon), true);
548 ASSERT_EQ(sslClient->SetHandshakeLabel(hSessionHost), true);
549 ASSERT_EQ(hSessionDaemon->sslHandshake, true);
550 ASSERT_EQ(hSessionHost->sslHandshake, true);
551
552 delete sslClient;
553 delete sslServer;
554 delete hSSLInfoHost;
555 delete hSSLInfoDaemon;
556 delete hSessionHost;
557 delete hSessionDaemon;
558 }
559
560 /**
561 * @tc.name: SetHandshakeLabelTest002
562 * @tc.desc: test SetHandshakeLabel function when handshake not ok.
563 * @tc.type: FUNC
564 */
565 HWTEST_F(HdcSSLTest, SetHandshakeLabelTest002, TestSize.Level0)
566 {
567 SSLInfoPtr hSSLInfoDaemon = new HdcSSLInfo();
568 SSLInfoPtr hSSLInfoHost = new HdcSSLInfo();
569 HSession hSessionDaemon = new HdcSession();
570 HSession hSessionHost = new HdcSession();
571 HdcSSLBase::SetSSLInfo(hSSLInfoDaemon, hSessionDaemon);
572 HdcSSLBase::SetSSLInfo(hSSLInfoHost, hSessionHost);
573 HdcSSLBase *sslServer = new (std::nothrow) HdcDaemonSSL(hSSLInfoDaemon);
574 HdcSSLBase *sslClient = new (std::nothrow) HdcHostSSL(hSSLInfoHost);
575 std::vector<unsigned char> psk(32);
576 fill(psk.begin(), psk.end(), 0);
577 sslClient->InputPsk(psk.data(), psk.size());
578 sslServer->InputPsk(psk.data(), psk.size());
579 sslServer->InitSSL();
580 sslClient->InitSSL();
581 ASSERT_EQ(sslServer->SetHandshakeLabel(hSessionDaemon), false);
582 ASSERT_EQ(sslClient->SetHandshakeLabel(hSessionHost), false);
583 ASSERT_EQ(hSessionDaemon->sslHandshake, false);
584 ASSERT_EQ(hSessionHost->sslHandshake, false);
585 delete sslServer;
586 delete sslClient;
587 delete hSSLInfoDaemon;
588 delete hSSLInfoHost;
589 delete hSessionDaemon;
590 delete hSessionHost;
591 }
592
593 /**
594 * @tc.name: GetPskEncryptTest001
595 * @tc.desc: test GetPskEncrypt function using generated public key and private key
596 * @tc.type: FUNC
597 */
598 HWTEST_F(HdcSSLTest, GetPskEncryptTest001, TestSize.Level0)
599 {
600 SSLInfoPtr hSSLInfo = new HdcSSLInfo();
601 HSession hSession = new HdcSession();
602 HdcSSLBase::SetSSLInfo(hSSLInfo, hSession);
603 hSSLInfo->isDaemon = false;
604 HdcSSLBase *sslBase = new (std::nothrow) HdcHostSSL(hSSLInfo);
605 ASSERT_EQ(sslBase->GenPsk(), true);
606 string publicKey;
607 string privateKey;
608 GenerateRSAKeyPair(publicKey, privateKey);
609 std::unique_ptr<unsigned char[]> payload(std::make_unique<unsigned char[]>(BUF_SIZE_DEFAULT2));
610 int payloadSize = sslBase->GetPskEncrypt(payload.get(), BUF_SIZE_DEFAULT2, publicKey);
611 ASSERT_GT(payloadSize, 0);
612 unsigned char tokenDecode[BUF_SIZE_DEFAULT] = { 0 };
613 std::unique_ptr<unsigned char[]> out(std::make_unique<unsigned char[]>(BUF_SIZE_DEFAULT2));
614 int tbytes = EVP_DecodeBlock(tokenDecode, payload.get(), payloadSize);
615 ASSERT_TRUE(tbytes > 0);
616 EVP_PKEY *priKey = ReadPrivateKeyFromString(privateKey);
617 int outLen = RsaPrikeyDecrypt(tokenDecode, tbytes, out.get(), priKey);
618 ASSERT_EQ(outLen, BUF_SIZE_PSK);
619 for (int i = 0; i < BUF_SIZE_PSK; ++i) {
620 ASSERT_EQ(out.get()[i], sslBase->preSharedKey[i]);
621 }
622 EVP_PKEY_free(priKey);
623 delete sslBase;
624 delete hSession;
625 delete hSSLInfo;
626 }
627 } // namespace Hdc