1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <fstream>
17 #include <gtest/gtest.h>
18 #include <iostream>
19 #include <openssl/rsa.h>
20 #include <openssl/ssl.h>
21 #include <sstream>
22 #include <string>
23 #include <string_view>
24 #include <unistd.h>
25 #include <vector>
26
27 #include "net_address.h"
28 #include "secure_data.h"
29 #include "socket_error.h"
30 #include "socket_state_base.h"
31 #include "tls.h"
32 #include "tls_certificate.h"
33 #include "tls_configuration.h"
34 #include "tls_key.h"
35 #include "tls_socket.h"
36
37 namespace OHOS {
38 namespace NetStack {
39 namespace TlsSocket {
40 namespace {
41 const std::string_view PRIVATE_KEY_PEM_CHAIN = "/data/ClientCertChain/privekey.pem.unsecure";
42 const std::string_view CA_PATH_CHAIN = "/data/ClientCertChain/cacert.crt";
43 const std::string_view MID_CA_PATH_CHAIN = "/data/ClientCertChain/caMidcert.crt";
44 const std::string_view CLIENT_CRT_CHAIN = "/data/ClientCertChain/secondServer.crt";
45 const std::string_view IP_ADDRESS = "/data/Ip/address.txt";
46 const std::string_view PORT = "/data/Ip/port.txt";
47
CheckCaFileExistence(const char * function)48 inline bool CheckCaFileExistence(const char *function)
49 {
50 if (access(CA_PATH_CHAIN.data(), 0)) {
51 std::cout << "CA file does not exist! (" << function << ")";
52 return false;
53 }
54 return true;
55 }
56
ReadFileContent(const std::string_view fileName)57 std::string ReadFileContent(const std::string_view fileName)
58 {
59 std::ifstream file;
60 file.open(fileName);
61 std::stringstream ss;
62 ss << file.rdbuf();
63 std::string infos = ss.str();
64 file.close();
65 return infos;
66 }
67
GetIp(std::string ip)68 std::string GetIp(std::string ip)
69 {
70 return ip.substr(0, ip.length() - 1);
71 }
72 } // namespace
73
74 class TlsSocketTest : public testing::Test {
75 public:
SetUpTestCase()76 static void SetUpTestCase() {}
77
TearDownTestCase()78 static void TearDownTestCase() {}
79
SetUp()80 virtual void SetUp() {}
81
TearDown()82 virtual void TearDown() {}
83 };
84
SetCertChainOneWayHwTestShortParam(TLSSocket & server)85 void SetCertChainOneWayHwTestShortParam(TLSSocket &server)
86 {
87 TLSConnectOptions options;
88 TLSSecureOptions secureOption;
89 Socket::NetAddress address;
90
91 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
92 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
93 address.SetFamilyBySaFamily(AF_INET);
94
95 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
96 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
97 secureOption.SetCaChain(caVec);
98 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
99
100 options.SetNetAddress(address);
101 options.SetTlsSecureOptions(secureOption);
102
103 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
104 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
105 }
106
SetCertChainOneWayHwTestLongParam(TLSSocket & server)107 void SetCertChainOneWayHwTestLongParam(TLSSocket &server)
108 {
109 TLSConnectOptions options;
110 TLSSecureOptions secureOption;
111 Socket::NetAddress address;
112
113 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
114 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
115 address.SetFamilyBySaFamily(AF_INET);
116
117 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
118 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
119 secureOption.SetCaChain(caVec);
120 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
121 secureOption.SetCipherSuite("AES256-SHA256");
122 std::string protocolV13 = "TLSv1.3";
123 std::vector<std::string> protocolVec = {protocolV13};
124 secureOption.SetProtocolChain(protocolVec);
125
126 options.SetNetAddress(address);
127 options.SetTlsSecureOptions(secureOption);
128
129 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
130 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
131 }
132
133 HWTEST_F(TlsSocketTest, bindInterface, testing::ext::TestSize.Level2)
134 {
135 if (!CheckCaFileExistence("bindInterface")) {
136 return;
137 }
138
139 TLSSocket server;
140 Socket::NetAddress address;
141
142 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
143 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
144 address.SetFamilyBySaFamily(AF_INET);
145
__anonab02f0270602(int32_t errCode) 146 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
147 }
148
149 HWTEST_F(TlsSocketTest, connectInterface, testing::ext::TestSize.Level2)
150 {
151 if (!CheckCaFileExistence("connectInterface")) {
152 return;
153 }
154 TLSSocket server;
155 SetCertChainOneWayHwTestShortParam(server);
156
157 const std::string data = "how do you do? this is connectInterface";
158 Socket::TCPSendOptions tcpSendOptions;
159 tcpSendOptions.SetData(data);
__anonab02f0270702(int32_t errCode) 160 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
161 sleep(2);
162
__anonab02f0270802(int32_t errCode) 163 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
164 sleep(2);
165 }
166
167 HWTEST_F(TlsSocketTest, closeInterface, testing::ext::TestSize.Level2)
168 {
169 if (!CheckCaFileExistence("closeInterface")) {
170 return;
171 }
172 TLSSocket server;
173 SetCertChainOneWayHwTestShortParam(server);
174
175 const std::string data = "how do you do? this is closeInterface";
176 Socket::TCPSendOptions tcpSendOptions;
177 tcpSendOptions.SetData(data);
178
__anonab02f0270902(int32_t errCode) 179 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
180 sleep(2);
181
__anonab02f0270a02(int32_t errCode) 182 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
183 }
184
185 HWTEST_F(TlsSocketTest, sendInterface, testing::ext::TestSize.Level2)
186 {
187 if (!CheckCaFileExistence("sendInterface")) {
188 return;
189 }
190 TLSSocket server;
191 SetCertChainOneWayHwTestShortParam(server);
192
193 const std::string data = "how do you do? this is sendInterface";
194 Socket::TCPSendOptions tcpSendOptions;
195 tcpSendOptions.SetData(data);
196
__anonab02f0270b02(int32_t errCode) 197 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
198 sleep(2);
199
__anonab02f0270c02(int32_t errCode) 200 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
201 }
202
203 HWTEST_F(TlsSocketTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
204 {
205 if (!CheckCaFileExistence("getRemoteAddressInterface")) {
206 return;
207 }
208 TLSSocket server;
209 TLSConnectOptions options;
210 TLSSecureOptions secureOption;
211 Socket::NetAddress address;
212
213 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
214 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
215 address.SetFamilyBySaFamily(AF_INET);
216
217 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
218 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
219 secureOption.SetCaChain(caVec);
220 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
221
222 options.SetNetAddress(address);
223 options.SetTlsSecureOptions(secureOption);
224
__anonab02f0270d02(int32_t errCode) 225 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
__anonab02f0270e02(int32_t errCode) 226 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
227
228 Socket::NetAddress netAddress;
__anonab02f0270f02(int32_t errCode, const Socket::NetAddress &address) 229 server.GetRemoteAddress([&netAddress](int32_t errCode, const Socket::NetAddress &address) {
230 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
231 netAddress.SetAddress(address.GetAddress());
232 netAddress.SetPort(address.GetPort());
233 netAddress.SetFamilyBySaFamily(address.GetSaFamily());
234 });
235 EXPECT_STREQ(netAddress.GetAddress().c_str(), GetIp(ReadFileContent(IP_ADDRESS)).c_str());
236 EXPECT_EQ(address.GetPort(), std::atoi(ReadFileContent(PORT).c_str()));
237 EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
238
239 const std::string data = "how do you do? this is getRemoteAddressInterface";
240 Socket::TCPSendOptions tcpSendOptions;
241 tcpSendOptions.SetData(data);
242
__anonab02f0271002(int32_t errCode) 243 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
244 sleep(2);
245
__anonab02f0271102(int32_t errCode) 246 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
247 }
248
249 HWTEST_F(TlsSocketTest, getStateInterface, testing::ext::TestSize.Level2)
250 {
251 if (!CheckCaFileExistence("getRemoteAddressInterface")) {
252 return;
253 }
254 TLSSocket server;
255 SetCertChainOneWayHwTestShortParam(server);
256
257 Socket::SocketStateBase TlsSocketstate;
__anonab02f0271202(int32_t errCode, const Socket::SocketStateBase &state) 258 server.GetState([&TlsSocketstate](int32_t errCode, const Socket::SocketStateBase &state) {
259 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
260 TlsSocketstate = state;
261 });
262 std::cout << "TlsSocketstate.IsClose(): " << TlsSocketstate.IsClose() << std::endl;
263 EXPECT_TRUE(TlsSocketstate.IsBound());
264 EXPECT_TRUE(!TlsSocketstate.IsClose());
265 EXPECT_TRUE(TlsSocketstate.IsConnected());
266
267 const std::string data = "how do you do? this is getStateInterface";
268 Socket::TCPSendOptions tcpSendOptions;
269 tcpSendOptions.SetData(data);
__anonab02f0271302(int32_t errCode) 270 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
271
272 sleep(2);
273
__anonab02f0271402(int32_t errCode) 274 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
275 }
276
277 HWTEST_F(TlsSocketTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
278 {
279 if (!CheckCaFileExistence("getRemoteCertificateInterface")) {
280 return;
281 }
282 TLSSocket server;
283 SetCertChainOneWayHwTestShortParam(server);
284 Socket::TCPSendOptions tcpSendOptions;
285
286 const std::string data = "how do you do? This is UT test getRemoteCertificateInterface";
287 tcpSendOptions.SetData(data);
__anonab02f0271502(int32_t errCode) 288 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
289
290 server.GetRemoteCertificate(
__anonab02f0271602(int32_t errCode, const X509CertRawData &cert) 291 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
292
293 sleep(2);
__anonab02f0271702(int32_t errCode) 294 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
295 }
296
297 HWTEST_F(TlsSocketTest, protocolInterface, testing::ext::TestSize.Level2)
298 {
299 if (!CheckCaFileExistence("protocolInterface")) {
300 return;
301 }
302 TLSSocket server;
303 SetCertChainOneWayHwTestLongParam(server);
304
305 const std::string data = "how do you do? this is protocolInterface";
306 Socket::TCPSendOptions tcpSendOptions;
307 tcpSendOptions.SetData(data);
308
__anonab02f0271802(int32_t errCode) 309 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
310 std::string getProtocolVal;
__anonab02f0271902(int32_t errCode, const std::string &protocol) 311 server.GetProtocol([&getProtocolVal](int32_t errCode, const std::string &protocol) {
312 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
313 getProtocolVal = protocol;
314 });
315 EXPECT_STREQ(getProtocolVal.c_str(), "TLSv1.3");
316 sleep(2);
317
__anonab02f0271a02(int32_t errCode) 318 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
319 }
320
321 HWTEST_F(TlsSocketTest, getCipherSuiteInterface, testing::ext::TestSize.Level2)
322 {
323 if (!CheckCaFileExistence("getCipherSuiteInterface")) {
324 return;
325 }
326 TLSSocket server;
327 SetCertChainOneWayHwTestLongParam(server);
328
329 bool flag = false;
330 const std::string data = "how do you do? This is getCipherSuiteInterface";
331 Socket::TCPSendOptions tcpSendOptions;
332 tcpSendOptions.SetData(data);
__anonab02f0271b02(int32_t errCode) 333 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
334
335 std::vector<std::string> cipherSuite;
__anonab02f0271c02(int32_t errCode, const std::vector<std::string> &suite) 336 server.GetCipherSuite([&cipherSuite](int32_t errCode, const std::vector<std::string> &suite) {
337 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
338 cipherSuite = suite;
339 });
340
341 for (auto const &iter : cipherSuite) {
342 if (iter == "AES256-SHA256") {
343 flag = true;
344 }
345 }
346
347 EXPECT_TRUE(flag);
348 sleep(2);
349
__anonab02f0271d02(int32_t errCode) 350 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
351 }
352
353 HWTEST_F(TlsSocketTest, onMessageDataInterface, testing::ext::TestSize.Level2)
354 {
355 if (!CheckCaFileExistence("tlsSocketOnMessageData")) {
356 return;
357 }
358 std::string getData = "server->client";
359 TLSSocket server;
360 SetCertChainOneWayHwTestLongParam(server);
361
__anonab02f0271e02(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) 362 server.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
363 if (data == getData) {
364 EXPECT_TRUE(true);
365 } else {
366 EXPECT_TRUE(false);
367 }
368 });
369
370 const std::string data = "how do you do? this is tlsSocketOnMessageData";
371 Socket::TCPSendOptions tcpSendOptions;
372 tcpSendOptions.SetData(data);
__anonab02f0271f02(int32_t errCode) 373 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
374
375 sleep(2);
__anonab02f0272002(int32_t errCode) 376 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
377 }
378 } // namespace TlsSocket
379 } // namespace NetStack
380 } // namespace OHOS
381