1 use super::super::BoxFuture; 2 use super::io::BoxedIo; 3 #[cfg(feature = "tls")] 4 use super::tls::TlsConnector; 5 use http::Uri; 6 use std::fmt; 7 use std::task::{Context, Poll}; 8 use tower::make::MakeConnection; 9 use tower_service::Service; 10 11 pub(crate) struct Connector<C> { 12 inner: C, 13 #[cfg(feature = "tls")] 14 tls: Option<TlsConnector>, 15 #[cfg(not(feature = "tls"))] 16 #[allow(dead_code)] 17 tls: Option<()>, 18 } 19 20 impl<C> Connector<C> { 21 #[cfg(not(feature = "tls"))] new(inner: C) -> Self22 pub(crate) fn new(inner: C) -> Self { 23 Self { inner, tls: None } 24 } 25 26 #[cfg(feature = "tls")] new(inner: C, tls: Option<TlsConnector>) -> Self27 pub(crate) fn new(inner: C, tls: Option<TlsConnector>) -> Self { 28 Self { inner, tls } 29 } 30 31 #[cfg(feature = "tls-roots-common")] tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option<TlsConnector>32 fn tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option<TlsConnector> { 33 if self.tls.is_some() { 34 return self.tls.clone(); 35 } 36 37 let host = match (scheme, host) { 38 (Some("https"), Some(host)) => host, 39 _ => return None, 40 }; 41 42 TlsConnector::new(None, None, host).ok() 43 } 44 } 45 46 impl<C> Service<Uri> for Connector<C> 47 where 48 C: MakeConnection<Uri>, 49 C::Connection: Unpin + Send + 'static, 50 C::Future: Send + 'static, 51 crate::Error: From<C::Error> + Send + 'static, 52 { 53 type Response = BoxedIo; 54 type Error = crate::Error; 55 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>; 56 poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>57 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 58 MakeConnection::poll_ready(&mut self.inner, cx).map_err(Into::into) 59 } 60 call(&mut self, uri: Uri) -> Self::Future61 fn call(&mut self, uri: Uri) -> Self::Future { 62 #[cfg(all(feature = "tls", not(feature = "tls-roots-common")))] 63 let tls = self.tls.clone(); 64 65 #[cfg(feature = "tls-roots-common")] 66 let tls = self.tls_or_default(uri.scheme_str(), uri.host()); 67 68 #[cfg(feature = "tls")] 69 let is_https = uri.scheme_str() == Some("https"); 70 let connect = self.inner.make_connection(uri); 71 72 Box::pin(async move { 73 let io = connect.await?; 74 75 #[cfg(feature = "tls")] 76 { 77 if let Some(tls) = tls { 78 if is_https { 79 let conn = tls.connect(io).await?; 80 return Ok(BoxedIo::new(conn)); 81 } else { 82 return Ok(BoxedIo::new(io)); 83 } 84 } else if is_https { 85 return Err(HttpsUriWithoutTlsSupport(()).into()); 86 } 87 } 88 89 Ok(BoxedIo::new(io)) 90 }) 91 } 92 } 93 94 /// Error returned when trying to connect to an HTTPS endpoint without TLS enabled. 95 #[derive(Debug)] 96 pub(crate) struct HttpsUriWithoutTlsSupport(()); 97 98 impl fmt::Display for HttpsUriWithoutTlsSupport { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 100 write!(f, "Connecting to HTTPS without TLS enabled") 101 } 102 } 103 104 // std::error::Error only requires a type to impl Debug and Display 105 impl std::error::Error for HttpsUriWithoutTlsSupport {} 106