• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use crate::util::ConnInfo;
15 use crate::{ConnDetail, TimeGroup};
16 
17 pub(crate) trait Dispatcher {
18     type Handle;
19 
dispatch(&self) -> Option<Self::Handle>20     fn dispatch(&self) -> Option<Self::Handle>;
21 
is_shutdown(&self) -> bool22     fn is_shutdown(&self) -> bool;
23 
24     #[allow(dead_code)]
is_goaway(&self) -> bool25     fn is_goaway(&self) -> bool;
26 }
27 
28 pub(crate) enum ConnDispatcher<S> {
29     #[cfg(feature = "http1_1")]
30     Http1(http1::Http1Dispatcher<S>),
31 
32     #[cfg(feature = "http2")]
33     Http2(http2::Http2Dispatcher<S>),
34 
35     #[cfg(feature = "http3")]
36     Http3(http3::Http3Dispatcher<S>),
37 }
38 
39 impl<S> Dispatcher for ConnDispatcher<S> {
40     type Handle = Conn<S>;
41 
dispatch(&self) -> Option<Self::Handle>42     fn dispatch(&self) -> Option<Self::Handle> {
43         match self {
44             #[cfg(feature = "http1_1")]
45             Self::Http1(h1) => h1.dispatch().map(Conn::Http1),
46 
47             #[cfg(feature = "http2")]
48             Self::Http2(h2) => h2.dispatch().map(Conn::Http2),
49 
50             #[cfg(feature = "http3")]
51             Self::Http3(h3) => h3.dispatch().map(Conn::Http3),
52         }
53     }
54 
is_shutdown(&self) -> bool55     fn is_shutdown(&self) -> bool {
56         match self {
57             #[cfg(feature = "http1_1")]
58             Self::Http1(h1) => h1.is_shutdown(),
59 
60             #[cfg(feature = "http2")]
61             Self::Http2(h2) => h2.is_shutdown(),
62 
63             #[cfg(feature = "http3")]
64             Self::Http3(h3) => h3.is_shutdown(),
65         }
66     }
67 
is_goaway(&self) -> bool68     fn is_goaway(&self) -> bool {
69         match self {
70             #[cfg(feature = "http1_1")]
71             Self::Http1(h1) => h1.is_goaway(),
72 
73             #[cfg(feature = "http2")]
74             Self::Http2(h2) => h2.is_goaway(),
75 
76             #[cfg(feature = "http3")]
77             Self::Http3(h3) => h3.is_goaway(),
78         }
79     }
80 }
81 
82 pub(crate) enum Conn<S> {
83     #[cfg(feature = "http1_1")]
84     Http1(http1::Http1Conn<S>),
85 
86     #[cfg(feature = "http2")]
87     Http2(http2::Http2Conn<S>),
88 
89     #[cfg(feature = "http3")]
90     Http3(http3::Http3Conn<S>),
91 }
92 
93 impl<S: ConnInfo> Conn<S> {
get_detail(&mut self) -> ConnDetail94     pub(crate) fn get_detail(&mut self) -> ConnDetail {
95         match self {
96             #[cfg(feature = "http1_1")]
97             Conn::Http1(io) => io.raw_mut().conn_data().detail(),
98             #[cfg(feature = "http2")]
99             Conn::Http2(io) => io.detail.clone(),
100             #[cfg(feature = "http3")]
101             Conn::Http3(io) => io.detail.clone(),
102         }
103     }
104 }
105 
106 pub(crate) struct TimeInfoConn<S> {
107     conn: Conn<S>,
108     time_group: TimeGroup,
109 }
110 
111 impl<S> TimeInfoConn<S> {
new(conn: Conn<S>, time_group: TimeGroup) -> Self112     pub(crate) fn new(conn: Conn<S>, time_group: TimeGroup) -> Self {
113         Self { conn, time_group }
114     }
115 
time_group_mut(&mut self) -> &mut TimeGroup116     pub(crate) fn time_group_mut(&mut self) -> &mut TimeGroup {
117         &mut self.time_group
118     }
119 
time_group(&mut self) -> &TimeGroup120     pub(crate) fn time_group(&mut self) -> &TimeGroup {
121         &self.time_group
122     }
123 
connection(self) -> Conn<S>124     pub(crate) fn connection(self) -> Conn<S> {
125         self.conn
126     }
127 }
128 
129 #[cfg(feature = "http1_1")]
130 pub(crate) mod http1 {
131     use std::cell::UnsafeCell;
132     use std::sync::atomic::{AtomicBool, Ordering};
133     use std::sync::Arc;
134 
135     use super::{ConnDispatcher, Dispatcher};
136     use crate::runtime::Semaphore;
137     #[cfg(feature = "tokio_base")]
138     use crate::runtime::SemaphorePermit;
139     use crate::util::progress::SpeedController;
140 
141     impl<S> ConnDispatcher<S> {
http1(io: S) -> Self142         pub(crate) fn http1(io: S) -> Self {
143             Self::Http1(Http1Dispatcher::new(io))
144         }
145     }
146 
147     /// HTTP1-based connection manager, which can dispatch connections to other
148     /// threads according to HTTP1 syntax.
149     pub(crate) struct Http1Dispatcher<S> {
150         inner: Arc<Inner<S>>,
151     }
152 
153     pub(crate) struct Inner<S> {
154         pub(crate) io: UnsafeCell<S>,
155         // `occupied` indicates that the connection is occupied. Only one coroutine
156         // can get the handle at the same time. Once the handle is fetched, the flag
157         // position is true.
158         pub(crate) occupied: AtomicBool,
159         // `shutdown` indicates that the connection need to be shut down.
160         pub(crate) shutdown: AtomicBool,
161     }
162 
163     unsafe impl<S> Sync for Inner<S> {}
164 
165     impl<S> Http1Dispatcher<S> {
new(io: S) -> Self166         pub(crate) fn new(io: S) -> Self {
167             Self {
168                 inner: Arc::new(Inner {
169                     io: UnsafeCell::new(io),
170                     occupied: AtomicBool::new(false),
171                     shutdown: AtomicBool::new(false),
172                 }),
173             }
174         }
175     }
176 
177     impl<S> Dispatcher for Http1Dispatcher<S> {
178         type Handle = Http1Conn<S>;
179 
dispatch(&self) -> Option<Self::Handle>180         fn dispatch(&self) -> Option<Self::Handle> {
181             self.inner
182                 .occupied
183                 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
184                 .ok()
185                 .map(|_| Http1Conn::from_inner(self.inner.clone()))
186         }
187 
is_shutdown(&self) -> bool188         fn is_shutdown(&self) -> bool {
189             self.inner.shutdown.load(Ordering::Relaxed)
190         }
191 
is_goaway(&self) -> bool192         fn is_goaway(&self) -> bool {
193             false
194         }
195     }
196 
197     /// Handle returned to other threads for I/O operations.
198     pub(crate) struct Http1Conn<S> {
199         pub(crate) speed_controller: SpeedController,
200         pub(crate) sem: Option<WrappedSemPermit>,
201         pub(crate) inner: Arc<Inner<S>>,
202     }
203 
204     impl<S> Http1Conn<S> {
from_inner(inner: Arc<Inner<S>>) -> Self205         pub(crate) fn from_inner(inner: Arc<Inner<S>>) -> Self {
206             Self {
207                 speed_controller: SpeedController::none(),
208                 sem: None,
209                 inner,
210             }
211         }
212 
occupy_sem(&mut self, sem: WrappedSemPermit)213         pub(crate) fn occupy_sem(&mut self, sem: WrappedSemPermit) {
214             self.sem = Some(sem);
215         }
216 
raw_mut(&mut self) -> &mut S217         pub(crate) fn raw_mut(&mut self) -> &mut S {
218             // SAFETY: In the case of `HTTP1`, only one coroutine gets the handle
219             // at the same time.
220             unsafe { &mut *self.inner.io.get() }
221         }
222 
shutdown(&self)223         pub(crate) fn shutdown(&self) {
224             self.inner.shutdown.store(true, Ordering::Release);
225         }
226 
cancel_guard(&self) -> CancelGuard<S>227         pub(crate) fn cancel_guard(&self) -> CancelGuard<S> {
228             CancelGuard {
229                 inner: self.inner.clone(),
230                 running: true,
231             }
232         }
233     }
234 
235     impl<S> Drop for Http1Conn<S> {
drop(&mut self)236         fn drop(&mut self) {
237             self.inner.occupied.store(false, Ordering::Release)
238         }
239     }
240 
241     /// Http1 cancel guard
242     pub(crate) struct CancelGuard<S> {
243         inner: Arc<Inner<S>>,
244         /// Default true
245         running: bool,
246     }
247 
248     impl<S> CancelGuard<S> {
normal_end(&mut self)249         pub(crate) fn normal_end(&mut self) {
250             self.running = false
251         }
252     }
253 
254     impl<S> Drop for CancelGuard<S> {
drop(&mut self)255         fn drop(&mut self) {
256             // When a drop occurs, if running is still true, it means a cancel has occurred,
257             // and the IO needs to be shutdown to prevent the reuse of dirty data
258             if self.running {
259                 self.inner.shutdown.store(true, Ordering::Release);
260             }
261         }
262     }
263 
264     pub(crate) struct WrappedSemaphore {
265         sem: Arc<Semaphore>,
266     }
267 
268     impl WrappedSemaphore {
new(permits: usize) -> Self269         pub(crate) fn new(permits: usize) -> Self {
270             Self {
271                 #[cfg(feature = "tokio_base")]
272                 sem: Arc::new(tokio::sync::Semaphore::new(permits)),
273                 #[cfg(feature = "ylong_base")]
274                 sem: Arc::new(ylong_runtime::sync::Semaphore::new(permits).unwrap()),
275             }
276         }
277 
acquire(&self) -> WrappedSemPermit278         pub(crate) async fn acquire(&self) -> WrappedSemPermit {
279             #[cfg(feature = "ylong_base")]
280             {
281                 let semaphore = self.sem.clone();
282                 let _permit = semaphore.acquire().await.unwrap();
283                 WrappedSemPermit { sem: semaphore }
284             }
285 
286             #[cfg(feature = "tokio_base")]
287             {
288                 let permit = self.sem.clone().acquire_owned().await.unwrap();
289                 WrappedSemPermit { permit }
290             }
291         }
292     }
293 
294     impl Clone for WrappedSemaphore {
clone(&self) -> Self295         fn clone(&self) -> Self {
296             Self {
297                 sem: self.sem.clone(),
298             }
299         }
300     }
301 
302     pub(crate) struct WrappedSemPermit {
303         #[cfg(feature = "ylong_base")]
304         pub(crate) sem: Arc<Semaphore>,
305         #[cfg(feature = "tokio_base")]
306         #[allow(dead_code)]
307         pub(crate) permit: SemaphorePermit,
308     }
309 
310     #[cfg(feature = "ylong_base")]
311     impl Drop for WrappedSemPermit {
drop(&mut self)312         fn drop(&mut self) {
313             self.sem.release();
314         }
315     }
316 }
317 
318 #[cfg(feature = "http2")]
319 pub(crate) mod http2 {
320     use std::collections::HashMap;
321     use std::future::Future;
322     use std::marker::PhantomData;
323     use std::pin::Pin;
324     use std::sync::atomic::{AtomicBool, Ordering};
325     use std::sync::{Arc, Mutex};
326     use std::task::{Context, Poll};
327 
328     use ylong_http::error::HttpError;
329     use ylong_http::h2::{
330         ErrorCode, Frame, FrameDecoder, FrameEncoder, FrameFlags, Goaway, H2Error, Payload,
331         RstStream, Settings, SettingsBuilder, StreamId,
332     };
333 
334     use crate::runtime::{
335         bounded_channel, unbounded_channel, AsyncRead, AsyncWrite, AsyncWriteExt, BoundedReceiver,
336         BoundedSender, SendError, UnboundedReceiver, UnboundedSender, WriteHalf,
337     };
338     use crate::util::config::H2Config;
339     use crate::util::dispatcher::{ConnDispatcher, Dispatcher};
340     use crate::util::h2::{
341         ConnManager, FlowControl, H2StreamState, RecvData, RequestWrapper, SendData,
342         StreamEndState, Streams,
343     };
344     use crate::util::progress::SpeedController;
345     use crate::ErrorKind::Request;
346     use crate::{ConnDetail, ErrorKind, HttpClientError};
347 
348     const DEFAULT_MAX_FRAME_SIZE: usize = 2 << 13;
349     const DEFAULT_WINDOW_SIZE: u32 = 65535;
350 
351     pub(crate) type ManagerSendFut =
352         Pin<Box<dyn Future<Output = Result<(), SendError<RespMessage>>> + Send + Sync>>;
353 
354     pub(crate) enum RespMessage {
355         Output(Frame),
356         OutputExit(DispatchErrorKind),
357     }
358 
359     pub(crate) enum OutputMessage {
360         Output(Frame),
361         OutputExit(DispatchErrorKind),
362     }
363 
364     pub(crate) struct ReqMessage {
365         pub(crate) sender: BoundedSender<RespMessage>,
366         pub(crate) request: RequestWrapper,
367     }
368 
369     #[derive(Debug, Eq, PartialEq, Copy, Clone)]
370     pub(crate) enum DispatchErrorKind {
371         H2(H2Error),
372         Io(std::io::ErrorKind),
373         ChannelClosed,
374         Disconnect,
375     }
376 
377     // HTTP2-based connection manager, which can dispatch connections to other
378     // threads according to HTTP2 syntax.
379     pub(crate) struct Http2Dispatcher<S> {
380         pub(crate) detail: ConnDetail,
381         pub(crate) allowed_cache: usize,
382         pub(crate) sender: UnboundedSender<ReqMessage>,
383         pub(crate) io_shutdown: Arc<AtomicBool>,
384         pub(crate) io_goaway: Arc<AtomicBool>,
385         pub(crate) handles: Vec<crate::runtime::JoinHandle<()>>,
386         pub(crate) _mark: PhantomData<S>,
387     }
388 
389     pub(crate) struct Http2Conn<S> {
390         pub(crate) speed_controller: SpeedController,
391         pub(crate) allow_cached_frames: usize,
392         // Sends frame to StreamController
393         pub(crate) sender: UnboundedSender<ReqMessage>,
394         pub(crate) receiver: RespReceiver,
395         pub(crate) io_shutdown: Arc<AtomicBool>,
396         pub(crate) detail: ConnDetail,
397         pub(crate) _mark: PhantomData<S>,
398     }
399 
400     pub(crate) struct StreamController {
401         // The connection close flag organizes new stream commits to the current connection when
402         // closed.
403         pub(crate) io_shutdown: Arc<AtomicBool>,
404         pub(crate) io_goaway: Arc<AtomicBool>,
405         // The senders of all connected stream channels of response.
406         pub(crate) senders: HashMap<StreamId, BoundedSender<RespMessage>>,
407         pub(crate) curr_message: HashMap<StreamId, ManagerSendFut>,
408         // Stream information on the connection.
409         pub(crate) streams: Streams,
410         // Received GO_AWAY frame.
411         pub(crate) go_away_error_code: Option<u32>,
412         // The last GO_AWAY frame sent by the client.
413         pub(crate) go_away_sync: GoAwaySync,
414     }
415 
416     #[derive(Default)]
417     pub(crate) struct GoAwaySync {
418         pub(crate) going_away: Option<Goaway>,
419     }
420 
421     #[derive(Default)]
422     pub(crate) struct SettingsSync {
423         pub(crate) settings: SettingsState,
424     }
425 
426     #[derive(Default, Clone)]
427     pub(crate) enum SettingsState {
428         Acknowledging(Settings),
429         #[default]
430         Synced,
431     }
432 
433     #[derive(Default)]
434     pub(crate) struct RespReceiver {
435         receiver: Option<BoundedReceiver<RespMessage>>,
436     }
437 
438     impl<S> ConnDispatcher<S>
439     where
440         S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,
441     {
http2(detail: ConnDetail, config: H2Config, io: S) -> Self442         pub(crate) fn http2(detail: ConnDetail, config: H2Config, io: S) -> Self {
443             Self::Http2(Http2Dispatcher::new(detail, config, io))
444         }
445     }
446 
447     impl<S> Http2Dispatcher<S>
448     where
449         S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,
450     {
new(detail: ConnDetail, config: H2Config, io: S) -> Self451         pub(crate) fn new(detail: ConnDetail, config: H2Config, io: S) -> Self {
452             let mut flow = FlowControl::new(DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE);
453             flow.setup_recv_window(config.conn_window_size());
454 
455             let streams = Streams::new(config.stream_window_size(), DEFAULT_WINDOW_SIZE, flow);
456             let shutdown_flag = Arc::new(AtomicBool::new(false));
457             let goaway_flag = Arc::new(AtomicBool::new(false));
458             let mut controller =
459                 StreamController::new(streams, shutdown_flag.clone(), goaway_flag.clone());
460 
461             let (input_tx, input_rx) = unbounded_channel();
462             let (req_tx, req_rx) = unbounded_channel();
463 
464             let settings = create_initial_settings(&config);
465 
466             // Error is not possible, so it is not handled for the time
467             // being.
468             let mut handles = Vec::with_capacity(3);
469             // send initial settings and update conn recv window
470             if input_tx.send(settings).is_ok()
471                 && controller
472                     .streams
473                     .release_conn_recv_window(0, &input_tx)
474                     .is_ok()
475             {
476                 Self::launch(
477                     config.allowed_cache_frame_size(),
478                     config.use_huffman_coding(),
479                     controller,
480                     (input_tx, input_rx),
481                     req_rx,
482                     &mut handles,
483                     io,
484                 );
485             }
486             Self {
487                 detail,
488                 allowed_cache: config.allowed_cache_frame_size(),
489                 sender: req_tx,
490                 io_shutdown: shutdown_flag,
491                 io_goaway: goaway_flag,
492                 handles,
493                 _mark: PhantomData,
494             }
495         }
496 
launch( allow_num: usize, use_huffman: bool, controller: StreamController, input_channel: (UnboundedSender<Frame>, UnboundedReceiver<Frame>), req_rx: UnboundedReceiver<ReqMessage>, handles: &mut Vec<crate::runtime::JoinHandle<()>>, io: S, )497         fn launch(
498             allow_num: usize,
499             use_huffman: bool,
500             controller: StreamController,
501             input_channel: (UnboundedSender<Frame>, UnboundedReceiver<Frame>),
502             req_rx: UnboundedReceiver<ReqMessage>,
503             handles: &mut Vec<crate::runtime::JoinHandle<()>>,
504             io: S,
505         ) {
506             let (resp_tx, resp_rx) = bounded_channel(allow_num);
507             let (read, write) = crate::runtime::split(io);
508             let settings_sync = Arc::new(Mutex::new(SettingsSync::default()));
509             let send_settings_sync = settings_sync.clone();
510             let send = crate::runtime::spawn(async move {
511                 let mut writer = write;
512                 if async_send_preface(&mut writer).await.is_ok() {
513                     let encoder = FrameEncoder::new(DEFAULT_MAX_FRAME_SIZE, use_huffman);
514                     let mut send =
515                         SendData::new(encoder, send_settings_sync, writer, input_channel.1);
516                     let _ = Pin::new(&mut send).await;
517                 }
518             });
519             handles.push(send);
520 
521             let recv_settings_sync = settings_sync.clone();
522             let recv = crate::runtime::spawn(async move {
523                 let decoder = FrameDecoder::new();
524                 let mut recv = RecvData::new(decoder, recv_settings_sync, read, resp_tx);
525                 let _ = Pin::new(&mut recv).await;
526             });
527             handles.push(recv);
528 
529             let manager = crate::runtime::spawn(async move {
530                 let mut conn_manager =
531                     ConnManager::new(settings_sync, input_channel.0, resp_rx, req_rx, controller);
532                 let _ = Pin::new(&mut conn_manager).await;
533             });
534             handles.push(manager);
535         }
536     }
537 
538     impl<S> Dispatcher for Http2Dispatcher<S> {
539         type Handle = Http2Conn<S>;
540 
dispatch(&self) -> Option<Self::Handle>541         fn dispatch(&self) -> Option<Self::Handle> {
542             let sender = self.sender.clone();
543             let handle = Http2Conn::new(
544                 self.allowed_cache,
545                 self.io_shutdown.clone(),
546                 sender,
547                 self.detail.clone(),
548             );
549             Some(handle)
550         }
551 
is_shutdown(&self) -> bool552         fn is_shutdown(&self) -> bool {
553             self.io_shutdown.load(Ordering::Relaxed)
554         }
555 
is_goaway(&self) -> bool556         fn is_goaway(&self) -> bool {
557             self.io_goaway.load(Ordering::Relaxed)
558         }
559     }
560 
561     impl<S> Drop for Http2Dispatcher<S> {
drop(&mut self)562         fn drop(&mut self) {
563             for handle in &self.handles {
564                 #[cfg(feature = "ylong_base")]
565                 handle.cancel();
566                 #[cfg(feature = "tokio_base")]
567                 handle.abort();
568             }
569         }
570     }
571 
572     impl<S> Http2Conn<S> {
new( allow_cached_num: usize, io_shutdown: Arc<AtomicBool>, sender: UnboundedSender<ReqMessage>, detail: ConnDetail, ) -> Self573         pub(crate) fn new(
574             allow_cached_num: usize,
575             io_shutdown: Arc<AtomicBool>,
576             sender: UnboundedSender<ReqMessage>,
577             detail: ConnDetail,
578         ) -> Self {
579             Self {
580                 speed_controller: SpeedController::none(),
581                 allow_cached_frames: allow_cached_num,
582                 sender,
583                 receiver: RespReceiver::default(),
584                 io_shutdown,
585                 detail,
586                 _mark: PhantomData,
587             }
588         }
589 
send_frame_to_controller( &mut self, request: RequestWrapper, ) -> Result<(), HttpClientError>590         pub(crate) fn send_frame_to_controller(
591             &mut self,
592             request: RequestWrapper,
593         ) -> Result<(), HttpClientError> {
594             let (tx, rx) = bounded_channel::<RespMessage>(self.allow_cached_frames);
595             self.receiver.set_receiver(rx);
596             self.sender
597                 .send(ReqMessage {
598                     sender: tx,
599                     request,
600                 })
601                 .map_err(|_| {
602                     HttpClientError::from_str(ErrorKind::Request, "Request Sender Closed !")
603                 })
604         }
605     }
606 
607     impl StreamController {
new( streams: Streams, shutdown: Arc<AtomicBool>, goaway: Arc<AtomicBool>, ) -> Self608         pub(crate) fn new(
609             streams: Streams,
610             shutdown: Arc<AtomicBool>,
611             goaway: Arc<AtomicBool>,
612         ) -> Self {
613             Self {
614                 io_shutdown: shutdown,
615                 io_goaway: goaway,
616                 senders: HashMap::new(),
617                 curr_message: HashMap::new(),
618                 streams,
619                 go_away_error_code: None,
620                 go_away_sync: GoAwaySync::default(),
621             }
622         }
623 
shutdown(&self)624         pub(crate) fn shutdown(&self) {
625             self.io_shutdown.store(true, Ordering::Release);
626         }
627 
goaway(&self)628         pub(crate) fn goaway(&self) {
629             self.io_goaway.store(true, Ordering::Release);
630         }
631 
get_unsent_streams( &mut self, last_stream_id: StreamId, ) -> Result<Vec<StreamId>, H2Error>632         pub(crate) fn get_unsent_streams(
633             &mut self,
634             last_stream_id: StreamId,
635         ) -> Result<Vec<StreamId>, H2Error> {
636             // The last-stream-id in the subsequent GO_AWAY frame
637             // cannot be greater than the last-stream-id in the previous GO_AWAY frame.
638             if self.streams.max_send_id < last_stream_id {
639                 return Err(H2Error::ConnectionError(ErrorCode::ProtocolError));
640             }
641             self.streams.max_send_id = last_stream_id;
642             Ok(self.streams.get_unset_streams(last_stream_id))
643         }
644 
send_message_to_stream( &mut self, cx: &mut Context<'_>, stream_id: StreamId, message: RespMessage, ) -> Poll<Result<(), H2Error>>645         pub(crate) fn send_message_to_stream(
646             &mut self,
647             cx: &mut Context<'_>,
648             stream_id: StreamId,
649             message: RespMessage,
650         ) -> Poll<Result<(), H2Error>> {
651             if let Some(sender) = self.senders.get(&stream_id) {
652                 // If the client coroutine has exited, this frame is skipped.
653                 let mut tx = {
654                     let sender = sender.clone();
655                     let ft = async move { sender.send(message).await };
656                     Box::pin(ft)
657                 };
658 
659                 match tx.as_mut().poll(cx) {
660                     Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
661                     // The current coroutine sending the request exited prematurely.
662                     Poll::Ready(Err(_)) => {
663                         self.senders.remove(&stream_id);
664                         Poll::Ready(Err(H2Error::StreamError(stream_id, ErrorCode::NoError)))
665                     }
666                     Poll::Pending => {
667                         self.curr_message.insert(stream_id, tx);
668                         Poll::Pending
669                     }
670                 }
671             } else {
672                 Poll::Ready(Err(H2Error::StreamError(stream_id, ErrorCode::NoError)))
673             }
674         }
675 
poll_blocked_message( &mut self, cx: &mut Context<'_>, input_tx: &UnboundedSender<Frame>, ) -> Poll<()>676         pub(crate) fn poll_blocked_message(
677             &mut self,
678             cx: &mut Context<'_>,
679             input_tx: &UnboundedSender<Frame>,
680         ) -> Poll<()> {
681             let keys: Vec<StreamId> = self.curr_message.keys().cloned().collect();
682             let mut blocked = false;
683 
684             for key in keys {
685                 if let Some(mut task) = self.curr_message.remove(&key) {
686                     match task.as_mut().poll(cx) {
687                         Poll::Ready(Ok(_)) => {}
688                         // The current coroutine sending the request exited prematurely.
689                         Poll::Ready(Err(_)) => {
690                             self.senders.remove(&key);
691                             if let Some(state) = self.streams.stream_state(key) {
692                                 if !matches!(state, H2StreamState::Closed(_)) {
693                                     if let StreamEndState::OK = self.streams.send_local_reset(key) {
694                                         let rest_payload =
695                                             RstStream::new(ErrorCode::NoError.into_code());
696                                         let frame = Frame::new(
697                                             key,
698                                             FrameFlags::empty(),
699                                             Payload::RstStream(rest_payload),
700                                         );
701                                         // ignore the send error occurs here in order to finish all
702                                         // tasks.
703                                         let _ = input_tx.send(frame);
704                                     }
705                                 }
706                             }
707                         }
708                         Poll::Pending => {
709                             self.curr_message.insert(key, task);
710                             blocked = true;
711                         }
712                     }
713                 }
714             }
715             if blocked {
716                 Poll::Pending
717             } else {
718                 Poll::Ready(())
719             }
720         }
721     }
722 
723     impl RespReceiver {
set_receiver(&mut self, receiver: BoundedReceiver<RespMessage>)724         pub(crate) fn set_receiver(&mut self, receiver: BoundedReceiver<RespMessage>) {
725             self.receiver = Some(receiver);
726         }
727 
recv(&mut self) -> Result<Frame, HttpClientError>728         pub(crate) async fn recv(&mut self) -> Result<Frame, HttpClientError> {
729             match self.receiver {
730                 Some(ref mut receiver) => {
731                     #[cfg(feature = "tokio_base")]
732                     match receiver.recv().await {
733                         None => err_from_msg!(Request, "Response Sender Closed !"),
734                         Some(message) => match message {
735                             RespMessage::Output(frame) => Ok(frame),
736                             RespMessage::OutputExit(e) => Err(dispatch_client_error(e)),
737                         },
738                     }
739 
740                     #[cfg(feature = "ylong_base")]
741                     match receiver.recv().await {
742                         Err(err) => Err(HttpClientError::from_error(ErrorKind::Request, err)),
743                         Ok(message) => match message {
744                             RespMessage::Output(frame) => Ok(frame),
745                             RespMessage::OutputExit(e) => Err(dispatch_client_error(e)),
746                         },
747                     }
748                 }
749                 // this will not happen.
750                 None => Err(HttpClientError::from_str(
751                     ErrorKind::Request,
752                     "Invalid Frame Receiver !",
753                 )),
754             }
755         }
756 
poll_recv( &mut self, cx: &mut Context<'_>, ) -> Poll<Result<Frame, HttpClientError>>757         pub(crate) fn poll_recv(
758             &mut self,
759             cx: &mut Context<'_>,
760         ) -> Poll<Result<Frame, HttpClientError>> {
761             if let Some(ref mut receiver) = self.receiver {
762                 #[cfg(feature = "tokio_base")]
763                 match receiver.poll_recv(cx) {
764                     Poll::Ready(None) => {
765                         Poll::Ready(err_from_msg!(Request, "Response Sender Closed !"))
766                     }
767                     Poll::Ready(Some(message)) => match message {
768                         RespMessage::Output(frame) => Poll::Ready(Ok(frame)),
769                         RespMessage::OutputExit(e) => Poll::Ready(Err(dispatch_client_error(e))),
770                     },
771                     Poll::Pending => Poll::Pending,
772                 }
773 
774                 #[cfg(feature = "ylong_base")]
775                 match receiver.poll_recv(cx) {
776                     Poll::Ready(Err(e)) => {
777                         Poll::Ready(Err(HttpClientError::from_error(ErrorKind::Request, e)))
778                     }
779                     Poll::Ready(Ok(message)) => match message {
780                         RespMessage::Output(frame) => Poll::Ready(Ok(frame)),
781                         RespMessage::OutputExit(e) => Poll::Ready(Err(dispatch_client_error(e))),
782                     },
783                     Poll::Pending => Poll::Pending,
784                 }
785             } else {
786                 Poll::Ready(err_from_msg!(Request, "Invalid Frame Receiver !"))
787             }
788         }
789     }
790 
async_send_preface<S>(writer: &mut WriteHalf<S>) -> Result<(), DispatchErrorKind> where S: AsyncWrite + Unpin,791     async fn async_send_preface<S>(writer: &mut WriteHalf<S>) -> Result<(), DispatchErrorKind>
792     where
793         S: AsyncWrite + Unpin,
794     {
795         const PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
796         writer
797             .write_all(PREFACE)
798             .await
799             .map_err(|e| DispatchErrorKind::Io(e.kind()))
800     }
801 
create_initial_settings(config: &H2Config) -> Frame802     pub(crate) fn create_initial_settings(config: &H2Config) -> Frame {
803         let settings = SettingsBuilder::new()
804             .max_header_list_size(config.max_header_list_size())
805             .max_frame_size(config.max_frame_size())
806             .header_table_size(config.header_table_size())
807             .enable_push(config.enable_push())
808             .initial_window_size(config.stream_window_size())
809             .build();
810 
811         Frame::new(0, FrameFlags::new(0), Payload::Settings(settings))
812     }
813 
814     impl From<std::io::Error> for DispatchErrorKind {
from(value: std::io::Error) -> Self815         fn from(value: std::io::Error) -> Self {
816             DispatchErrorKind::Io(value.kind())
817         }
818     }
819 
820     impl From<H2Error> for DispatchErrorKind {
from(err: H2Error) -> Self821         fn from(err: H2Error) -> Self {
822             DispatchErrorKind::H2(err)
823         }
824     }
825 
dispatch_client_error(dispatch_error: DispatchErrorKind) -> HttpClientError826     pub(crate) fn dispatch_client_error(dispatch_error: DispatchErrorKind) -> HttpClientError {
827         match dispatch_error {
828             DispatchErrorKind::H2(e) => HttpClientError::from_error(Request, HttpError::from(e)),
829             DispatchErrorKind::Io(e) => {
830                 HttpClientError::from_io_error(Request, std::io::Error::from(e))
831             }
832             DispatchErrorKind::ChannelClosed => {
833                 HttpClientError::from_str(Request, "Coroutine channel closed.")
834             }
835             DispatchErrorKind::Disconnect => {
836                 HttpClientError::from_str(Request, "remote peer closed.")
837             }
838         }
839     }
840 }
841 
842 #[cfg(feature = "http3")]
843 pub(crate) mod http3 {
844     use std::marker::PhantomData;
845     use std::pin::Pin;
846     use std::sync::atomic::{AtomicBool, Ordering};
847     use std::sync::{Arc, Mutex};
848 
849     use ylong_http::error::HttpError;
850     use ylong_http::h3::{Frame, FrameDecoder, H3Error};
851 
852     use crate::async_impl::QuicConn;
853     use crate::runtime::{
854         bounded_channel, unbounded_channel, AsyncRead, AsyncWrite, BoundedReceiver, BoundedSender,
855         UnboundedSender,
856     };
857     use crate::util::config::H3Config;
858     use crate::util::data_ref::BodyDataRef;
859     use crate::util::dispatcher::{ConnDispatcher, Dispatcher};
860     use crate::util::h3::io_manager::IOManager;
861     use crate::util::h3::stream_manager::StreamManager;
862     use crate::util::progress::SpeedController;
863     use crate::ErrorKind::Request;
864     use crate::{ConnDetail, ConnInfo, ErrorKind, HttpClientError};
865 
866     pub(crate) struct Http3Dispatcher<S> {
867         pub(crate) detail: ConnDetail,
868         pub(crate) req_tx: UnboundedSender<ReqMessage>,
869         pub(crate) handles: Vec<crate::runtime::JoinHandle<()>>,
870         pub(crate) _mark: PhantomData<S>,
871         pub(crate) io_shutdown: Arc<AtomicBool>,
872         pub(crate) io_goaway: Arc<AtomicBool>,
873     }
874 
875     pub(crate) struct Http3Conn<S> {
876         pub(crate) speed_controller: SpeedController,
877         pub(crate) sender: UnboundedSender<ReqMessage>,
878         pub(crate) resp_receiver: BoundedReceiver<RespMessage>,
879         pub(crate) resp_sender: BoundedSender<RespMessage>,
880         pub(crate) io_shutdown: Arc<AtomicBool>,
881         pub(crate) detail: ConnDetail,
882         pub(crate) _mark: PhantomData<S>,
883     }
884 
885     pub(crate) struct RequestWrapper {
886         pub(crate) header: Frame,
887         pub(crate) data: BodyDataRef,
888     }
889 
890     #[derive(Debug, Clone)]
891     pub(crate) enum DispatchErrorKind {
892         H3(H3Error),
893         Io(std::io::ErrorKind),
894         Quic(quiche::Error),
895         ChannelClosed,
896         StreamFinished,
897         // todo: retry?
898         GoawayReceived,
899         Disconnect,
900     }
901 
902     pub(crate) enum RespMessage {
903         Output(Frame),
904         OutputExit(DispatchErrorKind),
905     }
906 
907     pub(crate) struct ReqMessage {
908         pub(crate) request: RequestWrapper,
909         pub(crate) frame_tx: BoundedSender<RespMessage>,
910     }
911 
912     impl<S> Http3Dispatcher<S>
913     where
914         S: AsyncRead + AsyncWrite + ConnInfo + Sync + Send + Unpin + 'static,
915     {
new( detail: ConnDetail, config: H3Config, io: S, quic_connection: QuicConn, ) -> Self916         pub(crate) fn new(
917             detail: ConnDetail,
918             config: H3Config,
919             io: S,
920             quic_connection: QuicConn,
921         ) -> Self {
922             let (req_tx, req_rx) = unbounded_channel();
923             let (io_manager_tx, io_manager_rx) = unbounded_channel();
924             let (stream_manager_tx, stream_manager_rx) = unbounded_channel();
925             let mut handles = Vec::with_capacity(2);
926             let conn = Arc::new(Mutex::new(quic_connection));
927             let io_shutdown = Arc::new(AtomicBool::new(false));
928             let io_goaway = Arc::new(AtomicBool::new(false));
929             let mut stream_manager = StreamManager::new(
930                 conn.clone(),
931                 io_manager_tx,
932                 stream_manager_rx,
933                 req_rx,
934                 FrameDecoder::new(
935                     config.qpack_blocked_streams() as usize,
936                     config.qpack_max_table_capacity() as usize,
937                 ),
938                 io_shutdown.clone(),
939                 io_goaway.clone(),
940             );
941             let stream_handle = crate::runtime::spawn(async move {
942                 if stream_manager.init(config).is_err() {
943                     return;
944                 }
945                 let _ = Pin::new(&mut stream_manager).await;
946             });
947             handles.push(stream_handle);
948 
949             let io_handle = crate::runtime::spawn(async move {
950                 let mut io_manager = IOManager::new(io, conn, io_manager_rx, stream_manager_tx);
951                 let _ = Pin::new(&mut io_manager).await;
952             });
953             handles.push(io_handle);
954             // read_rx gets readable stream ids and writable client channels, then read
955             // stream and send to the corresponding channel
956             Self {
957                 detail,
958                 req_tx,
959                 handles,
960                 _mark: PhantomData,
961                 io_shutdown,
962                 io_goaway,
963             }
964         }
965     }
966 
967     impl<S> Http3Conn<S> {
new( detail: ConnDetail, sender: UnboundedSender<ReqMessage>, io_shutdown: Arc<AtomicBool>, ) -> Self968         pub(crate) fn new(
969             detail: ConnDetail,
970             sender: UnboundedSender<ReqMessage>,
971             io_shutdown: Arc<AtomicBool>,
972         ) -> Self {
973             const CHANNEL_SIZE: usize = 3;
974             let (resp_sender, resp_receiver) = bounded_channel(CHANNEL_SIZE);
975             Self {
976                 speed_controller: SpeedController::none(),
977                 sender,
978                 resp_sender,
979                 resp_receiver,
980                 _mark: PhantomData,
981                 io_shutdown,
982                 detail,
983             }
984         }
985 
send_frame_to_reader( &mut self, request: RequestWrapper, ) -> Result<(), HttpClientError>986         pub(crate) fn send_frame_to_reader(
987             &mut self,
988             request: RequestWrapper,
989         ) -> Result<(), HttpClientError> {
990             self.sender
991                 .send(ReqMessage {
992                     request,
993                     frame_tx: self.resp_sender.clone(),
994                 })
995                 .map_err(|_| {
996                     HttpClientError::from_str(ErrorKind::Request, "Request Sender Closed !")
997                 })
998         }
999 
recv_resp(&mut self) -> Result<Frame, HttpClientError>1000         pub(crate) async fn recv_resp(&mut self) -> Result<Frame, HttpClientError> {
1001             #[cfg(feature = "tokio_base")]
1002             match self.resp_receiver.recv().await {
1003                 None => err_from_msg!(Request, "Response Sender Closed !"),
1004                 Some(message) => match message {
1005                     RespMessage::Output(frame) => Ok(frame),
1006                     RespMessage::OutputExit(e) => Err(dispatch_client_error(e)),
1007                 },
1008             }
1009 
1010             #[cfg(feature = "ylong_base")]
1011             match self.resp_receiver.recv().await {
1012                 Err(err) => Err(HttpClientError::from_error(ErrorKind::Request, err)),
1013                 Ok(message) => match message {
1014                     RespMessage::Output(frame) => Ok(frame),
1015                     RespMessage::OutputExit(e) => Err(dispatch_client_error(e)),
1016                 },
1017             }
1018         }
1019     }
1020 
1021     impl<S> ConnDispatcher<S>
1022     where
1023         S: AsyncRead + AsyncWrite + ConnInfo + Sync + Send + Unpin + 'static,
1024     {
http3( detail: ConnDetail, config: H3Config, io: S, quic_connection: QuicConn, ) -> Self1025         pub(crate) fn http3(
1026             detail: ConnDetail,
1027             config: H3Config,
1028             io: S,
1029             quic_connection: QuicConn,
1030         ) -> Self {
1031             Self::Http3(Http3Dispatcher::new(detail, config, io, quic_connection))
1032         }
1033     }
1034 
1035     impl<S> Dispatcher for Http3Dispatcher<S> {
1036         type Handle = Http3Conn<S>;
1037 
dispatch(&self) -> Option<Self::Handle>1038         fn dispatch(&self) -> Option<Self::Handle> {
1039             let sender = self.req_tx.clone();
1040             Some(Http3Conn::new(
1041                 self.detail.clone(),
1042                 sender,
1043                 self.io_shutdown.clone(),
1044             ))
1045         }
1046 
is_shutdown(&self) -> bool1047         fn is_shutdown(&self) -> bool {
1048             self.io_shutdown.load(Ordering::Relaxed)
1049         }
1050 
is_goaway(&self) -> bool1051         fn is_goaway(&self) -> bool {
1052             self.io_goaway.load(Ordering::Relaxed)
1053         }
1054     }
1055 
1056     impl<S> Drop for Http3Dispatcher<S> {
drop(&mut self)1057         fn drop(&mut self) {
1058             for handle in &self.handles {
1059                 #[cfg(feature = "tokio_base")]
1060                 handle.abort();
1061                 #[cfg(feature = "ylong_base")]
1062                 handle.cancel();
1063             }
1064         }
1065     }
1066 
1067     impl From<std::io::Error> for DispatchErrorKind {
from(value: std::io::Error) -> Self1068         fn from(value: std::io::Error) -> Self {
1069             DispatchErrorKind::Io(value.kind())
1070         }
1071     }
1072 
1073     impl From<H3Error> for DispatchErrorKind {
from(err: H3Error) -> Self1074         fn from(err: H3Error) -> Self {
1075             DispatchErrorKind::H3(err)
1076         }
1077     }
1078 
1079     impl From<quiche::Error> for DispatchErrorKind {
from(value: quiche::Error) -> Self1080         fn from(value: quiche::Error) -> Self {
1081             DispatchErrorKind::Quic(value)
1082         }
1083     }
1084 
dispatch_client_error(dispatch_error: DispatchErrorKind) -> HttpClientError1085     pub(crate) fn dispatch_client_error(dispatch_error: DispatchErrorKind) -> HttpClientError {
1086         match dispatch_error {
1087             DispatchErrorKind::H3(e) => HttpClientError::from_error(Request, HttpError::from(e)),
1088             DispatchErrorKind::Io(e) => {
1089                 HttpClientError::from_io_error(Request, std::io::Error::from(e))
1090             }
1091             DispatchErrorKind::ChannelClosed => {
1092                 HttpClientError::from_str(Request, "Coroutine channel closed.")
1093             }
1094             DispatchErrorKind::Quic(e) => HttpClientError::from_error(Request, e),
1095             DispatchErrorKind::GoawayReceived => {
1096                 HttpClientError::from_str(Request, "received remote goaway.")
1097             }
1098             DispatchErrorKind::StreamFinished => {
1099                 HttpClientError::from_str(Request, "stream finished.")
1100             }
1101             DispatchErrorKind::Disconnect => {
1102                 HttpClientError::from_str(Request, "remote peer closed.")
1103             }
1104         }
1105     }
1106 }
1107 
1108 #[cfg(test)]
1109 mod ut_dispatch {
1110     use crate::dispatcher::{ConnDispatcher, Dispatcher};
1111 
1112     /// UT test cases for `ConnDispatcher::is_shutdown`.
1113     ///
1114     /// # Brief
1115     /// 1. Creates a `ConnDispatcher`.
1116     /// 2. Calls `ConnDispatcher::is_shutdown` to get the result.
1117     /// 3. Calls `ConnDispatcher::dispatch` to get the result.
1118     /// 4. Checks if the result is false.
1119     #[test]
ut_is_shutdown()1120     fn ut_is_shutdown() {
1121         let conn = ConnDispatcher::http1(b"Data");
1122         let res = conn.is_shutdown();
1123         assert!(!res);
1124         let res = conn.dispatch();
1125         assert!(res.is_some());
1126     }
1127 }
1128