1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use core::fmt; 15 use core::marker::PhantomData; 16 use core::mem::ManuallyDrop; 17 use std::io::{self, Read, Write}; 18 use std::panic::resume_unwind; 19 20 use libc::c_int; 21 22 use super::{InternalError, Ssl, SslError, SslErrorCode, SslRef}; 23 use crate::c_openssl::bio::{self, get_error, get_panic, get_stream_mut, get_stream_ref}; 24 use crate::c_openssl::error::ErrorStack; 25 use crate::c_openssl::ffi::ssl::{SSL_connect, SSL_set_bio, SSL_shutdown}; 26 use crate::c_openssl::foreign::Foreign; 27 use crate::util::c_openssl::bio::BioMethod; 28 29 /// A TLS session over a stream. 30 pub struct SslStream<S> { 31 pub(crate) ssl: ManuallyDrop<Ssl>, 32 method: ManuallyDrop<BioMethod>, 33 p: PhantomData<S>, 34 } 35 36 impl<S> SslStream<S> { get_error(&mut self, err: c_int) -> SslError37 pub(crate) fn get_error(&mut self, err: c_int) -> SslError { 38 self.check_panic(); 39 let code = self.ssl.get_error(err); 40 let internal = match code { 41 SslErrorCode::SSL => { 42 let e = ErrorStack::get(); 43 Some(InternalError::Ssl(e)) 44 } 45 SslErrorCode::SYSCALL => { 46 let error = ErrorStack::get(); 47 if error.errors().is_empty() { 48 self.get_bio_error().map(InternalError::Io) 49 } else { 50 Some(InternalError::Ssl(error)) 51 } 52 } 53 SslErrorCode::WANT_WRITE | SslErrorCode::WANT_READ => { 54 self.get_bio_error().map(InternalError::Io) 55 } 56 _ => None, 57 }; 58 SslError { code, internal } 59 } 60 check_panic(&mut self)61 fn check_panic(&mut self) { 62 if let Some(err) = unsafe { get_panic::<S>(self.ssl.get_raw_bio()) } { 63 resume_unwind(err) 64 } 65 } 66 get_bio_error(&mut self) -> Option<io::Error>67 fn get_bio_error(&mut self) -> Option<io::Error> { 68 unsafe { get_error::<S>(self.ssl.get_raw_bio()) } 69 } 70 get_ref(&self) -> &S71 pub(crate) fn get_ref(&self) -> &S { 72 unsafe { 73 let bio = self.ssl.get_raw_bio(); 74 get_stream_ref(bio) 75 } 76 } 77 get_mut(&mut self) -> &mut S78 pub(crate) fn get_mut(&mut self) -> &mut S { 79 unsafe { 80 let bio = self.ssl.get_raw_bio(); 81 get_stream_mut(bio) 82 } 83 } 84 ssl(&self) -> &SslRef85 pub(crate) fn ssl(&self) -> &SslRef { 86 &self.ssl 87 } 88 } 89 90 impl<S> fmt::Debug for SslStream<S> 91 where 92 S: fmt::Debug, 93 { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 95 write!(f, "stream[{:?}], {:?}", &self.get_ref(), &self.ssl()) 96 } 97 } 98 99 impl<S> Drop for SslStream<S> { drop(&mut self)100 fn drop(&mut self) { 101 unsafe { 102 ManuallyDrop::drop(&mut self.ssl); 103 ManuallyDrop::drop(&mut self.method); 104 } 105 } 106 } 107 108 impl<S: Read + Write> SslStream<S> { ssl_read(&mut self, buf: &mut [u8]) -> Result<usize, SslError>109 pub(crate) fn ssl_read(&mut self, buf: &mut [u8]) -> Result<usize, SslError> { 110 if buf.is_empty() { 111 return Ok(0); 112 } 113 let ret = self.ssl.read(buf); 114 if ret > 0 { 115 Ok(ret as usize) 116 } else { 117 Err(self.get_error(ret)) 118 } 119 } 120 ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError>121 pub(crate) fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError> { 122 if buf.is_empty() { 123 return Ok(0); 124 } 125 let ret = self.ssl.write(buf); 126 if ret > 0 { 127 Ok(ret as usize) 128 } else { 129 Err(self.get_error(ret)) 130 } 131 } 132 new_base(ssl: Ssl, stream: S) -> Result<Self, ErrorStack>133 pub(crate) fn new_base(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> { 134 unsafe { 135 let (bio, method) = bio::new(stream)?; 136 SSL_set_bio(ssl.as_ptr(), bio, bio); 137 138 Ok(SslStream { 139 ssl: ManuallyDrop::new(ssl), 140 method: ManuallyDrop::new(method), 141 p: PhantomData, 142 }) 143 } 144 } 145 connect(&mut self) -> Result<(), SslError>146 pub(crate) fn connect(&mut self) -> Result<(), SslError> { 147 let ret = unsafe { SSL_connect(self.ssl.as_ptr()) }; 148 if ret > 0 { 149 Ok(()) 150 } else { 151 Err(self.get_error(ret)) 152 } 153 } 154 shutdown(&mut self) -> Result<ShutdownResult, SslError>155 pub(crate) fn shutdown(&mut self) -> Result<ShutdownResult, SslError> { 156 unsafe { 157 match SSL_shutdown(self.ssl.as_ptr()) { 158 0 => Ok(ShutdownResult::Sent), 159 1 => Ok(ShutdownResult::Received), 160 n => Err(self.get_error(n)), 161 } 162 } 163 } 164 } 165 166 impl<S: Read + Write> Read for SslStream<S> { 167 // ssl_read read(&mut self, buf: &mut [u8]) -> io::Result<usize>168 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 169 loop { 170 match self.ssl_read(buf) { 171 Ok(n) => return Ok(n), 172 // The TLS/SSL peer has closed the connection for writing by sending 173 // the close_notify alert. No more data can be read. 174 // Does not necessarily indicate that the underlying transport has been closed. 175 Err(ref e) if e.code == SslErrorCode::ZERO_RETURN => return Ok(0), 176 // A non-recoverable, fatal error in the SSL library occurred, usually a protocol 177 // error. 178 Err(ref e) if e.code == SslErrorCode::SYSCALL && e.get_io_error().is_none() => { 179 return Ok(0) 180 } 181 // When the last operation was a read operation from a nonblocking BIO. 182 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {} 183 // Other error. 184 Err(err) => { 185 return Err(err 186 .into_io_error() 187 .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err))) 188 } 189 }; 190 } 191 } 192 } 193 194 impl<S: Read + Write> Write for SslStream<S> { 195 // ssl_write write(&mut self, buf: &[u8]) -> io::Result<usize>196 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 197 loop { 198 match self.ssl_write(buf) { 199 Ok(n) => return Ok(n), 200 // When the last operation was a read operation from a nonblocking BIO. 201 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {} 202 Err(err) => { 203 return Err(err 204 .into_io_error() 205 .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err))); 206 } 207 } 208 } 209 } 210 211 // S.flush() flush(&mut self) -> io::Result<()>212 fn flush(&mut self) -> io::Result<()> { 213 self.get_mut().flush() 214 } 215 } 216 217 /// An SSL stream midway through the handshake process. 218 #[derive(Debug)] 219 pub(crate) struct MidHandshakeSslStream<S> { 220 pub(crate) _stream: SslStream<S>, 221 pub(crate) error: SslError, 222 } 223 224 impl<S> MidHandshakeSslStream<S> { error(&self) -> &SslError225 pub(crate) fn error(&self) -> &SslError { 226 &self.error 227 } 228 } 229 230 #[derive(Copy, Clone, Debug, PartialEq, Eq)] 231 pub(crate) enum ShutdownResult { 232 Sent, 233 Received, 234 } 235