• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cmp::min;
15 use std::mem::take;
16 use std::ops::Deref;
17 use std::pin::Pin;
18 use std::sync::atomic::Ordering;
19 use std::task::{Context, Poll};
20 use std::time::Instant;
21 
22 use ylong_http::error::HttpError;
23 use ylong_http::h2;
24 use ylong_http::h2::{ErrorCode, Frame, FrameFlags, H2Error, Payload, PseudoHeaders, StreamId};
25 use ylong_http::headers::Headers;
26 use ylong_http::request::uri::Scheme;
27 use ylong_http::request::RequestPart;
28 use ylong_http::response::status::StatusCode;
29 use ylong_http::response::ResponsePart;
30 
31 use crate::async_impl::conn::StreamData;
32 use crate::async_impl::request::Message;
33 use crate::async_impl::{HttpBody, Response};
34 use crate::error::{ErrorKind, HttpClientError};
35 use crate::runtime::{AsyncRead, ReadBuf};
36 use crate::util::data_ref::BodyDataRef;
37 use crate::util::dispatcher::http2::Http2Conn;
38 use crate::util::h2::RequestWrapper;
39 use crate::util::normalizer::BodyLengthParser;
40 
41 const UNUSED_FLAG: u8 = 0x0;
42 
request<S>( mut conn: Http2Conn<S>, mut message: Message, ) -> Result<Response, HttpClientError> where S: Sync + Send + Unpin + 'static,43 pub(crate) async fn request<S>(
44     mut conn: Http2Conn<S>,
45     mut message: Message,
46 ) -> Result<Response, HttpClientError>
47 where
48     S: Sync + Send + Unpin + 'static,
49 {
50     message
51         .interceptor
52         .intercept_request(message.request.ref_mut())?;
53     let part = message.request.ref_mut().part().clone();
54 
55     // TODO Implement trailer.
56     let is_end_stream = message.request.ref_mut().body().is_empty();
57     let (flag, payload) = build_headers_payload(part, is_end_stream)
58         .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?;
59     let data = BodyDataRef::new(message.request.clone());
60     let stream = RequestWrapper {
61         flag,
62         payload,
63         data,
64     };
65     message
66         .request
67         .ref_mut()
68         .time_group_mut()
69         .set_transfer_start(Instant::now());
70     conn.send_frame_to_controller(stream)?;
71     let frame = conn.receiver.recv().await?;
72     message
73         .request
74         .ref_mut()
75         .time_group_mut()
76         .set_transfer_end(Instant::now());
77     frame_2_response(conn, frame, message)
78 }
79 
frame_2_response<S>( conn: Http2Conn<S>, headers_frame: Frame, mut message: Message, ) -> Result<Response, HttpClientError> where S: Sync + Send + Unpin + 'static,80 fn frame_2_response<S>(
81     conn: Http2Conn<S>,
82     headers_frame: Frame,
83     mut message: Message,
84 ) -> Result<Response, HttpClientError>
85 where
86     S: Sync + Send + Unpin + 'static,
87 {
88     let part = match headers_frame.payload() {
89         Payload::Headers(headers) => {
90             let (pseudo, fields) = headers.parts();
91             let status_code = match pseudo.status() {
92                 Some(status) => StatusCode::from_bytes(status.as_bytes())
93                     .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?,
94                 None => {
95                     return Err(build_client_error(
96                         headers_frame.stream_id(),
97                         ErrorCode::ProtocolError,
98                     ));
99                 }
100             };
101             ResponsePart {
102                 version: ylong_http::version::Version::HTTP2,
103                 status: status_code,
104                 headers: fields.clone(),
105             }
106         }
107         Payload::RstStream(reset) => {
108             return Err(build_client_error(
109                 headers_frame.stream_id(),
110                 ErrorCode::try_from(reset.error_code()).unwrap_or(ErrorCode::ProtocolError),
111             ));
112         }
113         _ => {
114             return Err(build_client_error(
115                 headers_frame.stream_id(),
116                 ErrorCode::ProtocolError,
117             ));
118         }
119     };
120 
121     let text_io = TextIo::new(conn);
122     let length = match BodyLengthParser::new(message.request.ref_mut().method(), &part).parse() {
123         Ok(length) => length,
124         Err(e) => {
125             return Err(e);
126         }
127     };
128     let time_group = take(message.request.ref_mut().time_group_mut());
129     let body = HttpBody::new(message.interceptor, length, Box::new(text_io), &[0u8; 0])?;
130 
131     let mut response = Response::new(ylong_http::response::Response::from_raw_parts(part, body));
132     response.set_time_group(time_group);
133     Ok(response)
134 }
135 
build_headers_payload( mut part: RequestPart, is_end_stream: bool, ) -> Result<(FrameFlags, Payload), HttpError>136 pub(crate) fn build_headers_payload(
137     mut part: RequestPart,
138     is_end_stream: bool,
139 ) -> Result<(FrameFlags, Payload), HttpError> {
140     remove_connection_specific_headers(&mut part.headers)?;
141     let pseudo = build_pseudo_headers(&mut part)?;
142     let mut header_part = h2::Parts::new();
143     header_part.set_header_lines(part.headers);
144     header_part.set_pseudo(pseudo);
145     let headers_payload = h2::Headers::new(header_part);
146 
147     let mut flag = FrameFlags::new(UNUSED_FLAG);
148     flag.set_end_headers(true);
149     if is_end_stream {
150         flag.set_end_stream(true);
151     }
152     Ok((flag, Payload::Headers(headers_payload)))
153 }
154 
155 // Illegal headers validation in http2.
156 // [`Connection-Specific Headers`] implementation.
157 //
158 // [`Connection-Specific Headers`]: https://www.rfc-editor.org/rfc/rfc9113.html#name-connection-specific-header-
remove_connection_specific_headers(headers: &mut Headers) -> Result<(), HttpError>159 fn remove_connection_specific_headers(headers: &mut Headers) -> Result<(), HttpError> {
160     const CONNECTION_SPECIFIC_HEADERS: &[&str; 5] = &[
161         "connection",
162         "keep-alive",
163         "proxy-connection",
164         "upgrade",
165         "transfer-encoding",
166     ];
167     for specific_header in CONNECTION_SPECIFIC_HEADERS.iter() {
168         headers.remove(*specific_header);
169     }
170 
171     if let Some(te_ref) = headers.get("te") {
172         let te = te_ref.to_string()?;
173         if te.as_str() != "trailers" {
174             headers.remove("te");
175         }
176     }
177     Ok(())
178 }
179 
build_pseudo_headers(request_part: &mut RequestPart) -> Result<PseudoHeaders, HttpError>180 fn build_pseudo_headers(request_part: &mut RequestPart) -> Result<PseudoHeaders, HttpError> {
181     let mut pseudo = PseudoHeaders::default();
182     match request_part.uri.scheme() {
183         Some(scheme) => {
184             pseudo.set_scheme(Some(String::from(scheme.as_str())));
185         }
186         None => pseudo.set_scheme(Some(String::from(Scheme::HTTP.as_str()))),
187     }
188     pseudo.set_method(Some(String::from(request_part.method.as_str())));
189     pseudo.set_path(
190         request_part
191             .uri
192             .path_and_query()
193             .or_else(|| Some(String::from("/"))),
194     );
195     let host = request_part
196         .headers
197         .remove("host")
198         .and_then(|auth| auth.to_string().ok());
199     pseudo.set_authority(host);
200     Ok(pseudo)
201 }
202 
build_client_error(id: StreamId, code: ErrorCode) -> HttpClientError203 fn build_client_error(id: StreamId, code: ErrorCode) -> HttpClientError {
204     HttpClientError::from_error(
205         ErrorKind::Request,
206         HttpError::from(H2Error::StreamError(id, code)),
207     )
208 }
209 
210 struct TextIo<S> {
211     pub(crate) handle: Http2Conn<S>,
212     pub(crate) offset: usize,
213     pub(crate) remain: Option<Frame>,
214     pub(crate) is_closed: bool,
215 }
216 
217 struct HttpReadBuf<'a, 'b> {
218     buf: &'a mut ReadBuf<'b>,
219 }
220 
221 impl<'a, 'b> HttpReadBuf<'a, 'b> {
append_slice(&mut self, buf: &[u8])222     pub(crate) fn append_slice(&mut self, buf: &[u8]) {
223         #[cfg(feature = "ylong_base")]
224         self.buf.append(buf);
225 
226         #[cfg(feature = "tokio_base")]
227         self.buf.put_slice(buf);
228     }
229 }
230 
231 impl<'a, 'b> Deref for HttpReadBuf<'a, 'b> {
232     type Target = ReadBuf<'b>;
233 
deref(&self) -> &Self::Target234     fn deref(&self) -> &Self::Target {
235         self.buf
236     }
237 }
238 
239 impl<S> TextIo<S>
240 where
241     S: Sync + Send + Unpin + 'static,
242 {
new(handle: Http2Conn<S>) -> Self243     pub(crate) fn new(handle: Http2Conn<S>) -> Self {
244         Self {
245             handle,
246             offset: 0,
247             remain: None,
248             is_closed: false,
249         }
250     }
251 
match_channel_message( poll_result: Poll<Frame>, text_io: &mut TextIo<S>, buf: &mut HttpReadBuf, ) -> Option<Poll<std::io::Result<()>>>252     fn match_channel_message(
253         poll_result: Poll<Frame>,
254         text_io: &mut TextIo<S>,
255         buf: &mut HttpReadBuf,
256     ) -> Option<Poll<std::io::Result<()>>> {
257         match poll_result {
258             Poll::Ready(frame) => match frame.payload() {
259                 Payload::Headers(_) => {
260                     text_io.remain = Some(frame);
261                     text_io.offset = 0;
262                     Some(Poll::Ready(Ok(())))
263                 }
264                 Payload::Data(data) => {
265                     let data = data.data();
266                     let unfilled_len = buf.remaining();
267                     let data_len = data.len();
268                     let fill_len = min(data_len, unfilled_len);
269                     if unfilled_len < data_len {
270                         buf.append_slice(&data[..fill_len]);
271                         text_io.offset += fill_len;
272                         text_io.remain = Some(frame);
273                         Some(Poll::Ready(Ok(())))
274                     } else {
275                         buf.append_slice(&data[..fill_len]);
276                         Self::end_read(text_io, frame.flags().is_end_stream(), data_len)
277                     }
278                 }
279                 Payload::RstStream(reset) => {
280                     if reset.is_no_error() {
281                         text_io.is_closed = true;
282                         Some(Poll::Ready(Ok(())))
283                     } else {
284                         Some(Poll::Ready(Err(std::io::Error::new(
285                             std::io::ErrorKind::Other,
286                             HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
287                         ))))
288                     }
289                 }
290                 _ => Some(Poll::Ready(Err(std::io::Error::new(
291                     std::io::ErrorKind::Other,
292                     HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
293                 )))),
294             },
295             Poll::Pending => Some(Poll::Pending),
296         }
297     }
298 
end_read( text_io: &mut TextIo<S>, end_stream: bool, data_len: usize, ) -> Option<Poll<std::io::Result<()>>>299     fn end_read(
300         text_io: &mut TextIo<S>,
301         end_stream: bool,
302         data_len: usize,
303     ) -> Option<Poll<std::io::Result<()>>> {
304         text_io.offset = 0;
305         text_io.remain = None;
306         if end_stream {
307             text_io.is_closed = true;
308             Some(Poll::Ready(Ok(())))
309         } else if data_len == 0 {
310             // no data read and is not end stream.
311             None
312         } else {
313             Some(Poll::Ready(Ok(())))
314         }
315     }
316 
read_remaining_data( text_io: &mut TextIo<S>, buf: &mut HttpReadBuf, ) -> Option<Poll<std::io::Result<()>>>317     fn read_remaining_data(
318         text_io: &mut TextIo<S>,
319         buf: &mut HttpReadBuf,
320     ) -> Option<Poll<std::io::Result<()>>> {
321         if let Some(frame) = &text_io.remain {
322             return match frame.payload() {
323                 Payload::Headers(_) => Some(Poll::Ready(Ok(()))),
324                 Payload::Data(data) => {
325                     let data = data.data();
326                     let unfilled_len = buf.remaining();
327                     let data_len = data.len() - text_io.offset;
328                     let fill_len = min(unfilled_len, data_len);
329                     // The peripheral function already ensures that the remaing of buf will not be
330                     // 0.
331                     if unfilled_len < data_len {
332                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
333                         text_io.offset += fill_len;
334                         Some(Poll::Ready(Ok(())))
335                     } else {
336                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
337                         Self::end_read(text_io, frame.flags().is_end_stream(), data_len)
338                     }
339                 }
340                 _ => Some(Poll::Ready(Err(std::io::Error::new(
341                     std::io::ErrorKind::Other,
342                     HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
343                 )))),
344             };
345         }
346         None
347     }
348 }
349 
350 impl<S: Sync + Send + Unpin + 'static> StreamData for TextIo<S> {
shutdown(&self)351     fn shutdown(&self) {
352         self.handle.io_shutdown.store(true, Ordering::Release);
353     }
354 
is_stream_closable(&self) -> bool355     fn is_stream_closable(&self) -> bool {
356         self.is_closed
357     }
358 }
359 
360 impl<S: Sync + Send + Unpin + 'static> AsyncRead for TextIo<S> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>361     fn poll_read(
362         self: Pin<&mut Self>,
363         cx: &mut Context<'_>,
364         buf: &mut ReadBuf<'_>,
365     ) -> Poll<std::io::Result<()>> {
366         let text_io = self.get_mut();
367         let mut buf = HttpReadBuf { buf };
368 
369         if buf.remaining() == 0 || text_io.is_closed {
370             return Poll::Ready(Ok(()));
371         }
372         while buf.remaining() != 0 {
373             if let Some(result) = Self::read_remaining_data(text_io, &mut buf) {
374                 return result;
375             }
376 
377             let poll_result = text_io
378                 .handle
379                 .receiver
380                 .poll_recv(cx)
381                 .map_err(|_e| std::io::Error::from(std::io::ErrorKind::Other))?;
382 
383             if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) {
384                 return result;
385             }
386         }
387         Poll::Ready(Ok(()))
388     }
389 }
390 
391 #[cfg(feature = "http2")]
392 #[cfg(test)]
393 mod ut_http2 {
394     use ylong_http::body::TextBody;
395     use ylong_http::h2::Payload;
396     use ylong_http::request::RequestBuilder;
397 
398     use crate::async_impl::conn::http2::build_headers_payload;
399 
400     macro_rules! build_request {
401         (
402             Request: {
403                 Method: $method: expr,
404                 Uri: $uri:expr,
405                 Version: $version: expr,
406                 $(
407                     Header: $req_n: expr, $req_v: expr,
408                 )*
409                 Body: $req_body: expr,
410             }
411         ) => {
412             RequestBuilder::new()
413                 .method($method)
414                 .url($uri)
415                 .version($version)
416                 $(.header($req_n, $req_v))*
417                 .body(TextBody::from_bytes($req_body.as_bytes()))
418                 .expect("Request build failed")
419         }
420     }
421 
422     #[test]
ut_http2_build_headers_payload()423     fn ut_http2_build_headers_payload() {
424         let request = build_request!(
425             Request: {
426             Method: "GET",
427             Uri: "http://127.0.0.1:0/data",
428             Version: "HTTP/2.0",
429             Header: "te", "trailers",
430             Header: "host", "127.0.0.1:0",
431             Body: "Hi",
432         }
433         );
434         let (flag, _) = build_headers_payload(request.part().clone(), false).unwrap();
435         assert_eq!(flag.bits(), 0x4);
436         let (flag, payload) = build_headers_payload(request.part().clone(), true).unwrap();
437         assert_eq!(flag.bits(), 0x5);
438         if let Payload::Headers(headers) = payload {
439             let (pseudo, _headers) = headers.parts();
440             assert_eq!(pseudo.status(), None);
441             assert_eq!(pseudo.scheme().unwrap(), "http");
442             assert_eq!(pseudo.method().unwrap(), "GET");
443             assert_eq!(pseudo.authority().unwrap(), "127.0.0.1:0");
444             assert_eq!(pseudo.path().unwrap(), "/data")
445         } else {
446             panic!("Unexpected frame type")
447         }
448     }
449 
450     /// UT for ensure that the response body(data frame) can read ends normally.
451     ///
452     /// # Brief
453     /// 1. Creates three data frames, one greater than buf, one less than buf,
454     ///    and the last one equal to and finished with buf.
455     /// 2. The response body data is read from TextIo using a buf of 10 bytes.
456     /// 3. The body is all read, and the size is the same as the default.
457     /// 5. Checks that result.
458     #[cfg(feature = "ylong_base")]
459     #[test]
ut_http2_body_poll_read()460     fn ut_http2_body_poll_read() {
461         use std::net::{IpAddr, Ipv4Addr, SocketAddr};
462         use std::pin::Pin;
463         use std::sync::atomic::AtomicBool;
464         use std::sync::Arc;
465 
466         use ylong_http::h2::{Data, Frame, FrameFlags};
467         use ylong_runtime::futures::poll_fn;
468         use ylong_runtime::io::{AsyncRead, ReadBuf};
469 
470         use crate::async_impl::conn::http2::TextIo;
471         use crate::util::dispatcher::http2::Http2Conn;
472         use crate::{ConnDetail, ConnProtocol};
473 
474         let (resp_tx, resp_rx) = ylong_runtime::sync::mpsc::bounded_channel(20);
475         let (req_tx, _req_rx) = crate::runtime::unbounded_channel();
476         let shutdown = Arc::new(AtomicBool::new(false));
477         let detail = ConnDetail {
478             protocol: ConnProtocol::Tcp,
479             local: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
480             peer: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 443),
481             addr: "localhost".to_string(),
482         };
483         let mut conn: Http2Conn<()> = Http2Conn::new(20, shutdown, req_tx, detail);
484         conn.receiver.set_receiver(resp_rx);
485         let mut text_io = TextIo::new(conn);
486         let data_1 = Frame::new(
487             1,
488             FrameFlags::new(0),
489             Payload::Data(Data::new(vec![b'a'; 128])),
490         );
491         let data_2 = Frame::new(
492             1,
493             FrameFlags::new(0),
494             Payload::Data(Data::new(vec![b'a'; 2])),
495         );
496         let data_3 = Frame::new(
497             1,
498             FrameFlags::new(1),
499             Payload::Data(Data::new(vec![b'a'; 10])),
500         );
501 
502         ylong_runtime::block_on(async {
503             let _ = resp_tx
504                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_1))
505                 .await;
506             let _ = resp_tx
507                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_2))
508                 .await;
509             let _ = resp_tx
510                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_3))
511                 .await;
512         });
513 
514         ylong_runtime::block_on(async {
515             let mut buf = [0_u8; 10];
516             let mut output_vec = vec![];
517 
518             let mut size = buf.len();
519             // `output_vec < 1024` in order to be able to exit normally in case of an
520             // exception.
521             while size != 0 && output_vec.len() < 1024 {
522                 let mut buffer = ReadBuf::new(buf.as_mut_slice());
523                 poll_fn(|cx| Pin::new(&mut text_io).poll_read(cx, &mut buffer))
524                     .await
525                     .unwrap();
526                 size = buffer.filled_len();
527                 output_vec.extend_from_slice(&buf[..size]);
528             }
529             assert_eq!(output_vec.len(), 140);
530         })
531     }
532 }
533