• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket_fuzzer.h"
17 
18 #include <securec.h>
19 
20 #include "netstack_log.h"
21 #include "tls_socket.h"
22 
23 namespace OHOS {
24 namespace NetStack {
25 namespace {
26 const uint8_t *g_baseFuzzData = nullptr;
27 size_t g_baseFuzzSize = 0;
28 size_t g_baseFuzzPos;
29 constexpr size_t STR_LEN = 10;
30 }
31 template <class T>
GetData()32 T GetData()
33 {
34     T object{};
35     size_t objectSize = sizeof(object);
36     if (g_baseFuzzData == nullptr || objectSize > g_baseFuzzSize - g_baseFuzzPos) {
37         return object;
38     }
39     if (memcpy_s(&object, objectSize, g_baseFuzzData + g_baseFuzzPos, objectSize)) {
40         return {};
41     }
42     g_baseFuzzPos += objectSize;
43     return object;
44 }
45 
GetStringFromData(int strlen)46 std::string GetStringFromData(int strlen)
47 {
48     char cstr[strlen];
49     cstr[strlen - 1] = '\0';
50     for (int i = 0; i < strlen - 1; i++) {
51         cstr[i] = GetData<char>();
52     }
53     std::string str(cstr);
54     return str;
55 }
56 
BindFuzzTest(const uint8_t * data,size_t size)57 void BindFuzzTest(const uint8_t *data, size_t size)
58 {
59     NETSTACK_LOGD("BindFuzzTest:enter");
60     if ((data == nullptr) || (size == 0)) {
61         return;
62     }
63 
64     g_baseFuzzData = data;
65     g_baseFuzzSize = size;
66     g_baseFuzzPos = 0;
67 
68     TLSSocket tlsSocket;
69     NetAddress netAddress;
70     std::string str = GetStringFromData(STR_LEN);
71     netAddress.SetAddress(str);
72     netAddress.SetFamilyByJsValue(GetData<uint32_t>());
73     netAddress.SetFamilyBySaFamily(GetData<sa_family_t>());
74     netAddress.SetPort(GetData<uint16_t>());
75     tlsSocket.Bind(netAddress, [](bool ok) {
76         NETSTACK_LOGD("Calback received");
77     });
78     tlsSocket.Close([](int32_t errorNumber) {});
79     tlsSocket.GetRemoteAddress([](int32_t errorNumber, const NetAddress &address) {});
80     tlsSocket.GetCertificate([](int32_t errorNumber, const X509CertRawData &cert) {});
81     tlsSocket.GetRemoteCertificate([](int32_t errorNumber, const X509CertRawData &cert) {});
82     tlsSocket.GetProtocol([](int32_t errorNumber, const std::string &protocol) {});
83     tlsSocket.GetCipherSuite([](int32_t errorNumber, const std::vector<std::string> &suite) {});
84     tlsSocket.GetSignatureAlgorithms([](int32_t errorNumber, const std::vector<std::string> &algorithms) {});
85     tlsSocket.OnMessage([](const std::string &data, const SocketRemoteInfo &remoteInfo) {});
86     tlsSocket.OnConnect([]() {});
87     tlsSocket.OnClose([]() {});
88     tlsSocket.OnError([](int32_t errorNumber, const std::string &errorString) {});
89     tlsSocket.OffMessage();
90     tlsSocket.OffConnect();
91     tlsSocket.OffClose();
92     tlsSocket.OffError();
93 }
94 
ConnectFuzzTest(const uint8_t * data,size_t size)95 void ConnectFuzzTest(const uint8_t *data, size_t size)
96 {
97     NETSTACK_LOGD("ConnectFuzzTest:enter");
98     if ((data == nullptr) || (size == 0)) {
99         return;
100     }
101 
102     g_baseFuzzData = data;
103     g_baseFuzzSize = size;
104     g_baseFuzzPos = 0;
105 
106     TLSSocket tlsSocket;
107     NetAddress netAddress;
108     std::string str = GetStringFromData(STR_LEN);
109     netAddress.SetAddress(str);
110     netAddress.SetFamilyByJsValue(GetData<uint32_t>());
111     netAddress.SetFamilyBySaFamily(GetData<sa_family_t>());
112     netAddress.SetPort(GetData<uint16_t>());
113     TLSConnectOptions options;
114     options.SetNetAddress(netAddress);
115     options.SetCheckServerIdentity([](const std::string &hostName, const std::vector<std::string> &x509Certificates) {
116             NETSTACK_LOGD("Calback received");
117         });
118     std::vector<std::string> alpnProtocols(STR_LEN, str);
119     options.SetAlpnProtocols(alpnProtocols);
120     tlsSocket.Connect(options, [](bool ok) {
121         NETSTACK_LOGD("Calback received");
122     });
123 }
124 
SendFuzzTest(const uint8_t * data,size_t size)125 void SendFuzzTest(const uint8_t *data, size_t size)
126 {
127     if ((data == nullptr) || (size == 0)) {
128         return;
129     }
130     TLSSocket tlsSocket;
131     TCPSendOptions options;
132     std::string str = GetStringFromData(STR_LEN);
133     options.SetData(str);
134     options.SetEncoding(str);
135     tlsSocket.Send(options, [](bool ok) {
136         NETSTACK_LOGD("Calback received");
137     });
138 }
139 
SetExtraOptionsFuzzTest(const uint8_t * data,size_t size)140 void SetExtraOptionsFuzzTest(const uint8_t *data, size_t size)
141 {
142     if ((data == nullptr) || (size == 0)) {
143         return;
144     }
145     TLSSocket tlsSocket;
146     TCPExtraOptions options;
147     options.SetKeepAlive(*(reinterpret_cast<const bool *>(data)));
148     options.SetOOBInline(*(reinterpret_cast<const bool *>(data)));
149     options.SetTCPNoDelay(*(reinterpret_cast<const bool *>(data)));
150     tlsSocket.SetExtraOptions(options, [](bool ok) {
151         NETSTACK_LOGD("Calback received");
152     });
153 }
154 
SetCaChainFuzzTest(const uint8_t * data,size_t size)155 void SetCaChainFuzzTest(const uint8_t *data, size_t size)
156 {
157     NETSTACK_LOGD("SetCaChainFuzzTest:enter");
158     if ((data == nullptr) || (size == 0)) {
159         return;
160     }
161     g_baseFuzzData = data;
162     g_baseFuzzSize = size;
163     g_baseFuzzPos = 0;
164     std::string str = GetStringFromData(STR_LEN);
165     uint32_t count = GetData<uint32_t>() % 10;
166     std::vector<std::string> caChain;
167     caChain.reserve(count);
168     for (size_t i = 0; i < count; i++) {
169         caChain.emplace_back(str);
170     }
171     TLSSecureOptions option;
172     option.SetCaChain(caChain);
173     auto ret = option.GetCaChain();
174 }
175 
SetCertFuzzTest(const uint8_t * data,size_t size)176 void SetCertFuzzTest(const uint8_t *data, size_t size)
177 {
178     NETSTACK_LOGD("SetCertFuzzTest:enter");
179     if ((data == nullptr) || (size == 0)) {
180         return;
181     }
182     g_baseFuzzData = data;
183     g_baseFuzzSize = size;
184     g_baseFuzzPos = 0;
185     std::string cert = GetStringFromData(STR_LEN);
186     TLSSecureOptions option;
187     option.SetCert(cert);
188     auto ret = option.GetCert();
189 }
190 
SetKeyFuzzTest(const uint8_t * data,size_t size)191 void SetKeyFuzzTest(const uint8_t *data, size_t size)
192 {
193     NETSTACK_LOGD("SetKeyFuzzTest:enter");
194     if ((data == nullptr) || (size == 0)) {
195         return;
196     }
197     g_baseFuzzData = data;
198     g_baseFuzzSize = size;
199     g_baseFuzzPos = 0;
200     std::string str = GetStringFromData(STR_LEN);
201     SecureData secureData(str);
202     TLSSecureOptions option;
203     option.SetKey(secureData);
204     auto ret = option.GetKey();
205 }
206 
SetKeyPassFuzzTest(const uint8_t * data,size_t size)207 void SetKeyPassFuzzTest(const uint8_t *data, size_t size)
208 {
209     NETSTACK_LOGD("SetKeyPassFuzzTest:enter");
210     if ((data == nullptr) || (size == 0)) {
211         return;
212     }
213     g_baseFuzzData = data;
214     g_baseFuzzSize = size;
215     g_baseFuzzPos = 0;
216     std::string str = GetStringFromData(STR_LEN);
217     SecureData secureData(str);
218     TLSSecureOptions option;
219     option.SetKeyPass(secureData);
220     auto ret = option.GetKeyPass();
221 }
222 
SetProtocolChainFuzzTest(const uint8_t * data,size_t size)223 void SetProtocolChainFuzzTest(const uint8_t *data, size_t size)
224 {
225     NETSTACK_LOGD("SetProtocolChainFuzzTest:enter");
226     if ((data == nullptr) || (size == 0)) {
227         return;
228     }
229     g_baseFuzzData = data;
230     g_baseFuzzSize = size;
231     g_baseFuzzPos = 0;
232     std::string str = GetStringFromData(STR_LEN);
233     uint32_t count = GetData<uint32_t>() % 10;
234     std::vector<std::string> caChain;
235     caChain.reserve(count);
236     for (size_t i = 0; i < count; i++) {
237         caChain.emplace_back(str);
238     }
239     TLSSecureOptions option;
240     option.SetProtocolChain(caChain);
241     auto ret = option.GetProtocolChain();
242 }
243 
SetUseRemoteCipherPreferFuzzTest(const uint8_t * data,size_t size)244 void SetUseRemoteCipherPreferFuzzTest(const uint8_t *data, size_t size)
245 {
246     NETSTACK_LOGD("SetUseRemoteCipherPreferFuzzTest:enter");
247     if ((data == nullptr) || (size == 0)) {
248         return;
249     }
250     g_baseFuzzData = data;
251     g_baseFuzzSize = size;
252     g_baseFuzzPos = 0;
253     bool useRemoteCipherPrefer = GetData<int32_t>() % 2 == 0;
254     TLSSecureOptions option;
255     option.SetUseRemoteCipherPrefer(useRemoteCipherPrefer);
256     bool ret = option.UseRemoteCipherPrefer();
257     NETSTACK_LOGD("ret:%{public}s", ret ? "true" : "false");
258 }
259 
SetSignatureAlgorithmsFuzzTest(const uint8_t * data,size_t size)260 void SetSignatureAlgorithmsFuzzTest(const uint8_t *data, size_t size)
261 {
262     NETSTACK_LOGD("SetSignatureAlgorithmsFuzzTest:enter");
263     if ((data == nullptr) || (size == 0)) {
264         return;
265     }
266     g_baseFuzzData = data;
267     g_baseFuzzSize = size;
268     g_baseFuzzPos = 0;
269     std::string str = GetStringFromData(STR_LEN);
270     TLSSecureOptions option;
271     option.SetSignatureAlgorithms(str);
272     auto ret = option.GetSignatureAlgorithms();
273 }
274 
SetCipherSuiteFuzzTest(const uint8_t * data,size_t size)275 void SetCipherSuiteFuzzTest(const uint8_t *data, size_t size)
276 {
277     NETSTACK_LOGD("SetCipherSuiteFuzzTest:enter");
278     if ((data == nullptr) || (size == 0)) {
279         return;
280     }
281     g_baseFuzzData = data;
282     g_baseFuzzSize = size;
283     g_baseFuzzPos = 0;
284     std::string str = GetStringFromData(STR_LEN);
285     TLSSecureOptions option;
286     option.SetCipherSuite(str);
287     auto ret = option.GetCipherSuite();
288 }
289 
SetCrlChainFuzzTest(const uint8_t * data,size_t size)290 void SetCrlChainFuzzTest(const uint8_t *data, size_t size)
291 {
292     NETSTACK_LOGD("SetCrlChainFuzzTest:enter");
293     if ((data == nullptr) || (size == 0)) {
294         return;
295     }
296     g_baseFuzzData = data;
297     g_baseFuzzSize = size;
298     g_baseFuzzPos = 0;
299     std::string str = GetStringFromData(STR_LEN);
300     uint32_t count = GetData<uint32_t>() % 10;
301     std::vector<std::string> caChain;
302     caChain.reserve(count);
303     for (size_t i = 0; i < count; i++) {
304         caChain.emplace_back(str);
305     }
306     TLSSecureOptions option;
307     option.SetCrlChain(caChain);
308     auto ret = option.GetCrlChain();
309 }
310 
SetNetAddressFuzzTest(const uint8_t * data,size_t size)311 void SetNetAddressFuzzTest(const uint8_t *data, size_t size)
312 {
313     NETSTACK_LOGD("SetNetAddressFuzzTest:enter");
314     if ((data == nullptr) || (size == 0)) {
315         return;
316     }
317     g_baseFuzzData = data;
318     g_baseFuzzSize = size;
319     g_baseFuzzPos = 0;
320     NetAddress address;
321     std::string str = GetStringFromData(STR_LEN);
322     uint32_t num = GetData<uint32_t>();
323     uint16_t port = GetData<uint16_t>();
324     address.SetAddress(str);
325     address.SetFamilyByJsValue(num);
326     address.SetPort(port);
327     TLSConnectOptions option;
328     option.SetNetAddress(address);
329     auto ret = option.GetNetAddress();
330 }
331 
SetTlsSecureOptionsFuzzTest(const uint8_t * data,size_t size)332 void SetTlsSecureOptionsFuzzTest(const uint8_t *data, size_t size)
333 {
334     NETSTACK_LOGD("SetTlsSecureOptionsFuzzTest:enter");
335     if ((data == nullptr) || (size == 0)) {
336         return;
337     }
338     g_baseFuzzData = data;
339     g_baseFuzzSize = size;
340     g_baseFuzzPos = 0;
341     TLSSecureOptions tls;
342     std::string str = GetStringFromData(STR_LEN);
343     tls.SetCipherSuite(str);
344     tls.SetSignatureAlgorithms(str);
345     tls.SetCert(str);
346     TLSConnectOptions option;
347     option.SetTlsSecureOptions(tls);
348     auto ret = option.GetTlsSecureOptions();
349     option.SetCheckServerIdentity([](const std::string, const std::vector<std::string>) {});
350 }
351 
SetAlpnProtocolsFuzzTest(const uint8_t * data,size_t size)352 void SetAlpnProtocolsFuzzTest(const uint8_t *data, size_t size)
353 {
354     NETSTACK_LOGD("SetAlpnProtocolsFuzzTest:enter");
355     if ((data == nullptr) || (size == 0)) {
356         return;
357     }
358     g_baseFuzzData = data;
359     g_baseFuzzSize = size;
360     g_baseFuzzPos = 0;
361     std::string str = GetStringFromData(STR_LEN);
362     uint32_t count = GetData<uint32_t>() % 10;
363     std::vector<std::string> strs;
364     strs.reserve(count);
365     for (size_t i = 0; i < count; i++) {
366         strs.emplace_back(str);
367     }
368     TLSConnectOptions option;
369     option.SetAlpnProtocols(strs);
370     auto ret = option.GetCheckServerIdentity();
371 }
372 } // NetStack
373 } // OHOS
374 
375 /* Fuzzer entry point */
LLVMFuzzerTestOneInput(const uint8_t * data,size_t size)376 extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)
377 {
378     /* Run your code on data */
379     OHOS::NetStack::BindFuzzTest(data, size);
380     OHOS::NetStack::ConnectFuzzTest(data, size);
381     OHOS::NetStack::SendFuzzTest(data, size);
382     OHOS::NetStack::SetExtraOptionsFuzzTest(data, size);
383     OHOS::NetStack::SetCaChainFuzzTest(data, size);
384     OHOS::NetStack::SetCertFuzzTest(data, size);
385     OHOS::NetStack::SetKeyFuzzTest(data, size);
386     OHOS::NetStack::SetKeyPassFuzzTest(data, size);
387     OHOS::NetStack::SetProtocolChainFuzzTest(data, size);
388     OHOS::NetStack::SetUseRemoteCipherPreferFuzzTest(data, size);
389     OHOS::NetStack::SetSignatureAlgorithmsFuzzTest(data, size);
390     OHOS::NetStack::SetCipherSuiteFuzzTest(data, size);
391     OHOS::NetStack::SetCrlChainFuzzTest(data, size);
392     OHOS::NetStack::SetNetAddressFuzzTest(data, size);
393     OHOS::NetStack::SetTlsSecureOptionsFuzzTest(data, size);
394     OHOS::NetStack::SetAlpnProtocolsFuzzTest(data, size);
395     return 0;
396 }
397