1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use core::fmt;
15 use core::marker::PhantomData;
16 use core::mem::ManuallyDrop;
17 use std::io::{self, Read, Write};
18 use std::panic::resume_unwind;
19 use std::ptr;
20
21 use libc::c_int;
22
23 use super::{InternalError, Ssl, SslError, SslErrorCode, SslRef};
24 use crate::c_openssl::bio::{self, get_error, get_panic, get_stream_mut, get_stream_ref};
25 use crate::c_openssl::error::ErrorStack;
26 use crate::c_openssl::ffi::ssl::{SSL_connect, SSL_set_bio, SSL_shutdown};
27 use crate::c_openssl::foreign::Foreign;
28 use crate::c_openssl::verify::PinsVerifyInfo;
29 use crate::util::base64::encode;
30 use crate::util::c_openssl::bio::BioMethod;
31 use crate::util::c_openssl::error::VerifyError;
32 use crate::util::c_openssl::error::VerifyKind::PubKeyPinning;
33 use crate::util::c_openssl::ffi::ssl::SSL;
34 use crate::util::c_openssl::ffi::x509::{i2d_X509_PUBKEY, X509_free, X509_get_X509_PUBKEY, C_X509};
35 use crate::util::c_openssl::verify::sha256_digest;
36
37 /// A TLS session over a stream.
38 pub struct SslStream<S> {
39 pub(crate) ssl: ManuallyDrop<Ssl>,
40 method: ManuallyDrop<BioMethod>,
41 pinned_pubkey: Option<PinsVerifyInfo>,
42 p: PhantomData<S>,
43 }
44
45 impl<S> SslStream<S> {
get_error(&mut self, err: c_int) -> SslError46 pub(crate) fn get_error(&mut self, err: c_int) -> SslError {
47 self.check_panic();
48 let code = self.ssl.get_error(err);
49 let internal = match code {
50 SslErrorCode::SSL => {
51 let e = ErrorStack::get();
52 Some(InternalError::Ssl(e))
53 }
54 SslErrorCode::SYSCALL => {
55 let error = ErrorStack::get();
56 if error.errors().is_empty() {
57 self.get_bio_error().map(InternalError::Io)
58 } else {
59 Some(InternalError::Ssl(error))
60 }
61 }
62 SslErrorCode::WANT_WRITE | SslErrorCode::WANT_READ => {
63 self.get_bio_error().map(InternalError::Io)
64 }
65 _ => None,
66 };
67 SslError { code, internal }
68 }
69
check_panic(&mut self)70 fn check_panic(&mut self) {
71 if let Some(err) = unsafe { get_panic::<S>(self.ssl.get_raw_bio()) } {
72 resume_unwind(err)
73 }
74 }
75
get_bio_error(&mut self) -> Option<io::Error>76 fn get_bio_error(&mut self) -> Option<io::Error> {
77 unsafe { get_error::<S>(self.ssl.get_raw_bio()) }
78 }
79
get_ref(&self) -> &S80 pub(crate) fn get_ref(&self) -> &S {
81 unsafe {
82 let bio = self.ssl.get_raw_bio();
83 get_stream_ref(bio)
84 }
85 }
86
get_mut(&mut self) -> &mut S87 pub(crate) fn get_mut(&mut self) -> &mut S {
88 unsafe {
89 let bio = self.ssl.get_raw_bio();
90 get_stream_mut(bio)
91 }
92 }
93
ssl(&self) -> &SslRef94 pub(crate) fn ssl(&self) -> &SslRef {
95 &self.ssl
96 }
97 }
98
99 impl<S> fmt::Debug for SslStream<S>
100 where
101 S: fmt::Debug,
102 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 write!(f, "stream[{:?}], {:?}", &self.get_ref(), &self.ssl())
105 }
106 }
107
108 impl<S> Drop for SslStream<S> {
drop(&mut self)109 fn drop(&mut self) {
110 unsafe {
111 ManuallyDrop::drop(&mut self.ssl);
112 ManuallyDrop::drop(&mut self.method);
113 }
114 }
115 }
116
117 impl<S: Read + Write> SslStream<S> {
ssl_read(&mut self, buf: &[u8]) -> Result<usize, SslError>118 pub(crate) fn ssl_read(&mut self, buf: &[u8]) -> Result<usize, SslError> {
119 if buf.is_empty() {
120 return Ok(0);
121 }
122 let ret = self.ssl.read(buf);
123 if ret > 0 {
124 Ok(ret as usize)
125 } else {
126 Err(self.get_error(ret))
127 }
128 }
129
ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError>130 pub(crate) fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError> {
131 if buf.is_empty() {
132 return Ok(0);
133 }
134 let ret = self.ssl.write(buf);
135 if ret > 0 {
136 Ok(ret as usize)
137 } else {
138 Err(self.get_error(ret))
139 }
140 }
141
new_base( ssl: Ssl, stream: S, pinned_pubkey: Option<PinsVerifyInfo>, ) -> Result<Self, ErrorStack>142 pub(crate) fn new_base(
143 ssl: Ssl,
144 stream: S,
145 pinned_pubkey: Option<PinsVerifyInfo>,
146 ) -> Result<Self, ErrorStack> {
147 unsafe {
148 let (bio, method) = bio::new(stream)?;
149 SSL_set_bio(ssl.as_ptr(), bio, bio);
150
151 Ok(SslStream {
152 ssl: ManuallyDrop::new(ssl),
153 method: ManuallyDrop::new(method),
154 pinned_pubkey,
155 p: PhantomData,
156 })
157 }
158 }
159
connect(&mut self) -> Result<(), SslError>160 pub(crate) fn connect(&mut self) -> Result<(), SslError> {
161 let ret = unsafe { SSL_connect(self.ssl.as_ptr()) };
162 if ret <= 0 {
163 return Err(self.get_error(ret));
164 }
165
166 if let Some(pins_info) = &self.pinned_pubkey {
167 if pins_info.is_root() {
168 verify_server_root_cert(self.ssl.as_ptr(), pins_info.get_digest())?;
169 } else {
170 verify_server_cert(self.ssl.as_ptr(), pins_info.get_digest())?;
171 }
172 }
173 Ok(())
174 }
175
shutdown(&mut self) -> Result<ShutdownResult, SslError>176 pub(crate) fn shutdown(&mut self) -> Result<ShutdownResult, SslError> {
177 unsafe {
178 match SSL_shutdown(self.ssl.as_ptr()) {
179 0 => Ok(ShutdownResult::Sent),
180 1 => Ok(ShutdownResult::Received),
181 n => Err(self.get_error(n)),
182 }
183 }
184 }
185 }
186
187 impl<S: Read + Write> Read for SslStream<S> {
188 // ssl_read
read(&mut self, buf: &mut [u8]) -> io::Result<usize>189 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
190 loop {
191 match self.ssl_read(buf) {
192 Ok(n) => return Ok(n),
193 // The TLS/SSL peer has closed the connection for writing by sending
194 // the close_notify alert. No more data can be read.
195 // Does not necessarily indicate that the underlying transport has been closed.
196 Err(ref e) if e.code == SslErrorCode::ZERO_RETURN => return Ok(0),
197 // A non-recoverable, fatal error in the SSL library occurred, usually a protocol
198 // error.
199 Err(ref e) if e.code == SslErrorCode::SYSCALL && e.get_io_error().is_none() => {
200 return Ok(0)
201 }
202 // When the last operation was a read operation from a nonblocking BIO.
203 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
204 // Other error.
205 Err(err) => {
206 return Err(err
207 .into_io_error()
208 .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)))
209 }
210 };
211 }
212 }
213 }
214
215 impl<S: Read + Write> Write for SslStream<S> {
216 // ssl_write
write(&mut self, buf: &[u8]) -> io::Result<usize>217 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
218 loop {
219 match self.ssl_write(buf) {
220 Ok(n) => return Ok(n),
221 // When the last operation was a read operation from a nonblocking BIO.
222 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
223 Err(err) => {
224 return Err(err
225 .into_io_error()
226 .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)));
227 }
228 }
229 }
230 }
231
232 // S.flush()
flush(&mut self) -> io::Result<()>233 fn flush(&mut self) -> io::Result<()> {
234 self.get_mut().flush()
235 }
236 }
237
238 /// An SSL stream midway through the handshake process.
239 #[derive(Debug)]
240 pub(crate) struct MidHandshakeSslStream<S> {
241 pub(crate) _stream: SslStream<S>,
242 pub(crate) error: SslError,
243 }
244
245 impl<S> MidHandshakeSslStream<S> {
error(&self) -> &SslError246 pub(crate) fn error(&self) -> &SslError {
247 &self.error
248 }
249 }
250
251 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
252 pub(crate) enum ShutdownResult {
253 Sent,
254 Received,
255 }
256
verify_server_root_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError>257 pub(crate) fn verify_server_root_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> {
258 use crate::c_openssl::ffi::ssl::{SSL_get_peer_cert_chain, X509_chain_up_ref};
259 use crate::c_openssl::stack::Stack;
260 use crate::c_openssl::x509::X509;
261
262 let cert_chain = unsafe { X509_chain_up_ref(SSL_get_peer_cert_chain(ssl)) };
263 if cert_chain.is_null() {
264 return Err(SslError {
265 code: SslErrorCode::SSL,
266 internal: Some(InternalError::Ssl(ErrorStack::get())),
267 });
268 }
269
270 let cert_chain: Stack<X509> = Stack::from_ptr(cert_chain);
271 let root_certificate = cert_chain.into_iter().last().ok_or_else(|| SslError {
272 code: SslErrorCode::SSL,
273 internal: Some(InternalError::Ssl(ErrorStack::get())),
274 })?;
275
276 verify_pinned_pubkey(pinned_key, root_certificate.as_ptr())
277 }
278
279 // TODO The SSLError thrown here is meaningless and has no information.
verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError>280 pub(crate) fn verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> {
281 #[cfg(feature = "c_openssl_3_0")]
282 use crate::util::c_openssl::ffi::ssl::SSL_get1_peer_certificate;
283 #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))]
284 use crate::util::c_openssl::ffi::ssl::SSL_get_peer_certificate;
285
286 let certificate = unsafe {
287 #[cfg(feature = "c_openssl_3_0")]
288 {
289 SSL_get1_peer_certificate(ssl)
290 }
291 #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))]
292 {
293 SSL_get_peer_certificate(ssl)
294 }
295 };
296 if certificate.is_null() {
297 return Err(SslError {
298 code: SslErrorCode::SSL,
299 internal: Some(InternalError::Ssl(ErrorStack::get())),
300 });
301 }
302
303 verify_pinned_pubkey(pinned_key, certificate)
304 }
305
verify_pinned_pubkey(pinned_key: &str, certificate: *mut C_X509) -> Result<(), SslError>306 fn verify_pinned_pubkey(pinned_key: &str, certificate: *mut C_X509) -> Result<(), SslError> {
307 let pubkey = unsafe { X509_get_X509_PUBKEY(certificate) };
308 // Get the length of the serialized data
309 let buf_size = unsafe { i2d_X509_PUBKEY(pubkey, ptr::null_mut()) };
310
311 if buf_size < 1 {
312 unsafe { X509_free(certificate) };
313 return Err(SslError {
314 code: SslErrorCode::SSL,
315 internal: Some(InternalError::Ssl(ErrorStack::get())),
316 });
317 }
318 let key = vec![0u8; buf_size as usize];
319 // The actual serialization
320 let serialized_data_size = unsafe { i2d_X509_PUBKEY(pubkey, &mut key.as_ptr()) };
321
322 if buf_size != serialized_data_size || serialized_data_size <= 0 {
323 unsafe { X509_free(certificate) };
324 return Err(SslError {
325 code: SslErrorCode::SSL,
326 internal: Some(InternalError::Ssl(ErrorStack::get())),
327 });
328 }
329
330 // sha256 length.
331 let mut digest = [0u8; 32];
332 unsafe { sha256_digest(key.as_slice(), serialized_data_size, &mut digest)? }
333
334 compare_pinned_digest(&digest, pinned_key.as_bytes(), certificate)
335 }
336
compare_pinned_digest( digest: &[u8], pinned_key: &[u8], certificate: *mut C_X509, ) -> Result<(), SslError>337 fn compare_pinned_digest(
338 digest: &[u8],
339 pinned_key: &[u8],
340 certificate: *mut C_X509,
341 ) -> Result<(), SslError> {
342 let base64_digest = encode(digest);
343 let mut user_bytes = pinned_key;
344
345 let mut begin;
346 let mut end;
347 let prefix = b"sha256//";
348 let suffix = b";sha256//";
349 while !user_bytes.is_empty() {
350 begin = match user_bytes
351 .windows(prefix.len())
352 .position(|window| window == prefix)
353 {
354 None => {
355 break;
356 }
357 Some(index) => index + 8,
358 };
359 end = match user_bytes
360 .windows(suffix.len())
361 .position(|window| window == suffix)
362 {
363 None => user_bytes.len(),
364 Some(index) => index,
365 };
366
367 let bytes = &user_bytes[begin..end];
368 if bytes.eq(base64_digest.as_slice()) {
369 unsafe { X509_free(certificate) };
370 return Ok(());
371 }
372
373 if end != user_bytes.len() {
374 user_bytes = &user_bytes[end + 1..];
375 } else {
376 user_bytes = &user_bytes[end..];
377 }
378 }
379
380 unsafe { X509_free(certificate) };
381 Err(SslError {
382 code: SslErrorCode::SSL,
383 internal: Some(InternalError::User(VerifyError::from_msg(
384 PubKeyPinning,
385 "Pinned public key verification failed.",
386 ))),
387 })
388 }
389