• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Copy-pasted from the internet
2 /// Available encoding character sets
3 #[derive(Clone, Copy, Debug)]
4 enum _CharacterSet {
5     /// The standard character set (uses `+` and `/`)
6     _Standard,
7     /// The URL safe character set (uses `-` and `_`)
8     _UrlSafe,
9 }
10 
11 static STANDARD_CHARS: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
12                                         abcdefghijklmnopqrstuvwxyz\
13                                         0123456789+/";
14 
15 static _URLSAFE_CHARS: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
16                                        abcdefghijklmnopqrstuvwxyz\
17                                        0123456789-_";
18 
encode(input: &[u8]) -> String19 pub fn encode(input: &[u8]) -> String {
20     let bytes = STANDARD_CHARS;
21 
22     let len = input.len();
23 
24     // Preallocate memory.
25     let prealloc_len = (len + 2) / 3 * 4;
26     let mut out_bytes = vec![b'='; prealloc_len];
27 
28     // Deal with padding bytes
29     let mod_len = len % 3;
30 
31     // Use iterators to reduce branching
32     {
33         let mut s_in = input[..len - mod_len].iter().map(|&x| x as u32);
34         let mut s_out = out_bytes.iter_mut();
35 
36         // Convenient shorthand
37         let enc = |val| bytes[val as usize];
38         let mut write = |val| *s_out.next().unwrap() = val;
39 
40         // Iterate though blocks of 4
41         while let (Some(first), Some(second), Some(third)) = (s_in.next(), s_in.next(), s_in.next())
42         {
43             let n = first << 16 | second << 8 | third;
44 
45             // This 24-bit number gets separated into four 6-bit numbers.
46             write(enc((n >> 18) & 63));
47             write(enc((n >> 12) & 63));
48             write(enc((n >> 6) & 63));
49             write(enc((n >> 0) & 63));
50         }
51 
52         // Heh, would be cool if we knew this was exhaustive
53         // (the dream of bounded integer types)
54         match mod_len {
55             0 => (),
56             1 => {
57                 let n = (input[len - 1] as u32) << 16;
58                 write(enc((n >> 18) & 63));
59                 write(enc((n >> 12) & 63));
60             }
61             2 => {
62                 let n = (input[len - 2] as u32) << 16 | (input[len - 1] as u32) << 8;
63                 write(enc((n >> 18) & 63));
64                 write(enc((n >> 12) & 63));
65                 write(enc((n >> 6) & 63));
66             }
67             _ => panic!("Algebra is broken, please alert the math police"),
68         }
69     }
70 
71     // `out_bytes` vec is prepopulated with `=` symbols and then only updated
72     // with base64 chars, so this unsafe is safe.
73     unsafe { String::from_utf8_unchecked(out_bytes) }
74 }
75 
76 /// Errors that can occur when decoding a base64 encoded string
77 #[derive(Clone, Copy, Debug, thiserror::Error)]
78 pub enum FromBase64Error {
79     /// The input contained a character not part of the base64 format
80     #[error("Invalid base64 byte")]
81     InvalidBase64Byte(u8, usize),
82     /// The input had an invalid length
83     #[error("Invalid base64 length")]
84     InvalidBase64Length,
85 }
86 
decode(input: &str) -> Result<Vec<u8>, FromBase64Error>87 pub fn decode(input: &str) -> Result<Vec<u8>, FromBase64Error> {
88     let mut r = Vec::with_capacity(input.len());
89     let mut buf: u32 = 0;
90     let mut modulus = 0;
91 
92     let mut it = input.as_bytes().iter();
93     for byte in it.by_ref() {
94         let code = DECODE_TABLE[*byte as usize];
95         if code >= SPECIAL_CODES_START {
96             match code {
97                 NEWLINE_CODE => continue,
98                 EQUALS_CODE => break,
99                 INVALID_CODE => {
100                     return Err(FromBase64Error::InvalidBase64Byte(
101                         *byte,
102                         (byte as *const _ as usize) - input.as_ptr() as usize,
103                     ))
104                 }
105                 _ => unreachable!(),
106             }
107         }
108         buf = (buf | code as u32) << 6;
109         modulus += 1;
110         if modulus == 4 {
111             modulus = 0;
112             r.push((buf >> 22) as u8);
113             r.push((buf >> 14) as u8);
114             r.push((buf >> 6) as u8);
115         }
116     }
117 
118     for byte in it {
119         match *byte {
120             b'=' | b'\r' | b'\n' => continue,
121             _ => {
122                 return Err(FromBase64Error::InvalidBase64Byte(
123                     *byte,
124                     (byte as *const _ as usize) - input.as_ptr() as usize,
125                 ))
126             }
127         }
128     }
129 
130     match modulus {
131         2 => {
132             r.push((buf >> 10) as u8);
133         }
134         3 => {
135             r.push((buf >> 16) as u8);
136             r.push((buf >> 8) as u8);
137         }
138         0 => (),
139         _ => return Err(FromBase64Error::InvalidBase64Length),
140     }
141 
142     Ok(r)
143 }
144 
145 const DECODE_TABLE: [u8; 256] = [
146     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD, 0xFF, 0xFF, 0xFD, 0xFF, 0xFF,
147     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
148     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0x3E, 0xFF, 0x3F,
149     0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF,
150     0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
151     0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0x3F,
152     0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
153     0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
154     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
155     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
156     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
157     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
158     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
159     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
160     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
161     0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
162 ];
163 const INVALID_CODE: u8 = 0xFF;
164 const EQUALS_CODE: u8 = 0xFE;
165 const NEWLINE_CODE: u8 = 0xFD;
166 const SPECIAL_CODES_START: u8 = NEWLINE_CODE;
167 
168 #[cfg(test)]
169 mod tests {
170     use super::*;
171 
172     #[test]
test_encode_basic()173     fn test_encode_basic() {
174         assert_eq!(encode(b""), "");
175         assert_eq!(encode(b"f"), "Zg==");
176         assert_eq!(encode(b"fo"), "Zm8=");
177         assert_eq!(encode(b"foo"), "Zm9v");
178         assert_eq!(encode(b"foob"), "Zm9vYg==");
179         assert_eq!(encode(b"fooba"), "Zm9vYmE=");
180         assert_eq!(encode(b"foobar"), "Zm9vYmFy");
181     }
182 
183     #[test]
test_encode_standard_safe()184     fn test_encode_standard_safe() {
185         assert_eq!(encode(&[251, 255]), "+/8=");
186     }
187 
188     #[test]
test_decode_basic()189     fn test_decode_basic() {
190         assert_eq!(decode("").unwrap(), b"");
191         assert_eq!(decode("Zg==").unwrap(), b"f");
192         assert_eq!(decode("Zm8=").unwrap(), b"fo");
193         assert_eq!(decode("Zm9v").unwrap(), b"foo");
194         assert_eq!(decode("Zm9vYg==").unwrap(), b"foob");
195         assert_eq!(decode("Zm9vYmE=").unwrap(), b"fooba");
196         assert_eq!(decode("Zm9vYmFy").unwrap(), b"foobar");
197     }
198 
199     #[test]
test_decode()200     fn test_decode() {
201         assert_eq!(decode("Zm9vYmFy").unwrap(), b"foobar");
202     }
203 
204     #[test]
test_decode_newlines()205     fn test_decode_newlines() {
206         assert_eq!(decode("Zm9v\r\nYmFy").unwrap(), b"foobar");
207         assert_eq!(decode("Zm9vYg==\r\n").unwrap(), b"foob");
208         assert_eq!(decode("Zm9v\nYmFy").unwrap(), b"foobar");
209         assert_eq!(decode("Zm9vYg==\n").unwrap(), b"foob");
210     }
211 
212     #[test]
test_decode_urlsafe()213     fn test_decode_urlsafe() {
214         assert_eq!(decode("-_8").unwrap(), decode("+/8=").unwrap());
215     }
216 
217     #[test]
test_from_base64_invalid_char()218     fn test_from_base64_invalid_char() {
219         assert!(decode("Zm$=").is_err());
220         assert!(decode("Zg==$").is_err());
221     }
222 
223     #[test]
test_decode_invalid_padding()224     fn test_decode_invalid_padding() {
225         assert!(decode("Z===").is_err());
226     }
227 }
228