• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::{
2     io::Cursor,
3     {fmt, sync::Arc},
4 };
5 
6 use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
7 use tokio::io::{AsyncRead, AsyncWrite};
8 use tokio_rustls::{
9     rustls::{server::WebPkiClientVerifier, ClientConfig, RootCertStore, ServerConfig},
10     TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector,
11 };
12 
13 use super::io::BoxedIo;
14 use crate::transport::{
15     server::{Connected, TlsStream},
16     Certificate, Identity,
17 };
18 
19 /// h2 alpn in plain format for rustls.
20 const ALPN_H2: &[u8] = b"h2";
21 
22 #[derive(Debug)]
23 enum TlsError {
24     H2NotNegotiated,
25     CertificateParseError,
26     PrivateKeyParseError,
27 }
28 
29 #[derive(Clone)]
30 pub(crate) struct TlsConnector {
31     config: Arc<ClientConfig>,
32     domain: Arc<ServerName<'static>>,
33 }
34 
35 impl TlsConnector {
new( ca_cert: Option<Certificate>, identity: Option<Identity>, domain: &str, ) -> Result<Self, crate::Error>36     pub(crate) fn new(
37         ca_cert: Option<Certificate>,
38         identity: Option<Identity>,
39         domain: &str,
40     ) -> Result<Self, crate::Error> {
41         let builder = ClientConfig::builder();
42         let mut roots = RootCertStore::empty();
43 
44         #[cfg(feature = "tls-roots")]
45         roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
46 
47         #[cfg(feature = "tls-webpki-roots")]
48         roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
49 
50         if let Some(cert) = ca_cert {
51             add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
52         }
53 
54         let builder = builder.with_root_certificates(roots);
55         let mut config = match identity {
56             Some(identity) => {
57                 let (client_cert, client_key) = load_identity(identity)?;
58                 builder.with_client_auth_cert(client_cert, client_key)?
59             }
60             None => builder.with_no_client_auth(),
61         };
62 
63         config.alpn_protocols.push(ALPN_H2.into());
64         Ok(Self {
65             config: Arc::new(config),
66             domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
67         })
68     }
69 
connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error> where I: AsyncRead + AsyncWrite + Send + Unpin + 'static,70     pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
71     where
72         I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
73     {
74         let io = RustlsConnector::from(self.config.clone())
75             .connect(self.domain.as_ref().to_owned(), io)
76             .await?;
77 
78         let (_, session) = io.get_ref();
79         if session.alpn_protocol() != Some(ALPN_H2) {
80             return Err(TlsError::H2NotNegotiated)?;
81         }
82 
83         Ok(BoxedIo::new(io))
84     }
85 }
86 
87 impl fmt::Debug for TlsConnector {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result88     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89         f.debug_struct("TlsConnector").finish()
90     }
91 }
92 
93 #[derive(Clone)]
94 pub(crate) struct TlsAcceptor {
95     inner: Arc<ServerConfig>,
96 }
97 
98 impl TlsAcceptor {
new( identity: Identity, client_ca_root: Option<Certificate>, client_auth_optional: bool, ) -> Result<Self, crate::Error>99     pub(crate) fn new(
100         identity: Identity,
101         client_ca_root: Option<Certificate>,
102         client_auth_optional: bool,
103     ) -> Result<Self, crate::Error> {
104         let builder = ServerConfig::builder();
105 
106         let builder = match client_ca_root {
107             None => builder.with_no_client_auth(),
108             Some(cert) => {
109                 let mut roots = RootCertStore::empty();
110                 add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
111                 let verifier = if client_auth_optional {
112                     WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated()
113                 } else {
114                     WebPkiClientVerifier::builder(roots.into())
115                 }
116                 .build()?;
117                 builder.with_client_cert_verifier(verifier)
118             }
119         };
120 
121         let (cert, key) = load_identity(identity)?;
122         let mut config = builder.with_single_cert(cert, key)?;
123 
124         config.alpn_protocols.push(ALPN_H2.into());
125         Ok(Self {
126             inner: Arc::new(config),
127         })
128     }
129 
accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,130     pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
131     where
132         IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
133     {
134         let acceptor = RustlsAcceptor::from(self.inner.clone());
135         acceptor.accept(io).await.map_err(Into::into)
136     }
137 }
138 
139 impl fmt::Debug for TlsAcceptor {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result140     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141         f.debug_struct("TlsAcceptor").finish()
142     }
143 }
144 
145 impl fmt::Display for TlsError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result146     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147         match self {
148             TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
149             TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."),
150             TlsError::PrivateKeyParseError => write!(
151                 f,
152                 "Error parsing TLS private key - no RSA or PKCS8-encoded keys found."
153             ),
154         }
155     }
156 }
157 
158 impl std::error::Error for TlsError {}
159 
load_identity( identity: Identity, ) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), TlsError>160 fn load_identity(
161     identity: Identity,
162 ) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), TlsError> {
163     let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert))
164         .collect::<Result<Vec<_>, _>>()
165         .map_err(|_| TlsError::CertificateParseError)?;
166 
167     let Ok(Some(key)) = rustls_pemfile::private_key(&mut Cursor::new(identity.key)) else {
168         return Err(TlsError::PrivateKeyParseError);
169     };
170 
171     Ok((cert, key))
172 }
173 
add_certs_from_pem( mut certs: &mut dyn std::io::BufRead, roots: &mut RootCertStore, ) -> Result<(), crate::Error>174 fn add_certs_from_pem(
175     mut certs: &mut dyn std::io::BufRead,
176     roots: &mut RootCertStore,
177 ) -> Result<(), crate::Error> {
178     for cert in rustls_pemfile::certs(&mut certs).collect::<Result<Vec<_>, _>>()? {
179         roots
180             .add(cert)
181             .map_err(|_| TlsError::CertificateParseError)?;
182     }
183 
184     Ok(())
185 }
186