• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use core::mem::take;
15 use std::collections::hash_map::IntoIter;
16 
17 use crate::body::mime::common::{
18     consume_crlf, data_copy, trim_front_lwsp, BytesResult, TokenResult,
19 };
20 use crate::body::mime::{CR, LF};
21 use crate::body::TokenStatus;
22 use crate::error::{ErrorKind, HttpError};
23 use crate::h1::response::decoder::{HEADER_NAME_BYTES, HEADER_VALUE_BYTES};
24 use crate::headers::{HeaderName, HeaderValue, Headers};
25 
26 #[derive(Debug, PartialEq)]
27 pub(crate) enum HeaderStatus {
28     Start,
29     Name,
30     Colon,
31     Value,
32     Crlf,
33     End,
34 }
35 
36 #[derive(Debug)]
37 pub(crate) struct EncodeHeaders {
38     pub(crate) stage: HeaderStatus,
39     pub(crate) into_iter: IntoIter<HeaderName, HeaderValue>,
40     pub(crate) value: Option<HeaderValue>,
41     pub(crate) src: Vec<u8>,
42     pub(crate) src_idx: usize,
43 }
44 
45 impl EncodeHeaders {
new(headers: Headers) -> Self46     pub(crate) fn new(headers: Headers) -> Self {
47         EncodeHeaders {
48             stage: HeaderStatus::Start,
49             into_iter: headers.into_iter(),
50             value: None,
51             src: vec![],
52             src_idx: 0,
53         }
54     }
55 
56     // when the encode stage go to next
check_next(&mut self)57     fn check_next(&mut self) {
58         match self.stage {
59             HeaderStatus::Start => match self.into_iter.next() {
60                 Some((name, value)) => {
61                     self.src = name.into_bytes();
62                     self.src_idx = 0;
63                     self.value = Some(value);
64                     self.stage = HeaderStatus::Name;
65                 }
66                 None => {
67                     self.stage = HeaderStatus::End;
68                 }
69             },
70             HeaderStatus::Name => {
71                 self.stage = HeaderStatus::Colon;
72                 self.src = b":".to_vec();
73                 self.src_idx = 0;
74             }
75             HeaderStatus::Colon => {
76                 self.stage = HeaderStatus::Value;
77                 match self.value.take() {
78                     Some(v) => {
79                         self.src = v.to_vec();
80                     }
81                     None => {
82                         self.src = vec![];
83                     }
84                 }
85                 self.src_idx = 0;
86             }
87             HeaderStatus::Value => {
88                 self.stage = HeaderStatus::Crlf;
89                 self.src = b"\r\n".to_vec();
90                 self.src_idx = 0;
91             }
92             HeaderStatus::Crlf => {
93                 self.stage = HeaderStatus::Start;
94             }
95             HeaderStatus::End => {}
96         }
97     }
98 
encode(&mut self, dst: &mut [u8]) -> TokenResult<usize>99     pub(crate) fn encode(&mut self, dst: &mut [u8]) -> TokenResult<usize> {
100         match self.stage {
101             HeaderStatus::Start => {
102                 self.check_next();
103                 Ok(TokenStatus::Partial(0))
104             }
105             HeaderStatus::Name | HeaderStatus::Colon | HeaderStatus::Value | HeaderStatus::Crlf => {
106                 match data_copy(&self.src, &mut self.src_idx, dst)? {
107                     TokenStatus::Partial(size) => Ok(TokenStatus::Partial(size)),
108                     TokenStatus::Complete(size) => {
109                         self.check_next();
110                         Ok(TokenStatus::Partial(size))
111                     }
112                 }
113             }
114             HeaderStatus::End => Ok(TokenStatus::Complete(0)),
115         }
116     }
117 }
118 
119 #[derive(Debug, PartialEq)]
120 pub(crate) struct DecodeHeaders {
121     pub(crate) stage: HeaderStatus,
122     pub(crate) name_src: Vec<u8>,
123     pub(crate) src: Vec<u8>,
124     pub(crate) headers: Headers,
125 }
126 
127 impl DecodeHeaders {
new() -> Self128     pub(crate) fn new() -> Self {
129         DecodeHeaders {
130             stage: HeaderStatus::Start,
131             headers: Headers::new(),
132             name_src: vec![],
133             src: vec![],
134         }
135     }
136 
137     // when the decode stage go to next
check_next(&mut self)138     fn check_next(&mut self) {
139         match self.stage {
140             HeaderStatus::Start => {
141                 self.stage = HeaderStatus::Name;
142             }
143             HeaderStatus::Name => {
144                 self.stage = HeaderStatus::Value;
145             }
146             HeaderStatus::Colon => {}
147             HeaderStatus::Value => {
148                 self.stage = HeaderStatus::Crlf;
149             }
150             HeaderStatus::Crlf => {
151                 self.stage = HeaderStatus::Start;
152                 self.src = vec![];
153             }
154             HeaderStatus::End => {}
155         }
156     }
157 
decode<'a>( &mut self, buf: &'a [u8], ) -> Result<(TokenStatus<Headers, ()>, &'a [u8]), HttpError>158     pub(crate) fn decode<'a>(
159         &mut self,
160         buf: &'a [u8],
161     ) -> Result<(TokenStatus<Headers, ()>, &'a [u8]), HttpError> {
162         if buf.is_empty() {
163             return Err(ErrorKind::InvalidInput.into());
164         }
165 
166         let mut results = TokenStatus::Partial(());
167         let mut remains = buf;
168         loop {
169             let rest = match self.stage {
170                 HeaderStatus::Start => self.start_decode(remains),
171                 HeaderStatus::Name => self.name_decode(remains),
172                 // not use
173                 HeaderStatus::Colon => Ok(remains),
174                 HeaderStatus::Value => self.value_decode(remains),
175                 HeaderStatus::Crlf => self.crlf_decode(remains),
176                 HeaderStatus::End => {
177                     results = TokenStatus::Complete(take(&mut self.headers));
178                     break;
179                 }
180             }?;
181             remains = rest;
182             if remains.is_empty() && self.stage != HeaderStatus::End {
183                 break;
184             }
185         }
186         Ok((results, remains))
187     }
188 
start_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError>189     fn start_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError> {
190         let buf = if self.src.is_empty() {
191             trim_front_lwsp(buf)
192         } else {
193             buf
194         };
195 
196         let cr_meet = self.is_cr_meet();
197         match buf[0] {
198             CR => {
199                 if cr_meet {
200                     Err(ErrorKind::InvalidInput.into())
201                 } else if buf.len() == 1 {
202                     self.src.push(CR);
203                     Ok(&[])
204                 } else if buf[1] == LF {
205                     self.stage = HeaderStatus::End;
206                     Ok(&buf[2..])
207                 } else {
208                     Err(ErrorKind::InvalidInput.into())
209                 }
210             }
211             LF => {
212                 self.stage = HeaderStatus::End;
213                 Ok(&buf[1..])
214             }
215             _ => {
216                 if cr_meet {
217                     Err(ErrorKind::InvalidInput.into())
218                 } else {
219                     self.check_next();
220                     Ok(buf)
221                 }
222             }
223         }
224     }
225 
226     // check '\r'
is_cr_meet(&self) -> bool227     fn is_cr_meet(&self) -> bool {
228         self.src.len() == 1 && self.src[0] == CR
229     }
230 
crlf_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError>231     fn crlf_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError> {
232         let cr_meet = self.is_cr_meet();
233         match consume_crlf(buf, cr_meet)? {
234             TokenStatus::Partial(_size) => Ok(&[]),
235             TokenStatus::Complete(unparsed) => {
236                 self.check_next();
237                 Ok(unparsed)
238             }
239         }
240     }
241 
name_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError>242     fn name_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError> {
243         let buf = if self.src.is_empty() {
244             trim_front_lwsp(buf)
245         } else {
246             buf
247         };
248 
249         match Self::get_header_name(buf)? {
250             TokenStatus::Partial(unparsed) => {
251                 self.src.extend_from_slice(unparsed);
252                 Ok(&[])
253             }
254             TokenStatus::Complete((src, unparsed)) => {
255                 // clone in this.
256                 self.src.extend_from_slice(src);
257                 self.name_src = take(&mut self.src);
258                 self.check_next();
259                 Ok(unparsed)
260             }
261         }
262     }
263 
value_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError>264     fn value_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError> {
265         let buf = if self.src.is_empty() {
266             trim_front_lwsp(buf)
267         } else {
268             buf
269         };
270 
271         match Self::get_header_value(buf)? {
272             TokenStatus::Partial(unparsed) => {
273                 self.src.extend_from_slice(unparsed);
274                 Ok(&[])
275             }
276             TokenStatus::Complete((src, unparsed)) => {
277                 // clone in this.
278                 self.src.extend_from_slice(src);
279                 let value = take(&mut self.src);
280                 let name = take(&mut self.name_src);
281 
282                 self.headers
283                     .insert(trim_front_lwsp(&name), trim_front_lwsp(&value))?;
284                 self.check_next();
285                 Ok(unparsed)
286             }
287         }
288     }
289 
290     // end with ":"
get_header_name(buf: &[u8]) -> BytesResult291     fn get_header_name(buf: &[u8]) -> BytesResult {
292         for (i, b) in buf.iter().enumerate() {
293             if *b == b':' {
294                 // match "k:v" or "k: v"
295                 return Ok(TokenStatus::Complete((&buf[..i], &buf[i + 1..])));
296             } else if !HEADER_NAME_BYTES[*b as usize] {
297                 return Err(ErrorKind::InvalidInput.into());
298             }
299         }
300         Ok(TokenStatus::Partial(buf))
301     }
302 
303     // end with "\r" or "\n" or "\r\n"
get_header_value(buf: &[u8]) -> BytesResult304     fn get_header_value(buf: &[u8]) -> BytesResult {
305         for (i, b) in buf.iter().enumerate() {
306             if *b == CR || *b == LF {
307                 return Ok(TokenStatus::Complete((&buf[..i], &buf[i..])));
308             } else if !HEADER_VALUE_BYTES[*b as usize] {
309                 return Err(ErrorKind::InvalidInput.into());
310             }
311         }
312         Ok(TokenStatus::Partial(buf))
313     }
314 }
315 
316 #[cfg(test)]
317 mod ut_decode_headers {
318     use crate::body::mime::common::headers::DecodeHeaders;
319     use crate::body::TokenStatus;
320     use crate::headers::Headers;
321 
322     /// UT test cases for `DecodeHeaders::decode`.
323     ///
324     /// # Brief
325     /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`.
326     /// 2. Uses `DecodeHeaders::decode` to decode headers.
327     /// 3. The headers is divided by "\r\n".
328     /// 4. Checks whether the result is correct.
329     #[test]
ut_decode_headers_new()330     fn ut_decode_headers_new() {
331         let buf = b"\r\nabcd";
332         let mut decoder = DecodeHeaders::new();
333         let (headers, rest) = decoder.decode(buf).unwrap();
334         assert_eq!(headers, TokenStatus::Complete(Headers::new()));
335         assert_eq!(rest, b"abcd");
336 
337         // has LWSP
338         let buf = b"    \r\nabcd";
339         let mut decoder = DecodeHeaders::new();
340         let (headers, rest) = decoder.decode(buf).unwrap();
341         assert_eq!(headers, TokenStatus::Complete(Headers::new()));
342         assert_eq!(rest, b"abcd");
343     }
344 
345     /// UT test cases for `DecodeHeaders::decode`.
346     ///
347     /// # Brief
348     /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`.
349     /// 2. Uses `DecodeHeaders::decode` to decode headers.
350     /// 3. The headers is divided by "\n".
351     /// 4. Checks whether the result is correct.
352     #[test]
ut_decode_headers_new2()353     fn ut_decode_headers_new2() {
354         let buf = b"\nabcd";
355         let mut decoder = DecodeHeaders::new();
356         let (headers, rest) = decoder.decode(buf).unwrap();
357         assert_eq!(headers, TokenStatus::Complete(Headers::new()));
358         assert_eq!(rest, b"abcd");
359 
360         // has LWSP
361         let buf = b"    \nabcd";
362         let mut decoder = DecodeHeaders::new();
363         let (headers, rest) = decoder.decode(buf).unwrap();
364         assert_eq!(headers, TokenStatus::Complete(Headers::new()));
365         assert_eq!(rest, b"abcd");
366     }
367 
368     /// UT test cases for `DecodeHeaders::decode`.
369     ///
370     /// # Brief
371     /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`.
372     /// 2. Uses `DecodeHeaders::decode` to decode headers.
373     /// 3. The headers has *LWSP-char(b' ' or b'\t').
374     /// 4. Checks whether the result is correct.
375     #[test]
ut_decode_headers_decode()376     fn ut_decode_headers_decode() {
377         // all use "\r\n"
378         let buf = b"    name1:   value1\r\n    name2:    value2\r\n\r\n";
379         let mut decoder = DecodeHeaders::new();
380         let (headers, rest) = decoder.decode(buf).unwrap();
381         assert_eq!(
382             headers,
383             TokenStatus::Complete({
384                 let mut headers = Headers::new();
385                 headers.insert("name1", "value1").unwrap();
386                 headers.insert("name2", "value2").unwrap();
387                 headers
388             })
389         );
390         assert_eq!(std::str::from_utf8(rest).unwrap(), "");
391 
392         // all use "\n"
393         let buf = b"name1:value1\nname2:value2\n\n";
394         let mut decoder = DecodeHeaders::new();
395         let (headers, rest) = decoder.decode(buf).unwrap();
396         assert_eq!(
397             headers,
398             TokenStatus::Complete({
399                 let mut headers = Headers::new();
400                 headers.insert("name1", "value1").unwrap();
401                 headers.insert("name2", "value2").unwrap();
402                 headers
403             })
404         );
405         assert_eq!(std::str::from_utf8(rest).unwrap(), "");
406 
407         // some use "\r\n"
408         let buf = b"name1:value1\nname2:value2\r\n\n";
409         let mut decoder = DecodeHeaders::new();
410         let (headers, rest) = decoder.decode(buf).unwrap();
411         assert_eq!(
412             headers,
413             TokenStatus::Complete({
414                 let mut headers = Headers::new();
415                 headers.insert("name1", "value1").unwrap();
416                 headers.insert("name2", "value2").unwrap();
417                 headers
418             })
419         );
420         assert_eq!(std::str::from_utf8(rest).unwrap(), "");
421     }
422 
423     /// UT test cases for `DecodeHeaders::decode`.
424     ///
425     /// # Brief
426     /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`.
427     /// 2. Uses `DecodeHeaders::decode` to decode headers.
428     /// 3. The headers is common.
429     /// 4. Checks whether the result is correct.
430     #[test]
ut_decode_headers_decode2()431     fn ut_decode_headers_decode2() {
432         let buf = b"name1:value1\r\n\r\n\r\naaaa";
433         let mut decoder = DecodeHeaders::new();
434         let (headers, rest) = decoder.decode(buf).unwrap();
435         assert_eq!(
436             headers,
437             TokenStatus::Complete({
438                 let mut headers = Headers::new();
439                 headers.insert("name1", "value1").unwrap();
440                 headers
441             })
442         );
443         assert_eq!(std::str::from_utf8(rest).unwrap(), "\r\naaaa");
444     }
445 
446     /// UT test cases for `DecodeHeaders::decode`.
447     ///
448     /// # Brief
449     /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`.
450     /// 2. Uses `DecodeHeaders::decode` to decode headers.
451     /// 3. The decode bytes are divided into several executions.
452     /// 4. Checks whether the result is correct.
453     #[test]
ut_decode_headers_decode3()454     fn ut_decode_headers_decode3() {
455         let buf = b"name1:value1\r\nname2:value2\r\n\r\naaaa";
456         let mut decoder = DecodeHeaders::new();
457         // nam
458         let (headers, rest) = decoder.decode(&buf[0..3]).unwrap();
459         assert_eq!(headers, TokenStatus::Partial(()));
460         assert_eq!(std::str::from_utf8(rest).unwrap(), "");
461 
462         // e1:value1\r
463         let (headers, rest) = decoder.decode(&buf[3..13]).unwrap();
464         assert_eq!(headers, TokenStatus::Partial(()));
465         assert_eq!(std::str::from_utf8(rest).unwrap(), "");
466 
467         // \nname2:value2\r\n\r
468         let (headers, rest) = decoder.decode(&buf[13..29]).unwrap();
469         assert_eq!(headers, TokenStatus::Partial(()));
470         assert_eq!(std::str::from_utf8(rest).unwrap(), "");
471 
472         // \n
473         let (headers, rest) = decoder.decode(&buf[29..30]).unwrap();
474         assert_eq!(
475             headers,
476             TokenStatus::Complete({
477                 let mut headers = Headers::new();
478                 headers.insert("name1", "value1").unwrap();
479                 headers.insert("name2", "value2").unwrap();
480                 headers
481             })
482         );
483         assert_eq!(std::str::from_utf8(rest).unwrap(), "");
484 
485         // aaaa
486         let (headers, rest) = decoder.decode(&buf[30..34]).unwrap();
487         assert_eq!(headers, TokenStatus::Complete(Headers::new()));
488         assert_eq!(std::str::from_utf8(rest).unwrap(), "aaaa");
489     }
490 }
491