• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use core::pin::Pin;
15 use core::task::{Context, Poll};
16 use core::{future, ptr, slice};
17 use std::io::{self, Read, Write};
18 
19 use crate::async_impl::ssl_stream::{check_io_to_poll, Wrapper};
20 use crate::c_openssl::verify::PinsVerifyInfo;
21 use crate::runtime::{AsyncRead, AsyncWrite, ReadBuf};
22 use crate::util::c_openssl::error::ErrorStack;
23 use crate::util::c_openssl::ssl::{self, ShutdownResult, Ssl, SslErrorCode};
24 
25 /// An asynchronous version of [`openssl::ssl::SslStream`].
26 #[derive(Debug)]
27 pub struct AsyncSslStream<S>(ssl::SslStream<Wrapper<S>>);
28 
29 impl<S> AsyncSslStream<S> {
with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R where F: FnOnce(&mut ssl::SslStream<Wrapper<S>>) -> R,30     fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
31     where
32         F: FnOnce(&mut ssl::SslStream<Wrapper<S>>) -> R,
33     {
34         // SAFETY: must guarantee that you will never move the data out of the
35         // mutable reference you receive.
36         let this = unsafe { self.get_unchecked_mut() };
37 
38         // sets context, SslStream to R, reset 0.
39         this.0.get_mut().context = ctx as *mut _ as *mut ();
40         let r = f(&mut this.0);
41         this.0.get_mut().context = ptr::null_mut();
42         r
43     }
44 
45     /// Returns a pinned mutable reference to the underlying stream.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S>46     fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
47         // SAFETY:
48         unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
49     }
50 
51     #[cfg(feature = "http2")]
negotiated_alpn_protocol(&self) -> Option<&[u8]>52     pub(crate) fn negotiated_alpn_protocol(&self) -> Option<&[u8]> {
53         self.0.ssl().negotiated_alpn_protocol()
54     }
55 }
56 
57 impl<S> AsyncSslStream<S>
58 where
59     S: AsyncRead + AsyncWrite,
60 {
61     /// Like [`SslStream::new`](ssl::SslStream::new).
new( ssl: Ssl, stream: S, pinned_pubkey: Option<PinsVerifyInfo>, ) -> Result<Self, ErrorStack>62     pub(crate) fn new(
63         ssl: Ssl,
64         stream: S,
65         pinned_pubkey: Option<PinsVerifyInfo>,
66     ) -> Result<Self, ErrorStack> {
67         // This corresponds to `SSL_set_bio`.
68         ssl::SslStream::new_base(
69             ssl,
70             Wrapper {
71                 stream,
72                 context: ptr::null_mut(),
73             },
74             pinned_pubkey,
75         )
76         .map(AsyncSslStream)
77     }
78 
79     /// Like [`SslStream::connect`](ssl::SslStream::connect).
poll_connect(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::SslError>>80     fn poll_connect(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::SslError>> {
81         self.with_context(cx, |s| check_result_to_poll(s.connect()))
82     }
83 
84     /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
connect(mut self: Pin<&mut Self>) -> Result<(), ssl::SslError>85     pub(crate) async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::SslError> {
86         future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
87     }
88 }
89 
90 impl<S> AsyncRead for AsyncSslStream<S>
91 where
92     S: AsyncRead + AsyncWrite,
93 {
94     // wrap read.
poll_read( self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>95     fn poll_read(
96         self: Pin<&mut Self>,
97         ctx: &mut Context<'_>,
98         buf: &mut ReadBuf<'_>,
99     ) -> Poll<io::Result<()>> {
100         // set async func
101         self.with_context(ctx, |s| {
102             let slice = unsafe {
103                 let buf = buf.unfilled_mut();
104                 slice::from_raw_parts_mut(buf.as_mut_ptr().cast::<u8>(), buf.len())
105             };
106             match check_io_to_poll(s.read(slice))? {
107                 Poll::Ready(len) => {
108                     #[cfg(feature = "tokio_base")]
109                     unsafe {
110                         buf.assume_init(len);
111                     }
112                     #[cfg(feature = "ylong_base")]
113                     buf.assume_init(len);
114 
115                     buf.advance(len);
116                     Poll::Ready(Ok(()))
117                 }
118                 Poll::Pending => Poll::Pending,
119             }
120         })
121     }
122 }
123 
124 impl<S> AsyncWrite for AsyncSslStream<S>
125 where
126     S: AsyncRead + AsyncWrite,
127 {
poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>>128     fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
129         self.with_context(ctx, |s| check_io_to_poll(s.write(buf)))
130     }
131 
poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>>132     fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
133         self.with_context(ctx, |s| check_io_to_poll(s.flush()))
134     }
135 
poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>>136     fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
137         // Shuts down the session.
138         match self.as_mut().with_context(ctx, |s| s.shutdown()) {
139             // Sends a close notify message to the peer, after which `ShutdownResult::Sent` is
140             // returned. Awaits the receipt of a close notify message from the peer,
141             // after which `ShutdownResult::Received` is returned.
142             Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
143             // The SSL session has been closed.
144             Err(ref e) if e.code() == SslErrorCode::ZERO_RETURN => {}
145             // When the underlying BIO could not satisfy the needs of SSL_shutdown() to continue the
146             // handshake
147             Err(ref e)
148                 if e.code() == SslErrorCode::WANT_READ || e.code() == SslErrorCode::WANT_WRITE =>
149             {
150                 return Poll::Pending;
151             }
152             // Really error.
153             Err(e) => {
154                 return Poll::Ready(Err(e
155                     .into_io_error()
156                     .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
157             }
158         }
159         // Returns success when the I/O connection has completely shut down.
160         self.get_pin_mut().poll_shutdown(ctx)
161     }
162 }
163 
164 /// Checks `ssl::Error`.
check_result_to_poll<T>(r: Result<T, ssl::SslError>) -> Poll<Result<T, ssl::SslError>>165 fn check_result_to_poll<T>(r: Result<T, ssl::SslError>) -> Poll<Result<T, ssl::SslError>> {
166     match r {
167         Ok(t) => Poll::Ready(Ok(t)),
168         Err(e) => match e.code() {
169             SslErrorCode::WANT_READ | SslErrorCode::WANT_WRITE => Poll::Pending,
170             _ => Poll::Ready(Err(e)),
171         },
172     }
173 }
174