• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use core::fmt;
15 use core::marker::PhantomData;
16 use core::mem::ManuallyDrop;
17 use std::io::{self, Read, Write};
18 use std::panic::resume_unwind;
19 use std::ptr;
20 
21 use libc::c_int;
22 
23 use super::{InternalError, Ssl, SslError, SslErrorCode, SslRef};
24 use crate::c_openssl::bio::{self, get_error, get_panic, get_stream_mut, get_stream_ref};
25 use crate::c_openssl::error::ErrorStack;
26 use crate::c_openssl::ffi::ssl::{SSL_connect, SSL_set_bio, SSL_shutdown};
27 use crate::c_openssl::foreign::Foreign;
28 use crate::c_openssl::verify::PinsVerifyInfo;
29 use crate::util::base64::encode;
30 use crate::util::c_openssl::bio::BioMethod;
31 use crate::util::c_openssl::error::VerifyError;
32 use crate::util::c_openssl::error::VerifyKind::PubKeyPinning;
33 use crate::util::c_openssl::ffi::ssl::SSL;
34 use crate::util::c_openssl::ffi::x509::{i2d_X509_PUBKEY, X509_free, X509_get_X509_PUBKEY, C_X509};
35 use crate::util::c_openssl::verify::sha256_digest;
36 
37 /// A TLS session over a stream.
38 pub struct SslStream<S> {
39     pub(crate) ssl: ManuallyDrop<Ssl>,
40     method: ManuallyDrop<BioMethod>,
41     pinned_pubkey: Option<PinsVerifyInfo>,
42     p: PhantomData<S>,
43 }
44 
45 impl<S> SslStream<S> {
get_error(&mut self, err: c_int) -> SslError46     pub(crate) fn get_error(&mut self, err: c_int) -> SslError {
47         self.check_panic();
48         let code = self.ssl.get_error(err);
49         let internal = match code {
50             SslErrorCode::SSL => {
51                 let e = ErrorStack::get();
52                 Some(InternalError::Ssl(e))
53             }
54             SslErrorCode::SYSCALL => {
55                 let error = ErrorStack::get();
56                 if error.errors().is_empty() {
57                     self.get_bio_error().map(InternalError::Io)
58                 } else {
59                     Some(InternalError::Ssl(error))
60                 }
61             }
62             SslErrorCode::WANT_WRITE | SslErrorCode::WANT_READ => {
63                 self.get_bio_error().map(InternalError::Io)
64             }
65             _ => None,
66         };
67         SslError { code, internal }
68     }
69 
check_panic(&mut self)70     fn check_panic(&mut self) {
71         if let Some(err) = unsafe { get_panic::<S>(self.ssl.get_raw_bio()) } {
72             resume_unwind(err)
73         }
74     }
75 
get_bio_error(&mut self) -> Option<io::Error>76     fn get_bio_error(&mut self) -> Option<io::Error> {
77         unsafe { get_error::<S>(self.ssl.get_raw_bio()) }
78     }
79 
get_ref(&self) -> &S80     pub(crate) fn get_ref(&self) -> &S {
81         unsafe {
82             let bio = self.ssl.get_raw_bio();
83             get_stream_ref(bio)
84         }
85     }
86 
get_mut(&mut self) -> &mut S87     pub(crate) fn get_mut(&mut self) -> &mut S {
88         unsafe {
89             let bio = self.ssl.get_raw_bio();
90             get_stream_mut(bio)
91         }
92     }
93 
ssl(&self) -> &SslRef94     pub(crate) fn ssl(&self) -> &SslRef {
95         &self.ssl
96     }
97 }
98 
99 impl<S> fmt::Debug for SslStream<S>
100 where
101     S: fmt::Debug,
102 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result103     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104         write!(f, "stream[{:?}], {:?}", &self.get_ref(), &self.ssl())
105     }
106 }
107 
108 impl<S> Drop for SslStream<S> {
drop(&mut self)109     fn drop(&mut self) {
110         unsafe {
111             ManuallyDrop::drop(&mut self.ssl);
112             ManuallyDrop::drop(&mut self.method);
113         }
114     }
115 }
116 
117 impl<S: Read + Write> SslStream<S> {
ssl_read(&mut self, buf: &[u8]) -> Result<usize, SslError>118     pub(crate) fn ssl_read(&mut self, buf: &[u8]) -> Result<usize, SslError> {
119         if buf.is_empty() {
120             return Ok(0);
121         }
122         let ret = self.ssl.read(buf);
123         if ret > 0 {
124             Ok(ret as usize)
125         } else {
126             Err(self.get_error(ret))
127         }
128     }
129 
ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError>130     pub(crate) fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError> {
131         if buf.is_empty() {
132             return Ok(0);
133         }
134         let ret = self.ssl.write(buf);
135         if ret > 0 {
136             Ok(ret as usize)
137         } else {
138             Err(self.get_error(ret))
139         }
140     }
141 
new_base( ssl: Ssl, stream: S, pinned_pubkey: Option<PinsVerifyInfo>, ) -> Result<Self, ErrorStack>142     pub(crate) fn new_base(
143         ssl: Ssl,
144         stream: S,
145         pinned_pubkey: Option<PinsVerifyInfo>,
146     ) -> Result<Self, ErrorStack> {
147         unsafe {
148             let (bio, method) = bio::new(stream)?;
149             SSL_set_bio(ssl.as_ptr(), bio, bio);
150 
151             Ok(SslStream {
152                 ssl: ManuallyDrop::new(ssl),
153                 method: ManuallyDrop::new(method),
154                 pinned_pubkey,
155                 p: PhantomData,
156             })
157         }
158     }
159 
connect(&mut self) -> Result<(), SslError>160     pub(crate) fn connect(&mut self) -> Result<(), SslError> {
161         let ret = unsafe { SSL_connect(self.ssl.as_ptr()) };
162         if ret <= 0 {
163             return Err(self.get_error(ret));
164         }
165 
166         if let Some(pins_info) = &self.pinned_pubkey {
167             if pins_info.is_root() {
168                 verify_server_root_cert(self.ssl.as_ptr(), pins_info.get_digest())?;
169             } else {
170                 verify_server_cert(self.ssl.as_ptr(), pins_info.get_digest())?;
171             }
172         }
173         Ok(())
174     }
175 
shutdown(&mut self) -> Result<ShutdownResult, SslError>176     pub(crate) fn shutdown(&mut self) -> Result<ShutdownResult, SslError> {
177         unsafe {
178             match SSL_shutdown(self.ssl.as_ptr()) {
179                 0 => Ok(ShutdownResult::Sent),
180                 1 => Ok(ShutdownResult::Received),
181                 n => Err(self.get_error(n)),
182             }
183         }
184     }
185 }
186 
187 impl<S: Read + Write> Read for SslStream<S> {
188     // ssl_read
read(&mut self, buf: &mut [u8]) -> io::Result<usize>189     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
190         loop {
191             match self.ssl_read(buf) {
192                 Ok(n) => return Ok(n),
193                 // The TLS/SSL peer has closed the connection for writing by sending
194                 // the close_notify alert. No more data can be read.
195                 // Does not necessarily indicate that the underlying transport has been closed.
196                 Err(ref e) if e.code == SslErrorCode::ZERO_RETURN => return Ok(0),
197                 // A non-recoverable, fatal error in the SSL library occurred, usually a protocol
198                 // error.
199                 Err(ref e) if e.code == SslErrorCode::SYSCALL && e.get_io_error().is_none() => {
200                     return Ok(0)
201                 }
202                 // When the last operation was a read operation from a nonblocking BIO.
203                 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
204                 // Other error.
205                 Err(err) => {
206                     return Err(err
207                         .into_io_error()
208                         .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)))
209                 }
210             };
211         }
212     }
213 }
214 
215 impl<S: Read + Write> Write for SslStream<S> {
216     // ssl_write
write(&mut self, buf: &[u8]) -> io::Result<usize>217     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
218         loop {
219             match self.ssl_write(buf) {
220                 Ok(n) => return Ok(n),
221                 // When the last operation was a read operation from a nonblocking BIO.
222                 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
223                 Err(err) => {
224                     return Err(err
225                         .into_io_error()
226                         .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)));
227                 }
228             }
229         }
230     }
231 
232     // S.flush()
flush(&mut self) -> io::Result<()>233     fn flush(&mut self) -> io::Result<()> {
234         self.get_mut().flush()
235     }
236 }
237 
238 /// An SSL stream midway through the handshake process.
239 #[derive(Debug)]
240 pub(crate) struct MidHandshakeSslStream<S> {
241     pub(crate) _stream: SslStream<S>,
242     pub(crate) error: SslError,
243 }
244 
245 impl<S> MidHandshakeSslStream<S> {
error(&self) -> &SslError246     pub(crate) fn error(&self) -> &SslError {
247         &self.error
248     }
249 }
250 
251 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
252 pub(crate) enum ShutdownResult {
253     Sent,
254     Received,
255 }
256 
verify_server_root_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError>257 pub(crate) fn verify_server_root_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> {
258     use crate::c_openssl::ffi::ssl::{SSL_get_peer_cert_chain, X509_chain_up_ref};
259     use crate::c_openssl::stack::Stack;
260     use crate::c_openssl::x509::X509;
261 
262     let cert_chain = unsafe { X509_chain_up_ref(SSL_get_peer_cert_chain(ssl)) };
263     if cert_chain.is_null() {
264         return Err(SslError {
265             code: SslErrorCode::SSL,
266             internal: Some(InternalError::Ssl(ErrorStack::get())),
267         });
268     }
269 
270     let cert_chain: Stack<X509> = Stack::from_ptr(cert_chain);
271     let root_certificate = cert_chain.into_iter().last().ok_or_else(|| SslError {
272         code: SslErrorCode::SSL,
273         internal: Some(InternalError::Ssl(ErrorStack::get())),
274     })?;
275 
276     verify_pinned_pubkey(pinned_key, root_certificate.as_ptr())
277 }
278 
279 // TODO The SSLError thrown here is meaningless and has no information.
verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError>280 pub(crate) fn verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> {
281     #[cfg(feature = "c_openssl_3_0")]
282     use crate::util::c_openssl::ffi::ssl::SSL_get1_peer_certificate;
283     #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))]
284     use crate::util::c_openssl::ffi::ssl::SSL_get_peer_certificate;
285 
286     let certificate = unsafe {
287         #[cfg(feature = "c_openssl_3_0")]
288         {
289             SSL_get1_peer_certificate(ssl)
290         }
291         #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))]
292         {
293             SSL_get_peer_certificate(ssl)
294         }
295     };
296     if certificate.is_null() {
297         return Err(SslError {
298             code: SslErrorCode::SSL,
299             internal: Some(InternalError::Ssl(ErrorStack::get())),
300         });
301     }
302 
303     verify_pinned_pubkey(pinned_key, certificate)
304 }
305 
verify_pinned_pubkey(pinned_key: &str, certificate: *mut C_X509) -> Result<(), SslError>306 fn verify_pinned_pubkey(pinned_key: &str, certificate: *mut C_X509) -> Result<(), SslError> {
307     let pubkey = unsafe { X509_get_X509_PUBKEY(certificate) };
308     // Get the length of the serialized data
309     let buf_size = unsafe { i2d_X509_PUBKEY(pubkey, ptr::null_mut()) };
310 
311     if buf_size < 1 {
312         unsafe { X509_free(certificate) };
313         return Err(SslError {
314             code: SslErrorCode::SSL,
315             internal: Some(InternalError::Ssl(ErrorStack::get())),
316         });
317     }
318     let key = vec![0u8; buf_size as usize];
319     // The actual serialization
320     let serialized_data_size = unsafe { i2d_X509_PUBKEY(pubkey, &mut key.as_ptr()) };
321 
322     if buf_size != serialized_data_size || serialized_data_size <= 0 {
323         unsafe { X509_free(certificate) };
324         return Err(SslError {
325             code: SslErrorCode::SSL,
326             internal: Some(InternalError::Ssl(ErrorStack::get())),
327         });
328     }
329 
330     // sha256 length.
331     let mut digest = [0u8; 32];
332     unsafe { sha256_digest(key.as_slice(), serialized_data_size, &mut digest)? }
333 
334     compare_pinned_digest(&digest, pinned_key.as_bytes(), certificate)
335 }
336 
compare_pinned_digest( digest: &[u8], pinned_key: &[u8], certificate: *mut C_X509, ) -> Result<(), SslError>337 fn compare_pinned_digest(
338     digest: &[u8],
339     pinned_key: &[u8],
340     certificate: *mut C_X509,
341 ) -> Result<(), SslError> {
342     let base64_digest = encode(digest);
343     let mut user_bytes = pinned_key;
344 
345     let mut begin;
346     let mut end;
347     let prefix = b"sha256//";
348     let suffix = b";sha256//";
349     while !user_bytes.is_empty() {
350         begin = match user_bytes
351             .windows(prefix.len())
352             .position(|window| window == prefix)
353         {
354             None => {
355                 break;
356             }
357             Some(index) => index + 8,
358         };
359         end = match user_bytes
360             .windows(suffix.len())
361             .position(|window| window == suffix)
362         {
363             None => user_bytes.len(),
364             Some(index) => index,
365         };
366 
367         let bytes = &user_bytes[begin..end];
368         if bytes.eq(base64_digest.as_slice()) {
369             unsafe { X509_free(certificate) };
370             return Ok(());
371         }
372 
373         if end != user_bytes.len() {
374             user_bytes = &user_bytes[end + 1..];
375         } else {
376             user_bytes = &user_bytes[end..];
377         }
378     }
379 
380     unsafe { X509_free(certificate) };
381     Err(SslError {
382         code: SslErrorCode::SSL,
383         internal: Some(InternalError::User(VerifyError::from_msg(
384             PubKeyPinning,
385             "Pinned public key verification failed.",
386         ))),
387     })
388 }
389