• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cmp::min;
15 use std::mem::take;
16 use std::ops::Deref;
17 use std::pin::Pin;
18 use std::sync::atomic::Ordering;
19 use std::task::{Context, Poll};
20 use std::time::Instant;
21 
22 use ylong_http::error::HttpError;
23 use ylong_http::h2;
24 use ylong_http::h2::{ErrorCode, Frame, FrameFlags, H2Error, Payload, PseudoHeaders, StreamId};
25 use ylong_http::headers::Headers;
26 use ylong_http::request::uri::Scheme;
27 use ylong_http::request::RequestPart;
28 use ylong_http::response::status::StatusCode;
29 use ylong_http::response::ResponsePart;
30 
31 use crate::async_impl::conn::StreamData;
32 use crate::async_impl::request::Message;
33 use crate::async_impl::{HttpBody, Response};
34 use crate::error::{ErrorKind, HttpClientError};
35 use crate::runtime::{AsyncRead, ReadBuf};
36 use crate::util::config::HttpVersion;
37 use crate::util::data_ref::BodyDataRef;
38 use crate::util::dispatcher::http2::Http2Conn;
39 use crate::util::h2::RequestWrapper;
40 use crate::util::normalizer::BodyLengthParser;
41 use crate::ErrorKind::BodyTransfer;
42 
43 const UNUSED_FLAG: u8 = 0x0;
44 
request<S>( mut conn: Http2Conn<S>, mut message: Message, ) -> Result<Response, HttpClientError> where S: Sync + Send + Unpin + 'static,45 pub(crate) async fn request<S>(
46     mut conn: Http2Conn<S>,
47     mut message: Message,
48 ) -> Result<Response, HttpClientError>
49 where
50     S: Sync + Send + Unpin + 'static,
51 {
52     message
53         .interceptor
54         .intercept_request(message.request.ref_mut())?;
55     let part = message.request.ref_mut().part().clone();
56 
57     // TODO Implement trailer.
58     let is_end_stream = message.request.ref_mut().body().is_empty();
59     let (flag, payload) = build_headers_payload(part, is_end_stream)
60         .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?;
61     let data = BodyDataRef::new(message.request.clone(), conn.speed_controller.clone());
62     let stream = RequestWrapper {
63         flag,
64         payload,
65         data,
66     };
67     message
68         .request
69         .ref_mut()
70         .time_group_mut()
71         .set_transfer_start(Instant::now());
72     conn.send_frame_to_controller(stream)?;
73     let frame = conn.receiver.recv().await?;
74     message
75         .request
76         .ref_mut()
77         .time_group_mut()
78         .set_transfer_end(Instant::now());
79     frame_2_response(conn, frame, message)
80 }
81 
frame_2_response<S>( conn: Http2Conn<S>, headers_frame: Frame, mut message: Message, ) -> Result<Response, HttpClientError> where S: Sync + Send + Unpin + 'static,82 fn frame_2_response<S>(
83     conn: Http2Conn<S>,
84     headers_frame: Frame,
85     mut message: Message,
86 ) -> Result<Response, HttpClientError>
87 where
88     S: Sync + Send + Unpin + 'static,
89 {
90     let part = match headers_frame.payload() {
91         Payload::Headers(headers) => {
92             let (pseudo, fields) = headers.parts();
93             let status_code = match pseudo.status() {
94                 Some(status) => StatusCode::from_bytes(status.as_bytes())
95                     .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?,
96                 None => {
97                     return Err(build_client_error(
98                         headers_frame.stream_id(),
99                         ErrorCode::ProtocolError,
100                     ));
101                 }
102             };
103             ResponsePart {
104                 version: ylong_http::version::Version::HTTP2,
105                 status: status_code,
106                 headers: fields.clone(),
107             }
108         }
109         Payload::RstStream(reset) => {
110             return Err(build_client_error(
111                 headers_frame.stream_id(),
112                 ErrorCode::try_from(reset.error_code()).unwrap_or(ErrorCode::ProtocolError),
113             ));
114         }
115         _ => {
116             return Err(build_client_error(
117                 headers_frame.stream_id(),
118                 ErrorCode::ProtocolError,
119             ));
120         }
121     };
122 
123     let text_io = TextIo::new(conn);
124     let length = match BodyLengthParser::new(message.request.ref_mut().method(), &part).parse() {
125         Ok(length) => length,
126         Err(e) => {
127             return Err(e);
128         }
129     };
130     let time_group = take(message.request.ref_mut().time_group_mut());
131     let body = HttpBody::new(message.interceptor, length, Box::new(text_io), &[0u8; 0])?;
132 
133     let mut response = Response::new(ylong_http::response::Response::from_raw_parts(part, body));
134     response.set_time_group(time_group);
135     Ok(response)
136 }
137 
build_headers_payload( mut part: RequestPart, is_end_stream: bool, ) -> Result<(FrameFlags, Payload), HttpError>138 pub(crate) fn build_headers_payload(
139     mut part: RequestPart,
140     is_end_stream: bool,
141 ) -> Result<(FrameFlags, Payload), HttpError> {
142     remove_connection_specific_headers(&mut part.headers)?;
143     let pseudo = build_pseudo_headers(&mut part)?;
144     let mut header_part = h2::Parts::new();
145     header_part.set_header_lines(part.headers);
146     header_part.set_pseudo(pseudo);
147     let headers_payload = h2::Headers::new(header_part);
148 
149     let mut flag = FrameFlags::new(UNUSED_FLAG);
150     flag.set_end_headers(true);
151     if is_end_stream {
152         flag.set_end_stream(true);
153     }
154     Ok((flag, Payload::Headers(headers_payload)))
155 }
156 
157 // Illegal headers validation in http2.
158 // [`Connection-Specific Headers`] implementation.
159 //
160 // [`Connection-Specific Headers`]: https://www.rfc-editor.org/rfc/rfc9113.html#name-connection-specific-header-
remove_connection_specific_headers(headers: &mut Headers) -> Result<(), HttpError>161 fn remove_connection_specific_headers(headers: &mut Headers) -> Result<(), HttpError> {
162     const CONNECTION_SPECIFIC_HEADERS: &[&str; 5] = &[
163         "connection",
164         "keep-alive",
165         "proxy-connection",
166         "upgrade",
167         "transfer-encoding",
168     ];
169     for specific_header in CONNECTION_SPECIFIC_HEADERS.iter() {
170         headers.remove(*specific_header);
171     }
172 
173     if let Some(te_ref) = headers.get("te") {
174         let te = te_ref.to_string()?;
175         if te.as_str() != "trailers" {
176             headers.remove("te");
177         }
178     }
179     Ok(())
180 }
181 
build_pseudo_headers(request_part: &mut RequestPart) -> Result<PseudoHeaders, HttpError>182 fn build_pseudo_headers(request_part: &mut RequestPart) -> Result<PseudoHeaders, HttpError> {
183     let mut pseudo = PseudoHeaders::default();
184     match request_part.uri.scheme() {
185         Some(scheme) => {
186             pseudo.set_scheme(Some(String::from(scheme.as_str())));
187         }
188         None => pseudo.set_scheme(Some(String::from(Scheme::HTTP.as_str()))),
189     }
190     pseudo.set_method(Some(String::from(request_part.method.as_str())));
191     pseudo.set_path(
192         request_part
193             .uri
194             .path_and_query()
195             .or_else(|| Some(String::from("/"))),
196     );
197     let host = request_part
198         .headers
199         .remove("host")
200         .and_then(|auth| auth.to_string().ok());
201     pseudo.set_authority(host);
202     Ok(pseudo)
203 }
204 
build_client_error(id: StreamId, code: ErrorCode) -> HttpClientError205 fn build_client_error(id: StreamId, code: ErrorCode) -> HttpClientError {
206     HttpClientError::from_error(
207         ErrorKind::Request,
208         HttpError::from(H2Error::StreamError(id, code)),
209     )
210 }
211 
212 struct TextIo<S> {
213     pub(crate) handle: Http2Conn<S>,
214     pub(crate) offset: usize,
215     pub(crate) remain: Option<Frame>,
216     pub(crate) is_closed: bool,
217 }
218 
219 struct HttpReadBuf<'a, 'b> {
220     buf: &'a mut ReadBuf<'b>,
221 }
222 
223 impl<'a, 'b> HttpReadBuf<'a, 'b> {
append_slice(&mut self, buf: &[u8])224     pub(crate) fn append_slice(&mut self, buf: &[u8]) {
225         #[cfg(feature = "ylong_base")]
226         self.buf.append(buf);
227 
228         #[cfg(feature = "tokio_base")]
229         self.buf.put_slice(buf);
230     }
231 }
232 
233 impl<'a, 'b> Deref for HttpReadBuf<'a, 'b> {
234     type Target = ReadBuf<'b>;
235 
deref(&self) -> &Self::Target236     fn deref(&self) -> &Self::Target {
237         self.buf
238     }
239 }
240 
241 impl<S> TextIo<S>
242 where
243     S: Sync + Send + Unpin + 'static,
244 {
new(handle: Http2Conn<S>) -> Self245     pub(crate) fn new(handle: Http2Conn<S>) -> Self {
246         Self {
247             handle,
248             offset: 0,
249             remain: None,
250             is_closed: false,
251         }
252     }
253 
match_channel_message( poll_result: Poll<Frame>, text_io: &mut TextIo<S>, buf: &mut HttpReadBuf, ) -> Option<Poll<std::io::Result<()>>>254     fn match_channel_message(
255         poll_result: Poll<Frame>,
256         text_io: &mut TextIo<S>,
257         buf: &mut HttpReadBuf,
258     ) -> Option<Poll<std::io::Result<()>>> {
259         match poll_result {
260             Poll::Ready(frame) => match frame.payload() {
261                 Payload::Headers(_) => {
262                     text_io.remain = Some(frame);
263                     text_io.offset = 0;
264                     Some(Poll::Ready(Ok(())))
265                 }
266                 Payload::Data(data) => {
267                     let data = data.data();
268                     let unfilled_len = buf.remaining();
269                     let data_len = data.len();
270                     let fill_len = min(data_len, unfilled_len);
271                     if unfilled_len < data_len {
272                         buf.append_slice(&data[..fill_len]);
273                         text_io.offset += fill_len;
274                         text_io.remain = Some(frame);
275                         Some(Poll::Ready(Ok(())))
276                     } else {
277                         buf.append_slice(&data[..fill_len]);
278                         Self::end_read(text_io, frame.flags().is_end_stream(), data_len)
279                     }
280                 }
281                 Payload::RstStream(reset) => {
282                     if reset.is_no_error() {
283                         text_io.is_closed = true;
284                         Some(Poll::Ready(Ok(())))
285                     } else {
286                         Some(Poll::Ready(Err(std::io::Error::new(
287                             std::io::ErrorKind::Other,
288                             HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
289                         ))))
290                     }
291                 }
292                 _ => Some(Poll::Ready(Err(std::io::Error::new(
293                     std::io::ErrorKind::Other,
294                     HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
295                 )))),
296             },
297             Poll::Pending => Some(Poll::Pending),
298         }
299     }
300 
end_read( text_io: &mut TextIo<S>, end_stream: bool, data_len: usize, ) -> Option<Poll<std::io::Result<()>>>301     fn end_read(
302         text_io: &mut TextIo<S>,
303         end_stream: bool,
304         data_len: usize,
305     ) -> Option<Poll<std::io::Result<()>>> {
306         text_io.offset = 0;
307         text_io.remain = None;
308         if end_stream {
309             text_io.is_closed = true;
310             Some(Poll::Ready(Ok(())))
311         } else if data_len == 0 {
312             // no data read and is not end stream.
313             None
314         } else {
315             Some(Poll::Ready(Ok(())))
316         }
317     }
318 
read_remaining_data( text_io: &mut TextIo<S>, buf: &mut HttpReadBuf, ) -> Option<Poll<std::io::Result<()>>>319     fn read_remaining_data(
320         text_io: &mut TextIo<S>,
321         buf: &mut HttpReadBuf,
322     ) -> Option<Poll<std::io::Result<()>>> {
323         if let Some(frame) = &text_io.remain {
324             return match frame.payload() {
325                 Payload::Headers(_) => Some(Poll::Ready(Ok(()))),
326                 Payload::Data(data) => {
327                     let data = data.data();
328                     let unfilled_len = buf.remaining();
329                     let data_len = data.len() - text_io.offset;
330                     let fill_len = min(unfilled_len, data_len);
331                     // The peripheral function already ensures that the remaing of buf will not be
332                     // 0.
333                     if unfilled_len < data_len {
334                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
335                         text_io.offset += fill_len;
336                         Some(Poll::Ready(Ok(())))
337                     } else {
338                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
339                         Self::end_read(text_io, frame.flags().is_end_stream(), data_len)
340                     }
341                 }
342                 _ => Some(Poll::Ready(Err(std::io::Error::new(
343                     std::io::ErrorKind::Other,
344                     HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
345                 )))),
346             };
347         }
348         None
349     }
350 }
351 
352 impl<S: Sync + Send + Unpin + 'static> StreamData for TextIo<S> {
shutdown(&self)353     fn shutdown(&self) {
354         self.handle.io_shutdown.store(true, Ordering::Release);
355     }
356 
is_stream_closable(&self) -> bool357     fn is_stream_closable(&self) -> bool {
358         self.is_closed
359     }
360 
http_version(&self) -> HttpVersion361     fn http_version(&self) -> HttpVersion {
362         HttpVersion::Http2
363     }
364 }
365 
366 impl<S: Sync + Send + Unpin + 'static> AsyncRead for TextIo<S> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>367     fn poll_read(
368         self: Pin<&mut Self>,
369         cx: &mut Context<'_>,
370         buf: &mut ReadBuf<'_>,
371     ) -> Poll<std::io::Result<()>> {
372         let mut buf = HttpReadBuf { buf };
373         let text_io = self.get_mut();
374         if buf.remaining() == 0 || text_io.is_closed {
375             return Poll::Ready(Ok(()));
376         }
377         if text_io
378             .handle
379             .speed_controller
380             .poll_recv_pending_timeout(cx)
381         {
382             return Poll::Ready(Err(std::io::Error::new(
383                 std::io::ErrorKind::TimedOut,
384                 HttpClientError::from_str(BodyTransfer, "Below low speed limit"),
385             )));
386         }
387         // Min speed contains the max speed limit sleep time.
388         text_io.handle.speed_controller.init_min_recv_if_not_start();
389         if text_io
390             .handle
391             .speed_controller
392             .poll_max_recv_delay_time(cx)
393             .is_pending()
394         {
395             return Poll::Pending;
396         }
397         text_io.handle.speed_controller.init_max_recv_if_not_start();
398         while buf.remaining() != 0 {
399             if let Some(result) = Self::read_remaining_data(text_io, &mut buf) {
400                 return match result {
401                     Poll::Ready(Ok(_)) => {
402                         let filled: usize = buf.filled().len();
403                         text_io
404                             .handle
405                             .speed_controller
406                             .min_recv_speed_limit(filled)
407                             .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
408                         text_io
409                             .handle
410                             .speed_controller
411                             .delay_max_recv_speed_limit(filled);
412                         text_io.handle.speed_controller.reset_recv_pending_timeout();
413                         Poll::Ready(Ok(()))
414                     }
415                     Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
416                     Poll::Pending => Poll::Pending,
417                 };
418             }
419 
420             let poll_result = text_io
421                 .handle
422                 .receiver
423                 .poll_recv(cx)
424                 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
425 
426             if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) {
427                 return match result {
428                     Poll::Ready(Ok(_)) => {
429                         let filled: usize = buf.filled().len();
430                         text_io
431                             .handle
432                             .speed_controller
433                             .min_recv_speed_limit(filled)
434                             .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
435                         text_io
436                             .handle
437                             .speed_controller
438                             .delay_max_recv_speed_limit(filled);
439                         text_io.handle.speed_controller.reset_recv_pending_timeout();
440                         Poll::Ready(Ok(()))
441                     }
442                     Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
443                     Poll::Pending => Poll::Pending,
444                 };
445             }
446         }
447         Poll::Ready(Ok(()))
448     }
449 }
450 
451 #[cfg(feature = "http2")]
452 #[cfg(test)]
453 mod ut_http2 {
454     use ylong_http::body::TextBody;
455     use ylong_http::h2::Payload;
456     use ylong_http::request::RequestBuilder;
457 
458     use crate::async_impl::conn::http2::build_headers_payload;
459 
460     macro_rules! build_request {
461         (
462             Request: {
463                 Method: $method: expr,
464                 Uri: $uri:expr,
465                 Version: $version: expr,
466                 $(
467                     Header: $req_n: expr, $req_v: expr,
468                 )*
469                 Body: $req_body: expr,
470             }
471         ) => {
472             RequestBuilder::new()
473                 .method($method)
474                 .url($uri)
475                 .version($version)
476                 $(.header($req_n, $req_v))*
477                 .body(TextBody::from_bytes($req_body.as_bytes()))
478                 .expect("Request build failed")
479         }
480     }
481 
482     #[test]
ut_http2_build_headers_payload()483     fn ut_http2_build_headers_payload() {
484         let request = build_request!(
485             Request: {
486             Method: "GET",
487             Uri: "http://127.0.0.1:0/data",
488             Version: "HTTP/2.0",
489             Header: "te", "trailers",
490             Header: "host", "127.0.0.1:0",
491             Body: "Hi",
492         }
493         );
494         let (flag, _) = build_headers_payload(request.part().clone(), false).unwrap();
495         assert_eq!(flag.bits(), 0x4);
496         let (flag, payload) = build_headers_payload(request.part().clone(), true).unwrap();
497         assert_eq!(flag.bits(), 0x5);
498         if let Payload::Headers(headers) = payload {
499             let (pseudo, _headers) = headers.parts();
500             assert_eq!(pseudo.status(), None);
501             assert_eq!(pseudo.scheme().unwrap(), "http");
502             assert_eq!(pseudo.method().unwrap(), "GET");
503             assert_eq!(pseudo.authority().unwrap(), "127.0.0.1:0");
504             assert_eq!(pseudo.path().unwrap(), "/data")
505         } else {
506             panic!("Unexpected frame type")
507         }
508     }
509 
510     /// UT for ensure that the response body(data frame) can read ends normally.
511     ///
512     /// # Brief
513     /// 1. Creates three data frames, one greater than buf, one less than buf,
514     ///    and the last one equal to and finished with buf.
515     /// 2. The response body data is read from TextIo using a buf of 10 bytes.
516     /// 3. The body is all read, and the size is the same as the default.
517     /// 5. Checks that result.
518     #[cfg(feature = "ylong_base")]
519     #[test]
ut_http2_body_poll_read()520     fn ut_http2_body_poll_read() {
521         use std::net::{IpAddr, Ipv4Addr, SocketAddr};
522         use std::pin::Pin;
523         use std::sync::atomic::AtomicBool;
524         use std::sync::Arc;
525 
526         use ylong_http::h2::{Data, Frame, FrameFlags};
527         use ylong_runtime::futures::poll_fn;
528         use ylong_runtime::io::{AsyncRead, ReadBuf};
529 
530         use crate::async_impl::conn::http2::TextIo;
531         use crate::util::dispatcher::http2::Http2Conn;
532         use crate::{ConnDetail, ConnProtocol};
533 
534         let (resp_tx, resp_rx) = ylong_runtime::sync::mpsc::bounded_channel(20);
535         let (req_tx, _req_rx) = crate::runtime::unbounded_channel();
536         let shutdown = Arc::new(AtomicBool::new(false));
537         let detail = ConnDetail {
538             protocol: ConnProtocol::Tcp,
539             local: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
540             peer: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 443),
541             addr: "localhost".to_string(),
542         };
543         let mut conn: Http2Conn<()> = Http2Conn::new(20, shutdown, req_tx, detail);
544         conn.receiver.set_receiver(resp_rx);
545         let mut text_io = TextIo::new(conn);
546         let data_1 = Frame::new(
547             1,
548             FrameFlags::new(0),
549             Payload::Data(Data::new(vec![b'a'; 128])),
550         );
551         let data_2 = Frame::new(
552             1,
553             FrameFlags::new(0),
554             Payload::Data(Data::new(vec![b'a'; 2])),
555         );
556         let data_3 = Frame::new(
557             1,
558             FrameFlags::new(1),
559             Payload::Data(Data::new(vec![b'a'; 10])),
560         );
561 
562         ylong_runtime::block_on(async {
563             let _ = resp_tx
564                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_1))
565                 .await;
566             let _ = resp_tx
567                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_2))
568                 .await;
569             let _ = resp_tx
570                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_3))
571                 .await;
572         });
573 
574         ylong_runtime::block_on(async {
575             let mut buf = [0_u8; 10];
576             let mut output_vec = vec![];
577 
578             let mut size = buf.len();
579             // `output_vec < 1024` in order to be able to exit normally in case of an
580             // exception.
581             while size != 0 && output_vec.len() < 1024 {
582                 let mut buffer = ReadBuf::new(buf.as_mut_slice());
583                 poll_fn(|cx| Pin::new(&mut text_io).poll_read(cx, &mut buffer))
584                     .await
585                     .unwrap();
586                 size = buffer.filled_len();
587                 output_vec.extend_from_slice(&buf[..size]);
588             }
589             assert_eq!(output_vec.len(), 140);
590         })
591     }
592 }
593