• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use super::encoder::EncoderWriter;
2 use crate::engine::Engine;
3 use std::io;
4 
5 /// A `Write` implementation that base64-encodes data using the provided config and accumulates the
6 /// resulting base64 utf8 `&str` in a [StrConsumer] implementation (typically `String`), which is
7 /// then exposed via `into_inner()`.
8 ///
9 /// # Examples
10 ///
11 /// Buffer base64 in a new String:
12 ///
13 /// ```
14 /// use std::io::Write;
15 /// use base64::engine::general_purpose;
16 ///
17 /// let mut enc = base64::write::EncoderStringWriter::new(&general_purpose::STANDARD);
18 ///
19 /// enc.write_all(b"asdf").unwrap();
20 ///
21 /// // get the resulting String
22 /// let b64_string = enc.into_inner();
23 ///
24 /// assert_eq!("YXNkZg==", &b64_string);
25 /// ```
26 ///
27 /// Or, append to an existing `String`, which implements `StrConsumer`:
28 ///
29 /// ```
30 /// use std::io::Write;
31 /// use base64::engine::general_purpose;
32 ///
33 /// let mut buf = String::from("base64: ");
34 ///
35 /// let mut enc = base64::write::EncoderStringWriter::from_consumer(
36 ///     &mut buf,
37 ///     &general_purpose::STANDARD);
38 ///
39 /// enc.write_all(b"asdf").unwrap();
40 ///
41 /// // release the &mut reference on buf
42 /// let _ = enc.into_inner();
43 ///
44 /// assert_eq!("base64: YXNkZg==", &buf);
45 /// ```
46 ///
47 /// # Performance
48 ///
49 /// Because it has to validate that the base64 is UTF-8, it is about 80% as fast as writing plain
50 /// bytes to a `io::Write`.
51 pub struct EncoderStringWriter<'e, E: Engine, S: StrConsumer> {
52     encoder: EncoderWriter<'e, E, Utf8SingleCodeUnitWriter<S>>,
53 }
54 
55 impl<'e, E: Engine, S: StrConsumer> EncoderStringWriter<'e, E, S> {
56     /// Create a EncoderStringWriter that will append to the provided `StrConsumer`.
from_consumer(str_consumer: S, engine: &'e E) -> Self57     pub fn from_consumer(str_consumer: S, engine: &'e E) -> Self {
58         EncoderStringWriter {
59             encoder: EncoderWriter::new(Utf8SingleCodeUnitWriter { str_consumer }, engine),
60         }
61     }
62 
63     /// Encode all remaining buffered data, including any trailing incomplete input triples and
64     /// associated padding.
65     ///
66     /// Returns the base64-encoded form of the accumulated written data.
into_inner(mut self) -> S67     pub fn into_inner(mut self) -> S {
68         self.encoder
69             .finish()
70             .expect("Writing to a consumer should never fail")
71             .str_consumer
72     }
73 }
74 
75 impl<'e, E: Engine> EncoderStringWriter<'e, E, String> {
76     /// Create a EncoderStringWriter that will encode into a new `String` with the provided config.
new(engine: &'e E) -> Self77     pub fn new(engine: &'e E) -> Self {
78         EncoderStringWriter::from_consumer(String::new(), engine)
79     }
80 }
81 
82 impl<'e, E: Engine, S: StrConsumer> io::Write for EncoderStringWriter<'e, E, S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>83     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
84         self.encoder.write(buf)
85     }
86 
flush(&mut self) -> io::Result<()>87     fn flush(&mut self) -> io::Result<()> {
88         self.encoder.flush()
89     }
90 }
91 
92 /// An abstraction around consuming `str`s produced by base64 encoding.
93 pub trait StrConsumer {
94     /// Consume the base64 encoded data in `buf`
consume(&mut self, buf: &str)95     fn consume(&mut self, buf: &str);
96 }
97 
98 /// As for io::Write, `StrConsumer` is implemented automatically for `&mut S`.
99 impl<S: StrConsumer + ?Sized> StrConsumer for &mut S {
consume(&mut self, buf: &str)100     fn consume(&mut self, buf: &str) {
101         (**self).consume(buf);
102     }
103 }
104 
105 /// Pushes the str onto the end of the String
106 impl StrConsumer for String {
consume(&mut self, buf: &str)107     fn consume(&mut self, buf: &str) {
108         self.push_str(buf);
109     }
110 }
111 
112 /// A `Write` that only can handle bytes that are valid single-byte UTF-8 code units.
113 ///
114 /// This is safe because we only use it when writing base64, which is always valid UTF-8.
115 struct Utf8SingleCodeUnitWriter<S: StrConsumer> {
116     str_consumer: S,
117 }
118 
119 impl<S: StrConsumer> io::Write for Utf8SingleCodeUnitWriter<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>120     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
121         // Because we expect all input to be valid utf-8 individual bytes, we can encode any buffer
122         // length
123         let s = std::str::from_utf8(buf).expect("Input must be valid UTF-8");
124 
125         self.str_consumer.consume(s);
126 
127         Ok(buf.len())
128     }
129 
flush(&mut self) -> io::Result<()>130     fn flush(&mut self) -> io::Result<()> {
131         // no op
132         Ok(())
133     }
134 }
135 
136 #[cfg(test)]
137 mod tests {
138     use crate::{
139         engine::Engine, tests::random_engine, write::encoder_string_writer::EncoderStringWriter,
140     };
141     use rand::Rng;
142     use std::cmp;
143     use std::io::Write;
144 
145     #[test]
every_possible_split_of_input()146     fn every_possible_split_of_input() {
147         let mut rng = rand::thread_rng();
148         let mut orig_data = Vec::<u8>::new();
149         let mut normal_encoded = String::new();
150 
151         let size = 5_000;
152 
153         for i in 0..size {
154             orig_data.clear();
155             normal_encoded.clear();
156 
157             orig_data.resize(size, 0);
158             rng.fill(&mut orig_data[..]);
159 
160             let engine = random_engine(&mut rng);
161             engine.encode_string(&orig_data, &mut normal_encoded);
162 
163             let mut stream_encoder = EncoderStringWriter::new(&engine);
164             // Write the first i bytes, then the rest
165             stream_encoder.write_all(&orig_data[0..i]).unwrap();
166             stream_encoder.write_all(&orig_data[i..]).unwrap();
167 
168             let stream_encoded = stream_encoder.into_inner();
169 
170             assert_eq!(normal_encoded, stream_encoded);
171         }
172     }
173     #[test]
incremental_writes()174     fn incremental_writes() {
175         let mut rng = rand::thread_rng();
176         let mut orig_data = Vec::<u8>::new();
177         let mut normal_encoded = String::new();
178 
179         let size = 5_000;
180 
181         for _ in 0..size {
182             orig_data.clear();
183             normal_encoded.clear();
184 
185             orig_data.resize(size, 0);
186             rng.fill(&mut orig_data[..]);
187 
188             let engine = random_engine(&mut rng);
189             engine.encode_string(&orig_data, &mut normal_encoded);
190 
191             let mut stream_encoder = EncoderStringWriter::new(&engine);
192             // write small nibbles of data
193             let mut offset = 0;
194             while offset < size {
195                 let nibble_size = cmp::min(rng.gen_range(0..=64), size - offset);
196                 let len = stream_encoder
197                     .write(&orig_data[offset..offset + nibble_size])
198                     .unwrap();
199                 offset += len;
200             }
201 
202             let stream_encoded = stream_encoder.into_inner();
203 
204             assert_eq!(normal_encoded, stream_encoded);
205         }
206     }
207 }
208