1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <fstream>
17 #include <gtest/gtest.h>
18 #include <iostream>
19 #include <openssl/rsa.h>
20 #include <openssl/ssl.h>
21 #include <sstream>
22 #include <string>
23 #include <string_view>
24 #include <unistd.h>
25 #include <vector>
26
27 #include "net_address.h"
28 #include "secure_data.h"
29 #include "socket_error.h"
30 #include "socket_state_base.h"
31 #include "tls.h"
32 #include "tls_certificate.h"
33 #include "tls_configuration.h"
34 #include "tls_key.h"
35 #include "tls_socket.h"
36
37 namespace OHOS {
38 namespace NetStack {
39 namespace TlsSocket {
40 namespace {
41 const std::string_view PRIVATE_KEY_PEM_CHAIN = "/data/ClientCertChain/privekey.pem.unsecure";
42 const std::string_view CA_PATH_CHAIN = "/data/ClientCertChain/cacert.crt";
43 const std::string_view MID_CA_PATH_CHAIN = "/data/ClientCertChain/caMidcert.crt";
44 const std::string_view CLIENT_CRT_CHAIN = "/data/ClientCertChain/secondServer.crt";
45 const std::string_view IP_ADDRESS = "/data/Ip/address.txt";
46 const std::string_view PORT = "/data/Ip/port.txt";
47
CheckCaFileExistence(const char * function)48 inline bool CheckCaFileExistence(const char *function)
49 {
50 if (access(CA_PATH_CHAIN.data(), 0)) {
51 std::cout << "CA file does not exist! (" << function << ")";
52 return false;
53 }
54 return true;
55 }
56
ReadFileContent(const std::string_view fileName)57 std::string ReadFileContent(const std::string_view fileName)
58 {
59 std::ifstream file;
60 file.open(fileName);
61 std::stringstream ss;
62 ss << file.rdbuf();
63 std::string infos = ss.str();
64 file.close();
65 return infos;
66 }
67
GetIp(std::string ip)68 std::string GetIp(std::string ip)
69 {
70 return ip.substr(0, ip.length() - 1);
71 }
72 } // namespace
73
74 class TlsSocketTest : public testing::Test {
75 public:
SetUpTestCase()76 static void SetUpTestCase() {}
77
TearDownTestCase()78 static void TearDownTestCase() {}
79
SetUp()80 virtual void SetUp() {}
81
TearDown()82 virtual void TearDown() {}
83 };
84
SetCertChainHwTestShortParam(TLSSocket & server)85 void SetCertChainHwTestShortParam(TLSSocket &server)
86 {
87 TLSConnectOptions options;
88 TLSSecureOptions secureOption;
89 Socket::NetAddress address;
90
91 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
92 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
93 address.SetFamilyBySaFamily(AF_INET);
94
95 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
96 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
97 secureOption.SetCaChain(caVec);
98 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
99
100 options.SetNetAddress(address);
101 options.SetTlsSecureOptions(secureOption);
102
103 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
104 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
105 }
106
SetCertChainHwTestLongParam(TLSSocket & server)107 void SetCertChainHwTestLongParam(TLSSocket &server)
108 {
109 TLSConnectOptions options;
110 TLSSecureOptions secureOption;
111 Socket::NetAddress address;
112
113 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
114 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
115 address.SetFamilyBySaFamily(AF_INET);
116
117 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
118 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
119 secureOption.SetCaChain(caVec);
120 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
121 secureOption.SetCipherSuite("AES256-SHA256");
122 std::string protocolV13 = "TLSv1.3";
123 std::vector<std::string> protocolVec = {protocolV13};
124 secureOption.SetProtocolChain(protocolVec);
125
126 options.SetNetAddress(address);
127 options.SetTlsSecureOptions(secureOption);
128
129 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
130 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
131 }
132
133 HWTEST_F(TlsSocketTest, bindInterface, testing::ext::TestSize.Level2)
134 {
135 if (!CheckCaFileExistence("bindInterface")) {
136 return;
137 }
138
139 TLSSocket server;
140 Socket::NetAddress address;
141
142 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
143 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
144 address.SetFamilyBySaFamily(AF_INET);
145
__anond1e1d4740602(int32_t errCode) 146 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
147 }
148
149 HWTEST_F(TlsSocketTest, connectInterface, testing::ext::TestSize.Level2)
150 {
151 if (!CheckCaFileExistence("connectInterface")) {
152 return;
153 }
154 TLSSocket server;
155 SetCertChainHwTestShortParam(server);
156
157 const std::string data = "how do you do? this is connectInterface";
158 Socket::TCPSendOptions tcpSendOptions;
159 tcpSendOptions.SetData(data);
__anond1e1d4740702(int32_t errCode) 160 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
161 sleep(2);
162
__anond1e1d4740802(int32_t errCode) 163 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
164 sleep(2);
165 }
166
167 HWTEST_F(TlsSocketTest, closeInterface, testing::ext::TestSize.Level2)
168 {
169 if (!CheckCaFileExistence("closeInterface")) {
170 return;
171 }
172 TLSSocket server;
173 SetCertChainHwTestShortParam(server);
174
175 const std::string data = "how do you do? this is closeInterface";
176 Socket::TCPSendOptions tcpSendOptions;
177 tcpSendOptions.SetData(data);
178
__anond1e1d4740902(int32_t errCode) 179 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
180 sleep(2);
181
__anond1e1d4740a02(int32_t errCode) 182 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
183 }
184
185 HWTEST_F(TlsSocketTest, sendInterface, testing::ext::TestSize.Level2)
186 {
187 if (!CheckCaFileExistence("sendInterface")) {
188 return;
189 }
190 TLSSocket server;
191 SetCertChainHwTestShortParam(server);
192
193 const std::string data = "how do you do? this is sendInterface";
194 Socket::TCPSendOptions tcpSendOptions;
195 tcpSendOptions.SetData(data);
196
__anond1e1d4740b02(int32_t errCode) 197 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
198 sleep(2);
199
__anond1e1d4740c02(int32_t errCode) 200 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
201 }
202
203 HWTEST_F(TlsSocketTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
204 {
205 if (!CheckCaFileExistence("getRemoteAddressInterface")) {
206 return;
207 }
208 TLSSocket server;
209 TLSConnectOptions options;
210 TLSSecureOptions secureOption;
211 Socket::NetAddress address;
212
213 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
214 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
215 address.SetFamilyBySaFamily(AF_INET);
216
217 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
218 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
219 secureOption.SetCaChain(caVec);
220 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
221
222 options.SetNetAddress(address);
223 options.SetTlsSecureOptions(secureOption);
224
__anond1e1d4740d02(int32_t errCode) 225 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
__anond1e1d4740e02(int32_t errCode) 226 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
227
228 Socket::NetAddress netAddress;
__anond1e1d4740f02(int32_t errCode, const Socket::NetAddress &address) 229 server.GetRemoteAddress([&netAddress](int32_t errCode, const Socket::NetAddress &address) {
230 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
231 netAddress.SetAddress(address.GetAddress());
232 netAddress.SetPort(address.GetPort());
233 netAddress.SetFamilyBySaFamily(address.GetSaFamily());
234 });
235 EXPECT_STREQ(netAddress.GetAddress().c_str(), GetIp(ReadFileContent(IP_ADDRESS)).c_str());
236 EXPECT_EQ(address.GetPort(), std::atoi(ReadFileContent(PORT).c_str()));
237 EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
238
239 const std::string data = "how do you do? this is getRemoteAddressInterface";
240 Socket::TCPSendOptions tcpSendOptions;
241 tcpSendOptions.SetData(data);
242
__anond1e1d4741002(int32_t errCode) 243 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
244
__anond1e1d4741102(int32_t errCode) 245 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
246 }
247
248 HWTEST_F(TlsSocketTest, getStateInterface, testing::ext::TestSize.Level2)
249 {
250 if (!CheckCaFileExistence("getRemoteAddressInterface")) {
251 return;
252 }
253 TLSSocket server;
254 SetCertChainHwTestShortParam(server);
255
256 Socket::SocketStateBase TlsSocketstate;
__anond1e1d4741202(int32_t errCode, const Socket::SocketStateBase &state) 257 server.GetState([&TlsSocketstate](int32_t errCode, const Socket::SocketStateBase &state) {
258 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
259 TlsSocketstate = state;
260 });
261 std::cout << "TlsSocketstate.IsClose(): " << TlsSocketstate.IsClose() << std::endl;
262 EXPECT_TRUE(TlsSocketstate.IsBound());
263 EXPECT_TRUE(!TlsSocketstate.IsClose());
264 EXPECT_TRUE(TlsSocketstate.IsConnected());
265
266 const std::string data = "how do you do? this is getStateInterface";
267 Socket::TCPSendOptions tcpSendOptions;
268 tcpSendOptions.SetData(data);
__anond1e1d4741302(int32_t errCode) 269 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
270
271 sleep(2);
272
__anond1e1d4741402(int32_t errCode) 273 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
274 }
275
276 HWTEST_F(TlsSocketTest, getCertificateInterface, testing::ext::TestSize.Level2)
277 {
278 if (!CheckCaFileExistence("getCertificateInterface")) {
279 return;
280 }
281 TLSSocket server;
282 SetCertChainHwTestShortParam(server);
283 Socket::TCPSendOptions tcpSendOptions;
284 const std::string data = "how do you do? This is UT test getCertificateInterface";
285
286 tcpSendOptions.SetData(data);
__anond1e1d4741502(int32_t errCode) 287 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
288
289 server.GetCertificate(
__anond1e1d4741602(int32_t errCode, const X509CertRawData &cert) 290 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
291
292 sleep(2);
__anond1e1d4741702(int32_t errCode) 293 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
294 }
295
296 HWTEST_F(TlsSocketTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
297 {
298 if (!CheckCaFileExistence("getRemoteCertificateInterface")) {
299 return;
300 }
301 TLSSocket server;
302 SetCertChainHwTestShortParam(server);
303 Socket::TCPSendOptions tcpSendOptions;
304 const std::string data = "how do you do? This is UT test getRemoteCertificateInterface";
305 tcpSendOptions.SetData(data);
306
__anond1e1d4741802(int32_t errCode) 307 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
308
309 server.GetRemoteCertificate(
__anond1e1d4741902(int32_t errCode, const X509CertRawData &cert) 310 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
311
312 sleep(2);
__anond1e1d4741a02(int32_t errCode) 313 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
314 }
315
316 HWTEST_F(TlsSocketTest, protocolInterface, testing::ext::TestSize.Level2)
317 {
318 if (!CheckCaFileExistence("protocolInterface")) {
319 return;
320 }
321 TLSSocket server;
322 SetCertChainHwTestLongParam(server);
323
324 const std::string data = "how do you do? this is protocolInterface";
325 Socket::TCPSendOptions tcpSendOptions;
326 tcpSendOptions.SetData(data);
327
__anond1e1d4741b02(int32_t errCode) 328 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
329 std::string getProtocolVal;
__anond1e1d4741c02(int32_t errCode, const std::string &protocol) 330 server.GetProtocol([&getProtocolVal](int32_t errCode, const std::string &protocol) {
331 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
332 getProtocolVal = protocol;
333 });
334 EXPECT_STREQ(getProtocolVal.c_str(), "TLSv1.3");
335 sleep(2);
336
__anond1e1d4741d02(int32_t errCode) 337 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
338 }
339
340 HWTEST_F(TlsSocketTest, getCipherSuiteInterface, testing::ext::TestSize.Level2)
341 {
342 if (!CheckCaFileExistence("getCipherSuiteInterface")) {
343 return;
344 }
345 TLSSocket server;
346 SetCertChainHwTestLongParam(server);
347
348 bool flag = false;
349 const std::string data = "how do you do? This is getCipherSuiteInterface";
350 Socket::TCPSendOptions tcpSendOptions;
351 tcpSendOptions.SetData(data);
__anond1e1d4741e02(int32_t errCode) 352 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
353
354 std::vector<std::string> cipherSuite;
__anond1e1d4741f02(int32_t errCode, const std::vector<std::string> &suite) 355 server.GetCipherSuite([&cipherSuite](int32_t errCode, const std::vector<std::string> &suite) {
356 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
357 cipherSuite = suite;
358 });
359
360 for (auto const &iter : cipherSuite) {
361 if (iter == "AES256-SHA256") {
362 flag = true;
363 }
364 }
365
366 EXPECT_TRUE(flag);
367 sleep(2);
368
__anond1e1d4742002(int32_t errCode) 369 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
370 }
371
372 HWTEST_F(TlsSocketTest, getSignatureAlgorithmsInterface, testing::ext::TestSize.Level2)
373 {
374 if (!CheckCaFileExistence("getSignatureAlgorithmsInterface")) {
375 return;
376 }
377 TLSConnectOptions options;
378 TLSSocket server;
379 TLSSecureOptions secureOption;
380 Socket::NetAddress address;
381
382 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
383 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
384 address.SetFamilyBySaFamily(AF_INET);
385
386 std::string signatureAlgorithmVec = {"rsa_pss_rsae_sha256:ECDSA+SHA256"};
387 secureOption.SetSignatureAlgorithms(signatureAlgorithmVec);
388 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
389 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
390 secureOption.SetCaChain(caVec);
391 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
392 std::string protocolV13 = "TLSv1.3";
393 std::vector<std::string> protocolVec = {protocolV13};
394 secureOption.SetProtocolChain(protocolVec);
395
396 options.SetNetAddress(address);
397 options.SetTlsSecureOptions(secureOption);
398
399 bool flag = false;
__anond1e1d4742102(int32_t errCode) 400 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
401
__anond1e1d4742202(int32_t errCode) 402 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
403
404 const std::string data = "how do you do? this is getSignatureAlgorithmsInterface";
405 Socket::TCPSendOptions tcpSendOptions;
406 tcpSendOptions.SetData(data);
__anond1e1d4742302(int32_t errCode) 407 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
408
409 std::vector<std::string> signatureAlgorithms;
__anond1e1d4742402(int32_t errCode, const std::vector<std::string> &algorithms) 410 server.GetSignatureAlgorithms([&signatureAlgorithms](int32_t errCode, const std::vector<std::string> &algorithms) {
411 if (errCode == TLSSOCKET_SUCCESS) {
412 signatureAlgorithms = algorithms;
413 }
414 });
415 for (auto const &iter : signatureAlgorithms) {
416 if (iter == "ECDSA+SHA256") {
417 flag = true;
418 }
419 }
420 EXPECT_TRUE(flag);
421 sleep(2);
__anond1e1d4742502(int32_t errCode) 422 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
423 }
424
425 HWTEST_F(TlsSocketTest, onMessageDataInterface, testing::ext::TestSize.Level2)
426 {
427 if (!CheckCaFileExistence("tlsSocketOnMessageData")) {
428 return;
429 }
430 std::string getData = "server->client";
431 TLSSocket server;
432 SetCertChainHwTestLongParam(server);
__anond1e1d4742602(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) 433 server.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
434 if (data == getData) {
435 EXPECT_TRUE(true);
436 } else {
437 EXPECT_TRUE(false);
438 }
439 });
440
441 const std::string data = "how do you do? this is tlsSocketOnMessageData";
442 Socket::TCPSendOptions tcpSendOptions;
443 tcpSendOptions.SetData(data);
__anond1e1d4742702(int32_t errCode) 444 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
445
446 sleep(2);
__anond1e1d4742802(int32_t errCode) 447 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
448 }
449 } // namespace TlsSocket
450 } // namespace NetStack
451 } // namespace OHOS
452