• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cmp::min;
15 use std::future::Future;
16 use std::pin::Pin;
17 use std::task::{Context, Poll};
18 
19 use ylong_http::body::async_impl::Body;
20 use ylong_http::error::HttpError;
21 use ylong_http::h2;
22 use ylong_http::h2::{ErrorCode, Frame, FrameFlags, H2Error, Payload, PseudoHeaders};
23 use ylong_http::headers::Headers;
24 use ylong_http::request::uri::Scheme;
25 use ylong_http::request::{Request, RequestPart};
26 use ylong_http::response::status::StatusCode;
27 use ylong_http::response::{Response, ResponsePart};
28 
29 use crate::async_impl::client::Retryable;
30 use crate::async_impl::conn::HttpBody;
31 use crate::async_impl::StreamData;
32 use crate::error::{ErrorKind, HttpClientError};
33 use crate::util::dispatcher::http2::Http2Conn;
34 use crate::{AsyncRead, AsyncWrite, ReadBuf};
35 
36 const UNUSED_FLAG: u8 = 0x0;
37 
request<S, T>( mut conn: Http2Conn<S>, request: &mut Request<T>, retryable: &mut Retryable, ) -> Result<Response<HttpBody>, HttpClientError> where T: Body, S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,38 pub(crate) async fn request<S, T>(
39     mut conn: Http2Conn<S>,
40     request: &mut Request<T>,
41     retryable: &mut Retryable,
42 ) -> Result<Response<HttpBody>, HttpClientError>
43 where
44     T: Body,
45     S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,
46 {
47     let part = request.part().clone();
48     let body = request.body_mut();
49 
50     // TODO Due to the reason of the Body structure, the use of the trailer is not
51     // implemented here for the time being, and it needs to be completed after the
52     // Body trait is provided to obtain the trailer interface
53     match build_data_frame(conn.id as usize, body).await? {
54         None => {
55             let headers = build_headers_frame(conn.id, part, true)
56                 .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?;
57             conn.send_frame_to_controller(headers).map_err(|e| {
58                 retryable.set_retry(true);
59                 HttpClientError::new_with_cause(ErrorKind::Request, Some(e))
60             })?;
61         }
62         Some(data) => {
63             let headers = build_headers_frame(conn.id, part, false)
64                 .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?;
65             conn.send_frame_to_controller(headers).map_err(|e| {
66                 retryable.set_retry(true);
67                 HttpClientError::new_with_cause(ErrorKind::Request, Some(e))
68             })?;
69             conn.send_frame_to_controller(data).map_err(|e| {
70                 retryable.set_retry(true);
71                 HttpClientError::new_with_cause(ErrorKind::Request, Some(e))
72             })?;
73         }
74     }
75     let frame = Pin::new(&mut conn.stream_info)
76         .await
77         .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?;
78     frame_2_response(conn, frame, retryable)
79 }
80 
frame_2_response<S>( conn: Http2Conn<S>, headers_frame: Frame, retryable: &mut Retryable, ) -> Result<Response<HttpBody>, HttpClientError> where S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,81 fn frame_2_response<S>(
82     conn: Http2Conn<S>,
83     headers_frame: Frame,
84     retryable: &mut Retryable,
85 ) -> Result<Response<HttpBody>, HttpClientError>
86 where
87     S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,
88 {
89     let part = match headers_frame.payload() {
90         Payload::Headers(headers) => {
91             let (pseudo, fields) = headers.parts();
92             let status_code = match pseudo.status() {
93                 Some(status) => StatusCode::from_bytes(status.as_bytes())
94                     .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?,
95                 None => {
96                     return Err(HttpClientError::new_with_cause(
97                         ErrorKind::Request,
98                         Some(HttpError::from(H2Error::StreamError(
99                             conn.id,
100                             ErrorCode::ProtocolError,
101                         ))),
102                     ));
103                 }
104             };
105             ResponsePart {
106                 version: ylong_http::version::Version::HTTP2,
107                 status: status_code,
108                 headers: fields.clone(),
109             }
110         }
111         Payload::RstStream(reset) => {
112             return Err(HttpClientError::new_with_cause(
113                 ErrorKind::Request,
114                 Some(HttpError::from(reset.error(conn.id).map_err(|e| {
115                     HttpClientError::new_with_cause(ErrorKind::Request, Some(e))
116                 })?)),
117             ));
118         }
119         Payload::Goaway(_) => {
120             // return Err(HttpClientError::from(ErrorKind::Resend));
121             retryable.set_retry(true);
122             return Err(HttpClientError::new_with_message(
123                 ErrorKind::Request,
124                 "GoAway",
125             ));
126         }
127         _ => {
128             return Err(HttpClientError::new_with_cause(
129                 ErrorKind::Request,
130                 Some(HttpError::from(H2Error::StreamError(
131                     conn.id,
132                     ErrorCode::ProtocolError,
133                 ))),
134             ));
135         }
136     };
137 
138     let body = {
139         if headers_frame.flags().is_end_stream() {
140             HttpBody::empty()
141         } else {
142             // TODO Can Content-Length in h2 be null?
143             let content_length = part
144                 .headers
145                 .get("Content-Length")
146                 .map(|v| v.to_str().unwrap_or(String::new()))
147                 .and_then(|s| s.parse::<usize>().ok());
148             match content_length {
149                 None => HttpBody::empty(),
150                 Some(0) => HttpBody::empty(),
151                 Some(size) => {
152                     let text_io = TextIo::new(conn);
153                     HttpBody::text(size, &[0u8; 0], Box::new(text_io))
154                 }
155             }
156         }
157     };
158     Ok(Response::from_raw_parts(part, body))
159 }
160 
build_data_frame<T: Body>( id: usize, body: &mut T, ) -> Result<Option<Frame>, HttpClientError>161 pub(crate) async fn build_data_frame<T: Body>(
162     id: usize,
163     body: &mut T,
164 ) -> Result<Option<Frame>, HttpClientError> {
165     let mut data_vec = vec![];
166     let mut buf = [0u8; 1024];
167     loop {
168         let size = body
169             .data(&mut buf)
170             .await
171             .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?;
172         if size == 0 {
173             break;
174         }
175         data_vec.extend_from_slice(&buf[..size]);
176     }
177     if data_vec.is_empty() {
178         Ok(None)
179     } else {
180         // TODO When the Body trait supports trailer, END_STREAM_FLAG needs to be
181         // modified
182         let mut flag = FrameFlags::new(UNUSED_FLAG);
183         flag.set_end_stream(true);
184         Ok(Some(Frame::new(
185             id,
186             flag,
187             Payload::Data(h2::Data::new(data_vec)),
188         )))
189     }
190 }
191 
build_headers_frame( id: u32, part: RequestPart, is_end_stream: bool, ) -> Result<Frame, HttpError>192 pub(crate) fn build_headers_frame(
193     id: u32,
194     part: RequestPart,
195     is_end_stream: bool,
196 ) -> Result<Frame, HttpError> {
197     check_connection_specific_headers(id, &part.headers)?;
198     let pseudo = build_pseudo_headers(&part);
199     let mut header_part = h2::Parts::new();
200     header_part.set_header_lines(part.headers);
201     header_part.set_pseudo(pseudo);
202     let headers_payload = h2::Headers::new(header_part);
203 
204     let mut flag = FrameFlags::new(UNUSED_FLAG);
205     flag.set_end_headers(true);
206     if is_end_stream {
207         flag.set_end_stream(true);
208     }
209     Ok(Frame::new(
210         id as usize,
211         flag,
212         Payload::Headers(headers_payload),
213     ))
214 }
215 
216 // Illegal headers validation in http2.
217 // [`Connection-Specific Headers`] implementation.
218 //
219 // [`Connection-Specific Headers`]: https://www.rfc-editor.org/rfc/rfc9113.html#name-connection-specific-header-
check_connection_specific_headers(id: u32, headers: &Headers) -> Result<(), HttpError>220 fn check_connection_specific_headers(id: u32, headers: &Headers) -> Result<(), HttpError> {
221     const CONNECTION_SPECIFIC_HEADERS: &[&str; 5] = &[
222         "connection",
223         "keep-alive",
224         "proxy-connection",
225         "upgrade",
226         "transfer-encoding",
227     ];
228     for specific_header in CONNECTION_SPECIFIC_HEADERS.iter() {
229         if headers.get(*specific_header).is_some() {
230             return Err(H2Error::StreamError(id, ErrorCode::ProtocolError).into());
231         }
232     }
233     if let Some(te_value) = headers.get("te") {
234         if te_value.to_str()? != "trailers" {
235             return Err(H2Error::StreamError(id, ErrorCode::ProtocolError).into());
236         }
237     }
238     Ok(())
239 }
240 
build_pseudo_headers(request_part: &RequestPart) -> PseudoHeaders241 fn build_pseudo_headers(request_part: &RequestPart) -> PseudoHeaders {
242     let mut pseudo = PseudoHeaders::default();
243     match request_part.uri.scheme() {
244         Some(scheme) => {
245             pseudo.set_scheme(Some(String::from(scheme.as_str())));
246         }
247         None => pseudo.set_scheme(Some(String::from(Scheme::HTTP.as_str()))),
248     }
249     pseudo.set_method(Some(String::from(request_part.method.as_str())));
250     pseudo.set_path(
251         request_part
252             .uri
253             .path_and_query()
254             .or_else(|| Some(String::from("/"))),
255     );
256     // TODO Validity verification is required, for example: `Authority` must be
257     // consistent with the `Host` header
258     pseudo.set_authority(request_part.uri.authority().map(|auth| auth.to_string()));
259     pseudo
260 }
261 
262 struct TextIo<S> {
263     pub(crate) handle: Http2Conn<S>,
264     pub(crate) offset: usize,
265     pub(crate) remain: Option<Frame>,
266     pub(crate) is_closed: bool,
267 }
268 
269 impl<S> TextIo<S> {
new(handle: Http2Conn<S>) -> Self270     pub(crate) fn new(handle: Http2Conn<S>) -> Self {
271         Self {
272             handle,
273             offset: 0,
274             remain: None,
275             is_closed: false,
276         }
277     }
278 }
279 
280 impl<S: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static> StreamData for TextIo<S> {
shutdown(&self)281     fn shutdown(&self) {
282         todo!()
283     }
284 }
285 
286 impl<S: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static> AsyncRead for TextIo<S> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>287     fn poll_read(
288         self: Pin<&mut Self>,
289         cx: &mut Context<'_>,
290         buf: &mut ReadBuf<'_>,
291     ) -> Poll<std::io::Result<()>> {
292         let text_io = self.get_mut();
293 
294         if buf.remaining() == 0 || text_io.is_closed {
295             return Poll::Ready(Ok(()));
296         }
297         while buf.remaining() != 0 {
298             if let Some(frame) = &text_io.remain {
299                 match frame.payload() {
300                     Payload::Headers(_) => {
301                         break;
302                     }
303                     Payload::Data(data) => {
304                         let data = data.data();
305                         let unfilled_len = buf.remaining();
306                         let data_len = data.len() - text_io.offset;
307                         let fill_len = min(unfilled_len, data_len);
308                         if unfilled_len < data_len {
309                             buf.put_slice(&data[text_io.offset..text_io.offset + fill_len]);
310                             text_io.offset += fill_len;
311                             break;
312                         } else {
313                             buf.put_slice(&data[text_io.offset..text_io.offset + fill_len]);
314                             text_io.offset += fill_len;
315                             if frame.flags().is_end_stream() {
316                                 text_io.is_closed = true;
317                                 break;
318                             }
319                         }
320                     }
321                     _ => {
322                         return Poll::Ready(Err(std::io::Error::new(
323                             std::io::ErrorKind::Other,
324                             HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
325                         )))
326                     }
327                 }
328             }
329 
330             let poll_result = Pin::new(&mut text_io.handle.stream_info)
331                 .poll(cx)
332                 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
333 
334             // TODO Added the frame type.
335             match poll_result {
336                 Poll::Ready(frame) => match frame.payload() {
337                     Payload::Headers(_) => {
338                         text_io.remain = Some(frame);
339                         text_io.offset = 0;
340                         break;
341                     }
342                     Payload::Data(data) => {
343                         let data = data.data();
344                         let unfilled_len = buf.remaining();
345                         let data_len = data.len();
346                         let fill_len = min(data_len, unfilled_len);
347                         if unfilled_len < data_len {
348                             buf.put_slice(&data[..fill_len]);
349                             text_io.offset += fill_len;
350                             text_io.remain = Some(frame);
351                             break;
352                         } else {
353                             buf.put_slice(&data[..fill_len]);
354                             if frame.flags().is_end_stream() {
355                                 text_io.is_closed = true;
356                                 break;
357                             }
358                         }
359                     }
360                     Payload::RstStream(_) => {
361                         return Poll::Ready(Err(std::io::Error::new(
362                             std::io::ErrorKind::Other,
363                             HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
364                         )))
365                     }
366                     _ => {
367                         return Poll::Ready(Err(std::io::Error::new(
368                             std::io::ErrorKind::Other,
369                             HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
370                         )))
371                     }
372                 },
373                 Poll::Pending => {
374                     return Poll::Pending;
375                 }
376             }
377         }
378         Poll::Ready(Ok(()))
379     }
380 }
381 
382 #[cfg(feature = "http2")]
383 #[cfg(test)]
384 mod ut_http2 {
385     use ylong_http::body::TextBody;
386     use ylong_http::h2::{ErrorCode, H2Error, Payload};
387     use ylong_http::request::RequestBuilder;
388 
389     use crate::async_impl::conn::http2::{build_data_frame, build_headers_frame};
390 
391     macro_rules! build_request {
392         (
393             Request: {
394                 Method: $method: expr,
395                 Uri: $uri:expr,
396                 Version: $version: expr,
397                 $(
398                     Header: $req_n: expr, $req_v: expr,
399                 )*
400                 Body: $req_body: expr,
401             }
402         ) => {
403             RequestBuilder::new()
404                 .method($method)
405                 .url($uri)
406                 .version($version)
407                 $(.header($req_n, $req_v))*
408                 .body(TextBody::from_bytes($req_body.as_bytes()))
409                 .expect("Request build failed")
410         }
411     }
412 
413     #[test]
ut_http2_build_headers_frame()414     fn ut_http2_build_headers_frame() {
415         let request = build_request!(
416             Request: {
417             Method: "GET",
418             Uri: "http://127.0.0.1:0/data",
419             Version: "HTTP/2.0",
420             Header: "te", "trailers",
421             Header: "host", "127.0.0.1:0",
422             Body: "Hi",
423         }
424         );
425         let frame = build_headers_frame(1, request.part().clone(), false)
426             .expect("headers frame build failed");
427         assert_eq!(frame.flags().bits(), 0x4);
428         let frame = build_headers_frame(1, request.part().clone(), true)
429             .expect("headers frame build failed");
430         assert_eq!(frame.stream_id(), 1);
431         assert_eq!(frame.flags().bits(), 0x5);
432         if let Payload::Headers(headers) = frame.payload() {
433             let (pseudo, _headers) = headers.parts();
434             assert_eq!(pseudo.status(), None);
435             assert_eq!(pseudo.scheme().unwrap(), "http");
436             assert_eq!(pseudo.method().unwrap(), "GET");
437             assert_eq!(pseudo.authority().unwrap(), "127.0.0.1:0");
438             assert_eq!(pseudo.path().unwrap(), "/data")
439         } else {
440             panic!("Unexpected frame type")
441         }
442         let request = build_request!(
443             Request: {
444             Method: "GET",
445             Uri: "http://127.0.0.1:0/data",
446             Version: "HTTP/2.0",
447             Header: "upgrade", "h2",
448             Header: "host", "127.0.0.1:0",
449             Body: "Hi",
450         }
451         );
452         let frame = build_headers_frame(1, request.part().clone(), true);
453         assert_eq!(
454             frame.err(),
455             Some(H2Error::StreamError(1, ErrorCode::ProtocolError).into())
456         );
457     }
458 }
459