• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use core::fmt;
15 use core::marker::PhantomData;
16 use core::mem::ManuallyDrop;
17 use std::io::{self, Read, Write};
18 use std::panic::resume_unwind;
19 
20 use libc::c_int;
21 
22 use super::{InternalError, Ssl, SslError, SslErrorCode, SslRef};
23 use crate::c_openssl::bio::{self, get_error, get_panic, get_stream_mut, get_stream_ref};
24 use crate::c_openssl::error::ErrorStack;
25 use crate::c_openssl::ffi::ssl::{SSL_connect, SSL_set_bio, SSL_shutdown};
26 use crate::c_openssl::foreign::Foreign;
27 use crate::util::c_openssl::bio::BioMethod;
28 
29 /// A TLS session over a stream.
30 pub struct SslStream<S> {
31     pub(crate) ssl: ManuallyDrop<Ssl>,
32     method: ManuallyDrop<BioMethod>,
33     p: PhantomData<S>,
34 }
35 
36 impl<S> SslStream<S> {
get_error(&mut self, err: c_int) -> SslError37     pub(crate) fn get_error(&mut self, err: c_int) -> SslError {
38         self.check_panic();
39         let code = self.ssl.get_error(err);
40         let internal = match code {
41             SslErrorCode::SSL => {
42                 let e = ErrorStack::get();
43                 Some(InternalError::Ssl(e))
44             }
45             SslErrorCode::SYSCALL => {
46                 let error = ErrorStack::get();
47                 if error.errors().is_empty() {
48                     self.get_bio_error().map(InternalError::Io)
49                 } else {
50                     Some(InternalError::Ssl(error))
51                 }
52             }
53             SslErrorCode::WANT_WRITE | SslErrorCode::WANT_READ => {
54                 self.get_bio_error().map(InternalError::Io)
55             }
56             _ => None,
57         };
58         SslError { code, internal }
59     }
60 
check_panic(&mut self)61     fn check_panic(&mut self) {
62         if let Some(err) = unsafe { get_panic::<S>(self.ssl.get_raw_bio()) } {
63             resume_unwind(err)
64         }
65     }
66 
get_bio_error(&mut self) -> Option<io::Error>67     fn get_bio_error(&mut self) -> Option<io::Error> {
68         unsafe { get_error::<S>(self.ssl.get_raw_bio()) }
69     }
70 
get_ref(&self) -> &S71     pub(crate) fn get_ref(&self) -> &S {
72         unsafe {
73             let bio = self.ssl.get_raw_bio();
74             get_stream_ref(bio)
75         }
76     }
77 
get_mut(&mut self) -> &mut S78     pub(crate) fn get_mut(&mut self) -> &mut S {
79         unsafe {
80             let bio = self.ssl.get_raw_bio();
81             get_stream_mut(bio)
82         }
83     }
84 
ssl(&self) -> &SslRef85     pub(crate) fn ssl(&self) -> &SslRef {
86         &self.ssl
87     }
88 }
89 
90 impl<S> fmt::Debug for SslStream<S>
91 where
92     S: fmt::Debug,
93 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result94     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95         write!(f, "stream[{:?}], {:?}", &self.get_ref(), &self.ssl())
96     }
97 }
98 
99 impl<S> Drop for SslStream<S> {
drop(&mut self)100     fn drop(&mut self) {
101         unsafe {
102             ManuallyDrop::drop(&mut self.ssl);
103             ManuallyDrop::drop(&mut self.method);
104         }
105     }
106 }
107 
108 impl<S: Read + Write> SslStream<S> {
ssl_read(&mut self, buf: &mut [u8]) -> Result<usize, SslError>109     pub(crate) fn ssl_read(&mut self, buf: &mut [u8]) -> Result<usize, SslError> {
110         if buf.is_empty() {
111             return Ok(0);
112         }
113         let ret = self.ssl.read(buf);
114         if ret > 0 {
115             Ok(ret as usize)
116         } else {
117             Err(self.get_error(ret))
118         }
119     }
120 
ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError>121     pub(crate) fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError> {
122         if buf.is_empty() {
123             return Ok(0);
124         }
125         let ret = self.ssl.write(buf);
126         if ret > 0 {
127             Ok(ret as usize)
128         } else {
129             Err(self.get_error(ret))
130         }
131     }
132 
new_base(ssl: Ssl, stream: S) -> Result<Self, ErrorStack>133     pub(crate) fn new_base(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
134         unsafe {
135             let (bio, method) = bio::new(stream)?;
136             SSL_set_bio(ssl.as_ptr(), bio, bio);
137 
138             Ok(SslStream {
139                 ssl: ManuallyDrop::new(ssl),
140                 method: ManuallyDrop::new(method),
141                 p: PhantomData,
142             })
143         }
144     }
145 
connect(&mut self) -> Result<(), SslError>146     pub(crate) fn connect(&mut self) -> Result<(), SslError> {
147         let ret = unsafe { SSL_connect(self.ssl.as_ptr()) };
148         if ret > 0 {
149             Ok(())
150         } else {
151             Err(self.get_error(ret))
152         }
153     }
154 
shutdown(&mut self) -> Result<ShutdownResult, SslError>155     pub(crate) fn shutdown(&mut self) -> Result<ShutdownResult, SslError> {
156         unsafe {
157             match SSL_shutdown(self.ssl.as_ptr()) {
158                 0 => Ok(ShutdownResult::Sent),
159                 1 => Ok(ShutdownResult::Received),
160                 n => Err(self.get_error(n)),
161             }
162         }
163     }
164 }
165 
166 impl<S: Read + Write> Read for SslStream<S> {
167     // ssl_read
read(&mut self, buf: &mut [u8]) -> io::Result<usize>168     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
169         loop {
170             match self.ssl_read(buf) {
171                 Ok(n) => return Ok(n),
172                 // The TLS/SSL peer has closed the connection for writing by sending
173                 // the close_notify alert. No more data can be read.
174                 // Does not necessarily indicate that the underlying transport has been closed.
175                 Err(ref e) if e.code == SslErrorCode::ZERO_RETURN => return Ok(0),
176                 // A non-recoverable, fatal error in the SSL library occurred, usually a protocol
177                 // error.
178                 Err(ref e) if e.code == SslErrorCode::SYSCALL && e.get_io_error().is_none() => {
179                     return Ok(0)
180                 }
181                 // When the last operation was a read operation from a nonblocking BIO.
182                 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
183                 // Other error.
184                 Err(err) => {
185                     return Err(err
186                         .into_io_error()
187                         .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)))
188                 }
189             };
190         }
191     }
192 }
193 
194 impl<S: Read + Write> Write for SslStream<S> {
195     // ssl_write
write(&mut self, buf: &[u8]) -> io::Result<usize>196     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
197         loop {
198             match self.ssl_write(buf) {
199                 Ok(n) => return Ok(n),
200                 // When the last operation was a read operation from a nonblocking BIO.
201                 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
202                 Err(err) => {
203                     return Err(err
204                         .into_io_error()
205                         .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)));
206                 }
207             }
208         }
209     }
210 
211     // S.flush()
flush(&mut self) -> io::Result<()>212     fn flush(&mut self) -> io::Result<()> {
213         self.get_mut().flush()
214     }
215 }
216 
217 /// An SSL stream midway through the handshake process.
218 #[derive(Debug)]
219 pub(crate) struct MidHandshakeSslStream<S> {
220     pub(crate) _stream: SslStream<S>,
221     pub(crate) error: SslError,
222 }
223 
224 impl<S> MidHandshakeSslStream<S> {
error(&self) -> &SslError225     pub(crate) fn error(&self) -> &SslError {
226         &self.error
227     }
228 }
229 
230 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
231 pub(crate) enum ShutdownResult {
232     Sent,
233     Received,
234 }
235