1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::cmp::min;
15 use std::future::Future;
16 use std::pin::Pin;
17 use std::task::{Context, Poll};
18
19 use ylong_http::body::async_impl::Body;
20 use ylong_http::error::HttpError;
21 use ylong_http::h2;
22 use ylong_http::h2::{ErrorCode, Frame, FrameFlags, H2Error, Payload, PseudoHeaders};
23 use ylong_http::headers::Headers;
24 use ylong_http::request::uri::Scheme;
25 use ylong_http::request::{Request, RequestPart};
26 use ylong_http::response::status::StatusCode;
27 use ylong_http::response::{Response, ResponsePart};
28
29 use crate::async_impl::client::Retryable;
30 use crate::async_impl::conn::HttpBody;
31 use crate::async_impl::StreamData;
32 use crate::error::{ErrorKind, HttpClientError};
33 use crate::util::dispatcher::http2::Http2Conn;
34 use crate::{AsyncRead, AsyncWrite, ReadBuf};
35
36 const UNUSED_FLAG: u8 = 0x0;
37
request<S, T>( mut conn: Http2Conn<S>, request: &mut Request<T>, retryable: &mut Retryable, ) -> Result<Response<HttpBody>, HttpClientError> where T: Body, S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,38 pub(crate) async fn request<S, T>(
39 mut conn: Http2Conn<S>,
40 request: &mut Request<T>,
41 retryable: &mut Retryable,
42 ) -> Result<Response<HttpBody>, HttpClientError>
43 where
44 T: Body,
45 S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,
46 {
47 let part = request.part().clone();
48 let body = request.body_mut();
49
50 // TODO Due to the reason of the Body structure, the use of the trailer is not
51 // implemented here for the time being, and it needs to be completed after the
52 // Body trait is provided to obtain the trailer interface
53 match build_data_frame(conn.id as usize, body).await? {
54 None => {
55 let headers = build_headers_frame(conn.id, part, true)
56 .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?;
57 conn.send_frame_to_controller(headers).map_err(|e| {
58 retryable.set_retry(true);
59 HttpClientError::new_with_cause(ErrorKind::Request, Some(e))
60 })?;
61 }
62 Some(data) => {
63 let headers = build_headers_frame(conn.id, part, false)
64 .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?;
65 conn.send_frame_to_controller(headers).map_err(|e| {
66 retryable.set_retry(true);
67 HttpClientError::new_with_cause(ErrorKind::Request, Some(e))
68 })?;
69 conn.send_frame_to_controller(data).map_err(|e| {
70 retryable.set_retry(true);
71 HttpClientError::new_with_cause(ErrorKind::Request, Some(e))
72 })?;
73 }
74 }
75 let frame = Pin::new(&mut conn.stream_info)
76 .await
77 .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?;
78 frame_2_response(conn, frame, retryable)
79 }
80
frame_2_response<S>( conn: Http2Conn<S>, headers_frame: Frame, retryable: &mut Retryable, ) -> Result<Response<HttpBody>, HttpClientError> where S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,81 fn frame_2_response<S>(
82 conn: Http2Conn<S>,
83 headers_frame: Frame,
84 retryable: &mut Retryable,
85 ) -> Result<Response<HttpBody>, HttpClientError>
86 where
87 S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,
88 {
89 let part = match headers_frame.payload() {
90 Payload::Headers(headers) => {
91 let (pseudo, fields) = headers.parts();
92 let status_code = match pseudo.status() {
93 Some(status) => StatusCode::from_bytes(status.as_bytes())
94 .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?,
95 None => {
96 return Err(HttpClientError::new_with_cause(
97 ErrorKind::Request,
98 Some(HttpError::from(H2Error::StreamError(
99 conn.id,
100 ErrorCode::ProtocolError,
101 ))),
102 ));
103 }
104 };
105 ResponsePart {
106 version: ylong_http::version::Version::HTTP2,
107 status: status_code,
108 headers: fields.clone(),
109 }
110 }
111 Payload::RstStream(reset) => {
112 return Err(HttpClientError::new_with_cause(
113 ErrorKind::Request,
114 Some(HttpError::from(reset.error(conn.id).map_err(|e| {
115 HttpClientError::new_with_cause(ErrorKind::Request, Some(e))
116 })?)),
117 ));
118 }
119 Payload::Goaway(_) => {
120 // return Err(HttpClientError::from(ErrorKind::Resend));
121 retryable.set_retry(true);
122 return Err(HttpClientError::new_with_message(
123 ErrorKind::Request,
124 "GoAway",
125 ));
126 }
127 _ => {
128 return Err(HttpClientError::new_with_cause(
129 ErrorKind::Request,
130 Some(HttpError::from(H2Error::StreamError(
131 conn.id,
132 ErrorCode::ProtocolError,
133 ))),
134 ));
135 }
136 };
137
138 let body = {
139 if headers_frame.flags().is_end_stream() {
140 HttpBody::empty()
141 } else {
142 // TODO Can Content-Length in h2 be null?
143 let content_length = part
144 .headers
145 .get("Content-Length")
146 .map(|v| v.to_str().unwrap_or(String::new()))
147 .and_then(|s| s.parse::<usize>().ok());
148 match content_length {
149 None => HttpBody::empty(),
150 Some(0) => HttpBody::empty(),
151 Some(size) => {
152 let text_io = TextIo::new(conn);
153 HttpBody::text(size, &[0u8; 0], Box::new(text_io))
154 }
155 }
156 }
157 };
158 Ok(Response::from_raw_parts(part, body))
159 }
160
build_data_frame<T: Body>( id: usize, body: &mut T, ) -> Result<Option<Frame>, HttpClientError>161 pub(crate) async fn build_data_frame<T: Body>(
162 id: usize,
163 body: &mut T,
164 ) -> Result<Option<Frame>, HttpClientError> {
165 let mut data_vec = vec![];
166 let mut buf = [0u8; 1024];
167 loop {
168 let size = body
169 .data(&mut buf)
170 .await
171 .map_err(|e| HttpClientError::new_with_cause(ErrorKind::Request, Some(e)))?;
172 if size == 0 {
173 break;
174 }
175 data_vec.extend_from_slice(&buf[..size]);
176 }
177 if data_vec.is_empty() {
178 Ok(None)
179 } else {
180 // TODO When the Body trait supports trailer, END_STREAM_FLAG needs to be
181 // modified
182 let mut flag = FrameFlags::new(UNUSED_FLAG);
183 flag.set_end_stream(true);
184 Ok(Some(Frame::new(
185 id,
186 flag,
187 Payload::Data(h2::Data::new(data_vec)),
188 )))
189 }
190 }
191
build_headers_frame( id: u32, part: RequestPart, is_end_stream: bool, ) -> Result<Frame, HttpError>192 pub(crate) fn build_headers_frame(
193 id: u32,
194 part: RequestPart,
195 is_end_stream: bool,
196 ) -> Result<Frame, HttpError> {
197 check_connection_specific_headers(id, &part.headers)?;
198 let pseudo = build_pseudo_headers(&part);
199 let mut header_part = h2::Parts::new();
200 header_part.set_header_lines(part.headers);
201 header_part.set_pseudo(pseudo);
202 let headers_payload = h2::Headers::new(header_part);
203
204 let mut flag = FrameFlags::new(UNUSED_FLAG);
205 flag.set_end_headers(true);
206 if is_end_stream {
207 flag.set_end_stream(true);
208 }
209 Ok(Frame::new(
210 id as usize,
211 flag,
212 Payload::Headers(headers_payload),
213 ))
214 }
215
216 // Illegal headers validation in http2.
217 // [`Connection-Specific Headers`] implementation.
218 //
219 // [`Connection-Specific Headers`]: https://www.rfc-editor.org/rfc/rfc9113.html#name-connection-specific-header-
check_connection_specific_headers(id: u32, headers: &Headers) -> Result<(), HttpError>220 fn check_connection_specific_headers(id: u32, headers: &Headers) -> Result<(), HttpError> {
221 const CONNECTION_SPECIFIC_HEADERS: &[&str; 5] = &[
222 "connection",
223 "keep-alive",
224 "proxy-connection",
225 "upgrade",
226 "transfer-encoding",
227 ];
228 for specific_header in CONNECTION_SPECIFIC_HEADERS.iter() {
229 if headers.get(*specific_header).is_some() {
230 return Err(H2Error::StreamError(id, ErrorCode::ProtocolError).into());
231 }
232 }
233 if let Some(te_value) = headers.get("te") {
234 if te_value.to_str()? != "trailers" {
235 return Err(H2Error::StreamError(id, ErrorCode::ProtocolError).into());
236 }
237 }
238 Ok(())
239 }
240
build_pseudo_headers(request_part: &RequestPart) -> PseudoHeaders241 fn build_pseudo_headers(request_part: &RequestPart) -> PseudoHeaders {
242 let mut pseudo = PseudoHeaders::default();
243 match request_part.uri.scheme() {
244 Some(scheme) => {
245 pseudo.set_scheme(Some(String::from(scheme.as_str())));
246 }
247 None => pseudo.set_scheme(Some(String::from(Scheme::HTTP.as_str()))),
248 }
249 pseudo.set_method(Some(String::from(request_part.method.as_str())));
250 pseudo.set_path(
251 request_part
252 .uri
253 .path_and_query()
254 .or_else(|| Some(String::from("/"))),
255 );
256 // TODO Validity verification is required, for example: `Authority` must be
257 // consistent with the `Host` header
258 pseudo.set_authority(request_part.uri.authority().map(|auth| auth.to_string()));
259 pseudo
260 }
261
262 struct TextIo<S> {
263 pub(crate) handle: Http2Conn<S>,
264 pub(crate) offset: usize,
265 pub(crate) remain: Option<Frame>,
266 pub(crate) is_closed: bool,
267 }
268
269 impl<S> TextIo<S> {
new(handle: Http2Conn<S>) -> Self270 pub(crate) fn new(handle: Http2Conn<S>) -> Self {
271 Self {
272 handle,
273 offset: 0,
274 remain: None,
275 is_closed: false,
276 }
277 }
278 }
279
280 impl<S: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static> StreamData for TextIo<S> {
shutdown(&self)281 fn shutdown(&self) {
282 todo!()
283 }
284 }
285
286 impl<S: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static> AsyncRead for TextIo<S> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>287 fn poll_read(
288 self: Pin<&mut Self>,
289 cx: &mut Context<'_>,
290 buf: &mut ReadBuf<'_>,
291 ) -> Poll<std::io::Result<()>> {
292 let text_io = self.get_mut();
293
294 if buf.remaining() == 0 || text_io.is_closed {
295 return Poll::Ready(Ok(()));
296 }
297 while buf.remaining() != 0 {
298 if let Some(frame) = &text_io.remain {
299 match frame.payload() {
300 Payload::Headers(_) => {
301 break;
302 }
303 Payload::Data(data) => {
304 let data = data.data();
305 let unfilled_len = buf.remaining();
306 let data_len = data.len() - text_io.offset;
307 let fill_len = min(unfilled_len, data_len);
308 if unfilled_len < data_len {
309 buf.put_slice(&data[text_io.offset..text_io.offset + fill_len]);
310 text_io.offset += fill_len;
311 break;
312 } else {
313 buf.put_slice(&data[text_io.offset..text_io.offset + fill_len]);
314 text_io.offset += fill_len;
315 if frame.flags().is_end_stream() {
316 text_io.is_closed = true;
317 break;
318 }
319 }
320 }
321 _ => {
322 return Poll::Ready(Err(std::io::Error::new(
323 std::io::ErrorKind::Other,
324 HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
325 )))
326 }
327 }
328 }
329
330 let poll_result = Pin::new(&mut text_io.handle.stream_info)
331 .poll(cx)
332 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
333
334 // TODO Added the frame type.
335 match poll_result {
336 Poll::Ready(frame) => match frame.payload() {
337 Payload::Headers(_) => {
338 text_io.remain = Some(frame);
339 text_io.offset = 0;
340 break;
341 }
342 Payload::Data(data) => {
343 let data = data.data();
344 let unfilled_len = buf.remaining();
345 let data_len = data.len();
346 let fill_len = min(data_len, unfilled_len);
347 if unfilled_len < data_len {
348 buf.put_slice(&data[..fill_len]);
349 text_io.offset += fill_len;
350 text_io.remain = Some(frame);
351 break;
352 } else {
353 buf.put_slice(&data[..fill_len]);
354 if frame.flags().is_end_stream() {
355 text_io.is_closed = true;
356 break;
357 }
358 }
359 }
360 Payload::RstStream(_) => {
361 return Poll::Ready(Err(std::io::Error::new(
362 std::io::ErrorKind::Other,
363 HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
364 )))
365 }
366 _ => {
367 return Poll::Ready(Err(std::io::Error::new(
368 std::io::ErrorKind::Other,
369 HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
370 )))
371 }
372 },
373 Poll::Pending => {
374 return Poll::Pending;
375 }
376 }
377 }
378 Poll::Ready(Ok(()))
379 }
380 }
381
382 #[cfg(feature = "http2")]
383 #[cfg(test)]
384 mod ut_http2 {
385 use ylong_http::body::TextBody;
386 use ylong_http::h2::{ErrorCode, H2Error, Payload};
387 use ylong_http::request::RequestBuilder;
388
389 use crate::async_impl::conn::http2::{build_data_frame, build_headers_frame};
390
391 macro_rules! build_request {
392 (
393 Request: {
394 Method: $method: expr,
395 Uri: $uri:expr,
396 Version: $version: expr,
397 $(
398 Header: $req_n: expr, $req_v: expr,
399 )*
400 Body: $req_body: expr,
401 }
402 ) => {
403 RequestBuilder::new()
404 .method($method)
405 .url($uri)
406 .version($version)
407 $(.header($req_n, $req_v))*
408 .body(TextBody::from_bytes($req_body.as_bytes()))
409 .expect("Request build failed")
410 }
411 }
412
413 #[test]
ut_http2_build_headers_frame()414 fn ut_http2_build_headers_frame() {
415 let request = build_request!(
416 Request: {
417 Method: "GET",
418 Uri: "http://127.0.0.1:0/data",
419 Version: "HTTP/2.0",
420 Header: "te", "trailers",
421 Header: "host", "127.0.0.1:0",
422 Body: "Hi",
423 }
424 );
425 let frame = build_headers_frame(1, request.part().clone(), false)
426 .expect("headers frame build failed");
427 assert_eq!(frame.flags().bits(), 0x4);
428 let frame = build_headers_frame(1, request.part().clone(), true)
429 .expect("headers frame build failed");
430 assert_eq!(frame.stream_id(), 1);
431 assert_eq!(frame.flags().bits(), 0x5);
432 if let Payload::Headers(headers) = frame.payload() {
433 let (pseudo, _headers) = headers.parts();
434 assert_eq!(pseudo.status(), None);
435 assert_eq!(pseudo.scheme().unwrap(), "http");
436 assert_eq!(pseudo.method().unwrap(), "GET");
437 assert_eq!(pseudo.authority().unwrap(), "127.0.0.1:0");
438 assert_eq!(pseudo.path().unwrap(), "/data")
439 } else {
440 panic!("Unexpected frame type")
441 }
442 let request = build_request!(
443 Request: {
444 Method: "GET",
445 Uri: "http://127.0.0.1:0/data",
446 Version: "HTTP/2.0",
447 Header: "upgrade", "h2",
448 Header: "host", "127.0.0.1:0",
449 Body: "Hi",
450 }
451 );
452 let frame = build_headers_frame(1, request.part().clone(), true);
453 assert_eq!(
454 frame.err(),
455 Some(H2Error::StreamError(1, ErrorCode::ProtocolError).into())
456 );
457 }
458 }
459