• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::fmt;
2 use std::io::{self, Read};
3 use std::str::{self, FromStr};
4 
5 #[derive(Debug)]
6 pub enum CharReadError {
7     UnexpectedEof,
8     Utf8(str::Utf8Error),
9     Io(io::Error),
10 }
11 
12 impl From<str::Utf8Error> for CharReadError {
13     #[cold]
from(e: str::Utf8Error) -> CharReadError14     fn from(e: str::Utf8Error) -> CharReadError {
15         CharReadError::Utf8(e)
16     }
17 }
18 
19 impl From<io::Error> for CharReadError {
20     #[cold]
from(e: io::Error) -> CharReadError21     fn from(e: io::Error) -> CharReadError {
22         CharReadError::Io(e)
23     }
24 }
25 
26 impl fmt::Display for CharReadError {
27     #[cold]
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result28     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29         use self::CharReadError::{Io, UnexpectedEof, Utf8};
30         match *self {
31             UnexpectedEof => write!(f, "unexpected end of stream"),
32             Utf8(ref e) => write!(f, "UTF-8 decoding error: {e}"),
33             Io(ref e) => write!(f, "I/O error: {e}"),
34         }
35     }
36 }
37 
38 /// Character encoding used for parsing
39 #[derive(Debug, Copy, Clone, Eq, PartialEq)]
40 #[non_exhaustive]
41 pub enum Encoding {
42     /// Explicitly UTF-8 only
43     Utf8,
44     /// UTF-8 fallback, but can be any 8-bit encoding
45     Default,
46     /// ISO-8859-1
47     Latin1,
48     /// US-ASCII
49     Ascii,
50     /// Big-Endian
51     Utf16Be,
52     /// Little-Endian
53     Utf16Le,
54     /// Unknown endianness yet, will be sniffed
55     Utf16,
56     /// Not determined yet, may be sniffed to be anything
57     Unknown,
58 }
59 
60 // Rustc inlines eq_ignore_ascii_case and creates kilobytes of code!
61 #[inline(never)]
icmp(lower: &str, varcase: &str) -> bool62 fn icmp(lower: &str, varcase: &str) -> bool {
63     lower.bytes().zip(varcase.bytes()).all(|(l, v)| l == v.to_ascii_lowercase())
64 }
65 
66 impl FromStr for Encoding {
67     type Err = &'static str;
68 
from_str(val: &str) -> Result<Self, Self::Err>69     fn from_str(val: &str) -> Result<Self, Self::Err> {
70         if ["utf-8", "utf8"].into_iter().any(move |label| icmp(label, val)) {
71             Ok(Encoding::Utf8)
72         } else if ["iso-8859-1", "latin1"].into_iter().any(move |label| icmp(label, val)) {
73             Ok(Encoding::Latin1)
74         } else if ["utf-16", "utf16"].into_iter().any(move |label| icmp(label, val)) {
75             Ok(Encoding::Utf16)
76         } else if ["ascii", "us-ascii"].into_iter().any(move |label| icmp(label, val)) {
77             Ok(Encoding::Ascii)
78         } else {
79             Err("unknown encoding name")
80         }
81     }
82 }
83 
84 impl fmt::Display for Encoding {
85     #[cold]
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result86     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87         f.write_str(match self {
88             Encoding::Utf8 => "UTF-8",
89             Encoding::Default => "UTF-8",
90             Encoding::Latin1 => "ISO-8859-1",
91             Encoding::Ascii => "US-ASCII",
92             Encoding::Utf16Be => "UTF-16",
93             Encoding::Utf16Le => "UTF-16",
94             Encoding::Utf16 => "UTF-16",
95             Encoding::Unknown => "(unknown)",
96         })
97     }
98 }
99 
100 pub(crate) struct CharReader {
101     pub encoding: Encoding,
102 }
103 
104 impl CharReader {
new() -> Self105     pub fn new() -> Self {
106         Self {
107             encoding: Encoding::Unknown,
108         }
109     }
110 
next_char_from<R: Read>(&mut self, source: &mut R) -> Result<Option<char>, CharReadError>111     pub fn next_char_from<R: Read>(&mut self, source: &mut R) -> Result<Option<char>, CharReadError> {
112         let mut bytes = source.bytes();
113         const MAX_CODEPOINT_LEN: usize = 4;
114 
115         let mut buf = [0u8; MAX_CODEPOINT_LEN];
116         let mut pos = 0;
117         loop {
118             let next = match bytes.next() {
119                 Some(Ok(b)) => b,
120                 Some(Err(e)) => return Err(e.into()),
121                 None if pos == 0 => return Ok(None),
122                 None => return Err(CharReadError::UnexpectedEof),
123             };
124 
125             match self.encoding {
126                 Encoding::Utf8 | Encoding::Default => {
127                     // fast path for ASCII subset
128                     if pos == 0 && next.is_ascii() {
129                         return Ok(Some(next.into()));
130                     }
131 
132                     buf[pos] = next;
133                     pos += 1;
134 
135                     match str::from_utf8(&buf[..pos]) {
136                         Ok(s) => return Ok(s.chars().next()), // always Some(..)
137                         Err(_) if pos < MAX_CODEPOINT_LEN => continue,
138                         Err(e) => return Err(e.into()),
139                     }
140                 },
141                 Encoding::Latin1 => {
142                     return Ok(Some(next.into()));
143                 },
144                 Encoding::Ascii => {
145                     if next.is_ascii() {
146                         return Ok(Some(next.into()));
147                     } else {
148                         return Err(CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, "char is not ASCII")));
149                     }
150                 },
151                 Encoding::Unknown | Encoding::Utf16 => {
152                     buf[pos] = next;
153                     pos += 1;
154 
155                     // sniff BOM
156                     if pos <= 3 && buf[..pos] == [0xEF, 0xBB, 0xBF][..pos] {
157                         if pos == 3 && self.encoding != Encoding::Utf16 {
158                             pos = 0;
159                             self.encoding = Encoding::Utf8;
160                         }
161                     } else if pos <= 2 && buf[..pos] == [0xFE, 0xFF][..pos] {
162                         if pos == 2 {
163                             pos = 0;
164                             self.encoding = Encoding::Utf16Be;
165                         }
166                     } else if pos <= 2 && buf[..pos] == [0xFF, 0xFE][..pos] {
167                         if pos == 2 {
168                             pos = 0;
169                             self.encoding = Encoding::Utf16Le;
170                         }
171                     } else if pos == 1 && self.encoding == Encoding::Utf16 {
172                         // sniff ASCII char in UTF-16
173                         self.encoding = if next == 0 { Encoding::Utf16Be } else { Encoding::Utf16Le };
174                     } else {
175                         // UTF-8 is the default, but XML decl can change it to other 8-bit encoding
176                         self.encoding = Encoding::Default;
177                         if pos == 1 && next.is_ascii() {
178                             return Ok(Some(next.into()));
179                         }
180                     }
181                 },
182                 Encoding::Utf16Be => {
183                     buf[pos] = next;
184                     pos += 1;
185                     if pos == 2 {
186                         if let Some(Ok(c)) = char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap())]).next() {
187                             return Ok(Some(c));
188                         }
189                     } else if pos == 4 { // surrogate
190                         return char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap()), u16::from_be_bytes(buf[2..4].try_into().unwrap())])
191                             .next().transpose()
192                             .map_err(|e| CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, e)));
193                     }
194                 },
195                 Encoding::Utf16Le => {
196                     buf[pos] = next;
197                     pos += 1;
198                     if pos == 2 {
199                         if let Some(Ok(c)) = char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap())]).next() {
200                             return Ok(Some(c));
201                         }
202                     } else if pos == 4 { // surrogate
203                         return char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap()), u16::from_le_bytes(buf[2..4].try_into().unwrap())])
204                             .next().transpose()
205                             .map_err(|e| CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, e)));
206                     }
207                 },
208             }
209         }
210     }
211 }
212 
213 #[cfg(test)]
214 mod tests {
215     use super::{CharReadError, CharReader, Encoding};
216 
217     #[test]
test_next_char_from()218     fn test_next_char_from() {
219         use std::io;
220 
221         let mut bytes: &[u8] = "correct".as_bytes();    // correct ASCII
222         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('c'));
223 
224         let mut bytes: &[u8] = b"\xEF\xBB\xBF\xE2\x80\xA2!";  // BOM
225         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('•'));
226 
227         let mut bytes: &[u8] = b"\xEF\xBB\xBFx123";  // BOM
228         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('x'));
229 
230         let mut bytes: &[u8] = b"\xEF\xBB\xBF";  // Nothing after BOM
231         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
232 
233         let mut bytes: &[u8] = b"\xEF\xBB";  // Nothing after BO
234         assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof)));
235 
236         let mut bytes: &[u8] = b"\xEF\xBB\x42";  // Nothing after BO
237         assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(_)));
238 
239         let mut bytes: &[u8] = b"\xFE\xFF\x00\x42";  // UTF-16
240         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B'));
241 
242         let mut bytes: &[u8] = b"\xFF\xFE\x42\x00";  // UTF-16
243         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B'));
244 
245         let mut bytes: &[u8] = b"\xFF\xFE";  // UTF-16
246         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
247 
248         let mut bytes: &[u8] = b"\xFF\xFE\x00";  // UTF-16
249         assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof)));
250 
251         let mut bytes: &[u8] = "правильно".as_bytes();  // correct BMP
252         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('п'));
253 
254         let mut bytes: &[u8] = "правильно".as_bytes();
255         assert_eq!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes).unwrap(), Some('킿'));
256 
257         let mut bytes: &[u8] = "правильно".as_bytes();
258         assert_eq!(CharReader { encoding: Encoding::Utf16Le }.next_char_from(&mut bytes).unwrap(), Some('뿐'));
259 
260         let mut bytes: &[u8] = b"\xD8\xD8\x80";
261         assert!(matches!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes), Err(_)));
262 
263         let mut bytes: &[u8] = b"\x00\x42";
264         assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B'));
265 
266         let mut bytes: &[u8] = b"\x42\x00";
267         assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B'));
268 
269         let mut bytes: &[u8] = b"\x00";
270         assert!(matches!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes), Err(_)));
271 
272         let mut bytes: &[u8] = "��".as_bytes();          // correct non-BMP
273         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('��'));
274 
275         let mut bytes: &[u8] = b"";                     // empty
276         assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
277 
278         let mut bytes: &[u8] = b"\xf0\x9f\x98";         // incomplete code point
279         match CharReader::new().next_char_from(&mut bytes).unwrap_err() {
280             super::CharReadError::UnexpectedEof => {},
281             e => panic!("Unexpected result: {e:?}")
282         };
283 
284         let mut bytes: &[u8] = b"\xff\x9f\x98\x32";     // invalid code point
285         match CharReader::new().next_char_from(&mut bytes).unwrap_err() {
286             super::CharReadError::Utf8(_) => {},
287             e => panic!("Unexpected result: {e:?}")
288         };
289 
290         // error during read
291         struct ErrorReader;
292         impl io::Read for ErrorReader {
293             fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
294                 Err(io::Error::new(io::ErrorKind::Other, "test error"))
295             }
296         }
297 
298         let mut r = ErrorReader;
299         match CharReader::new().next_char_from(&mut r).unwrap_err() {
300             super::CharReadError::Io(ref e) if e.kind() == io::ErrorKind::Other &&
301                                                e.to_string().contains("test error") => {},
302             e => panic!("Unexpected result: {e:?}")
303         }
304     }
305 }
306