• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <MtpDataPacket.h>
18 #include <MtpDevHandle.h>
19 #include <MtpPacketFuzzerUtils.h>
20 #include <fuzzer/FuzzedDataProvider.h>
21 #include <utils/String16.h>
22 
23 using namespace android;
24 
25 class MtpDataPacketFuzzer : MtpPacketFuzzerUtils {
26   public:
MtpDataPacketFuzzer(const uint8_t * data,size_t size)27     MtpDataPacketFuzzer(const uint8_t* data, size_t size) : mFdp(data, size) {
28         mUsbDevFsUrb = (struct usbdevfs_urb*)malloc(sizeof(struct usbdevfs_urb) +
29                                                    sizeof(struct usbdevfs_iso_packet_desc));
30     };
~MtpDataPacketFuzzer()31     ~MtpDataPacketFuzzer() { free(mUsbDevFsUrb); };
32     void process();
33 
34   private:
35     FuzzedDataProvider mFdp;
36 };
37 
process()38 void MtpDataPacketFuzzer::process() {
39     MtpDataPacket mtpDataPacket;
40     while (mFdp.remaining_bytes() > 0) {
41         auto mtpDataAPI = mFdp.PickValueInArray<const std::function<void()>>({
42                 [&]() { mtpDataPacket.allocate(mFdp.ConsumeIntegralInRange(kMinSize, kMaxSize)); },
43                 [&]() { mtpDataPacket.reset(); },
44                 [&]() {
45                     mtpDataPacket.setOperationCode(mFdp.ConsumeIntegralInRange(kMinSize, kMaxSize));
46                 },
47                 [&]() {
48                     mtpDataPacket.setTransactionID(mFdp.ConsumeIntegralInRange(kMinSize, kMaxSize));
49                 },
50                 [&]() {
51                     Int8List* result = mtpDataPacket.getAInt8();
52                     delete result;
53                 },
54                 [&]() {
55                     Int16List* result = mtpDataPacket.getAInt16();
56                     delete result;
57                 },
58                 [&]() {
59                     Int32List* result = mtpDataPacket.getAInt32();
60                     delete result;
61                 },
62                 [&]() {
63                     Int64List* result = mtpDataPacket.getAInt64();
64                     delete result;
65                 },
66                 [&]() {
67                     UInt8List* result = mtpDataPacket.getAUInt8();
68                     delete result;
69                 },
70                 [&]() {
71                     UInt16List* result = mtpDataPacket.getAUInt16();
72                     delete result;
73                 },
74                 [&]() {
75                     UInt32List* result = mtpDataPacket.getAUInt32();
76                     delete result;
77                 },
78                 [&]() {
79                     UInt64List* result = mtpDataPacket.getAUInt64();
80                     delete result;
81                 },
82                 [&]() {
83                     if (mFdp.ConsumeBool()) {
84                         std::vector<uint8_t> initData =
85                                 mFdp.ConsumeBytes<uint8_t>(mFdp.ConsumeIntegral<uint8_t>());
86                         mtpDataPacket.putAUInt8(initData.data(), initData.size());
87                     } else {
88                         mtpDataPacket.putAUInt8(nullptr, 0);
89                     }
90                 },
91                 [&]() {
92                     if (mFdp.ConsumeBool()) {
93                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
94                         uint16_t arr[size];
95                         for (size_t idx = 0; idx < size; ++idx) {
96                             arr[idx] = mFdp.ConsumeIntegral<uint16_t>();
97                         }
98                         mtpDataPacket.putAUInt16(arr, size);
99                     } else {
100                         mtpDataPacket.putAUInt16(nullptr, 0);
101                     }
102                 },
103                 [&]() {
104                     if (mFdp.ConsumeBool()) {
105                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
106                         uint32_t arr[size];
107                         for (size_t idx = 0; idx < size; ++idx) {
108                             arr[idx] = mFdp.ConsumeIntegral<uint32_t>();
109                         }
110                         mtpDataPacket.putAUInt32(arr, size);
111                     } else {
112                         mtpDataPacket.putAUInt32(nullptr, 0);
113                     }
114                 },
115                 [&]() {
116                     if (mFdp.ConsumeBool()) {
117                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
118                         uint64_t arr[size];
119                         for (size_t idx = 0; idx < size; ++idx) {
120                             arr[idx] = mFdp.ConsumeIntegral<uint64_t>();
121                         }
122                         mtpDataPacket.putAUInt64(arr, size);
123                     } else {
124                         mtpDataPacket.putAUInt64(nullptr, 0);
125                     }
126                 },
127                 [&]() {
128                     if (mFdp.ConsumeBool()) {
129                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
130                         int64_t arr[size];
131                         for (size_t idx = 0; idx < size; ++idx) {
132                             arr[idx] = mFdp.ConsumeIntegral<int64_t>();
133                         }
134                         mtpDataPacket.putAInt64(arr, size);
135                     } else {
136                         mtpDataPacket.putAInt64(nullptr, 0);
137                     }
138                 },
139                 [&]() {
140                     if (mFdp.ConsumeBool()) {
141                         std::vector<uint16_t> arr;
142                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
143                         for (size_t idx = 0; idx < size; ++idx) {
144                             arr.push_back(mFdp.ConsumeIntegral<uint16_t>());
145                         }
146                         mtpDataPacket.putAUInt16(&arr);
147                     } else {
148                         mtpDataPacket.putAUInt16(nullptr);
149                     }
150                 },
151                 [&]() {
152                     if (mFdp.ConsumeBool()) {
153                         std::vector<uint32_t> arr;
154                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
155                         for (size_t idx = 0; idx < size; ++idx) {
156                             arr.push_back(mFdp.ConsumeIntegral<uint32_t>());
157                         }
158                         mtpDataPacket.putAUInt32(&arr);
159                     } else {
160                         mtpDataPacket.putAUInt32(nullptr);
161                     }
162                 },
163 
164                 [&]() {
165                     if (mFdp.ConsumeBool()) {
166                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
167                         int32_t arr[size];
168                         for (size_t idx = 0; idx < size; ++idx) {
169                             arr[idx] = mFdp.ConsumeIntegral<int32_t>();
170                         }
171                         mtpDataPacket.putAInt32(arr, size);
172                     } else {
173                         mtpDataPacket.putAInt32(nullptr, 0);
174                     }
175                 },
176                 [&]() {
177                     if (mFdp.ConsumeBool()) {
178                         mtpDataPacket.putString(
179                                 (mFdp.ConsumeRandomLengthString(kMaxLength)).c_str());
180                     } else {
181                         mtpDataPacket.putString(static_cast<char*>(nullptr));
182                     }
183                 },
184                 [&]() {
185                     android::MtpStringBuffer sBuffer(
186                             (mFdp.ConsumeRandomLengthString(kMaxLength)).c_str());
187                     if (mFdp.ConsumeBool()) {
188                         mtpDataPacket.getString(sBuffer);
189                     } else {
190                         mtpDataPacket.putString(sBuffer);
191                     }
192                 },
193                 [&]() {
194                     MtpDevHandle handle;
195                     handle.start(mFdp.ConsumeBool());
196                     std::string text = mFdp.ConsumeRandomLengthString(kMaxLength);
197                     char* data = const_cast<char*>(text.c_str());
198                     handle.read(static_cast<void*>(data), text.length());
199                     if (mFdp.ConsumeBool()) {
200                         mtpDataPacket.read(&handle);
201                     } else if (mFdp.ConsumeBool()) {
202                         mtpDataPacket.write(&handle);
203                     } else {
204                         std::string textData = mFdp.ConsumeRandomLengthString(kMaxLength);
205                         char* Data = const_cast<char*>(textData.c_str());
206                         mtpDataPacket.writeData(&handle, static_cast<void*>(Data),
207                                                 textData.length());
208                     }
209                     handle.close();
210                 },
211                 [&]() {
212                     if (mFdp.ConsumeBool()) {
213                         std::string str = mFdp.ConsumeRandomLengthString(kMaxLength);
214                         android::String16 s(str.c_str());
215                         char16_t* data = const_cast<char16_t*>(s.string());
216                         mtpDataPacket.putString(reinterpret_cast<uint16_t*>(data));
217                     } else {
218                         mtpDataPacket.putString(static_cast<uint16_t*>(nullptr));
219                     }
220                 },
221                 [&]() {
222                     if (mFdp.ConsumeBool()) {
223                         std::vector<int8_t> data = mFdp.ConsumeBytes<int8_t>(
224                                 mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize));
225                         mtpDataPacket.putAInt8(data.data(), data.size());
226                     } else {
227                         mtpDataPacket.putAInt8(nullptr, 0);
228                     }
229                 },
230                 [&]() {
231                     if (mFdp.ConsumeBool()) {
232                         std::vector<uint8_t> data = mFdp.ConsumeBytes<uint8_t>(
233                                 mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize));
234                         mtpDataPacket.putAUInt8(data.data(), data.size());
235                     } else {
236                         mtpDataPacket.putAUInt8(nullptr, 0);
237                     }
238                 },
239                 [&]() {
240                     fillFilePath(&mFdp);
241                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
242                     fillUsbRequest(fd, &mFdp);
243                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
244                     std::vector<int8_t> data = mFdp.ConsumeBytes<int8_t>(
245                             mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize));
246                     mtpDataPacket.readData(&mUsbRequest, data.data(), data.size());
247                     usb_device_close(mUsbRequest.dev);
248                 },
249                 [&]() {
250                     fillFilePath(&mFdp);
251                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
252                     fillUsbRequest(fd, &mFdp);
253                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
254                     mtpDataPacket.write(
255                             &mUsbRequest,
256                             mFdp.PickValueInArray<UrbPacketDivisionMode>(kUrbPacketDivisionModes),
257                             fd, mFdp.ConsumeIntegralInRange(kMinSize, kMaxSize));
258                     usb_device_close(mUsbRequest.dev);
259                 },
260                 [&]() {
261                     fillFilePath(&mFdp);
262                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
263                     fillUsbRequest(fd, &mFdp);
264                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
265                     mtpDataPacket.read(&mUsbRequest);
266                     usb_device_close(mUsbRequest.dev);
267                 },
268                 [&]() {
269                     fillFilePath(&mFdp);
270                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
271                     fillUsbRequest(fd, &mFdp);
272                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
273                     mtpDataPacket.write(&mUsbRequest, mFdp.PickValueInArray<UrbPacketDivisionMode>(
274                                                              kUrbPacketDivisionModes));
275                     usb_device_close(mUsbRequest.dev);
276                 },
277                 [&]() {
278                     fillFilePath(&mFdp);
279                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
280                     fillUsbRequest(fd, &mFdp);
281                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
282                     mtpDataPacket.readDataHeader(&mUsbRequest);
283                     usb_device_close(mUsbRequest.dev);
284                 },
285                 [&]() {
286                     fillFilePath(&mFdp);
287                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
288                     fillUsbRequest(fd, &mFdp);
289                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
290                     mtpDataPacket.readDataAsync(&mUsbRequest);
291                     usb_device_close(mUsbRequest.dev);
292                 },
293                 [&]() {
294                     fillFilePath(&mFdp);
295                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
296                     fillUsbRequest(fd, &mFdp);
297                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
298                     mtpDataPacket.readDataWait(mUsbRequest.dev);
299                     usb_device_close(mUsbRequest.dev);
300                 },
301                 [&]() {
302                     if (mFdp.ConsumeBool()) {
303                         std::vector<int16_t> data;
304                         for (size_t idx = 0;
305                              idx < mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize); ++idx) {
306                             data.push_back(mFdp.ConsumeIntegral<int16_t>());
307                         }
308                         mtpDataPacket.putAInt16(data.data(), data.size());
309                     } else {
310                         mtpDataPacket.putAInt16(nullptr, 0);
311                     }
312                 },
313                 [&]() {
314                     int32_t arr[4];
315                     for (size_t idx = 0; idx < 4; ++idx) {
316                         arr[idx] = mFdp.ConsumeIntegral<int32_t>();
317                     }
318                     mtpDataPacket.putInt128(arr);
319                 },
320                 [&]() { mtpDataPacket.putInt64(mFdp.ConsumeIntegral<int64_t>()); },
321                 [&]() {
322                     int16_t out;
323                     mtpDataPacket.getInt16(out);
324                 },
325                 [&]() {
326                     int32_t out;
327                     mtpDataPacket.getInt32(out);
328                 },
329                 [&]() {
330                     int8_t out;
331                     mtpDataPacket.getInt8(out);
332                 },
333                 [&]() {
334                     uint32_t arr[4];
335                     for (size_t idx = 0; idx < 4; ++idx) {
336                         arr[idx] = mFdp.ConsumeIntegral<uint32_t>();
337                     }
338                     if (mFdp.ConsumeBool()) {
339                         mtpDataPacket.putUInt128(arr);
340                     } else {
341                         mtpDataPacket.getUInt128(arr);
342                     }
343                 },
344                 [&]() { mtpDataPacket.putUInt64(mFdp.ConsumeIntegral<uint64_t>()); },
345                 [&]() {
346                     uint64_t out;
347                     mtpDataPacket.getUInt64(out);
348                 },
349                 [&]() { mtpDataPacket.putInt128(mFdp.ConsumeIntegral<int64_t>()); },
350                 [&]() { mtpDataPacket.putUInt128(mFdp.ConsumeIntegral<uint64_t>()); },
351                 [&]() {
352                     int32_t length;
353                     void* data = mtpDataPacket.getData(&length);
354                     free(data);
355                 },
356         });
357         mtpDataAPI();
358     }
359 }
360 
LLVMFuzzerTestOneInput(const uint8_t * data,size_t size)361 extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
362     MtpDataPacketFuzzer mtpDataPacketFuzzer(data, size);
363     mtpDataPacketFuzzer.process();
364     return 0;
365 }
366