1 use std::{
2 cmp,
3 io::{self, Read as _},
4 iter,
5 };
6
7 use rand::{Rng as _, RngCore as _};
8
9 use super::decoder::{DecoderReader, BUF_SIZE};
10 use crate::{
11 engine::{general_purpose::STANDARD, Engine, GeneralPurpose},
12 tests::{random_alphabet, random_config, random_engine},
13 DecodeError,
14 };
15
16 #[test]
simple()17 fn simple() {
18 let tests: &[(&[u8], &[u8])] = &[
19 (&b"0"[..], &b"MA=="[..]),
20 (b"01", b"MDE="),
21 (b"012", b"MDEy"),
22 (b"0123", b"MDEyMw=="),
23 (b"01234", b"MDEyMzQ="),
24 (b"012345", b"MDEyMzQ1"),
25 (b"0123456", b"MDEyMzQ1Ng=="),
26 (b"01234567", b"MDEyMzQ1Njc="),
27 (b"012345678", b"MDEyMzQ1Njc4"),
28 (b"0123456789", b"MDEyMzQ1Njc4OQ=="),
29 ][..];
30
31 for (text_expected, base64data) in tests.iter() {
32 // Read n bytes at a time.
33 for n in 1..base64data.len() + 1 {
34 let mut wrapped_reader = io::Cursor::new(base64data);
35 let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD);
36
37 // handle errors as you normally would
38 let mut text_got = Vec::new();
39 let mut buffer = vec![0u8; n];
40 while let Ok(read) = decoder.read(&mut buffer[..]) {
41 if read == 0 {
42 break;
43 }
44 text_got.extend_from_slice(&buffer[..read]);
45 }
46
47 assert_eq!(
48 text_got,
49 *text_expected,
50 "\nGot: {}\nExpected: {}",
51 String::from_utf8_lossy(&text_got[..]),
52 String::from_utf8_lossy(text_expected)
53 );
54 }
55 }
56 }
57
58 // Make sure we error out on trailing junk.
59 #[test]
trailing_junk()60 fn trailing_junk() {
61 let tests: &[&[u8]] = &[&b"MDEyMzQ1Njc4*!@#$%^&"[..], b"MDEyMzQ1Njc4OQ== "][..];
62
63 for base64data in tests.iter() {
64 // Read n bytes at a time.
65 for n in 1..base64data.len() + 1 {
66 let mut wrapped_reader = io::Cursor::new(base64data);
67 let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD);
68
69 // handle errors as you normally would
70 let mut buffer = vec![0u8; n];
71 let mut saw_error = false;
72 loop {
73 match decoder.read(&mut buffer[..]) {
74 Err(_) => {
75 saw_error = true;
76 break;
77 }
78 Ok(read) if read == 0 => break,
79 Ok(_) => (),
80 }
81 }
82
83 assert!(saw_error);
84 }
85 }
86 }
87
88 #[test]
handles_short_read_from_delegate()89 fn handles_short_read_from_delegate() {
90 let mut rng = rand::thread_rng();
91 let mut bytes = Vec::new();
92 let mut b64 = String::new();
93 let mut decoded = Vec::new();
94
95 for _ in 0..10_000 {
96 bytes.clear();
97 b64.clear();
98 decoded.clear();
99
100 let size = rng.gen_range(0..(10 * BUF_SIZE));
101 bytes.extend(iter::repeat(0).take(size));
102 bytes.truncate(size);
103 rng.fill_bytes(&mut bytes[..size]);
104 assert_eq!(size, bytes.len());
105
106 let engine = random_engine(&mut rng);
107 engine.encode_string(&bytes[..], &mut b64);
108
109 let mut wrapped_reader = io::Cursor::new(b64.as_bytes());
110 let mut short_reader = RandomShortRead {
111 delegate: &mut wrapped_reader,
112 rng: &mut rng,
113 };
114
115 let mut decoder = DecoderReader::new(&mut short_reader, &engine);
116
117 let decoded_len = decoder.read_to_end(&mut decoded).unwrap();
118 assert_eq!(size, decoded_len);
119 assert_eq!(&bytes[..], &decoded[..]);
120 }
121 }
122
123 #[test]
read_in_short_increments()124 fn read_in_short_increments() {
125 let mut rng = rand::thread_rng();
126 let mut bytes = Vec::new();
127 let mut b64 = String::new();
128 let mut decoded = Vec::new();
129
130 for _ in 0..10_000 {
131 bytes.clear();
132 b64.clear();
133 decoded.clear();
134
135 let size = rng.gen_range(0..(10 * BUF_SIZE));
136 bytes.extend(iter::repeat(0).take(size));
137 // leave room to play around with larger buffers
138 decoded.extend(iter::repeat(0).take(size * 3));
139
140 rng.fill_bytes(&mut bytes[..]);
141 assert_eq!(size, bytes.len());
142
143 let engine = random_engine(&mut rng);
144
145 engine.encode_string(&bytes[..], &mut b64);
146
147 let mut wrapped_reader = io::Cursor::new(&b64[..]);
148 let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
149
150 consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut decoder);
151 }
152 }
153
154 #[test]
read_in_short_increments_with_short_delegate_reads()155 fn read_in_short_increments_with_short_delegate_reads() {
156 let mut rng = rand::thread_rng();
157 let mut bytes = Vec::new();
158 let mut b64 = String::new();
159 let mut decoded = Vec::new();
160
161 for _ in 0..10_000 {
162 bytes.clear();
163 b64.clear();
164 decoded.clear();
165
166 let size = rng.gen_range(0..(10 * BUF_SIZE));
167 bytes.extend(iter::repeat(0).take(size));
168 // leave room to play around with larger buffers
169 decoded.extend(iter::repeat(0).take(size * 3));
170
171 rng.fill_bytes(&mut bytes[..]);
172 assert_eq!(size, bytes.len());
173
174 let engine = random_engine(&mut rng);
175
176 engine.encode_string(&bytes[..], &mut b64);
177
178 let mut base_reader = io::Cursor::new(&b64[..]);
179 let mut decoder = DecoderReader::new(&mut base_reader, &engine);
180 let mut short_reader = RandomShortRead {
181 delegate: &mut decoder,
182 rng: &mut rand::thread_rng(),
183 };
184
185 consume_with_short_reads_and_validate(
186 &mut rng,
187 &bytes[..],
188 &mut decoded,
189 &mut short_reader,
190 );
191 }
192 }
193
194 #[test]
reports_invalid_last_symbol_correctly()195 fn reports_invalid_last_symbol_correctly() {
196 let mut rng = rand::thread_rng();
197 let mut bytes = Vec::new();
198 let mut b64 = String::new();
199 let mut b64_bytes = Vec::new();
200 let mut decoded = Vec::new();
201 let mut bulk_decoded = Vec::new();
202
203 for _ in 0..1_000 {
204 bytes.clear();
205 b64.clear();
206 b64_bytes.clear();
207
208 let size = rng.gen_range(1..(10 * BUF_SIZE));
209 bytes.extend(iter::repeat(0).take(size));
210 decoded.extend(iter::repeat(0).take(size));
211 rng.fill_bytes(&mut bytes[..]);
212 assert_eq!(size, bytes.len());
213
214 let config = random_config(&mut rng);
215 let alphabet = random_alphabet(&mut rng);
216 // changing padding will cause invalid padding errors when we twiddle the last byte
217 let engine = GeneralPurpose::new(alphabet, config.with_encode_padding(false));
218 engine.encode_string(&bytes[..], &mut b64);
219 b64_bytes.extend(b64.bytes());
220 assert_eq!(b64_bytes.len(), b64.len());
221
222 // change the last character to every possible symbol. Should behave the same as bulk
223 // decoding whether invalid or valid.
224 for &s1 in alphabet.symbols.iter() {
225 decoded.clear();
226 bulk_decoded.clear();
227
228 // replace the last
229 *b64_bytes.last_mut().unwrap() = s1;
230 let bulk_res = engine.decode_vec(&b64_bytes[..], &mut bulk_decoded);
231
232 let mut wrapped_reader = io::Cursor::new(&b64_bytes[..]);
233 let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
234
235 let stream_res = decoder.read_to_end(&mut decoded).map(|_| ()).map_err(|e| {
236 e.into_inner()
237 .and_then(|e| e.downcast::<DecodeError>().ok())
238 });
239
240 assert_eq!(bulk_res.map_err(|e| Some(Box::new(e))), stream_res);
241 }
242 }
243 }
244
245 #[test]
reports_invalid_byte_correctly()246 fn reports_invalid_byte_correctly() {
247 let mut rng = rand::thread_rng();
248 let mut bytes = Vec::new();
249 let mut b64 = String::new();
250 let mut decoded = Vec::new();
251
252 for _ in 0..10_000 {
253 bytes.clear();
254 b64.clear();
255 decoded.clear();
256
257 let size = rng.gen_range(1..(10 * BUF_SIZE));
258 bytes.extend(iter::repeat(0).take(size));
259 rng.fill_bytes(&mut bytes[..size]);
260 assert_eq!(size, bytes.len());
261
262 let engine = random_engine(&mut rng);
263
264 engine.encode_string(&bytes[..], &mut b64);
265 // replace one byte, somewhere, with '*', which is invalid
266 let bad_byte_pos = rng.gen_range(0..b64.len());
267 let mut b64_bytes = b64.bytes().collect::<Vec<u8>>();
268 b64_bytes[bad_byte_pos] = b'*';
269
270 let mut wrapped_reader = io::Cursor::new(b64_bytes.clone());
271 let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
272
273 // some gymnastics to avoid double-moving the io::Error, which is not Copy
274 let read_decode_err = decoder
275 .read_to_end(&mut decoded)
276 .map_err(|e| {
277 let kind = e.kind();
278 let inner = e
279 .into_inner()
280 .and_then(|e| e.downcast::<DecodeError>().ok());
281 inner.map(|i| (*i, kind))
282 })
283 .err()
284 .and_then(|o| o);
285
286 let mut bulk_buf = Vec::new();
287 let bulk_decode_err = engine.decode_vec(&b64_bytes[..], &mut bulk_buf).err();
288
289 // it's tricky to predict where the invalid data's offset will be since if it's in the last
290 // chunk it will be reported at the first padding location because it's treated as invalid
291 // padding. So, we just check that it's the same as it is for decoding all at once.
292 assert_eq!(
293 bulk_decode_err.map(|e| (e, io::ErrorKind::InvalidData)),
294 read_decode_err
295 );
296 }
297 }
298
consume_with_short_reads_and_validate<R: io::Read>( rng: &mut rand::rngs::ThreadRng, expected_bytes: &[u8], decoded: &mut [u8], short_reader: &mut R, )299 fn consume_with_short_reads_and_validate<R: io::Read>(
300 rng: &mut rand::rngs::ThreadRng,
301 expected_bytes: &[u8],
302 decoded: &mut [u8],
303 short_reader: &mut R,
304 ) {
305 let mut total_read = 0_usize;
306 loop {
307 assert!(
308 total_read <= expected_bytes.len(),
309 "tr {} size {}",
310 total_read,
311 expected_bytes.len()
312 );
313 if total_read == expected_bytes.len() {
314 assert_eq!(expected_bytes, &decoded[..total_read]);
315 // should be done
316 assert_eq!(0, short_reader.read(&mut *decoded).unwrap());
317 // didn't write anything
318 assert_eq!(expected_bytes, &decoded[..total_read]);
319
320 break;
321 }
322 let decode_len = rng.gen_range(1..cmp::max(2, expected_bytes.len() * 2));
323
324 let read = short_reader
325 .read(&mut decoded[total_read..total_read + decode_len])
326 .unwrap();
327 total_read += read;
328 }
329 }
330
331 /// Limits how many bytes a reader will provide in each read call.
332 /// Useful for shaking out code that may work fine only with typical input sources that always fill
333 /// the buffer.
334 struct RandomShortRead<'a, 'b, R: io::Read, N: rand::Rng> {
335 delegate: &'b mut R,
336 rng: &'a mut N,
337 }
338
339 impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, N> {
read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error>340 fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
341 // avoid 0 since it means EOF for non-empty buffers
342 let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len());
343
344 self.delegate.read(&mut buf[..effective_len])
345 }
346 }
347