• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::{
2     engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodePaddingMode},
3     DecodeError, PAD_BYTE,
4 };
5 
6 // decode logic operates on chunks of 8 input bytes without padding
7 const INPUT_CHUNK_LEN: usize = 8;
8 const DECODED_CHUNK_LEN: usize = 6;
9 
10 // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
11 // 2 bytes of any output u64 should not be counted as written to (but must be available in a
12 // slice).
13 const DECODED_CHUNK_SUFFIX: usize = 2;
14 
15 // how many u64's of input to handle at a time
16 const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
17 
18 const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
19 
20 // includes the trailing 2 bytes for the final u64 write
21 const DECODED_BLOCK_LEN: usize =
22     CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
23 
24 #[doc(hidden)]
25 pub struct GeneralPurposeEstimate {
26     /// Total number of decode chunks, including a possibly partial last chunk
27     num_chunks: usize,
28     decoded_len_estimate: usize,
29 }
30 
31 impl GeneralPurposeEstimate {
new(encoded_len: usize) -> Self32     pub(crate) fn new(encoded_len: usize) -> Self {
33         Self {
34             num_chunks: encoded_len
35                 .checked_add(INPUT_CHUNK_LEN - 1)
36                 .expect("Overflow when calculating number of chunks in input")
37                 / INPUT_CHUNK_LEN,
38             decoded_len_estimate: encoded_len
39                 .checked_add(3)
40                 .expect("Overflow when calculating decoded len estimate")
41                 / 4
42                 * 3,
43         }
44     }
45 }
46 
47 impl DecodeEstimate for GeneralPurposeEstimate {
decoded_len_estimate(&self) -> usize48     fn decoded_len_estimate(&self) -> usize {
49         self.decoded_len_estimate
50     }
51 }
52 
53 /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
54 /// Returns the number of bytes written, or an error.
55 // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
56 // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
57 // but this is fragile and the best setting changes with only minor code modifications.
58 #[inline]
decode_helper( input: &[u8], estimate: GeneralPurposeEstimate, output: &mut [u8], decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result<usize, DecodeError>59 pub(crate) fn decode_helper(
60     input: &[u8],
61     estimate: GeneralPurposeEstimate,
62     output: &mut [u8],
63     decode_table: &[u8; 256],
64     decode_allow_trailing_bits: bool,
65     padding_mode: DecodePaddingMode,
66 ) -> Result<usize, DecodeError> {
67     let remainder_len = input.len() % INPUT_CHUNK_LEN;
68 
69     // Because the fast decode loop writes in groups of 8 bytes (unrolled to
70     // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
71     // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
72     // soon enough that there will always be 2 more bytes of valid data written after that loop.
73     let trailing_bytes_to_skip = match remainder_len {
74         // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
75         // and the fast decode logic cannot handle padding
76         0 => INPUT_CHUNK_LEN,
77         // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
78         1 | 5 => {
79             // trailing whitespace is so common that it's worth it to check the last byte to
80             // possibly return a better error message
81             if let Some(b) = input.last() {
82                 if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE {
83                     return Err(DecodeError::InvalidByte(input.len() - 1, *b));
84                 }
85             }
86 
87             return Err(DecodeError::InvalidLength);
88         }
89         // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
90         // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
91         // previous chunk.
92         2 => INPUT_CHUNK_LEN + 2,
93         // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this
94         // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
95         // with an error, not panic from going past the bounds of the output slice, so we let it
96         // use stage 3 + 4.
97         3 => INPUT_CHUNK_LEN + 3,
98         // This can also decode to one output byte because it may be 2 input chars + 2 padding
99         // chars, which would decode to 1 byte.
100         4 => INPUT_CHUNK_LEN + 4,
101         // Everything else is a legal decode len (given that we don't require padding), and will
102         // decode to at least 2 bytes of output.
103         _ => remainder_len,
104     };
105 
106     // rounded up to include partial chunks
107     let mut remaining_chunks = estimate.num_chunks;
108 
109     let mut input_index = 0;
110     let mut output_index = 0;
111 
112     {
113         let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
114 
115         // Fast loop, stage 1
116         // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
117         if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
118             while input_index <= max_start_index {
119                 let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
120                 let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
121 
122                 decode_chunk(
123                     &input_slice[0..],
124                     input_index,
125                     decode_table,
126                     &mut output_slice[0..],
127                 )?;
128                 decode_chunk(
129                     &input_slice[8..],
130                     input_index + 8,
131                     decode_table,
132                     &mut output_slice[6..],
133                 )?;
134                 decode_chunk(
135                     &input_slice[16..],
136                     input_index + 16,
137                     decode_table,
138                     &mut output_slice[12..],
139                 )?;
140                 decode_chunk(
141                     &input_slice[24..],
142                     input_index + 24,
143                     decode_table,
144                     &mut output_slice[18..],
145                 )?;
146 
147                 input_index += INPUT_BLOCK_LEN;
148                 output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
149                 remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
150             }
151         }
152 
153         // Fast loop, stage 2 (aka still pretty fast loop)
154         // 8 bytes at a time for whatever we didn't do in stage 1.
155         if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
156             while input_index < max_start_index {
157                 decode_chunk(
158                     &input[input_index..(input_index + INPUT_CHUNK_LEN)],
159                     input_index,
160                     decode_table,
161                     &mut output
162                         [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
163                 )?;
164 
165                 output_index += DECODED_CHUNK_LEN;
166                 input_index += INPUT_CHUNK_LEN;
167                 remaining_chunks -= 1;
168             }
169         }
170     }
171 
172     // Stage 3
173     // If input length was such that a chunk had to be deferred until after the fast loop
174     // because decoding it would have produced 2 trailing bytes that wouldn't then be
175     // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
176     // trailing bytes.
177     // However, we still need to avoid the last chunk (partial or complete) because it could
178     // have padding, so we always do 1 fewer to avoid the last chunk.
179     for _ in 1..remaining_chunks {
180         decode_chunk_precise(
181             &input[input_index..],
182             input_index,
183             decode_table,
184             &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
185         )?;
186 
187         input_index += INPUT_CHUNK_LEN;
188         output_index += DECODED_CHUNK_LEN;
189     }
190 
191     // always have one more (possibly partial) block of 8 input
192     debug_assert!(input.len() - input_index > 1 || input.is_empty());
193     debug_assert!(input.len() - input_index <= 8);
194 
195     super::decode_suffix::decode_suffix(
196         input,
197         input_index,
198         output,
199         output_index,
200         decode_table,
201         decode_allow_trailing_bits,
202         padding_mode,
203     )
204 }
205 
206 /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
207 /// first 6 of those contain meaningful data.
208 ///
209 /// `input` is the bytes to decode, of which the first 8 bytes will be processed.
210 /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
211 /// accurately)
212 /// `decode_table` is the lookup table for the particular base64 alphabet.
213 /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
214 /// data.
215 // yes, really inline (worth 30-50% speedup)
216 #[inline(always)]
decode_chunk( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>217 fn decode_chunk(
218     input: &[u8],
219     index_at_start_of_input: usize,
220     decode_table: &[u8; 256],
221     output: &mut [u8],
222 ) -> Result<(), DecodeError> {
223     let morsel = decode_table[input[0] as usize];
224     if morsel == INVALID_VALUE {
225         return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
226     }
227     let mut accum = (morsel as u64) << 58;
228 
229     let morsel = decode_table[input[1] as usize];
230     if morsel == INVALID_VALUE {
231         return Err(DecodeError::InvalidByte(
232             index_at_start_of_input + 1,
233             input[1],
234         ));
235     }
236     accum |= (morsel as u64) << 52;
237 
238     let morsel = decode_table[input[2] as usize];
239     if morsel == INVALID_VALUE {
240         return Err(DecodeError::InvalidByte(
241             index_at_start_of_input + 2,
242             input[2],
243         ));
244     }
245     accum |= (morsel as u64) << 46;
246 
247     let morsel = decode_table[input[3] as usize];
248     if morsel == INVALID_VALUE {
249         return Err(DecodeError::InvalidByte(
250             index_at_start_of_input + 3,
251             input[3],
252         ));
253     }
254     accum |= (morsel as u64) << 40;
255 
256     let morsel = decode_table[input[4] as usize];
257     if morsel == INVALID_VALUE {
258         return Err(DecodeError::InvalidByte(
259             index_at_start_of_input + 4,
260             input[4],
261         ));
262     }
263     accum |= (morsel as u64) << 34;
264 
265     let morsel = decode_table[input[5] as usize];
266     if morsel == INVALID_VALUE {
267         return Err(DecodeError::InvalidByte(
268             index_at_start_of_input + 5,
269             input[5],
270         ));
271     }
272     accum |= (morsel as u64) << 28;
273 
274     let morsel = decode_table[input[6] as usize];
275     if morsel == INVALID_VALUE {
276         return Err(DecodeError::InvalidByte(
277             index_at_start_of_input + 6,
278             input[6],
279         ));
280     }
281     accum |= (morsel as u64) << 22;
282 
283     let morsel = decode_table[input[7] as usize];
284     if morsel == INVALID_VALUE {
285         return Err(DecodeError::InvalidByte(
286             index_at_start_of_input + 7,
287             input[7],
288         ));
289     }
290     accum |= (morsel as u64) << 16;
291 
292     write_u64(output, accum);
293 
294     Ok(())
295 }
296 
297 /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
298 /// trailing garbage bytes.
299 #[inline]
decode_chunk_precise( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>300 fn decode_chunk_precise(
301     input: &[u8],
302     index_at_start_of_input: usize,
303     decode_table: &[u8; 256],
304     output: &mut [u8],
305 ) -> Result<(), DecodeError> {
306     let mut tmp_buf = [0_u8; 8];
307 
308     decode_chunk(
309         input,
310         index_at_start_of_input,
311         decode_table,
312         &mut tmp_buf[..],
313     )?;
314 
315     output[0..6].copy_from_slice(&tmp_buf[0..6]);
316 
317     Ok(())
318 }
319 
320 #[inline]
write_u64(output: &mut [u8], value: u64)321 fn write_u64(output: &mut [u8], value: u64) {
322     output[..8].copy_from_slice(&value.to_be_bytes());
323 }
324 
325 #[cfg(test)]
326 mod tests {
327     use super::*;
328 
329     use crate::engine::general_purpose::STANDARD;
330 
331     #[test]
decode_chunk_precise_writes_only_6_bytes()332     fn decode_chunk_precise_writes_only_6_bytes() {
333         let input = b"Zm9vYmFy"; // "foobar"
334         let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
335 
336         decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
337         assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
338     }
339 
340     #[test]
decode_chunk_writes_8_bytes()341     fn decode_chunk_writes_8_bytes() {
342         let input = b"Zm9vYmFy"; // "foobar"
343         let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
344 
345         decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
346         assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
347     }
348 }
349