1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use core::mem::take; 15 use std::collections::hash_map::IntoIter; 16 17 use crate::body::mime::common::{ 18 consume_crlf, data_copy, trim_front_lwsp, BytesResult, TokenResult, 19 }; 20 use crate::body::mime::{CR, LF}; 21 use crate::body::TokenStatus; 22 use crate::error::{ErrorKind, HttpError}; 23 use crate::h1::response::decoder::{HEADER_NAME_BYTES, HEADER_VALUE_BYTES}; 24 use crate::headers::{HeaderName, HeaderValue, Headers}; 25 26 #[derive(Debug, PartialEq)] 27 pub(crate) enum HeaderStatus { 28 Start, 29 Name, 30 Colon, 31 Value, 32 Crlf, 33 End, 34 } 35 36 #[derive(Debug)] 37 pub(crate) struct EncodeHeaders { 38 pub(crate) stage: HeaderStatus, 39 pub(crate) into_iter: IntoIter<HeaderName, HeaderValue>, 40 pub(crate) value: Option<HeaderValue>, 41 pub(crate) src: Vec<u8>, 42 pub(crate) src_idx: usize, 43 } 44 45 impl EncodeHeaders { new(headers: Headers) -> Self46 pub(crate) fn new(headers: Headers) -> Self { 47 EncodeHeaders { 48 stage: HeaderStatus::Start, 49 into_iter: headers.into_iter(), 50 value: None, 51 src: vec![], 52 src_idx: 0, 53 } 54 } 55 56 // when the encode stage go to next check_next(&mut self)57 fn check_next(&mut self) { 58 match self.stage { 59 HeaderStatus::Start => match self.into_iter.next() { 60 Some((name, value)) => { 61 self.src = name.into_bytes(); 62 self.src_idx = 0; 63 self.value = Some(value); 64 self.stage = HeaderStatus::Name; 65 } 66 None => { 67 self.stage = HeaderStatus::End; 68 } 69 }, 70 HeaderStatus::Name => { 71 self.stage = HeaderStatus::Colon; 72 self.src = b":".to_vec(); 73 self.src_idx = 0; 74 } 75 HeaderStatus::Colon => { 76 self.stage = HeaderStatus::Value; 77 match self.value.take() { 78 Some(v) => { 79 self.src = v.to_vec(); 80 } 81 None => { 82 self.src = vec![]; 83 } 84 } 85 self.src_idx = 0; 86 } 87 HeaderStatus::Value => { 88 self.stage = HeaderStatus::Crlf; 89 self.src = b"\r\n".to_vec(); 90 self.src_idx = 0; 91 } 92 HeaderStatus::Crlf => { 93 self.stage = HeaderStatus::Start; 94 } 95 HeaderStatus::End => {} 96 } 97 } 98 encode(&mut self, dst: &mut [u8]) -> TokenResult<usize>99 pub(crate) fn encode(&mut self, dst: &mut [u8]) -> TokenResult<usize> { 100 match self.stage { 101 HeaderStatus::Start => { 102 self.check_next(); 103 Ok(TokenStatus::Partial(0)) 104 } 105 HeaderStatus::Name | HeaderStatus::Colon | HeaderStatus::Value | HeaderStatus::Crlf => { 106 match data_copy(&self.src, &mut self.src_idx, dst)? { 107 TokenStatus::Partial(size) => Ok(TokenStatus::Partial(size)), 108 TokenStatus::Complete(size) => { 109 self.check_next(); 110 Ok(TokenStatus::Partial(size)) 111 } 112 } 113 } 114 HeaderStatus::End => Ok(TokenStatus::Complete(0)), 115 } 116 } 117 } 118 119 #[derive(Debug, PartialEq)] 120 pub(crate) struct DecodeHeaders { 121 pub(crate) stage: HeaderStatus, 122 pub(crate) name_src: Vec<u8>, 123 pub(crate) src: Vec<u8>, 124 pub(crate) headers: Headers, 125 } 126 127 impl DecodeHeaders { new() -> Self128 pub(crate) fn new() -> Self { 129 DecodeHeaders { 130 stage: HeaderStatus::Start, 131 headers: Headers::new(), 132 name_src: vec![], 133 src: vec![], 134 } 135 } 136 137 // when the decode stage go to next check_next(&mut self)138 fn check_next(&mut self) { 139 match self.stage { 140 HeaderStatus::Start => { 141 self.stage = HeaderStatus::Name; 142 } 143 HeaderStatus::Name => { 144 self.stage = HeaderStatus::Value; 145 } 146 HeaderStatus::Colon => {} 147 HeaderStatus::Value => { 148 self.stage = HeaderStatus::Crlf; 149 } 150 HeaderStatus::Crlf => { 151 self.stage = HeaderStatus::Start; 152 self.src = vec![]; 153 } 154 HeaderStatus::End => {} 155 } 156 } 157 decode<'a>( &mut self, buf: &'a [u8], ) -> Result<(TokenStatus<Headers, ()>, &'a [u8]), HttpError>158 pub(crate) fn decode<'a>( 159 &mut self, 160 buf: &'a [u8], 161 ) -> Result<(TokenStatus<Headers, ()>, &'a [u8]), HttpError> { 162 if buf.is_empty() { 163 return Err(ErrorKind::InvalidInput.into()); 164 } 165 166 let mut results = TokenStatus::Partial(()); 167 let mut remains = buf; 168 loop { 169 let rest = match self.stage { 170 HeaderStatus::Start => self.start_decode(remains), 171 HeaderStatus::Name => self.name_decode(remains), 172 // not use 173 HeaderStatus::Colon => Ok(remains), 174 HeaderStatus::Value => self.value_decode(remains), 175 HeaderStatus::Crlf => self.crlf_decode(remains), 176 HeaderStatus::End => { 177 results = TokenStatus::Complete(take(&mut self.headers)); 178 break; 179 } 180 }?; 181 remains = rest; 182 if remains.is_empty() && self.stage != HeaderStatus::End { 183 break; 184 } 185 } 186 Ok((results, remains)) 187 } 188 start_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError>189 fn start_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError> { 190 let buf = if self.src.is_empty() { 191 trim_front_lwsp(buf) 192 } else { 193 buf 194 }; 195 196 let cr_meet = self.is_cr_meet(); 197 match buf[0] { 198 CR => { 199 if cr_meet { 200 Err(ErrorKind::InvalidInput.into()) 201 } else if buf.len() == 1 { 202 self.src.push(CR); 203 Ok(&[]) 204 } else if buf[1] == LF { 205 self.stage = HeaderStatus::End; 206 Ok(&buf[2..]) 207 } else { 208 Err(ErrorKind::InvalidInput.into()) 209 } 210 } 211 LF => { 212 self.stage = HeaderStatus::End; 213 Ok(&buf[1..]) 214 } 215 _ => { 216 if cr_meet { 217 Err(ErrorKind::InvalidInput.into()) 218 } else { 219 self.check_next(); 220 Ok(buf) 221 } 222 } 223 } 224 } 225 226 // check '\r' is_cr_meet(&self) -> bool227 fn is_cr_meet(&self) -> bool { 228 self.src.len() == 1 && self.src[0] == CR 229 } 230 crlf_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError>231 fn crlf_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError> { 232 let cr_meet = self.is_cr_meet(); 233 match consume_crlf(buf, cr_meet)? { 234 TokenStatus::Partial(_size) => Ok(&[]), 235 TokenStatus::Complete(unparsed) => { 236 self.check_next(); 237 Ok(unparsed) 238 } 239 } 240 } 241 name_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError>242 fn name_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError> { 243 let buf = if self.src.is_empty() { 244 trim_front_lwsp(buf) 245 } else { 246 buf 247 }; 248 249 match Self::get_header_name(buf)? { 250 TokenStatus::Partial(unparsed) => { 251 self.src.extend_from_slice(unparsed); 252 Ok(&[]) 253 } 254 TokenStatus::Complete((src, unparsed)) => { 255 // clone in this. 256 self.src.extend_from_slice(src); 257 self.name_src = take(&mut self.src); 258 self.check_next(); 259 Ok(unparsed) 260 } 261 } 262 } 263 value_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError>264 fn value_decode<'a>(&mut self, buf: &'a [u8]) -> Result<&'a [u8], HttpError> { 265 let buf = if self.src.is_empty() { 266 trim_front_lwsp(buf) 267 } else { 268 buf 269 }; 270 271 match Self::get_header_value(buf)? { 272 TokenStatus::Partial(unparsed) => { 273 self.src.extend_from_slice(unparsed); 274 Ok(&[]) 275 } 276 TokenStatus::Complete((src, unparsed)) => { 277 // clone in this. 278 self.src.extend_from_slice(src); 279 let value = take(&mut self.src); 280 let name = take(&mut self.name_src); 281 282 self.headers 283 .insert(trim_front_lwsp(&name), trim_front_lwsp(&value))?; 284 self.check_next(); 285 Ok(unparsed) 286 } 287 } 288 } 289 290 // end with ":" get_header_name(buf: &[u8]) -> BytesResult291 fn get_header_name(buf: &[u8]) -> BytesResult { 292 for (i, b) in buf.iter().enumerate() { 293 if *b == b':' { 294 // match "k:v" or "k: v" 295 return Ok(TokenStatus::Complete((&buf[..i], &buf[i + 1..]))); 296 } else if !HEADER_NAME_BYTES[*b as usize] { 297 return Err(ErrorKind::InvalidInput.into()); 298 } 299 } 300 Ok(TokenStatus::Partial(buf)) 301 } 302 303 // end with "\r" or "\n" or "\r\n" get_header_value(buf: &[u8]) -> BytesResult304 fn get_header_value(buf: &[u8]) -> BytesResult { 305 for (i, b) in buf.iter().enumerate() { 306 if *b == CR || *b == LF { 307 return Ok(TokenStatus::Complete((&buf[..i], &buf[i..]))); 308 } else if !HEADER_VALUE_BYTES[*b as usize] { 309 return Err(ErrorKind::InvalidInput.into()); 310 } 311 } 312 Ok(TokenStatus::Partial(buf)) 313 } 314 } 315 316 #[cfg(test)] 317 mod ut_decode_headers { 318 use crate::body::mime::common::headers::DecodeHeaders; 319 use crate::body::TokenStatus; 320 use crate::headers::Headers; 321 322 /// UT test cases for `DecodeHeaders::decode`. 323 /// 324 /// # Brief 325 /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`. 326 /// 2. Uses `DecodeHeaders::decode` to decode headers. 327 /// 3. The headers is divided by "\r\n". 328 /// 4. Checks whether the result is correct. 329 #[test] ut_decode_headers_new()330 fn ut_decode_headers_new() { 331 let buf = b"\r\nabcd"; 332 let mut decoder = DecodeHeaders::new(); 333 let (headers, rest) = decoder.decode(buf).unwrap(); 334 assert_eq!(headers, TokenStatus::Complete(Headers::new())); 335 assert_eq!(rest, b"abcd"); 336 337 // has LWSP 338 let buf = b" \r\nabcd"; 339 let mut decoder = DecodeHeaders::new(); 340 let (headers, rest) = decoder.decode(buf).unwrap(); 341 assert_eq!(headers, TokenStatus::Complete(Headers::new())); 342 assert_eq!(rest, b"abcd"); 343 } 344 345 /// UT test cases for `DecodeHeaders::decode`. 346 /// 347 /// # Brief 348 /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`. 349 /// 2. Uses `DecodeHeaders::decode` to decode headers. 350 /// 3. The headers is divided by "\n". 351 /// 4. Checks whether the result is correct. 352 #[test] ut_decode_headers_new2()353 fn ut_decode_headers_new2() { 354 let buf = b"\nabcd"; 355 let mut decoder = DecodeHeaders::new(); 356 let (headers, rest) = decoder.decode(buf).unwrap(); 357 assert_eq!(headers, TokenStatus::Complete(Headers::new())); 358 assert_eq!(rest, b"abcd"); 359 360 // has LWSP 361 let buf = b" \nabcd"; 362 let mut decoder = DecodeHeaders::new(); 363 let (headers, rest) = decoder.decode(buf).unwrap(); 364 assert_eq!(headers, TokenStatus::Complete(Headers::new())); 365 assert_eq!(rest, b"abcd"); 366 } 367 368 /// UT test cases for `DecodeHeaders::decode`. 369 /// 370 /// # Brief 371 /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`. 372 /// 2. Uses `DecodeHeaders::decode` to decode headers. 373 /// 3. The headers has *LWSP-char(b' ' or b'\t'). 374 /// 4. Checks whether the result is correct. 375 #[test] ut_decode_headers_decode()376 fn ut_decode_headers_decode() { 377 // all use "\r\n" 378 let buf = b" name1: value1\r\n name2: value2\r\n\r\n"; 379 let mut decoder = DecodeHeaders::new(); 380 let (headers, rest) = decoder.decode(buf).unwrap(); 381 assert_eq!( 382 headers, 383 TokenStatus::Complete({ 384 let mut headers = Headers::new(); 385 headers.insert("name1", "value1").unwrap(); 386 headers.insert("name2", "value2").unwrap(); 387 headers 388 }) 389 ); 390 assert_eq!(std::str::from_utf8(rest).unwrap(), ""); 391 392 // all use "\n" 393 let buf = b"name1:value1\nname2:value2\n\n"; 394 let mut decoder = DecodeHeaders::new(); 395 let (headers, rest) = decoder.decode(buf).unwrap(); 396 assert_eq!( 397 headers, 398 TokenStatus::Complete({ 399 let mut headers = Headers::new(); 400 headers.insert("name1", "value1").unwrap(); 401 headers.insert("name2", "value2").unwrap(); 402 headers 403 }) 404 ); 405 assert_eq!(std::str::from_utf8(rest).unwrap(), ""); 406 407 // some use "\r\n" 408 let buf = b"name1:value1\nname2:value2\r\n\n"; 409 let mut decoder = DecodeHeaders::new(); 410 let (headers, rest) = decoder.decode(buf).unwrap(); 411 assert_eq!( 412 headers, 413 TokenStatus::Complete({ 414 let mut headers = Headers::new(); 415 headers.insert("name1", "value1").unwrap(); 416 headers.insert("name2", "value2").unwrap(); 417 headers 418 }) 419 ); 420 assert_eq!(std::str::from_utf8(rest).unwrap(), ""); 421 } 422 423 /// UT test cases for `DecodeHeaders::decode`. 424 /// 425 /// # Brief 426 /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`. 427 /// 2. Uses `DecodeHeaders::decode` to decode headers. 428 /// 3. The headers is common. 429 /// 4. Checks whether the result is correct. 430 #[test] ut_decode_headers_decode2()431 fn ut_decode_headers_decode2() { 432 let buf = b"name1:value1\r\n\r\n\r\naaaa"; 433 let mut decoder = DecodeHeaders::new(); 434 let (headers, rest) = decoder.decode(buf).unwrap(); 435 assert_eq!( 436 headers, 437 TokenStatus::Complete({ 438 let mut headers = Headers::new(); 439 headers.insert("name1", "value1").unwrap(); 440 headers 441 }) 442 ); 443 assert_eq!(std::str::from_utf8(rest).unwrap(), "\r\naaaa"); 444 } 445 446 /// UT test cases for `DecodeHeaders::decode`. 447 /// 448 /// # Brief 449 /// 1. Creates a `DecodeHeaders` by `DecodeHeaders::new`. 450 /// 2. Uses `DecodeHeaders::decode` to decode headers. 451 /// 3. The decode bytes are divided into several executions. 452 /// 4. Checks whether the result is correct. 453 #[test] ut_decode_headers_decode3()454 fn ut_decode_headers_decode3() { 455 let buf = b"name1:value1\r\nname2:value2\r\n\r\naaaa"; 456 let mut decoder = DecodeHeaders::new(); 457 // nam 458 let (headers, rest) = decoder.decode(&buf[0..3]).unwrap(); 459 assert_eq!(headers, TokenStatus::Partial(())); 460 assert_eq!(std::str::from_utf8(rest).unwrap(), ""); 461 462 // e1:value1\r 463 let (headers, rest) = decoder.decode(&buf[3..13]).unwrap(); 464 assert_eq!(headers, TokenStatus::Partial(())); 465 assert_eq!(std::str::from_utf8(rest).unwrap(), ""); 466 467 // \nname2:value2\r\n\r 468 let (headers, rest) = decoder.decode(&buf[13..29]).unwrap(); 469 assert_eq!(headers, TokenStatus::Partial(())); 470 assert_eq!(std::str::from_utf8(rest).unwrap(), ""); 471 472 // \n 473 let (headers, rest) = decoder.decode(&buf[29..30]).unwrap(); 474 assert_eq!( 475 headers, 476 TokenStatus::Complete({ 477 let mut headers = Headers::new(); 478 headers.insert("name1", "value1").unwrap(); 479 headers.insert("name2", "value2").unwrap(); 480 headers 481 }) 482 ); 483 assert_eq!(std::str::from_utf8(rest).unwrap(), ""); 484 485 // aaaa 486 let (headers, rest) = decoder.decode(&buf[30..34]).unwrap(); 487 assert_eq!(headers, TokenStatus::Complete(Headers::new())); 488 assert_eq!(std::str::from_utf8(rest).unwrap(), "aaaa"); 489 } 490 } 491