1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <fstream>
17 #include <gtest/gtest.h>
18 #include <iostream>
19 #include <openssl/rsa.h>
20 #include <openssl/ssl.h>
21 #include <sstream>
22 #include <string>
23 #include <string_view>
24 #include <unistd.h>
25 #include <vector>
26
27 #include "accesstoken_kit.h"
28 #include "net_address.h"
29 #include "secure_data.h"
30 #include "socket_error.h"
31 #include "socket_state_base.h"
32 #include "tls.h"
33 #include "tls_certificate.h"
34 #include "tls_configuration.h"
35 #include "tls_key.h"
36 #include "tls_socket.h"
37 #include "token_setproc.h"
38
39 namespace OHOS {
40 namespace NetStack {
41 namespace TlsSocket {
42 namespace {
43 using namespace testing::ext;
44 using namespace Security::AccessToken;
45 const std::string_view PRIVATE_KEY_PEM_CHAIN = "/data/ClientCertChain/privekey.pem.unsecure";
46 const std::string_view CA_PATH_CHAIN = "/data/ClientCertChain/RootCa.pem";
47 const std::string_view MID_CA_PATH_CHAIN = "/data/ClientCertChain/MidCa.pem";
48 const std::string_view CLIENT_CRT_CHAIN = "/data/ClientCertChain/secondServer.crt";
49 const std::string_view IP_ADDRESS = "/data/Ip/address.txt";
50 const std::string_view PORT = "/data/Ip/port.txt";
51
CheckCaFileExistence(const char * function)52 inline bool CheckCaFileExistence(const char *function)
53 {
54 if (access(CA_PATH_CHAIN.data(), 0)) {
55 std::cout << "CA file does not exist! (" << function << ")";
56 return false;
57 }
58 return true;
59 }
60
ReadFileContent(const std::string_view fileName)61 std::string ReadFileContent(const std::string_view fileName)
62 {
63 std::ifstream file;
64 file.open(fileName);
65 std::stringstream ss;
66 ss << file.rdbuf();
67 std::string infos = ss.str();
68 file.close();
69 return infos;
70 }
71
GetIp(std::string ip)72 std::string GetIp(std::string ip)
73 {
74 return ip.substr(0, ip.length() - 1);
75 }
76 } // namespace
77
78 class TlsSocketTest : public testing::Test {
79 public:
SetUpTestCase()80 static void SetUpTestCase() {}
81
TearDownTestCase()82 static void TearDownTestCase() {}
83
SetUp()84 virtual void SetUp() {}
85
TearDown()86 virtual void TearDown() {}
87 };
88
SetUnilateralHwTestShortParam(TLSSocket & server)89 void SetUnilateralHwTestShortParam(TLSSocket &server)
90 {
91 TLSConnectOptions options;
92 TLSSecureOptions secureOption;
93 Socket::NetAddress address;
94
95 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
96 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
97 address.SetFamilyBySaFamily(AF_INET);
98
99 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
100 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
101 secureOption.SetCaChain(caVec);
102 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
103
104 options.SetNetAddress(address);
105 options.SetTlsSecureOptions(secureOption);
106
107 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
108 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
109 }
110
111 HapInfoParams testInfoParms = {.bundleName = "TlsSocketBranchTest",
112 .userID = 1,
113 .instIndex = 0,
114 .appIDDesc = "test",
115 .isSystemApp = true};
116
117 PermissionDef testPermDef = {
118 .permissionName = "ohos.permission.INTERNET",
119 .bundleName = "TlsSocketBranchTest",
120 .grantMode = 1,
121 .label = "label",
122 .labelId = 1,
123 .description = "Test Tls Socket Branch",
124 .descriptionId = 1,
125 .availableLevel = APL_SYSTEM_BASIC,
126 };
127
128 PermissionStateFull testState = {
129 .grantFlags = {2},
130 .grantStatus = {PermissionState::PERMISSION_GRANTED},
131 .isGeneral = true,
132 .permissionName = "ohos.permission.INTERNET",
133 .resDeviceID = {"local"},
134 };
135
136 HapPolicyParams testPolicyPrams = {
137 .apl = APL_SYSTEM_BASIC,
138 .domain = "test.domain",
139 .permList = {testPermDef},
140 .permStateList = {testState},
141 };
142
143 class AccessToken {
144 public:
AccessToken()145 AccessToken() : currentID_(GetSelfTokenID())
146 {
147 AccessTokenIDEx tokenIdEx = AccessTokenKit::AllocHapToken(testInfoParms, testPolicyPrams);
148 accessID_ = tokenIdEx.tokenIdExStruct.tokenID;
149 SetSelfTokenID(tokenIdEx.tokenIDEx);
150 }
~AccessToken()151 ~AccessToken()
152 {
153 AccessTokenKit::DeleteToken(accessID_);
154 SetSelfTokenID(currentID_);
155 }
156
157 private:
158 AccessTokenID currentID_;
159 AccessTokenID accessID_ = 0;
160 };
161
162 class TlsSocketBranchTest : public testing::Test {
163 public:
SetUpTestCase()164 static void SetUpTestCase() {}
165
TearDownTestCase()166 static void TearDownTestCase() {}
167
SetUp()168 virtual void SetUp() {}
169
TearDown()170 virtual void TearDown() {}
171 };
172
173 HWTEST_F(TlsSocketTest, bindInterface, testing::ext::TestSize.Level2)
174 {
175 if (!CheckCaFileExistence("bindInterface")) {
176 return;
177 }
178
179 TLSSocket server;
180 Socket::NetAddress address;
181
182 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
183 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
184 address.SetFamilyBySaFamily(AF_INET);
185
186 AccessToken token;
__anonad3c62a40402(int32_t errCode) 187 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
188 }
189
190 HWTEST_F(TlsSocketTest, connectInterface, testing::ext::TestSize.Level2)
191 {
192 if (!CheckCaFileExistence("connectInterface")) {
193 return;
194 }
195 TLSSocket server;
196 SetUnilateralHwTestShortParam(server);
197
198 AccessToken token;
199 const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
200 Socket::TCPSendOptions tcpSendOptions;
201 tcpSendOptions.SetData(data);
__anonad3c62a40502(int32_t errCode) 202 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
203
__anonad3c62a40602(int32_t errCode) 204 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
205 }
206
207 HWTEST_F(TlsSocketTest, closeInterface, testing::ext::TestSize.Level2)
208 {
209 if (!CheckCaFileExistence("closeInterface")) {
210 return;
211 }
212 TLSSocket server;
213 SetUnilateralHwTestShortParam(server);
214
215 AccessToken token;
216 const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
217 ;
218 Socket::TCPSendOptions tcpSendOptions;
219 tcpSendOptions.SetData(data);
220
__anonad3c62a40702(int32_t errCode) 221 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
222
__anonad3c62a40802(int32_t errCode) 223 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
224 }
225
226 HWTEST_F(TlsSocketTest, sendInterface, testing::ext::TestSize.Level2)
227 {
228 if (!CheckCaFileExistence("sendInterface")) {
229 return;
230 }
231 TLSSocket server;
232 SetUnilateralHwTestShortParam(server);
233
234 AccessToken token;
235 const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
236 Socket::TCPSendOptions tcpSendOptions;
237 tcpSendOptions.SetData(data);
238
__anonad3c62a40902(int32_t errCode) 239 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
240
__anonad3c62a40a02(int32_t errCode) 241 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
242 }
243
244 HWTEST_F(TlsSocketTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
245 {
246 if (!CheckCaFileExistence("getRemoteAddressInterface")) {
247 return;
248 }
249 TLSSocket server;
250 TLSConnectOptions options;
251 TLSSecureOptions secureOption;
252 Socket::NetAddress address;
253
254 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
255 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
256 address.SetFamilyBySaFamily(AF_INET);
257
258 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
259 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
260 secureOption.SetCaChain(caVec);
261 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
262
263 options.SetNetAddress(address);
264 options.SetTlsSecureOptions(secureOption);
265
__anonad3c62a40b02(int32_t errCode) 266 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
__anonad3c62a40c02(int32_t errCode) 267 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
268
269 AccessToken token;
270 Socket::NetAddress netAddress;
__anonad3c62a40d02(int32_t errCode, const Socket::NetAddress &address) 271 server.GetRemoteAddress([&netAddress](int32_t errCode, const Socket::NetAddress &address) {
272 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
273 netAddress.SetAddress(address.GetAddress());
274 netAddress.SetPort(address.GetPort());
275 netAddress.SetFamilyBySaFamily(address.GetSaFamily());
276 });
277 EXPECT_STREQ(netAddress.GetAddress().c_str(), GetIp(ReadFileContent(IP_ADDRESS)).c_str());
278 EXPECT_EQ(address.GetPort(), std::atoi(ReadFileContent(PORT).c_str()));
279 EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
280
281 const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
282 Socket::TCPSendOptions tcpSendOptions;
283 tcpSendOptions.SetData(data);
284
__anonad3c62a40e02(int32_t errCode) 285 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
286
__anonad3c62a40f02(int32_t errCode) 287 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
288 }
289
290 HWTEST_F(TlsSocketTest, getStateInterface, testing::ext::TestSize.Level2)
291 {
292 if (!CheckCaFileExistence("getRemoteAddressInterface")) {
293 return;
294 }
295
296 TLSSocket server;
297 SetUnilateralHwTestShortParam(server);
298
299 AccessToken token;
300 Socket::SocketStateBase TlsSocketstate;
__anonad3c62a41002(int32_t errCode, const Socket::SocketStateBase &state) 301 server.GetState([&TlsSocketstate](int32_t errCode, const Socket::SocketStateBase &state) {
302 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
303 TlsSocketstate = state;
304 });
305 std::cout << "TlsSocketstate.IsClose(): " << TlsSocketstate.IsClose() << std::endl;
306 EXPECT_TRUE(TlsSocketstate.IsBound());
307 EXPECT_TRUE(!TlsSocketstate.IsClose());
308 EXPECT_TRUE(TlsSocketstate.IsConnected());
309
310 const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
311 Socket::TCPSendOptions tcpSendOptions;
312 tcpSendOptions.SetData(data);
__anonad3c62a41102(int32_t errCode) 313 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
314
__anonad3c62a41202(int32_t errCode) 315 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
316 }
317
318 HWTEST_F(TlsSocketTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
319 {
320 if (!CheckCaFileExistence("getRemoteCertificateInterface")) {
321 return;
322 }
323 TLSSocket server;
324 SetUnilateralHwTestShortParam(server);
325 Socket::TCPSendOptions tcpSendOptions;
326 const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
327
328 AccessToken token;
329 tcpSendOptions.SetData(data);
330
__anonad3c62a41302(int32_t errCode) 331 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
332
333 server.GetRemoteCertificate(
__anonad3c62a41402(int32_t errCode, const X509CertRawData &cert) 334 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
335
__anonad3c62a41502(int32_t errCode) 336 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
337 }
338
339 HWTEST_F(TlsSocketTest, protocolInterface, testing::ext::TestSize.Level2)
340 {
341 if (!CheckCaFileExistence("protocolInterface")) {
342 return;
343 }
344 TLSConnectOptions options;
345 TLSSocket server;
346 TLSSecureOptions secureOption;
347 Socket::NetAddress address;
348
349 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
350 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
351 address.SetFamilyBySaFamily(AF_INET);
352
353 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
354 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
355 secureOption.SetCaChain(caVec);
356 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
357 std::string protocolV13 = "TLSv1.2";
358 std::vector<std::string> protocolVec = {protocolV13};
359 secureOption.SetProtocolChain(protocolVec);
360
361 options.SetNetAddress(address);
362 options.SetTlsSecureOptions(secureOption);
363
364 AccessToken token;
__anonad3c62a41602(int32_t errCode) 365 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
366
__anonad3c62a41702(int32_t errCode) 367 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
368
369 const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
370 Socket::TCPSendOptions tcpSendOptions;
371 tcpSendOptions.SetData(data);
372
__anonad3c62a41802(int32_t errCode) 373 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
374 std::string getProtocolVal;
__anonad3c62a41902(int32_t errCode, const std::string &protocol) 375 server.GetProtocol([&getProtocolVal](int32_t errCode, const std::string &protocol) {
376 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
377 getProtocolVal = protocol;
378 });
379 EXPECT_STREQ(getProtocolVal.c_str(), "TLSv1.2");
380
__anonad3c62a41a02(int32_t errCode) 381 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
382 }
383
384 HWTEST_F(TlsSocketTest, getCipherSuiteInterface, testing::ext::TestSize.Level2)
385 {
386 if (!CheckCaFileExistence("getCipherSuiteInterface")) {
387 return;
388 }
389
390 TLSConnectOptions options;
391 TLSSocket server;
392 TLSSecureOptions secureOption;
393 Socket::NetAddress address;
394
395 address.SetAddress(GetIp(ReadFileContent(IP_ADDRESS)));
396 address.SetPort(std::atoi(ReadFileContent(PORT).c_str()));
397 address.SetFamilyBySaFamily(AF_INET);
398
399 secureOption.SetKey(SecureData(ReadFileContent(PRIVATE_KEY_PEM_CHAIN)));
400 std::vector<std::string> caVec = {ReadFileContent(CA_PATH_CHAIN), ReadFileContent(MID_CA_PATH_CHAIN)};
401 secureOption.SetCaChain(caVec);
402 secureOption.SetCert(ReadFileContent(CLIENT_CRT_CHAIN));
403 secureOption.SetCipherSuite("ECDHE-RSA-AES128-GCM-SHA256");
404
405 options.SetNetAddress(address);
406 options.SetTlsSecureOptions(secureOption);
407
408 bool flag = false;
409 AccessToken token;
__anonad3c62a41b02(int32_t errCode) 410 server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
__anonad3c62a41c02(int32_t errCode) 411 server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
412
413 const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
414 Socket::TCPSendOptions tcpSendOptions;
415 tcpSendOptions.SetData(data);
__anonad3c62a41d02(int32_t errCode) 416 server.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
417
418 std::vector<std::string> cipherSuite;
__anonad3c62a41e02(int32_t errCode, const std::vector<std::string> &suite) 419 server.GetCipherSuite([&cipherSuite](int32_t errCode, const std::vector<std::string> &suite) {
420 EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
421 cipherSuite = suite;
422 });
423
424 for (auto const &iter : cipherSuite) {
425 if (iter == "ECDHE-RSA-AES128-GCM-SHA256") {
426 flag = true;
427 }
428 }
429
430 EXPECT_TRUE(flag);
431
__anonad3c62a41f02(int32_t errCode) 432 (void)server.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
433 }
434 } // namespace TlsSocket
435 } // namespace NetStack
436 } // namespace OHOS
437
438