• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::{
2     cmp,
3     io::{self, Read as _},
4     iter,
5 };
6 
7 use rand::{Rng as _, RngCore as _};
8 
9 use super::decoder::{DecoderReader, BUF_SIZE};
10 use crate::{
11     engine::{general_purpose::STANDARD, Engine, GeneralPurpose},
12     tests::{random_alphabet, random_config, random_engine},
13     DecodeError,
14 };
15 
16 #[test]
simple()17 fn simple() {
18     let tests: &[(&[u8], &[u8])] = &[
19         (&b"0"[..], &b"MA=="[..]),
20         (b"01", b"MDE="),
21         (b"012", b"MDEy"),
22         (b"0123", b"MDEyMw=="),
23         (b"01234", b"MDEyMzQ="),
24         (b"012345", b"MDEyMzQ1"),
25         (b"0123456", b"MDEyMzQ1Ng=="),
26         (b"01234567", b"MDEyMzQ1Njc="),
27         (b"012345678", b"MDEyMzQ1Njc4"),
28         (b"0123456789", b"MDEyMzQ1Njc4OQ=="),
29     ][..];
30 
31     for (text_expected, base64data) in tests.iter() {
32         // Read n bytes at a time.
33         for n in 1..base64data.len() + 1 {
34             let mut wrapped_reader = io::Cursor::new(base64data);
35             let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD);
36 
37             // handle errors as you normally would
38             let mut text_got = Vec::new();
39             let mut buffer = vec![0u8; n];
40             while let Ok(read) = decoder.read(&mut buffer[..]) {
41                 if read == 0 {
42                     break;
43                 }
44                 text_got.extend_from_slice(&buffer[..read]);
45             }
46 
47             assert_eq!(
48                 text_got,
49                 *text_expected,
50                 "\nGot: {}\nExpected: {}",
51                 String::from_utf8_lossy(&text_got[..]),
52                 String::from_utf8_lossy(text_expected)
53             );
54         }
55     }
56 }
57 
58 // Make sure we error out on trailing junk.
59 #[test]
trailing_junk()60 fn trailing_junk() {
61     let tests: &[&[u8]] = &[&b"MDEyMzQ1Njc4*!@#$%^&"[..], b"MDEyMzQ1Njc4OQ== "][..];
62 
63     for base64data in tests.iter() {
64         // Read n bytes at a time.
65         for n in 1..base64data.len() + 1 {
66             let mut wrapped_reader = io::Cursor::new(base64data);
67             let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD);
68 
69             // handle errors as you normally would
70             let mut buffer = vec![0u8; n];
71             let mut saw_error = false;
72             loop {
73                 match decoder.read(&mut buffer[..]) {
74                     Err(_) => {
75                         saw_error = true;
76                         break;
77                     }
78                     Ok(read) if read == 0 => break,
79                     Ok(_) => (),
80                 }
81             }
82 
83             assert!(saw_error);
84         }
85     }
86 }
87 
88 #[test]
handles_short_read_from_delegate()89 fn handles_short_read_from_delegate() {
90     let mut rng = rand::thread_rng();
91     let mut bytes = Vec::new();
92     let mut b64 = String::new();
93     let mut decoded = Vec::new();
94 
95     for _ in 0..10_000 {
96         bytes.clear();
97         b64.clear();
98         decoded.clear();
99 
100         let size = rng.gen_range(0..(10 * BUF_SIZE));
101         bytes.extend(iter::repeat(0).take(size));
102         bytes.truncate(size);
103         rng.fill_bytes(&mut bytes[..size]);
104         assert_eq!(size, bytes.len());
105 
106         let engine = random_engine(&mut rng);
107         engine.encode_string(&bytes[..], &mut b64);
108 
109         let mut wrapped_reader = io::Cursor::new(b64.as_bytes());
110         let mut short_reader = RandomShortRead {
111             delegate: &mut wrapped_reader,
112             rng: &mut rng,
113         };
114 
115         let mut decoder = DecoderReader::new(&mut short_reader, &engine);
116 
117         let decoded_len = decoder.read_to_end(&mut decoded).unwrap();
118         assert_eq!(size, decoded_len);
119         assert_eq!(&bytes[..], &decoded[..]);
120     }
121 }
122 
123 #[test]
read_in_short_increments()124 fn read_in_short_increments() {
125     let mut rng = rand::thread_rng();
126     let mut bytes = Vec::new();
127     let mut b64 = String::new();
128     let mut decoded = Vec::new();
129 
130     for _ in 0..10_000 {
131         bytes.clear();
132         b64.clear();
133         decoded.clear();
134 
135         let size = rng.gen_range(0..(10 * BUF_SIZE));
136         bytes.extend(iter::repeat(0).take(size));
137         // leave room to play around with larger buffers
138         decoded.extend(iter::repeat(0).take(size * 3));
139 
140         rng.fill_bytes(&mut bytes[..]);
141         assert_eq!(size, bytes.len());
142 
143         let engine = random_engine(&mut rng);
144 
145         engine.encode_string(&bytes[..], &mut b64);
146 
147         let mut wrapped_reader = io::Cursor::new(&b64[..]);
148         let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
149 
150         consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut decoder);
151     }
152 }
153 
154 #[test]
read_in_short_increments_with_short_delegate_reads()155 fn read_in_short_increments_with_short_delegate_reads() {
156     let mut rng = rand::thread_rng();
157     let mut bytes = Vec::new();
158     let mut b64 = String::new();
159     let mut decoded = Vec::new();
160 
161     for _ in 0..10_000 {
162         bytes.clear();
163         b64.clear();
164         decoded.clear();
165 
166         let size = rng.gen_range(0..(10 * BUF_SIZE));
167         bytes.extend(iter::repeat(0).take(size));
168         // leave room to play around with larger buffers
169         decoded.extend(iter::repeat(0).take(size * 3));
170 
171         rng.fill_bytes(&mut bytes[..]);
172         assert_eq!(size, bytes.len());
173 
174         let engine = random_engine(&mut rng);
175 
176         engine.encode_string(&bytes[..], &mut b64);
177 
178         let mut base_reader = io::Cursor::new(&b64[..]);
179         let mut decoder = DecoderReader::new(&mut base_reader, &engine);
180         let mut short_reader = RandomShortRead {
181             delegate: &mut decoder,
182             rng: &mut rand::thread_rng(),
183         };
184 
185         consume_with_short_reads_and_validate(
186             &mut rng,
187             &bytes[..],
188             &mut decoded,
189             &mut short_reader,
190         );
191     }
192 }
193 
194 #[test]
reports_invalid_last_symbol_correctly()195 fn reports_invalid_last_symbol_correctly() {
196     let mut rng = rand::thread_rng();
197     let mut bytes = Vec::new();
198     let mut b64 = String::new();
199     let mut b64_bytes = Vec::new();
200     let mut decoded = Vec::new();
201     let mut bulk_decoded = Vec::new();
202 
203     for _ in 0..1_000 {
204         bytes.clear();
205         b64.clear();
206         b64_bytes.clear();
207 
208         let size = rng.gen_range(1..(10 * BUF_SIZE));
209         bytes.extend(iter::repeat(0).take(size));
210         decoded.extend(iter::repeat(0).take(size));
211         rng.fill_bytes(&mut bytes[..]);
212         assert_eq!(size, bytes.len());
213 
214         let config = random_config(&mut rng);
215         let alphabet = random_alphabet(&mut rng);
216         // changing padding will cause invalid padding errors when we twiddle the last byte
217         let engine = GeneralPurpose::new(alphabet, config.with_encode_padding(false));
218         engine.encode_string(&bytes[..], &mut b64);
219         b64_bytes.extend(b64.bytes());
220         assert_eq!(b64_bytes.len(), b64.len());
221 
222         // change the last character to every possible symbol. Should behave the same as bulk
223         // decoding whether invalid or valid.
224         for &s1 in alphabet.symbols.iter() {
225             decoded.clear();
226             bulk_decoded.clear();
227 
228             // replace the last
229             *b64_bytes.last_mut().unwrap() = s1;
230             let bulk_res = engine.decode_vec(&b64_bytes[..], &mut bulk_decoded);
231 
232             let mut wrapped_reader = io::Cursor::new(&b64_bytes[..]);
233             let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
234 
235             let stream_res = decoder.read_to_end(&mut decoded).map(|_| ()).map_err(|e| {
236                 e.into_inner()
237                     .and_then(|e| e.downcast::<DecodeError>().ok())
238             });
239 
240             assert_eq!(bulk_res.map_err(|e| Some(Box::new(e))), stream_res);
241         }
242     }
243 }
244 
245 #[test]
reports_invalid_byte_correctly()246 fn reports_invalid_byte_correctly() {
247     let mut rng = rand::thread_rng();
248     let mut bytes = Vec::new();
249     let mut b64 = String::new();
250     let mut decoded = Vec::new();
251 
252     for _ in 0..10_000 {
253         bytes.clear();
254         b64.clear();
255         decoded.clear();
256 
257         let size = rng.gen_range(1..(10 * BUF_SIZE));
258         bytes.extend(iter::repeat(0).take(size));
259         rng.fill_bytes(&mut bytes[..size]);
260         assert_eq!(size, bytes.len());
261 
262         let engine = random_engine(&mut rng);
263 
264         engine.encode_string(&bytes[..], &mut b64);
265         // replace one byte, somewhere, with '*', which is invalid
266         let bad_byte_pos = rng.gen_range(0..b64.len());
267         let mut b64_bytes = b64.bytes().collect::<Vec<u8>>();
268         b64_bytes[bad_byte_pos] = b'*';
269 
270         let mut wrapped_reader = io::Cursor::new(b64_bytes.clone());
271         let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
272 
273         // some gymnastics to avoid double-moving the io::Error, which is not Copy
274         let read_decode_err = decoder
275             .read_to_end(&mut decoded)
276             .map_err(|e| {
277                 let kind = e.kind();
278                 let inner = e
279                     .into_inner()
280                     .and_then(|e| e.downcast::<DecodeError>().ok());
281                 inner.map(|i| (*i, kind))
282             })
283             .err()
284             .and_then(|o| o);
285 
286         let mut bulk_buf = Vec::new();
287         let bulk_decode_err = engine.decode_vec(&b64_bytes[..], &mut bulk_buf).err();
288 
289         // it's tricky to predict where the invalid data's offset will be since if it's in the last
290         // chunk it will be reported at the first padding location because it's treated as invalid
291         // padding. So, we just check that it's the same as it is for decoding all at once.
292         assert_eq!(
293             bulk_decode_err.map(|e| (e, io::ErrorKind::InvalidData)),
294             read_decode_err
295         );
296     }
297 }
298 
consume_with_short_reads_and_validate<R: io::Read>( rng: &mut rand::rngs::ThreadRng, expected_bytes: &[u8], decoded: &mut [u8], short_reader: &mut R, )299 fn consume_with_short_reads_and_validate<R: io::Read>(
300     rng: &mut rand::rngs::ThreadRng,
301     expected_bytes: &[u8],
302     decoded: &mut [u8],
303     short_reader: &mut R,
304 ) {
305     let mut total_read = 0_usize;
306     loop {
307         assert!(
308             total_read <= expected_bytes.len(),
309             "tr {} size {}",
310             total_read,
311             expected_bytes.len()
312         );
313         if total_read == expected_bytes.len() {
314             assert_eq!(expected_bytes, &decoded[..total_read]);
315             // should be done
316             assert_eq!(0, short_reader.read(&mut *decoded).unwrap());
317             // didn't write anything
318             assert_eq!(expected_bytes, &decoded[..total_read]);
319 
320             break;
321         }
322         let decode_len = rng.gen_range(1..cmp::max(2, expected_bytes.len() * 2));
323 
324         let read = short_reader
325             .read(&mut decoded[total_read..total_read + decode_len])
326             .unwrap();
327         total_read += read;
328     }
329 }
330 
331 /// Limits how many bytes a reader will provide in each read call.
332 /// Useful for shaking out code that may work fine only with typical input sources that always fill
333 /// the buffer.
334 struct RandomShortRead<'a, 'b, R: io::Read, N: rand::Rng> {
335     delegate: &'b mut R,
336     rng: &'a mut N,
337 }
338 
339 impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, N> {
read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error>340     fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
341         // avoid 0 since it means EOF for non-empty buffers
342         let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len());
343 
344         self.delegate.read(&mut buf[..effective_len])
345     }
346 }
347