1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use core::pin::Pin;
15 use core::task::{Context, Poll};
16 use core::{future, ptr, slice};
17 use std::io::{self, Read, Write};
18
19 use crate::async_impl::ssl_stream::{check_io_to_poll, Wrapper};
20 use crate::runtime::{AsyncRead, AsyncWrite, ReadBuf};
21 use crate::util::c_openssl::error::ErrorStack;
22 use crate::util::c_openssl::ssl::{self, ShutdownResult, Ssl, SslErrorCode};
23
24 /// An asynchronous version of [`openssl::ssl::SslStream`].
25 #[derive(Debug)]
26 pub struct AsyncSslStream<S>(ssl::SslStream<Wrapper<S>>);
27
28 impl<S> AsyncSslStream<S> {
with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R where F: FnOnce(&mut ssl::SslStream<Wrapper<S>>) -> R,29 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
30 where
31 F: FnOnce(&mut ssl::SslStream<Wrapper<S>>) -> R,
32 {
33 // SAFETY: must guarantee that you will never move the data out of the
34 // mutable reference you receive.
35 let this = unsafe { self.get_unchecked_mut() };
36
37 // sets context, SslStream to R, reset 0.
38 this.0.get_mut().context = ctx as *mut _ as *mut ();
39 let r = f(&mut this.0);
40 this.0.get_mut().context = ptr::null_mut();
41 r
42 }
43
44 /// Returns a pinned mutable reference to the underlying stream.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S>45 fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
46 // SAFETY:
47 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
48 }
49
negotiated_alpn_protocol(&self) -> Option<&[u8]>50 pub(crate) fn negotiated_alpn_protocol(&self) -> Option<&[u8]> {
51 self.0.ssl().negotiated_alpn_protocol()
52 }
53 }
54
55 impl<S> AsyncSslStream<S>
56 where
57 S: AsyncRead + AsyncWrite,
58 {
59 /// Like [`SslStream::new`](ssl::SslStream::new).
new( ssl: Ssl, stream: S, pinned_pubkey: Option<String>, ) -> Result<Self, ErrorStack>60 pub(crate) fn new(
61 ssl: Ssl,
62 stream: S,
63 pinned_pubkey: Option<String>,
64 ) -> Result<Self, ErrorStack> {
65 // This corresponds to `SSL_set_bio`.
66 ssl::SslStream::new_base(
67 ssl,
68 Wrapper {
69 stream,
70 context: ptr::null_mut(),
71 },
72 pinned_pubkey,
73 )
74 .map(AsyncSslStream)
75 }
76
77 /// Like [`SslStream::connect`](ssl::SslStream::connect).
poll_connect(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::SslError>>78 fn poll_connect(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::SslError>> {
79 self.with_context(cx, |s| check_result_to_poll(s.connect()))
80 }
81
82 /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
connect(mut self: Pin<&mut Self>) -> Result<(), ssl::SslError>83 pub(crate) async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::SslError> {
84 future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
85 }
86 }
87
88 impl<S> AsyncRead for AsyncSslStream<S>
89 where
90 S: AsyncRead + AsyncWrite,
91 {
92 // wrap read.
poll_read( self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>93 fn poll_read(
94 self: Pin<&mut Self>,
95 ctx: &mut Context<'_>,
96 buf: &mut ReadBuf<'_>,
97 ) -> Poll<io::Result<()>> {
98 // set async func
99 self.with_context(ctx, |s| {
100 let slice = unsafe {
101 let buf = buf.unfilled_mut();
102 slice::from_raw_parts_mut(buf.as_mut_ptr().cast::<u8>(), buf.len())
103 };
104 match check_io_to_poll(s.read(slice))? {
105 Poll::Ready(len) => {
106 #[cfg(feature = "tokio_base")]
107 unsafe {
108 buf.assume_init(len);
109 }
110 #[cfg(feature = "ylong_base")]
111 buf.assume_init(len);
112
113 buf.advance(len);
114 Poll::Ready(Ok(()))
115 }
116 Poll::Pending => Poll::Pending,
117 }
118 })
119 }
120 }
121
122 impl<S> AsyncWrite for AsyncSslStream<S>
123 where
124 S: AsyncRead + AsyncWrite,
125 {
poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>>126 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
127 self.with_context(ctx, |s| check_io_to_poll(s.write(buf)))
128 }
129
poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>>130 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
131 self.with_context(ctx, |s| check_io_to_poll(s.flush()))
132 }
133
poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>>134 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
135 // Shuts down the session.
136 match self.as_mut().with_context(ctx, |s| s.shutdown()) {
137 // Sends a close notify message to the peer, after which `ShutdownResult::Sent` is
138 // returned. Awaits the receipt of a close notify message from the peer,
139 // after which `ShutdownResult::Received` is returned.
140 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
141 // The SSL session has been closed.
142 Err(ref e) if e.code() == SslErrorCode::ZERO_RETURN => {}
143 // When the underlying BIO could not satisfy the needs of SSL_shutdown() to continue the
144 // handshake
145 Err(ref e)
146 if e.code() == SslErrorCode::WANT_READ || e.code() == SslErrorCode::WANT_WRITE =>
147 {
148 return Poll::Pending;
149 }
150 // Really error.
151 Err(e) => {
152 return Poll::Ready(Err(e
153 .into_io_error()
154 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
155 }
156 }
157 // Returns success when the I/O connection has completely shut down.
158 self.get_pin_mut().poll_shutdown(ctx)
159 }
160 }
161
162 /// Checks `ssl::Error`.
check_result_to_poll<T>(r: Result<T, ssl::SslError>) -> Poll<Result<T, ssl::SslError>>163 fn check_result_to_poll<T>(r: Result<T, ssl::SslError>) -> Poll<Result<T, ssl::SslError>> {
164 match r {
165 Ok(t) => Poll::Ready(Ok(t)),
166 Err(e) => match e.code() {
167 SslErrorCode::WANT_READ | SslErrorCode::WANT_WRITE => Poll::Pending,
168 _ => Poll::Ready(Err(e)),
169 },
170 }
171 }
172