1 use std::{
2 io::Cursor,
3 {fmt, sync::Arc},
4 };
5
6 use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
7 use tokio::io::{AsyncRead, AsyncWrite};
8 use tokio_rustls::{
9 rustls::{server::WebPkiClientVerifier, ClientConfig, RootCertStore, ServerConfig},
10 TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector,
11 };
12
13 use super::io::BoxedIo;
14 use crate::transport::{
15 server::{Connected, TlsStream},
16 Certificate, Identity,
17 };
18
19 /// h2 alpn in plain format for rustls.
20 const ALPN_H2: &[u8] = b"h2";
21
22 #[derive(Debug)]
23 enum TlsError {
24 H2NotNegotiated,
25 CertificateParseError,
26 PrivateKeyParseError,
27 }
28
29 #[derive(Clone)]
30 pub(crate) struct TlsConnector {
31 config: Arc<ClientConfig>,
32 domain: Arc<ServerName<'static>>,
33 }
34
35 impl TlsConnector {
new( ca_cert: Option<Certificate>, identity: Option<Identity>, domain: &str, ) -> Result<Self, crate::Error>36 pub(crate) fn new(
37 ca_cert: Option<Certificate>,
38 identity: Option<Identity>,
39 domain: &str,
40 ) -> Result<Self, crate::Error> {
41 let builder = ClientConfig::builder();
42 let mut roots = RootCertStore::empty();
43
44 #[cfg(feature = "tls-roots")]
45 roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
46
47 #[cfg(feature = "tls-webpki-roots")]
48 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
49
50 if let Some(cert) = ca_cert {
51 add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
52 }
53
54 let builder = builder.with_root_certificates(roots);
55 let mut config = match identity {
56 Some(identity) => {
57 let (client_cert, client_key) = load_identity(identity)?;
58 builder.with_client_auth_cert(client_cert, client_key)?
59 }
60 None => builder.with_no_client_auth(),
61 };
62
63 config.alpn_protocols.push(ALPN_H2.into());
64 Ok(Self {
65 config: Arc::new(config),
66 domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
67 })
68 }
69
connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error> where I: AsyncRead + AsyncWrite + Send + Unpin + 'static,70 pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
71 where
72 I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
73 {
74 let io = RustlsConnector::from(self.config.clone())
75 .connect(self.domain.as_ref().to_owned(), io)
76 .await?;
77
78 let (_, session) = io.get_ref();
79 if session.alpn_protocol() != Some(ALPN_H2) {
80 return Err(TlsError::H2NotNegotiated)?;
81 }
82
83 Ok(BoxedIo::new(io))
84 }
85 }
86
87 impl fmt::Debug for TlsConnector {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("TlsConnector").finish()
90 }
91 }
92
93 #[derive(Clone)]
94 pub(crate) struct TlsAcceptor {
95 inner: Arc<ServerConfig>,
96 }
97
98 impl TlsAcceptor {
new( identity: Identity, client_ca_root: Option<Certificate>, client_auth_optional: bool, ) -> Result<Self, crate::Error>99 pub(crate) fn new(
100 identity: Identity,
101 client_ca_root: Option<Certificate>,
102 client_auth_optional: bool,
103 ) -> Result<Self, crate::Error> {
104 let builder = ServerConfig::builder();
105
106 let builder = match client_ca_root {
107 None => builder.with_no_client_auth(),
108 Some(cert) => {
109 let mut roots = RootCertStore::empty();
110 add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
111 let verifier = if client_auth_optional {
112 WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated()
113 } else {
114 WebPkiClientVerifier::builder(roots.into())
115 }
116 .build()?;
117 builder.with_client_cert_verifier(verifier)
118 }
119 };
120
121 let (cert, key) = load_identity(identity)?;
122 let mut config = builder.with_single_cert(cert, key)?;
123
124 config.alpn_protocols.push(ALPN_H2.into());
125 Ok(Self {
126 inner: Arc::new(config),
127 })
128 }
129
accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,130 pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
131 where
132 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
133 {
134 let acceptor = RustlsAcceptor::from(self.inner.clone());
135 acceptor.accept(io).await.map_err(Into::into)
136 }
137 }
138
139 impl fmt::Debug for TlsAcceptor {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 f.debug_struct("TlsAcceptor").finish()
142 }
143 }
144
145 impl fmt::Display for TlsError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 match self {
148 TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
149 TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."),
150 TlsError::PrivateKeyParseError => write!(
151 f,
152 "Error parsing TLS private key - no RSA or PKCS8-encoded keys found."
153 ),
154 }
155 }
156 }
157
158 impl std::error::Error for TlsError {}
159
load_identity( identity: Identity, ) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), TlsError>160 fn load_identity(
161 identity: Identity,
162 ) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), TlsError> {
163 let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert))
164 .collect::<Result<Vec<_>, _>>()
165 .map_err(|_| TlsError::CertificateParseError)?;
166
167 let Ok(Some(key)) = rustls_pemfile::private_key(&mut Cursor::new(identity.key)) else {
168 return Err(TlsError::PrivateKeyParseError);
169 };
170
171 Ok((cert, key))
172 }
173
add_certs_from_pem( mut certs: &mut dyn std::io::BufRead, roots: &mut RootCertStore, ) -> Result<(), crate::Error>174 fn add_certs_from_pem(
175 mut certs: &mut dyn std::io::BufRead,
176 roots: &mut RootCertStore,
177 ) -> Result<(), crate::Error> {
178 for cert in rustls_pemfile::certs(&mut certs).collect::<Result<Vec<_>, _>>()? {
179 roots
180 .add(cert)
181 .map_err(|_| TlsError::CertificateParseError)?;
182 }
183
184 Ok(())
185 }
186