• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 //! Frame recv coroutine.
15 
16 use std::future::Future;
17 use std::pin::Pin;
18 use std::sync::{Arc, Mutex};
19 use std::task::{Context, Poll};
20 
21 use ylong_http::h2::{
22     ErrorCode, Frame, FrameDecoder, FrameKind, FramesIntoIter, H2Error, Payload, Setting,
23 };
24 
25 use crate::runtime::{AsyncRead, BoundedSender, ReadBuf, ReadHalf, SendError};
26 use crate::util::dispatcher::http2::{
27     DispatchErrorKind, OutputMessage, SettingsState, SettingsSync,
28 };
29 
30 pub(crate) type OutputSendFut =
31     Pin<Box<dyn Future<Output = Result<(), SendError<OutputMessage>>> + Send + Sync>>;
32 
33 #[derive(Copy, Clone)]
34 enum DecodeState {
35     Read,
36     Send,
37     Exit(DispatchErrorKind),
38 }
39 
40 pub(crate) struct RecvData<S> {
41     decoder: FrameDecoder,
42     settings: Arc<Mutex<SettingsSync>>,
43     reader: ReadHalf<S>,
44     state: DecodeState,
45     next_state: DecodeState,
46     resp_tx: BoundedSender<OutputMessage>,
47     curr_message: Option<OutputSendFut>,
48     pending_iter: Option<FramesIntoIter>,
49 }
50 
51 impl<S: AsyncRead + Unpin + Sync + Send + 'static> Future for RecvData<S> {
52     type Output = Result<(), DispatchErrorKind>;
53 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>54     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
55         let receiver = self.get_mut();
56         receiver.poll_read_frame(cx)
57     }
58 }
59 
60 impl<S: AsyncRead + Unpin + Sync + Send + 'static> RecvData<S> {
new( decoder: FrameDecoder, settings: Arc<Mutex<SettingsSync>>, reader: ReadHalf<S>, resp_tx: BoundedSender<OutputMessage>, ) -> Self61     pub(crate) fn new(
62         decoder: FrameDecoder,
63         settings: Arc<Mutex<SettingsSync>>,
64         reader: ReadHalf<S>,
65         resp_tx: BoundedSender<OutputMessage>,
66     ) -> Self {
67         Self {
68             decoder,
69             settings,
70             reader,
71             state: DecodeState::Read,
72             next_state: DecodeState::Read,
73             resp_tx,
74             curr_message: None,
75             pending_iter: None,
76         }
77     }
78 
poll_read_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>>79     fn poll_read_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>> {
80         let mut buf = [0u8; 1024];
81         loop {
82             match self.state {
83                 DecodeState::Read => {
84                     let mut read_buf = ReadBuf::new(&mut buf);
85                     match Pin::new(&mut self.reader).poll_read(cx, &mut read_buf) {
86                         Poll::Ready(Err(e)) => {
87                             return self.transmit_error(cx, e.into());
88                         }
89                         Poll::Ready(Ok(())) => {}
90                         Poll::Pending => {
91                             return Poll::Pending;
92                         }
93                     }
94                     let read = read_buf.filled().len();
95                     if read == 0 {
96                         let _ = self.transmit_message(
97                             cx,
98                             OutputMessage::OutputExit(DispatchErrorKind::Disconnect),
99                         );
100                         self.state = DecodeState::Send;
101                         return Poll::Pending;
102                     }
103 
104                     match self.decoder.decode(&buf[..read]) {
105                         Ok(frames) => match self.poll_iterator_frames(cx, frames.into_iter()) {
106                             Poll::Ready(Ok(_)) => {}
107                             Poll::Ready(Err(e)) => {
108                                 return Poll::Ready(Err(e));
109                             }
110                             Poll::Pending => {
111                                 self.next_state = DecodeState::Read;
112                             }
113                         },
114                         Err(e) => {
115                             match self.transmit_message(cx, OutputMessage::OutputExit(e.into())) {
116                                 Poll::Ready(Err(_)) => {
117                                     return Poll::Ready(Err(DispatchErrorKind::ChannelClosed))
118                                 }
119                                 Poll::Ready(Ok(_)) => {}
120                                 Poll::Pending => {
121                                     self.next_state = DecodeState::Read;
122                                     return Poll::Pending;
123                                 }
124                             }
125                         }
126                     }
127                 }
128                 DecodeState::Send => {
129                     match self.poll_blocked_task(cx) {
130                         Poll::Ready(Ok(_)) => {
131                             self.state = self.next_state;
132                             // Reset next state.
133                             self.next_state = DecodeState::Read;
134                         }
135                         Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
136                         Poll::Pending => return Poll::Pending,
137                     }
138                 }
139                 DecodeState::Exit(e) => {
140                     return Poll::Ready(Err(e));
141                 }
142             }
143         }
144     }
145 
poll_blocked_task(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>>146     fn poll_blocked_task(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>> {
147         if let Some(mut task) = self.curr_message.take() {
148             match task.as_mut().poll(cx) {
149                 Poll::Ready(Ok(_)) => {}
150                 Poll::Ready(Err(_)) => {
151                     return Poll::Ready(Err(DispatchErrorKind::ChannelClosed));
152                 }
153                 Poll::Pending => {
154                     self.curr_message = Some(task);
155                     return Poll::Pending;
156                 }
157             }
158         }
159 
160         if let Some(iter) = self.pending_iter.take() {
161             return self.poll_iterator_frames(cx, iter);
162         }
163         Poll::Ready(Ok(()))
164     }
165 
poll_iterator_frames( &mut self, cx: &mut Context<'_>, mut iter: FramesIntoIter, ) -> Poll<Result<(), DispatchErrorKind>>166     fn poll_iterator_frames(
167         &mut self,
168         cx: &mut Context<'_>,
169         mut iter: FramesIntoIter,
170     ) -> Poll<Result<(), DispatchErrorKind>> {
171         while let Some(kind) = iter.next() {
172             match kind {
173                 FrameKind::Complete(frame) => {
174                     // TODO Whether to continue processing the remaining frames after connection
175                     // error occurs in the Settings frame.
176                     let message = if let Err(e) = self.update_settings(&frame) {
177                         OutputMessage::OutputExit(DispatchErrorKind::H2(e))
178                     } else {
179                         OutputMessage::Output(frame)
180                     };
181 
182                     match self.transmit_message(cx, message) {
183                         Poll::Ready(Ok(_)) => {}
184                         Poll::Ready(Err(e)) => {
185                             return Poll::Ready(Err(e));
186                         }
187                         Poll::Pending => {
188                             self.pending_iter = Some(iter);
189                             return Poll::Pending;
190                         }
191                     }
192                 }
193                 FrameKind::Partial => {}
194             }
195         }
196         Poll::Ready(Ok(()))
197     }
198 
transmit_error( &mut self, cx: &mut Context<'_>, exit_err: DispatchErrorKind, ) -> Poll<Result<(), DispatchErrorKind>>199     fn transmit_error(
200         &mut self,
201         cx: &mut Context<'_>,
202         exit_err: DispatchErrorKind,
203     ) -> Poll<Result<(), DispatchErrorKind>> {
204         match self.transmit_message(cx, OutputMessage::OutputExit(exit_err)) {
205             Poll::Ready(_) => Poll::Ready(Err(exit_err)),
206             Poll::Pending => {
207                 self.next_state = DecodeState::Exit(exit_err);
208                 Poll::Pending
209             }
210         }
211     }
212 
transmit_message( &mut self, cx: &mut Context<'_>, message: OutputMessage, ) -> Poll<Result<(), DispatchErrorKind>>213     fn transmit_message(
214         &mut self,
215         cx: &mut Context<'_>,
216         message: OutputMessage,
217     ) -> Poll<Result<(), DispatchErrorKind>> {
218         let mut task = {
219             let sender = self.resp_tx.clone();
220             let ft = async move { sender.send(message).await };
221             Box::pin(ft)
222         };
223 
224         match task.as_mut().poll(cx) {
225             Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
226             // The current coroutine sending the request exited prematurely.
227             Poll::Ready(Err(_)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)),
228             Poll::Pending => {
229                 self.state = DecodeState::Send;
230                 self.curr_message = Some(task);
231                 Poll::Pending
232             }
233         }
234     }
235 
update_settings(&mut self, frame: &Frame) -> Result<(), H2Error>236     fn update_settings(&mut self, frame: &Frame) -> Result<(), H2Error> {
237         if let Payload::Settings(_settings) = frame.payload() {
238             if frame.flags().is_ack() {
239                 self.update_decoder_settings()?;
240             }
241         }
242         Ok(())
243     }
244 
update_decoder_settings(&mut self) -> Result<(), H2Error>245     fn update_decoder_settings(&mut self) -> Result<(), H2Error> {
246         let connection = self.settings.lock().unwrap();
247         match &connection.settings {
248             SettingsState::Acknowledging(settings) => {
249                 for setting in settings.get_settings() {
250                     if let Setting::MaxHeaderListSize(size) = setting {
251                         self.decoder.set_max_header_list_size(*size as usize);
252                     }
253                     if let Setting::MaxFrameSize(size) = setting {
254                         self.decoder.set_max_frame_size(*size)?;
255                     }
256                 }
257                 Ok(())
258             }
259             SettingsState::Synced => Err(H2Error::ConnectionError(ErrorCode::ConnectError)),
260         }
261     }
262 }
263