• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 Code Intelligence GmbH
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Modified from
16 // https://raw.githubusercontent.com/google/atheris/034284dc4bb1ad4f4ab6ba5d34fb4dca7c633660/fuzzed_data_provider.cc
17 //
18 // Original license and copyright notices:
19 //
20 // Copyright 2020 Google LLC
21 //
22 // Licensed under the Apache License, Version 2.0 (the "License");
23 // you may not use this file except in compliance with the License.
24 // You may obtain a copy of the License at
25 //
26 //      http://www.apache.org/licenses/LICENSE-2.0
27 //
28 // Unless required by applicable law or agreed to in writing, software
29 // distributed under the License is distributed on an "AS IS" BASIS,
30 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31 // See the License for the specific language governing permissions and
32 // limitations under the License.
33 //
34 // Modified from
35 // https://github.com/llvm/llvm-project/blob/70de7e0d9a95b7fcd7c105b06bd90fdf4e01f563/compiler-rt/include/fuzzer/FuzzedDataProvider.h
36 //
37 // Original license and copyright notices:
38 //
39 //===- FuzzedDataProvider.h - Utility header for fuzz targets ---*- C++ -* ===//
40 //
41 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42 // See https://llvm.org/LICENSE.txt for license information.
43 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44 //
45 
46 #include "fuzzed_data_provider.h"
47 
48 #include <algorithm>
49 #include <cstdint>
50 #include <string>
51 #include <type_traits>
52 #include <vector>
53 
54 #include "absl/strings/str_format.h"
55 
56 namespace {
57 
58 const uint8_t *gDataPtr = nullptr;
59 std::size_t gRemainingBytes = 0;
60 
61 // Advance by `bytes` bytes in the buffer or stay at the end if it has been
62 // consumed.
Advance(const std::size_t bytes)63 void Advance(const std::size_t bytes) {
64   if (bytes > gRemainingBytes) {
65     gRemainingBytes = 0;
66   } else {
67     gDataPtr += bytes;
68     gRemainingBytes -= bytes;
69   }
70 }
71 
ThrowIllegalArgumentException(JNIEnv & env,const std::string & message)72 void ThrowIllegalArgumentException(JNIEnv &env, const std::string &message) {
73   jclass illegal_argument_exception =
74       env.FindClass("java/lang/IllegalArgumentException");
75   env.ThrowNew(illegal_argument_exception, message.c_str());
76 }
77 
78 template <typename T>
79 struct JniArrayType {};
80 
81 #define JNI_ARRAY_TYPE(lower_case, sentence_case)                    \
82   template <>                                                        \
83   struct JniArrayType<j##lower_case> {                               \
84     typedef j##lower_case type;                                      \
85     typedef j##lower_case##Array array_type;                         \
86     static constexpr array_type (JNIEnv::*kNewArrayFunc)(jsize) =    \
87         &JNIEnv::New##sentence_case##Array;                          \
88     static constexpr void (JNIEnv::*kSetArrayRegionFunc)(            \
89         array_type array, jsize start, jsize len,                    \
90         const type *buf) = &JNIEnv::Set##sentence_case##ArrayRegion; \
91   };
92 
93 JNI_ARRAY_TYPE(boolean, Boolean);
94 JNI_ARRAY_TYPE(byte, Byte);
95 JNI_ARRAY_TYPE(short, Short);
96 JNI_ARRAY_TYPE(int, Int);
97 JNI_ARRAY_TYPE(long, Long);
98 
99 template <typename T>
100 typename JniArrayType<T>::array_type JNICALL
ConsumeIntegralArray(JNIEnv & env,jobject self,jint max_length)101 ConsumeIntegralArray(JNIEnv &env, jobject self, jint max_length) {
102   if (max_length < 0) {
103     ThrowIllegalArgumentException(env, "maxLength must not be negative");
104     return nullptr;
105   }
106   // Arrays of integral types are considered data and thus consumed from the
107   // beginning of the buffer.
108   std::size_t max_num_bytes = std::min(sizeof(T) * max_length, gRemainingBytes);
109   jsize actual_length = max_num_bytes / sizeof(T);
110   std::size_t actual_num_bytes = sizeof(T) * actual_length;
111   auto array = (env.*(JniArrayType<T>::kNewArrayFunc))(actual_length);
112   (env.*(JniArrayType<T>::kSetArrayRegionFunc))(
113       array, 0, actual_length, reinterpret_cast<const T *>(gDataPtr));
114   Advance(actual_num_bytes);
115   return array;
116 }
117 
118 template <typename T>
ConsumeRemainingAsArray(JNIEnv & env,jobject self)119 jbyteArray JNICALL ConsumeRemainingAsArray(JNIEnv &env, jobject self) {
120   return ConsumeIntegralArray<T>(env, self, std::numeric_limits<jint>::max());
121 }
122 
123 template <typename T>
ConsumeIntegralInRange(JNIEnv & env,jobject self,T min,T max)124 T JNICALL ConsumeIntegralInRange(JNIEnv &env, jobject self, T min, T max) {
125   if (min > max) {
126     ThrowIllegalArgumentException(
127         env, absl::StrFormat(
128                  "Consume*InRange: min must be <= max (got min: %d, max: %d)",
129                  min, max));
130     return 0;
131   }
132 
133   uint64_t range = static_cast<uint64_t>(max) - min;
134   uint64_t result = 0;
135   std::size_t offset = 0;
136 
137   while (offset < 8 * sizeof(T) && (range >> offset) > 0 &&
138          gRemainingBytes != 0) {
139     --gRemainingBytes;
140     result = (result << 8u) | gDataPtr[gRemainingBytes];
141     offset += 8;
142   }
143 
144   if (range != std::numeric_limits<T>::max())
145     // We accept modulo bias in favor of reading a dynamic number of bytes as
146     // this would make it harder for the fuzzer to mutate towards values from
147     // the table of recent compares.
148     result = result % (range + 1);
149 
150   return static_cast<T>(min + result);
151 }
152 
153 template <typename T>
ConsumeIntegral(JNIEnv & env,jobject self)154 T JNICALL ConsumeIntegral(JNIEnv &env, jobject self) {
155   // First generate an unsigned value and then (safely) cast it to a signed
156   // integral type. By doing this rather than calling ConsumeIntegralInRange
157   // with bounds [signed_min, signed_max], we ensure that there is a direct
158   // correspondence between the consumed raw bytes and the result (e.g., 0
159   // corresponds to 0 and not to signed_min). This should help mutating
160   // towards entries of the table of recent compares.
161   using UnsignedT = typename std::make_unsigned<T>::type;
162   static_assert(
163       std::numeric_limits<UnsignedT>::is_modulo,
164       "Unsigned to signed conversion requires modulo-based overflow handling");
165   return static_cast<T>(ConsumeIntegralInRange<UnsignedT>(
166       env, self, 0, std::numeric_limits<UnsignedT>::max()));
167 }
168 
ConsumeBool(JNIEnv & env,jobject self)169 bool JNICALL ConsumeBool(JNIEnv &env, jobject self) {
170   return ConsumeIntegral<uint8_t>(env, self) & 1u;
171 }
172 
ConsumeCharInternal(JNIEnv & env,jobject self,bool filter_surrogates)173 jchar ConsumeCharInternal(JNIEnv &env, jobject self, bool filter_surrogates) {
174   auto raw_codepoint = ConsumeIntegral<jchar>(env, self);
175   if (filter_surrogates && raw_codepoint >= 0xd800 && raw_codepoint < 0xe000)
176     raw_codepoint -= 0xd800;
177   return raw_codepoint;
178 }
179 
ConsumeChar(JNIEnv & env,jobject self)180 jchar JNICALL ConsumeChar(JNIEnv &env, jobject self) {
181   return ConsumeCharInternal(env, self, false);
182 }
183 
ConsumeCharNoSurrogates(JNIEnv & env,jobject self)184 jchar JNICALL ConsumeCharNoSurrogates(JNIEnv &env, jobject self) {
185   return ConsumeCharInternal(env, self, true);
186 }
187 
188 template <typename T>
ConsumeProbability(JNIEnv & env,jobject self)189 T JNICALL ConsumeProbability(JNIEnv &env, jobject self) {
190   using IntegralType =
191       typename std::conditional<(sizeof(T) <= sizeof(uint32_t)), uint32_t,
192                                 uint64_t>::type;
193   T result = static_cast<T>(ConsumeIntegral<IntegralType>(env, self));
194   result /= static_cast<T>(std::numeric_limits<IntegralType>::max());
195   return result;
196 }
197 
198 template <typename T>
ConsumeFloatInRange(JNIEnv & env,jobject self,T min,T max)199 T JNICALL ConsumeFloatInRange(JNIEnv &env, jobject self, T min, T max) {
200   if (min > max) {
201     ThrowIllegalArgumentException(
202         env, absl::StrFormat(
203                  "Consume*InRange: min must be <= max (got min: %f, max: %f)",
204                  min, max));
205     return 0.0;
206   }
207 
208   T range;
209   T result = min;
210 
211   // Deal with overflow, in the event min and max are very far apart
212   if (min < 0 && max > 0 && min + std::numeric_limits<T>::max() < max) {
213     range = (max / 2) - (min / 2);
214     if (ConsumeBool(env, self)) {
215       result += range;
216     }
217   } else {
218     range = max - min;
219   }
220 
221   T probability = ConsumeProbability<T>(env, self);
222   return result + range * probability;
223 }
224 
225 template <typename T>
ConsumeRegularFloat(JNIEnv & env,jobject self)226 T JNICALL ConsumeRegularFloat(JNIEnv &env, jobject self) {
227   return ConsumeFloatInRange(env, self, std::numeric_limits<T>::lowest(),
228                              std::numeric_limits<T>::max());
229 }
230 
231 template <typename T>
ConsumeFloat(JNIEnv & env,jobject self)232 T JNICALL ConsumeFloat(JNIEnv &env, jobject self) {
233   if (!gRemainingBytes) return 0.0;
234 
235   auto type_val = ConsumeIntegral<uint8_t>(env, self);
236 
237   if (type_val <= 10) {
238     // Consume the same amount of bytes as for a regular float/double
239     ConsumeRegularFloat<T>(env, self);
240 
241     switch (type_val) {
242       case 0:
243         return 0.0;
244       case 1:
245         return -0.0;
246       case 2:
247         return std::numeric_limits<T>::infinity();
248       case 3:
249         return -std::numeric_limits<T>::infinity();
250       case 4:
251         return std::numeric_limits<T>::quiet_NaN();
252       case 5:
253         return std::numeric_limits<T>::denorm_min();
254       case 6:
255         return -std::numeric_limits<T>::denorm_min();
256       case 7:
257         return std::numeric_limits<T>::min();
258       case 8:
259         return -std::numeric_limits<T>::min();
260       case 9:
261         return std::numeric_limits<T>::max();
262       case 10:
263         return -std::numeric_limits<T>::max();
264       default:
265         abort();
266     }
267   }
268 
269   T regular = ConsumeRegularFloat<T>(env, self);
270   return regular;
271 }
272 
273 // Polyfill for C++20 std::countl_one, which counts the number of leading ones
274 // in an unsigned integer.
countl_one(uint8_t byte)275 inline __attribute__((always_inline)) uint8_t countl_one(uint8_t byte) {
276   // The result of __builtin_clz is undefined for 0.
277   if (byte == 0xFF) return 8;
278   return __builtin_clz(static_cast<uint8_t>(~byte)) - 24;
279 }
280 
281 // Forces a byte to be a valid UTF-8 continuation byte.
ForceContinuationByte(uint8_t & byte)282 inline __attribute__((always_inline)) void ForceContinuationByte(
283     uint8_t &byte) {
284   byte = (byte | (1u << 7u)) & ~(1u << 6u);
285 }
286 
287 constexpr uint8_t kTwoByteZeroLeadingByte = 0b11000000;
288 constexpr uint8_t kTwoByteZeroContinuationByte = 0b10000000;
289 constexpr uint8_t kThreeByteLowLeadingByte = 0b11100000;
290 constexpr uint8_t kSurrogateLeadingByte = 0b11101101;
291 
292 enum class Utf8GenerationState {
293   LeadingByte_Generic,
294   LeadingByte_AfterBackslash,
295   ContinuationByte_Generic,
296   ContinuationByte_LowLeadingByte,
297   FirstContinuationByte_LowLeadingByte,
298   FirstContinuationByte_SurrogateLeadingByte,
299   FirstContinuationByte_Generic,
300   SecondContinuationByte_Generic,
301   LeadingByte_LowSurrogate,
302   FirstContinuationByte_LowSurrogate,
303   SecondContinuationByte_HighSurrogate,
304   SecondContinuationByte_LowSurrogate,
305 };
306 
307 // Consumes up to `max_bytes` arbitrary bytes pointed to by `ptr` and returns a
308 // valid "modified UTF-8" string of length at most `max_length` that resembles
309 // the input bytes as closely as possible as well as the number of consumed
310 // bytes. If `stop_on_slash` is true, then the string will end on the first
311 // single consumed '\'.
312 //
313 // "Modified UTF-8" is the string encoding used by the JNI. It is the same as
314 // the legacy encoding CESU-8, but with `\0` coded on two bytes. In these
315 // encodings, code points requiring 4 bytes in modern UTF-8 are represented as
316 // two surrogates, each of which is coded on 3 bytes.
317 //
318 // This function has been designed with the following goals in mind:
319 // 1. The generated string should be biased towards containing ASCII characters
320 //    as these are often the ones that affect control flow directly.
321 // 2. Correctly encoded data (e.g. taken from the table of recent compares)
322 //    should be emitted unchanged.
323 // 3. The raw fuzzer input should be preserved as far as possible, but the
324 //    output must always be correctly encoded.
325 //
326 // The JVM accepts string in two encodings: UTF-16 and modified UTF-8.
327 // Generating UTF-16 would make it harder to fulfill the first design goal and
328 // would potentially hinder compatibility with corpora using the much more
329 // widely used UTF-8 encoding, which is reasonably similar to modified UTF-8. As
330 // a result, this function uses modified UTF-8.
331 //
332 // See Algorithm 1 of https://arxiv.org/pdf/2010.03090.pdf for more details on
333 // the individual cases involved in determining the validity of a UTF-8 string.
334 template <bool ascii_only, bool stop_on_backslash>
FixUpModifiedUtf8(const uint8_t * data,std::size_t max_bytes,jint max_length)335 std::pair<std::string, std::size_t> FixUpModifiedUtf8(const uint8_t *data,
336                                                       std::size_t max_bytes,
337                                                       jint max_length) {
338   std::string str;
339   // Every character in modified UTF-8 is coded on at most six bytes. Every
340   // consumed byte is transformed into at most one code unit, except for the
341   // case of a zero byte which requires two bytes.
342   if (max_bytes > std::numeric_limits<std::size_t>::max() / 2)
343     max_bytes = std::numeric_limits<std::size_t>::max() / 2;
344   if (ascii_only) {
345     str.reserve(
346         std::min(2 * static_cast<std::size_t>(max_length), 2 * max_bytes));
347   } else {
348     str.reserve(
349         std::min(6 * static_cast<std::size_t>(max_length), 2 * max_bytes));
350   }
351 
352   Utf8GenerationState state = Utf8GenerationState::LeadingByte_Generic;
353   const uint8_t *pos = data;
354   const auto data_end = data + max_bytes;
355   for (std::size_t length = 0; length < max_length && pos != data_end; ++pos) {
356     uint8_t c = *pos;
357     if (ascii_only) {
358       // Clamp to 7-bit ASCII range.
359       c &= 0x7Fu;
360     }
361     // Fix up c or previously read bytes according to the value of c and the
362     // current state. In the end, add the fixed up code unit c to the string.
363     // Exception: The zero character has to be coded on two bytes and is the
364     // only case in which an iteration of the loop adds two code units.
365     switch (state) {
366       case Utf8GenerationState::LeadingByte_Generic: {
367         switch (ascii_only ? 0 : countl_one(c)) {
368           case 0: {
369             // valid - 1-byte code point (ASCII)
370             // The zero character has to be coded on two bytes in modified
371             // UTF-8.
372             if (c == 0) {
373               str += static_cast<char>(kTwoByteZeroLeadingByte);
374               c = kTwoByteZeroContinuationByte;
375             } else if (stop_on_backslash && c == '\\') {
376               state = Utf8GenerationState::LeadingByte_AfterBackslash;
377               // The slash either signals the end of the string or is skipped,
378               // so don't append anything.
379               continue;
380             }
381             // Remain in state LeadingByte.
382             ++length;
383             break;
384           }
385           case 1: {
386             // invalid - continuation byte at leader byte position
387             // Fix it up to be of the form 0b110XXXXX and fall through to the
388             // case of a 2-byte sequence.
389             c |= 1u << 6u;
390             c &= ~(1u << 5u);
391             [[fallthrough]];
392           }
393           case 2: {
394             // (most likely) valid - start of a 2-byte sequence
395             // ASCII characters must be coded on a single byte, so we must
396             // ensure that the lower two bits combined with the six non-header
397             // bits of the following byte do not form a 7-bit ASCII value. This
398             // could only be the case if at most the lowest bit is set.
399             if ((c & 0b00011110u) == 0) {
400               state = Utf8GenerationState::ContinuationByte_LowLeadingByte;
401             } else {
402               state = Utf8GenerationState::ContinuationByte_Generic;
403             }
404             break;
405           }
406           // The default case falls through to the case of three leading ones
407           // coming right after.
408           default: {
409             // invalid - at least four leading ones
410             // In the case of exactly four leading ones, this would be valid
411             // UTF-8, but is not valid in the JVM's modified UTF-8 encoding.
412             // Fix it up by clearing the fourth leading one and falling through
413             // to the 3-byte case.
414             c &= ~(1u << 4u);
415             [[fallthrough]];
416           }
417           case 3: {
418             // valid - start of a 3-byte sequence
419             if (c == kThreeByteLowLeadingByte) {
420               state = Utf8GenerationState::FirstContinuationByte_LowLeadingByte;
421             } else if (c == kSurrogateLeadingByte) {
422               state = Utf8GenerationState::
423                   FirstContinuationByte_SurrogateLeadingByte;
424             } else {
425               state = Utf8GenerationState::FirstContinuationByte_Generic;
426             }
427             break;
428           }
429         }
430         break;
431       }
432       case Utf8GenerationState::LeadingByte_AfterBackslash: {
433         if (c != '\\') {
434           // Mark the current byte as consumed.
435           ++pos;
436           goto done;
437         }
438         // A double backslash is consumed as a single one. As we skipped the
439         // first one, emit the second one as usual.
440         state = Utf8GenerationState::LeadingByte_Generic;
441         ++length;
442         break;
443       }
444       case Utf8GenerationState::ContinuationByte_LowLeadingByte: {
445         ForceContinuationByte(c);
446         // Preserve the zero character, which is coded on two bytes in modified
447         // UTF-8. In all other cases ensure that we are not incorrectly encoding
448         // an ASCII character on two bytes by setting the eigth least
449         // significant bit of the encoded value (second least significant bit of
450         // the leading byte).
451         auto previous_c = static_cast<uint8_t>(str.back());
452         if (previous_c != kTwoByteZeroLeadingByte ||
453             c != kTwoByteZeroContinuationByte) {
454           str.back() = static_cast<char>(previous_c | (1u << 1u));
455         }
456         state = Utf8GenerationState::LeadingByte_Generic;
457         ++length;
458         break;
459       }
460       case Utf8GenerationState::ContinuationByte_Generic: {
461         ForceContinuationByte(c);
462         state = Utf8GenerationState::LeadingByte_Generic;
463         ++length;
464         break;
465       }
466       case Utf8GenerationState::FirstContinuationByte_LowLeadingByte: {
467         ForceContinuationByte(c);
468         // Ensure that the current code point could not have been coded on two
469         // bytes. As two bytes encode up to 11 bits and three bytes encode up
470         // to 16 bits, we thus have to make it such that the five highest bits
471         // are not all zero. Four of these bits are the non-header bits of the
472         // leader byte. Thus, set the highest non-header bit in this byte (fifth
473         // highest in the encoded value).
474         c |= 1u << 5u;
475         state = Utf8GenerationState::SecondContinuationByte_Generic;
476         break;
477       }
478       case Utf8GenerationState::FirstContinuationByte_SurrogateLeadingByte: {
479         ForceContinuationByte(c);
480         if (c & (1u << 5u)) {
481           // Start with a high surrogate (0xD800-0xDBFF). c contains the second
482           // byte and the first two bits of the third byte. The first two bits
483           // of this second byte are fixed to 10 (in 0x8-0xB).
484           c |= 1u << 5u;
485           c &= ~(1u << 4u);
486           // The high surrogate must be followed by a low surrogate.
487           state = Utf8GenerationState::SecondContinuationByte_HighSurrogate;
488         } else {
489           state = Utf8GenerationState::SecondContinuationByte_Generic;
490         }
491         break;
492       }
493       case Utf8GenerationState::FirstContinuationByte_Generic: {
494         ForceContinuationByte(c);
495         state = Utf8GenerationState::SecondContinuationByte_Generic;
496         break;
497       }
498       case Utf8GenerationState::SecondContinuationByte_HighSurrogate: {
499         ForceContinuationByte(c);
500         state = Utf8GenerationState::LeadingByte_LowSurrogate;
501         ++length;
502         break;
503       }
504       case Utf8GenerationState::SecondContinuationByte_LowSurrogate:
505       case Utf8GenerationState::SecondContinuationByte_Generic: {
506         ForceContinuationByte(c);
507         state = Utf8GenerationState::LeadingByte_Generic;
508         ++length;
509         break;
510       }
511       case Utf8GenerationState::LeadingByte_LowSurrogate: {
512         // We have to emit a low surrogate leading byte, which is a fixed value.
513         // We still consume a byte from the input to make fuzzer changes more
514         // stable and preserve valid surrogate pairs picked up from e.g. the
515         // table of recent compares.
516         c = kSurrogateLeadingByte;
517         state = Utf8GenerationState::FirstContinuationByte_LowSurrogate;
518         break;
519       }
520       case Utf8GenerationState::FirstContinuationByte_LowSurrogate: {
521         ForceContinuationByte(c);
522         // Low surrogates are code points in the range 0xDC00-0xDFFF. c contains
523         // the second byte and the first two bits of the third byte. The first
524         // two bits of this second byte are fixed to 11 (in 0xC-0xF).
525         c |= (1u << 5u) | (1u << 4u);
526         // The second continuation byte of a low surrogate is not restricted,
527         // but we need to track it differently to allow for correct backtracking
528         // if it isn't completed.
529         state = Utf8GenerationState::SecondContinuationByte_LowSurrogate;
530         break;
531       }
532     }
533     str += static_cast<uint8_t>(c);
534   }
535 
536   // Backtrack the current incomplete character.
537   switch (state) {
538     case Utf8GenerationState::SecondContinuationByte_LowSurrogate:
539       str.pop_back();
540       [[fallthrough]];
541     case Utf8GenerationState::FirstContinuationByte_LowSurrogate:
542       str.pop_back();
543       [[fallthrough]];
544     case Utf8GenerationState::LeadingByte_LowSurrogate:
545       str.pop_back();
546       [[fallthrough]];
547     case Utf8GenerationState::SecondContinuationByte_Generic:
548     case Utf8GenerationState::SecondContinuationByte_HighSurrogate:
549       str.pop_back();
550       [[fallthrough]];
551     case Utf8GenerationState::ContinuationByte_Generic:
552     case Utf8GenerationState::ContinuationByte_LowLeadingByte:
553     case Utf8GenerationState::FirstContinuationByte_Generic:
554     case Utf8GenerationState::FirstContinuationByte_LowLeadingByte:
555     case Utf8GenerationState::FirstContinuationByte_SurrogateLeadingByte:
556       str.pop_back();
557       [[fallthrough]];
558     case Utf8GenerationState::LeadingByte_Generic:
559     case Utf8GenerationState::LeadingByte_AfterBackslash:
560       // No backtracking required.
561       break;
562   }
563 
564 done:
565   return std::make_pair(str, pos - data);
566 }
567 }  // namespace
568 
569 namespace jazzer {
570 // Exposed for testing only.
FixUpModifiedUtf8(const uint8_t * data,std::size_t max_bytes,jint max_length,bool ascii_only,bool stop_on_backslash)571 std::pair<std::string, std::size_t> FixUpModifiedUtf8(const uint8_t *data,
572                                                       std::size_t max_bytes,
573                                                       jint max_length,
574                                                       bool ascii_only,
575                                                       bool stop_on_backslash) {
576   if (ascii_only) {
577     if (stop_on_backslash) {
578       return ::FixUpModifiedUtf8<true, true>(data, max_bytes, max_length);
579     } else {
580       return ::FixUpModifiedUtf8<true, false>(data, max_bytes, max_length);
581     }
582   } else {
583     if (stop_on_backslash) {
584       return ::FixUpModifiedUtf8<false, true>(data, max_bytes, max_length);
585     } else {
586       return ::FixUpModifiedUtf8<false, false>(data, max_bytes, max_length);
587     }
588   }
589 }
590 }  // namespace jazzer
591 
592 namespace {
ConsumeStringInternal(JNIEnv & env,jint max_length,bool ascii_only,bool stop_on_backslash)593 jstring ConsumeStringInternal(JNIEnv &env, jint max_length, bool ascii_only,
594                               bool stop_on_backslash) {
595   if (max_length < 0) {
596     ThrowIllegalArgumentException(env, "maxLength must not be negative");
597     return nullptr;
598   }
599 
600   if (max_length == 0 || gRemainingBytes == 0) return env.NewStringUTF("");
601 
602   if (gRemainingBytes == 1) {
603     Advance(1);
604     return env.NewStringUTF("");
605   }
606 
607   std::size_t max_bytes = gRemainingBytes;
608   std::string str;
609   std::size_t consumed_bytes;
610   std::tie(str, consumed_bytes) = jazzer::FixUpModifiedUtf8(
611       gDataPtr, max_bytes, max_length, ascii_only, stop_on_backslash);
612   Advance(consumed_bytes);
613   return env.NewStringUTF(str.c_str());
614 }
615 
ConsumeAsciiString(JNIEnv & env,jobject self,jint max_length)616 jstring JNICALL ConsumeAsciiString(JNIEnv &env, jobject self, jint max_length) {
617   return ConsumeStringInternal(env, max_length, true, true);
618 }
619 
ConsumeString(JNIEnv & env,jobject self,jint max_length)620 jstring JNICALL ConsumeString(JNIEnv &env, jobject self, jint max_length) {
621   return ConsumeStringInternal(env, max_length, false, true);
622 }
623 
ConsumeRemainingAsAsciiString(JNIEnv & env,jobject self)624 jstring JNICALL ConsumeRemainingAsAsciiString(JNIEnv &env, jobject self) {
625   return ConsumeStringInternal(env, std::numeric_limits<jint>::max(), true,
626                                false);
627 }
628 
ConsumeRemainingAsString(JNIEnv & env,jobject self)629 jstring JNICALL ConsumeRemainingAsString(JNIEnv &env, jobject self) {
630   return ConsumeStringInternal(env, std::numeric_limits<jint>::max(), false,
631                                false);
632 }
633 
RemainingBytes(JNIEnv & env,jobject self)634 std::size_t RemainingBytes(JNIEnv &env, jobject self) {
635   return gRemainingBytes;
636 }
637 
638 const JNINativeMethod kFuzzedDataMethods[]{
639     {(char *)"consumeBoolean", (char *)"()Z", (void *)&ConsumeBool},
640     {(char *)"consumeByte", (char *)"()B", (void *)&ConsumeIntegral<jbyte>},
641     {(char *)"consumeByte", (char *)"(BB)B",
642      (void *)&ConsumeIntegralInRange<jbyte>},
643     {(char *)"consumeShort", (char *)"()S", (void *)&ConsumeIntegral<jshort>},
644     {(char *)"consumeShort", (char *)"(SS)S",
645      (void *)&ConsumeIntegralInRange<jshort>},
646     {(char *)"consumeInt", (char *)"()I", (void *)&ConsumeIntegral<jint>},
647     {(char *)"consumeInt", (char *)"(II)I",
648      (void *)&ConsumeIntegralInRange<jint>},
649     {(char *)"consumeLong", (char *)"()J", (void *)&ConsumeIntegral<jlong>},
650     {(char *)"consumeLong", (char *)"(JJ)J",
651      (void *)&ConsumeIntegralInRange<jlong>},
652     {(char *)"consumeFloat", (char *)"()F", (void *)&ConsumeFloat<jfloat>},
653     {(char *)"consumeRegularFloat", (char *)"()F",
654      (void *)&ConsumeRegularFloat<jfloat>},
655     {(char *)"consumeRegularFloat", (char *)"(FF)F",
656      (void *)&ConsumeFloatInRange<jfloat>},
657     {(char *)"consumeProbabilityFloat", (char *)"()F",
658      (void *)&ConsumeProbability<jfloat>},
659     {(char *)"consumeDouble", (char *)"()D", (void *)&ConsumeFloat<jdouble>},
660     {(char *)"consumeRegularDouble", (char *)"()D",
661      (void *)&ConsumeRegularFloat<jdouble>},
662     {(char *)"consumeRegularDouble", (char *)"(DD)D",
663      (void *)&ConsumeFloatInRange<jdouble>},
664     {(char *)"consumeProbabilityDouble", (char *)"()D",
665      (void *)&ConsumeProbability<jdouble>},
666     {(char *)"consumeChar", (char *)"()C", (void *)&ConsumeChar},
667     {(char *)"consumeChar", (char *)"(CC)C",
668      (void *)&ConsumeIntegralInRange<jchar>},
669     {(char *)"consumeCharNoSurrogates", (char *)"()C",
670      (void *)&ConsumeCharNoSurrogates},
671     {(char *)"consumeAsciiString", (char *)"(I)Ljava/lang/String;",
672      (void *)&ConsumeAsciiString},
673     {(char *)"consumeRemainingAsAsciiString", (char *)"()Ljava/lang/String;",
674      (void *)&ConsumeRemainingAsAsciiString},
675     {(char *)"consumeString", (char *)"(I)Ljava/lang/String;",
676      (void *)&ConsumeString},
677     {(char *)"consumeRemainingAsString", (char *)"()Ljava/lang/String;",
678      (void *)&ConsumeRemainingAsString},
679     {(char *)"consumeBooleans", (char *)"(I)[Z",
680      (void *)&ConsumeIntegralArray<jboolean>},
681     {(char *)"consumeBytes", (char *)"(I)[B",
682      (void *)&ConsumeIntegralArray<jbyte>},
683     {(char *)"consumeShorts", (char *)"(I)[S",
684      (void *)&ConsumeIntegralArray<jshort>},
685     {(char *)"consumeInts", (char *)"(I)[I",
686      (void *)&ConsumeIntegralArray<jint>},
687     {(char *)"consumeLongs", (char *)"(I)[J",
688      (void *)&ConsumeIntegralArray<jlong>},
689     {(char *)"consumeRemainingAsBytes", (char *)"()[B",
690      (void *)&ConsumeRemainingAsArray<jbyte>},
691     {(char *)"remainingBytes", (char *)"()I", (void *)&RemainingBytes},
692 };
693 const jint kNumFuzzedDataMethods =
694     sizeof(kFuzzedDataMethods) / sizeof(kFuzzedDataMethods[0]);
695 }  // namespace
696 
697 namespace jazzer {
698 
SetUpFuzzedDataProvider(JNIEnv & env)699 void SetUpFuzzedDataProvider(JNIEnv &env) {
700   jclass fuzzed_data_provider_class =
701       env.FindClass(kFuzzedDataProviderImplClass);
702   if (env.ExceptionCheck()) {
703     env.ExceptionDescribe();
704     throw std::runtime_error("failed to find FuzzedDataProviderImpl class");
705   }
706   env.RegisterNatives(fuzzed_data_provider_class, kFuzzedDataMethods,
707                       kNumFuzzedDataMethods);
708   if (env.ExceptionCheck()) {
709     env.ExceptionDescribe();
710     throw std::runtime_error(
711         "could not register native callbacks for FuzzedDataProvider");
712   }
713 }
714 
FeedFuzzedDataProvider(const uint8_t * data,std::size_t size)715 void FeedFuzzedDataProvider(const uint8_t *data, std::size_t size) {
716   gDataPtr = data;
717   gRemainingBytes = size;
718 }
719 }  // namespace jazzer
720