1 use super::encoder::EncoderWriter; 2 use crate::engine::Engine; 3 use std::io; 4 5 /// A `Write` implementation that base64-encodes data using the provided config and accumulates the 6 /// resulting base64 utf8 `&str` in a [StrConsumer] implementation (typically `String`), which is 7 /// then exposed via `into_inner()`. 8 /// 9 /// # Examples 10 /// 11 /// Buffer base64 in a new String: 12 /// 13 /// ``` 14 /// use std::io::Write; 15 /// use base64::engine::general_purpose; 16 /// 17 /// let mut enc = base64::write::EncoderStringWriter::new(&general_purpose::STANDARD); 18 /// 19 /// enc.write_all(b"asdf").unwrap(); 20 /// 21 /// // get the resulting String 22 /// let b64_string = enc.into_inner(); 23 /// 24 /// assert_eq!("YXNkZg==", &b64_string); 25 /// ``` 26 /// 27 /// Or, append to an existing `String`, which implements `StrConsumer`: 28 /// 29 /// ``` 30 /// use std::io::Write; 31 /// use base64::engine::general_purpose; 32 /// 33 /// let mut buf = String::from("base64: "); 34 /// 35 /// let mut enc = base64::write::EncoderStringWriter::from_consumer( 36 /// &mut buf, 37 /// &general_purpose::STANDARD); 38 /// 39 /// enc.write_all(b"asdf").unwrap(); 40 /// 41 /// // release the &mut reference on buf 42 /// let _ = enc.into_inner(); 43 /// 44 /// assert_eq!("base64: YXNkZg==", &buf); 45 /// ``` 46 /// 47 /// # Performance 48 /// 49 /// Because it has to validate that the base64 is UTF-8, it is about 80% as fast as writing plain 50 /// bytes to a `io::Write`. 51 pub struct EncoderStringWriter<'e, E: Engine, S: StrConsumer> { 52 encoder: EncoderWriter<'e, E, Utf8SingleCodeUnitWriter<S>>, 53 } 54 55 impl<'e, E: Engine, S: StrConsumer> EncoderStringWriter<'e, E, S> { 56 /// Create a EncoderStringWriter that will append to the provided `StrConsumer`. from_consumer(str_consumer: S, engine: &'e E) -> Self57 pub fn from_consumer(str_consumer: S, engine: &'e E) -> Self { 58 EncoderStringWriter { 59 encoder: EncoderWriter::new(Utf8SingleCodeUnitWriter { str_consumer }, engine), 60 } 61 } 62 63 /// Encode all remaining buffered data, including any trailing incomplete input triples and 64 /// associated padding. 65 /// 66 /// Returns the base64-encoded form of the accumulated written data. into_inner(mut self) -> S67 pub fn into_inner(mut self) -> S { 68 self.encoder 69 .finish() 70 .expect("Writing to a consumer should never fail") 71 .str_consumer 72 } 73 } 74 75 impl<'e, E: Engine> EncoderStringWriter<'e, E, String> { 76 /// Create a EncoderStringWriter that will encode into a new `String` with the provided config. new(engine: &'e E) -> Self77 pub fn new(engine: &'e E) -> Self { 78 EncoderStringWriter::from_consumer(String::new(), engine) 79 } 80 } 81 82 impl<'e, E: Engine, S: StrConsumer> io::Write for EncoderStringWriter<'e, E, S> { write(&mut self, buf: &[u8]) -> io::Result<usize>83 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 84 self.encoder.write(buf) 85 } 86 flush(&mut self) -> io::Result<()>87 fn flush(&mut self) -> io::Result<()> { 88 self.encoder.flush() 89 } 90 } 91 92 /// An abstraction around consuming `str`s produced by base64 encoding. 93 pub trait StrConsumer { 94 /// Consume the base64 encoded data in `buf` consume(&mut self, buf: &str)95 fn consume(&mut self, buf: &str); 96 } 97 98 /// As for io::Write, `StrConsumer` is implemented automatically for `&mut S`. 99 impl<S: StrConsumer + ?Sized> StrConsumer for &mut S { consume(&mut self, buf: &str)100 fn consume(&mut self, buf: &str) { 101 (**self).consume(buf); 102 } 103 } 104 105 /// Pushes the str onto the end of the String 106 impl StrConsumer for String { consume(&mut self, buf: &str)107 fn consume(&mut self, buf: &str) { 108 self.push_str(buf); 109 } 110 } 111 112 /// A `Write` that only can handle bytes that are valid single-byte UTF-8 code units. 113 /// 114 /// This is safe because we only use it when writing base64, which is always valid UTF-8. 115 struct Utf8SingleCodeUnitWriter<S: StrConsumer> { 116 str_consumer: S, 117 } 118 119 impl<S: StrConsumer> io::Write for Utf8SingleCodeUnitWriter<S> { write(&mut self, buf: &[u8]) -> io::Result<usize>120 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 121 // Because we expect all input to be valid utf-8 individual bytes, we can encode any buffer 122 // length 123 let s = std::str::from_utf8(buf).expect("Input must be valid UTF-8"); 124 125 self.str_consumer.consume(s); 126 127 Ok(buf.len()) 128 } 129 flush(&mut self) -> io::Result<()>130 fn flush(&mut self) -> io::Result<()> { 131 // no op 132 Ok(()) 133 } 134 } 135 136 #[cfg(test)] 137 mod tests { 138 use crate::{ 139 engine::Engine, tests::random_engine, write::encoder_string_writer::EncoderStringWriter, 140 }; 141 use rand::Rng; 142 use std::cmp; 143 use std::io::Write; 144 145 #[test] every_possible_split_of_input()146 fn every_possible_split_of_input() { 147 let mut rng = rand::thread_rng(); 148 let mut orig_data = Vec::<u8>::new(); 149 let mut normal_encoded = String::new(); 150 151 let size = 5_000; 152 153 for i in 0..size { 154 orig_data.clear(); 155 normal_encoded.clear(); 156 157 orig_data.resize(size, 0); 158 rng.fill(&mut orig_data[..]); 159 160 let engine = random_engine(&mut rng); 161 engine.encode_string(&orig_data, &mut normal_encoded); 162 163 let mut stream_encoder = EncoderStringWriter::new(&engine); 164 // Write the first i bytes, then the rest 165 stream_encoder.write_all(&orig_data[0..i]).unwrap(); 166 stream_encoder.write_all(&orig_data[i..]).unwrap(); 167 168 let stream_encoded = stream_encoder.into_inner(); 169 170 assert_eq!(normal_encoded, stream_encoded); 171 } 172 } 173 #[test] incremental_writes()174 fn incremental_writes() { 175 let mut rng = rand::thread_rng(); 176 let mut orig_data = Vec::<u8>::new(); 177 let mut normal_encoded = String::new(); 178 179 let size = 5_000; 180 181 for _ in 0..size { 182 orig_data.clear(); 183 normal_encoded.clear(); 184 185 orig_data.resize(size, 0); 186 rng.fill(&mut orig_data[..]); 187 188 let engine = random_engine(&mut rng); 189 engine.encode_string(&orig_data, &mut normal_encoded); 190 191 let mut stream_encoder = EncoderStringWriter::new(&engine); 192 // write small nibbles of data 193 let mut offset = 0; 194 while offset < size { 195 let nibble_size = cmp::min(rng.gen_range(0..=64), size - offset); 196 let len = stream_encoder 197 .write(&orig_data[offset..offset + nibble_size]) 198 .unwrap(); 199 offset += len; 200 } 201 202 let stream_encoded = stream_encoder.into_inner(); 203 204 assert_eq!(normal_encoded, stream_encoded); 205 } 206 } 207 } 208