1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use core::{cmp, ffi, fmt, str};
15 use std::ffi::CString;
16 use std::io::{Read, Write};
17
18 use libc::{c_char, c_int, c_long, c_void};
19
20 use super::error::HandshakeError;
21 use super::{MidHandshakeSslStream, SslContext, SslErrorCode, SslStream};
22 use crate::c_openssl::check_ret;
23 use crate::c_openssl::ffi::bio::BIO;
24 use crate::c_openssl::ffi::ssl::{
25 SSL_connect, SSL_ctrl, SSL_get0_param, SSL_get_error, SSL_get_rbio, SSL_get_verify_result,
26 SSL_read, SSL_state_string_long, SSL_write,
27 };
28 use crate::c_openssl::foreign::ForeignRef;
29 use crate::c_openssl::x509::{
30 X509VerifyParamRef, X509VerifyResult, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS,
31 };
32 use crate::util::c_openssl::check_ptr;
33 use crate::util::c_openssl::error::ErrorStack;
34 use crate::util::c_openssl::ffi::ssl::{SSL_free, SSL_new, SSL};
35 use crate::util::c_openssl::foreign::Foreign;
36
37 foreign_type!(
38 type CStruct = SSL;
39 fn drop = SSL_free;
40 /// The main SSL/TLS structure.
41 pub(crate) struct Ssl;
42 pub(crate) struct SslRef;
43 );
44
45 impl Ssl {
new(ctx: &SslContext) -> Result<Ssl, ErrorStack>46 pub(crate) fn new(ctx: &SslContext) -> Result<Ssl, ErrorStack> {
47 unsafe {
48 let ptr = check_ptr(SSL_new(ctx.as_ptr()))?;
49 Ok(Ssl::from_ptr(ptr))
50 }
51 }
52
53 /// Client connect to Server.
54 /// only `sync` use.
55 #[cfg(feature = "sync")]
connect<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,56 pub(crate) fn connect<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
57 where
58 S: Read + Write,
59 {
60 let mut stream = SslStream::new_base(self, stream)?;
61 let ret = unsafe { SSL_connect(stream.ssl.as_ptr()) };
62 if ret > 0 {
63 Ok(stream)
64 } else {
65 let error = stream.get_error(ret);
66 match error.code {
67 SslErrorCode::WANT_READ | SslErrorCode::WANT_WRITE => {
68 Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
69 _stream: stream,
70 error,
71 }))
72 }
73 _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
74 _stream: stream,
75 error,
76 })),
77 }
78 }
79 }
80 }
81
82 impl SslRef {
get_error(&self, err: c_int) -> SslErrorCode83 pub(crate) fn get_error(&self, err: c_int) -> SslErrorCode {
84 unsafe { SslErrorCode::from_int(SSL_get_error(self.as_ptr(), err)) }
85 }
86
ssl_status(&self) -> &'static str87 fn ssl_status(&self) -> &'static str {
88 let status = unsafe {
89 let ptr = SSL_state_string_long(self.as_ptr());
90 ffi::CStr::from_ptr(ptr as *const _)
91 };
92 str::from_utf8(status.to_bytes()).unwrap_or_default()
93 }
94
verify_result(&self) -> X509VerifyResult95 pub(crate) fn verify_result(&self) -> X509VerifyResult {
96 unsafe { X509VerifyResult::from_raw(SSL_get_verify_result(self.as_ptr()) as c_int) }
97 }
98
get_raw_bio(&self) -> *mut BIO99 pub(crate) fn get_raw_bio(&self) -> *mut BIO {
100 unsafe { SSL_get_rbio(self.as_ptr()) }
101 }
102
read(&mut self, buf: &mut [u8]) -> c_int103 pub(crate) fn read(&mut self, buf: &mut [u8]) -> c_int {
104 let len = cmp::min(c_int::MAX as usize, buf.len()) as c_int;
105 unsafe { SSL_read(self.as_ptr(), buf.as_ptr() as *mut c_void, len) }
106 }
107
write(&mut self, buf: &[u8]) -> c_int108 pub(crate) fn write(&mut self, buf: &[u8]) -> c_int {
109 let len = cmp::min(c_int::MAX as usize, buf.len()) as c_int;
110 unsafe { SSL_write(self.as_ptr(), buf.as_ptr() as *const c_void, len) }
111 }
112
set_host_name_in_sni(&mut self, name: &str) -> Result<(), ErrorStack>113 pub(crate) fn set_host_name_in_sni(&mut self, name: &str) -> Result<(), ErrorStack> {
114 let name = match CString::new(name) {
115 Ok(name) => name,
116 Err(_) => return Err(ErrorStack::get()),
117 };
118 check_ret(
119 unsafe { ssl_set_tlsext_host_name(self.as_ptr(), name.as_ptr() as *mut _) } as c_int,
120 )
121 .map(|_| ())
122 }
123
param_mut(&mut self) -> &mut X509VerifyParamRef124 pub(crate) fn param_mut(&mut self) -> &mut X509VerifyParamRef {
125 unsafe { X509VerifyParamRef::from_ptr_mut(SSL_get0_param(self.as_ptr())) }
126 }
127
set_verify_hostname(&mut self, host_name: &str) -> Result<(), ErrorStack>128 pub(crate) fn set_verify_hostname(&mut self, host_name: &str) -> Result<(), ErrorStack> {
129 let param = self.param_mut();
130 param.set_hostflags(X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
131 match host_name.parse() {
132 Ok(ip) => param.set_ip(ip),
133 Err(_) => param.set_host(host_name),
134 }
135 }
136 }
137
138 impl fmt::Debug for SslRef {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 write!(
141 f,
142 "Ssl[state: {}, verify result: {}]",
143 &self.ssl_status(),
144 &self.verify_result()
145 )
146 }
147 }
148
149 const SSL_CTRL_SET_TLSEXT_HOSTNAME: c_int = 0x37;
150 const TLSEXT_NAMETYPE_HOST_NAME: c_int = 0x0;
151
ssl_set_tlsext_host_name(s: *mut SSL, name: *mut c_char) -> c_long152 unsafe fn ssl_set_tlsext_host_name(s: *mut SSL, name: *mut c_char) -> c_long {
153 SSL_ctrl(
154 s,
155 SSL_CTRL_SET_TLSEXT_HOSTNAME,
156 TLSEXT_NAMETYPE_HOST_NAME as c_long,
157 name as *mut c_void,
158 )
159 }
160