1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tls_socket_fuzzer.h"
17
18 #include <securec.h>
19
20 #include "netstack_log.h"
21 #include "tls_socket.h"
22
23 namespace OHOS {
24 namespace NetStack {
25 namespace {
26 const uint8_t *g_baseFuzzData = nullptr;
27 size_t g_baseFuzzSize = 0;
28 size_t g_baseFuzzPos;
29 constexpr size_t STR_LEN = 10;
30 }
31 template <class T>
GetData()32 T GetData()
33 {
34 T object{};
35 size_t objectSize = sizeof(object);
36 if (g_baseFuzzData == nullptr || objectSize > g_baseFuzzSize - g_baseFuzzPos) {
37 return object;
38 }
39 if (memcpy_s(&object, objectSize, g_baseFuzzData + g_baseFuzzPos, objectSize)) {
40 return {};
41 }
42 g_baseFuzzPos += objectSize;
43 return object;
44 }
45
GetStringFromData(int strlen)46 std::string GetStringFromData(int strlen)
47 {
48 char cstr[strlen];
49 cstr[strlen - 1] = '\0';
50 for (int i = 0; i < strlen - 1; i++) {
51 cstr[i] = GetData<char>();
52 }
53 std::string str(cstr);
54 return str;
55 }
56
BindFuzzTest(const uint8_t * data,size_t size)57 void BindFuzzTest(const uint8_t *data, size_t size)
58 {
59 NETSTACK_LOGD("BindFuzzTest:enter");
60 if ((data == nullptr) || (size == 0)) {
61 return;
62 }
63
64 g_baseFuzzData = data;
65 g_baseFuzzSize = size;
66 g_baseFuzzPos = 0;
67
68 TLSSocket tlsSocket;
69 NetAddress netAddress;
70 std::string str = GetStringFromData(STR_LEN);
71 netAddress.SetAddress(str);
72 netAddress.SetFamilyByJsValue(GetData<uint32_t>());
73 netAddress.SetFamilyBySaFamily(GetData<sa_family_t>());
74 netAddress.SetPort(GetData<uint16_t>());
75 tlsSocket.Bind(netAddress, [](bool ok) {
76 NETSTACK_LOGD("Calback received");
77 });
78 tlsSocket.Close([](int32_t errorNumber) {});
79 tlsSocket.GetRemoteAddress([](int32_t errorNumber, const NetAddress &address) {});
80 tlsSocket.GetCertificate([](int32_t errorNumber, const X509CertRawData &cert) {});
81 tlsSocket.GetRemoteCertificate([](int32_t errorNumber, const X509CertRawData &cert) {});
82 tlsSocket.GetProtocol([](int32_t errorNumber, const std::string &protocol) {});
83 tlsSocket.GetCipherSuite([](int32_t errorNumber, const std::vector<std::string> &suite) {});
84 tlsSocket.GetSignatureAlgorithms([](int32_t errorNumber, const std::vector<std::string> &algorithms) {});
85 tlsSocket.OnMessage([](const std::string &data, const SocketRemoteInfo &remoteInfo) {});
86 tlsSocket.OnConnect([]() {});
87 tlsSocket.OnClose([]() {});
88 tlsSocket.OnError([](int32_t errorNumber, const std::string &errorString) {});
89 tlsSocket.OffMessage();
90 tlsSocket.OffConnect();
91 tlsSocket.OffClose();
92 tlsSocket.OffError();
93 }
94
ConnectFuzzTest(const uint8_t * data,size_t size)95 void ConnectFuzzTest(const uint8_t *data, size_t size)
96 {
97 NETSTACK_LOGD("ConnectFuzzTest:enter");
98 if ((data == nullptr) || (size == 0)) {
99 return;
100 }
101
102 g_baseFuzzData = data;
103 g_baseFuzzSize = size;
104 g_baseFuzzPos = 0;
105
106 TLSSocket tlsSocket;
107 NetAddress netAddress;
108 std::string str = GetStringFromData(STR_LEN);
109 netAddress.SetAddress(str);
110 netAddress.SetFamilyByJsValue(GetData<uint32_t>());
111 netAddress.SetFamilyBySaFamily(GetData<sa_family_t>());
112 netAddress.SetPort(GetData<uint16_t>());
113 TLSConnectOptions options;
114 options.SetNetAddress(netAddress);
115 options.SetCheckServerIdentity([](const std::string &hostName, const std::vector<std::string> &x509Certificates) {
116 NETSTACK_LOGD("Calback received");
117 });
118 std::vector<std::string> alpnProtocols(STR_LEN, str);
119 options.SetAlpnProtocols(alpnProtocols);
120 tlsSocket.Connect(options, [](bool ok) {
121 NETSTACK_LOGD("Calback received");
122 });
123 }
124
SendFuzzTest(const uint8_t * data,size_t size)125 void SendFuzzTest(const uint8_t *data, size_t size)
126 {
127 if ((data == nullptr) || (size == 0)) {
128 return;
129 }
130 TLSSocket tlsSocket;
131 TCPSendOptions options;
132 std::string str = GetStringFromData(STR_LEN);
133 options.SetData(str);
134 options.SetEncoding(str);
135 tlsSocket.Send(options, [](bool ok) {
136 NETSTACK_LOGD("Calback received");
137 });
138 }
139
SetExtraOptionsFuzzTest(const uint8_t * data,size_t size)140 void SetExtraOptionsFuzzTest(const uint8_t *data, size_t size)
141 {
142 if ((data == nullptr) || (size == 0)) {
143 return;
144 }
145 TLSSocket tlsSocket;
146 TCPExtraOptions options;
147 options.SetKeepAlive(*(reinterpret_cast<const bool *>(data)));
148 options.SetOOBInline(*(reinterpret_cast<const bool *>(data)));
149 options.SetTCPNoDelay(*(reinterpret_cast<const bool *>(data)));
150 tlsSocket.SetExtraOptions(options, [](bool ok) {
151 NETSTACK_LOGD("Calback received");
152 });
153 }
154
SetCaChainFuzzTest(const uint8_t * data,size_t size)155 void SetCaChainFuzzTest(const uint8_t *data, size_t size)
156 {
157 NETSTACK_LOGD("SetCaChainFuzzTest:enter");
158 if ((data == nullptr) || (size == 0)) {
159 return;
160 }
161 g_baseFuzzData = data;
162 g_baseFuzzSize = size;
163 g_baseFuzzPos = 0;
164 std::string str = GetStringFromData(STR_LEN);
165 uint32_t count = GetData<uint32_t>() % 10;
166 std::vector<std::string> caChain;
167 caChain.reserve(count);
168 for (size_t i = 0; i < count; i++) {
169 caChain.emplace_back(str);
170 }
171 TLSSecureOptions option;
172 option.SetCaChain(caChain);
173 auto ret = option.GetCaChain();
174 }
175
SetCertFuzzTest(const uint8_t * data,size_t size)176 void SetCertFuzzTest(const uint8_t *data, size_t size)
177 {
178 NETSTACK_LOGD("SetCertFuzzTest:enter");
179 if ((data == nullptr) || (size == 0)) {
180 return;
181 }
182 g_baseFuzzData = data;
183 g_baseFuzzSize = size;
184 g_baseFuzzPos = 0;
185 std::string cert = GetStringFromData(STR_LEN);
186 TLSSecureOptions option;
187 option.SetCert(cert);
188 auto ret = option.GetCert();
189 }
190
SetKeyFuzzTest(const uint8_t * data,size_t size)191 void SetKeyFuzzTest(const uint8_t *data, size_t size)
192 {
193 NETSTACK_LOGD("SetKeyFuzzTest:enter");
194 if ((data == nullptr) || (size == 0)) {
195 return;
196 }
197 g_baseFuzzData = data;
198 g_baseFuzzSize = size;
199 g_baseFuzzPos = 0;
200 std::string str = GetStringFromData(STR_LEN);
201 SecureData secureData(str);
202 TLSSecureOptions option;
203 option.SetKey(secureData);
204 auto ret = option.GetKey();
205 }
206
SetKeyPassFuzzTest(const uint8_t * data,size_t size)207 void SetKeyPassFuzzTest(const uint8_t *data, size_t size)
208 {
209 NETSTACK_LOGD("SetKeyPassFuzzTest:enter");
210 if ((data == nullptr) || (size == 0)) {
211 return;
212 }
213 g_baseFuzzData = data;
214 g_baseFuzzSize = size;
215 g_baseFuzzPos = 0;
216 std::string str = GetStringFromData(STR_LEN);
217 SecureData secureData(str);
218 TLSSecureOptions option;
219 option.SetKeyPass(secureData);
220 auto ret = option.GetKeyPass();
221 }
222
SetProtocolChainFuzzTest(const uint8_t * data,size_t size)223 void SetProtocolChainFuzzTest(const uint8_t *data, size_t size)
224 {
225 NETSTACK_LOGD("SetProtocolChainFuzzTest:enter");
226 if ((data == nullptr) || (size == 0)) {
227 return;
228 }
229 g_baseFuzzData = data;
230 g_baseFuzzSize = size;
231 g_baseFuzzPos = 0;
232 std::string str = GetStringFromData(STR_LEN);
233 uint32_t count = GetData<uint32_t>() % 10;
234 std::vector<std::string> caChain;
235 caChain.reserve(count);
236 for (size_t i = 0; i < count; i++) {
237 caChain.emplace_back(str);
238 }
239 TLSSecureOptions option;
240 option.SetProtocolChain(caChain);
241 auto ret = option.GetProtocolChain();
242 }
243
SetUseRemoteCipherPreferFuzzTest(const uint8_t * data,size_t size)244 void SetUseRemoteCipherPreferFuzzTest(const uint8_t *data, size_t size)
245 {
246 NETSTACK_LOGD("SetUseRemoteCipherPreferFuzzTest:enter");
247 if ((data == nullptr) || (size == 0)) {
248 return;
249 }
250 g_baseFuzzData = data;
251 g_baseFuzzSize = size;
252 g_baseFuzzPos = 0;
253 bool useRemoteCipherPrefer = GetData<int32_t>() % 2 == 0;
254 TLSSecureOptions option;
255 option.SetUseRemoteCipherPrefer(useRemoteCipherPrefer);
256 bool ret = option.UseRemoteCipherPrefer();
257 NETSTACK_LOGD("ret:%{public}s", ret ? "true" : "false");
258 }
259
SetSignatureAlgorithmsFuzzTest(const uint8_t * data,size_t size)260 void SetSignatureAlgorithmsFuzzTest(const uint8_t *data, size_t size)
261 {
262 NETSTACK_LOGD("SetSignatureAlgorithmsFuzzTest:enter");
263 if ((data == nullptr) || (size == 0)) {
264 return;
265 }
266 g_baseFuzzData = data;
267 g_baseFuzzSize = size;
268 g_baseFuzzPos = 0;
269 std::string str = GetStringFromData(STR_LEN);
270 TLSSecureOptions option;
271 option.SetSignatureAlgorithms(str);
272 auto ret = option.GetSignatureAlgorithms();
273 }
274
SetCipherSuiteFuzzTest(const uint8_t * data,size_t size)275 void SetCipherSuiteFuzzTest(const uint8_t *data, size_t size)
276 {
277 NETSTACK_LOGD("SetCipherSuiteFuzzTest:enter");
278 if ((data == nullptr) || (size == 0)) {
279 return;
280 }
281 g_baseFuzzData = data;
282 g_baseFuzzSize = size;
283 g_baseFuzzPos = 0;
284 std::string str = GetStringFromData(STR_LEN);
285 TLSSecureOptions option;
286 option.SetCipherSuite(str);
287 auto ret = option.GetCipherSuite();
288 }
289
SetCrlChainFuzzTest(const uint8_t * data,size_t size)290 void SetCrlChainFuzzTest(const uint8_t *data, size_t size)
291 {
292 NETSTACK_LOGD("SetCrlChainFuzzTest:enter");
293 if ((data == nullptr) || (size == 0)) {
294 return;
295 }
296 g_baseFuzzData = data;
297 g_baseFuzzSize = size;
298 g_baseFuzzPos = 0;
299 std::string str = GetStringFromData(STR_LEN);
300 uint32_t count = GetData<uint32_t>() % 10;
301 std::vector<std::string> caChain;
302 caChain.reserve(count);
303 for (size_t i = 0; i < count; i++) {
304 caChain.emplace_back(str);
305 }
306 TLSSecureOptions option;
307 option.SetCrlChain(caChain);
308 auto ret = option.GetCrlChain();
309 }
310
SetNetAddressFuzzTest(const uint8_t * data,size_t size)311 void SetNetAddressFuzzTest(const uint8_t *data, size_t size)
312 {
313 NETSTACK_LOGD("SetNetAddressFuzzTest:enter");
314 if ((data == nullptr) || (size == 0)) {
315 return;
316 }
317 g_baseFuzzData = data;
318 g_baseFuzzSize = size;
319 g_baseFuzzPos = 0;
320 NetAddress address;
321 std::string str = GetStringFromData(STR_LEN);
322 uint32_t num = GetData<uint32_t>();
323 uint16_t port = GetData<uint16_t>();
324 address.SetAddress(str);
325 address.SetFamilyByJsValue(num);
326 address.SetPort(port);
327 TLSConnectOptions option;
328 option.SetNetAddress(address);
329 auto ret = option.GetNetAddress();
330 }
331
SetTlsSecureOptionsFuzzTest(const uint8_t * data,size_t size)332 void SetTlsSecureOptionsFuzzTest(const uint8_t *data, size_t size)
333 {
334 NETSTACK_LOGD("SetTlsSecureOptionsFuzzTest:enter");
335 if ((data == nullptr) || (size == 0)) {
336 return;
337 }
338 g_baseFuzzData = data;
339 g_baseFuzzSize = size;
340 g_baseFuzzPos = 0;
341 TLSSecureOptions tls;
342 std::string str = GetStringFromData(STR_LEN);
343 tls.SetCipherSuite(str);
344 tls.SetSignatureAlgorithms(str);
345 tls.SetCert(str);
346 TLSConnectOptions option;
347 option.SetTlsSecureOptions(tls);
348 auto ret = option.GetTlsSecureOptions();
349 option.SetCheckServerIdentity([](const std::string, const std::vector<std::string>) {});
350 }
351
SetAlpnProtocolsFuzzTest(const uint8_t * data,size_t size)352 void SetAlpnProtocolsFuzzTest(const uint8_t *data, size_t size)
353 {
354 NETSTACK_LOGD("SetAlpnProtocolsFuzzTest:enter");
355 if ((data == nullptr) || (size == 0)) {
356 return;
357 }
358 g_baseFuzzData = data;
359 g_baseFuzzSize = size;
360 g_baseFuzzPos = 0;
361 std::string str = GetStringFromData(STR_LEN);
362 uint32_t count = GetData<uint32_t>() % 10;
363 std::vector<std::string> strs;
364 strs.reserve(count);
365 for (size_t i = 0; i < count; i++) {
366 strs.emplace_back(str);
367 }
368 TLSConnectOptions option;
369 option.SetAlpnProtocols(strs);
370 auto ret = option.GetCheckServerIdentity();
371 }
372 } // NetStack
373 } // OHOS
374
375 /* Fuzzer entry point */
LLVMFuzzerTestOneInput(const uint8_t * data,size_t size)376 extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)
377 {
378 /* Run your code on data */
379 OHOS::NetStack::BindFuzzTest(data, size);
380 OHOS::NetStack::ConnectFuzzTest(data, size);
381 OHOS::NetStack::SendFuzzTest(data, size);
382 OHOS::NetStack::SetExtraOptionsFuzzTest(data, size);
383 OHOS::NetStack::SetCaChainFuzzTest(data, size);
384 OHOS::NetStack::SetCertFuzzTest(data, size);
385 OHOS::NetStack::SetKeyFuzzTest(data, size);
386 OHOS::NetStack::SetKeyPassFuzzTest(data, size);
387 OHOS::NetStack::SetProtocolChainFuzzTest(data, size);
388 OHOS::NetStack::SetUseRemoteCipherPreferFuzzTest(data, size);
389 OHOS::NetStack::SetSignatureAlgorithmsFuzzTest(data, size);
390 OHOS::NetStack::SetCipherSuiteFuzzTest(data, size);
391 OHOS::NetStack::SetCrlChainFuzzTest(data, size);
392 OHOS::NetStack::SetNetAddressFuzzTest(data, size);
393 OHOS::NetStack::SetTlsSecureOptionsFuzzTest(data, size);
394 OHOS::NetStack::SetAlpnProtocolsFuzzTest(data, size);
395 return 0;
396 }
397