1 use crate::{
2 engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodePaddingMode},
3 DecodeError, PAD_BYTE,
4 };
5
6 // decode logic operates on chunks of 8 input bytes without padding
7 const INPUT_CHUNK_LEN: usize = 8;
8 const DECODED_CHUNK_LEN: usize = 6;
9
10 // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
11 // 2 bytes of any output u64 should not be counted as written to (but must be available in a
12 // slice).
13 const DECODED_CHUNK_SUFFIX: usize = 2;
14
15 // how many u64's of input to handle at a time
16 const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
17
18 const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
19
20 // includes the trailing 2 bytes for the final u64 write
21 const DECODED_BLOCK_LEN: usize =
22 CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
23
24 #[doc(hidden)]
25 pub struct GeneralPurposeEstimate {
26 /// Total number of decode chunks, including a possibly partial last chunk
27 num_chunks: usize,
28 decoded_len_estimate: usize,
29 }
30
31 impl GeneralPurposeEstimate {
new(encoded_len: usize) -> Self32 pub(crate) fn new(encoded_len: usize) -> Self {
33 Self {
34 num_chunks: encoded_len
35 .checked_add(INPUT_CHUNK_LEN - 1)
36 .expect("Overflow when calculating number of chunks in input")
37 / INPUT_CHUNK_LEN,
38 decoded_len_estimate: encoded_len
39 .checked_add(3)
40 .expect("Overflow when calculating decoded len estimate")
41 / 4
42 * 3,
43 }
44 }
45 }
46
47 impl DecodeEstimate for GeneralPurposeEstimate {
decoded_len_estimate(&self) -> usize48 fn decoded_len_estimate(&self) -> usize {
49 self.decoded_len_estimate
50 }
51 }
52
53 /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
54 /// Returns the number of bytes written, or an error.
55 // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
56 // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
57 // but this is fragile and the best setting changes with only minor code modifications.
58 #[inline]
decode_helper( input: &[u8], estimate: GeneralPurposeEstimate, output: &mut [u8], decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result<usize, DecodeError>59 pub(crate) fn decode_helper(
60 input: &[u8],
61 estimate: GeneralPurposeEstimate,
62 output: &mut [u8],
63 decode_table: &[u8; 256],
64 decode_allow_trailing_bits: bool,
65 padding_mode: DecodePaddingMode,
66 ) -> Result<usize, DecodeError> {
67 let remainder_len = input.len() % INPUT_CHUNK_LEN;
68
69 // Because the fast decode loop writes in groups of 8 bytes (unrolled to
70 // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
71 // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
72 // soon enough that there will always be 2 more bytes of valid data written after that loop.
73 let trailing_bytes_to_skip = match remainder_len {
74 // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
75 // and the fast decode logic cannot handle padding
76 0 => INPUT_CHUNK_LEN,
77 // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
78 1 | 5 => {
79 // trailing whitespace is so common that it's worth it to check the last byte to
80 // possibly return a better error message
81 if let Some(b) = input.last() {
82 if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE {
83 return Err(DecodeError::InvalidByte(input.len() - 1, *b));
84 }
85 }
86
87 return Err(DecodeError::InvalidLength);
88 }
89 // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
90 // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
91 // previous chunk.
92 2 => INPUT_CHUNK_LEN + 2,
93 // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this
94 // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
95 // with an error, not panic from going past the bounds of the output slice, so we let it
96 // use stage 3 + 4.
97 3 => INPUT_CHUNK_LEN + 3,
98 // This can also decode to one output byte because it may be 2 input chars + 2 padding
99 // chars, which would decode to 1 byte.
100 4 => INPUT_CHUNK_LEN + 4,
101 // Everything else is a legal decode len (given that we don't require padding), and will
102 // decode to at least 2 bytes of output.
103 _ => remainder_len,
104 };
105
106 // rounded up to include partial chunks
107 let mut remaining_chunks = estimate.num_chunks;
108
109 let mut input_index = 0;
110 let mut output_index = 0;
111
112 {
113 let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
114
115 // Fast loop, stage 1
116 // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
117 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
118 while input_index <= max_start_index {
119 let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
120 let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
121
122 decode_chunk(
123 &input_slice[0..],
124 input_index,
125 decode_table,
126 &mut output_slice[0..],
127 )?;
128 decode_chunk(
129 &input_slice[8..],
130 input_index + 8,
131 decode_table,
132 &mut output_slice[6..],
133 )?;
134 decode_chunk(
135 &input_slice[16..],
136 input_index + 16,
137 decode_table,
138 &mut output_slice[12..],
139 )?;
140 decode_chunk(
141 &input_slice[24..],
142 input_index + 24,
143 decode_table,
144 &mut output_slice[18..],
145 )?;
146
147 input_index += INPUT_BLOCK_LEN;
148 output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
149 remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
150 }
151 }
152
153 // Fast loop, stage 2 (aka still pretty fast loop)
154 // 8 bytes at a time for whatever we didn't do in stage 1.
155 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
156 while input_index < max_start_index {
157 decode_chunk(
158 &input[input_index..(input_index + INPUT_CHUNK_LEN)],
159 input_index,
160 decode_table,
161 &mut output
162 [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
163 )?;
164
165 output_index += DECODED_CHUNK_LEN;
166 input_index += INPUT_CHUNK_LEN;
167 remaining_chunks -= 1;
168 }
169 }
170 }
171
172 // Stage 3
173 // If input length was such that a chunk had to be deferred until after the fast loop
174 // because decoding it would have produced 2 trailing bytes that wouldn't then be
175 // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
176 // trailing bytes.
177 // However, we still need to avoid the last chunk (partial or complete) because it could
178 // have padding, so we always do 1 fewer to avoid the last chunk.
179 for _ in 1..remaining_chunks {
180 decode_chunk_precise(
181 &input[input_index..],
182 input_index,
183 decode_table,
184 &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
185 )?;
186
187 input_index += INPUT_CHUNK_LEN;
188 output_index += DECODED_CHUNK_LEN;
189 }
190
191 // always have one more (possibly partial) block of 8 input
192 debug_assert!(input.len() - input_index > 1 || input.is_empty());
193 debug_assert!(input.len() - input_index <= 8);
194
195 super::decode_suffix::decode_suffix(
196 input,
197 input_index,
198 output,
199 output_index,
200 decode_table,
201 decode_allow_trailing_bits,
202 padding_mode,
203 )
204 }
205
206 /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
207 /// first 6 of those contain meaningful data.
208 ///
209 /// `input` is the bytes to decode, of which the first 8 bytes will be processed.
210 /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
211 /// accurately)
212 /// `decode_table` is the lookup table for the particular base64 alphabet.
213 /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
214 /// data.
215 // yes, really inline (worth 30-50% speedup)
216 #[inline(always)]
decode_chunk( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>217 fn decode_chunk(
218 input: &[u8],
219 index_at_start_of_input: usize,
220 decode_table: &[u8; 256],
221 output: &mut [u8],
222 ) -> Result<(), DecodeError> {
223 let morsel = decode_table[input[0] as usize];
224 if morsel == INVALID_VALUE {
225 return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
226 }
227 let mut accum = (morsel as u64) << 58;
228
229 let morsel = decode_table[input[1] as usize];
230 if morsel == INVALID_VALUE {
231 return Err(DecodeError::InvalidByte(
232 index_at_start_of_input + 1,
233 input[1],
234 ));
235 }
236 accum |= (morsel as u64) << 52;
237
238 let morsel = decode_table[input[2] as usize];
239 if morsel == INVALID_VALUE {
240 return Err(DecodeError::InvalidByte(
241 index_at_start_of_input + 2,
242 input[2],
243 ));
244 }
245 accum |= (morsel as u64) << 46;
246
247 let morsel = decode_table[input[3] as usize];
248 if morsel == INVALID_VALUE {
249 return Err(DecodeError::InvalidByte(
250 index_at_start_of_input + 3,
251 input[3],
252 ));
253 }
254 accum |= (morsel as u64) << 40;
255
256 let morsel = decode_table[input[4] as usize];
257 if morsel == INVALID_VALUE {
258 return Err(DecodeError::InvalidByte(
259 index_at_start_of_input + 4,
260 input[4],
261 ));
262 }
263 accum |= (morsel as u64) << 34;
264
265 let morsel = decode_table[input[5] as usize];
266 if morsel == INVALID_VALUE {
267 return Err(DecodeError::InvalidByte(
268 index_at_start_of_input + 5,
269 input[5],
270 ));
271 }
272 accum |= (morsel as u64) << 28;
273
274 let morsel = decode_table[input[6] as usize];
275 if morsel == INVALID_VALUE {
276 return Err(DecodeError::InvalidByte(
277 index_at_start_of_input + 6,
278 input[6],
279 ));
280 }
281 accum |= (morsel as u64) << 22;
282
283 let morsel = decode_table[input[7] as usize];
284 if morsel == INVALID_VALUE {
285 return Err(DecodeError::InvalidByte(
286 index_at_start_of_input + 7,
287 input[7],
288 ));
289 }
290 accum |= (morsel as u64) << 16;
291
292 write_u64(output, accum);
293
294 Ok(())
295 }
296
297 /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
298 /// trailing garbage bytes.
299 #[inline]
decode_chunk_precise( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>300 fn decode_chunk_precise(
301 input: &[u8],
302 index_at_start_of_input: usize,
303 decode_table: &[u8; 256],
304 output: &mut [u8],
305 ) -> Result<(), DecodeError> {
306 let mut tmp_buf = [0_u8; 8];
307
308 decode_chunk(
309 input,
310 index_at_start_of_input,
311 decode_table,
312 &mut tmp_buf[..],
313 )?;
314
315 output[0..6].copy_from_slice(&tmp_buf[0..6]);
316
317 Ok(())
318 }
319
320 #[inline]
write_u64(output: &mut [u8], value: u64)321 fn write_u64(output: &mut [u8], value: u64) {
322 output[..8].copy_from_slice(&value.to_be_bytes());
323 }
324
325 #[cfg(test)]
326 mod tests {
327 use super::*;
328
329 use crate::engine::general_purpose::STANDARD;
330
331 #[test]
decode_chunk_precise_writes_only_6_bytes()332 fn decode_chunk_precise_writes_only_6_bytes() {
333 let input = b"Zm9vYmFy"; // "foobar"
334 let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
335
336 decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
337 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
338 }
339
340 #[test]
decode_chunk_writes_8_bytes()341 fn decode_chunk_writes_8_bytes() {
342 let input = b"Zm9vYmFy"; // "foobar"
343 let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
344
345 decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
346 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
347 }
348 }
349