• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use super::super::BoxFuture;
2 use super::io::BoxedIo;
3 #[cfg(feature = "tls")]
4 use super::tls::TlsConnector;
5 use http::Uri;
6 use std::fmt;
7 use std::task::{Context, Poll};
8 use tower::make::MakeConnection;
9 use tower_service::Service;
10 
11 pub(crate) struct Connector<C> {
12     inner: C,
13     #[cfg(feature = "tls")]
14     tls: Option<TlsConnector>,
15     #[cfg(not(feature = "tls"))]
16     #[allow(dead_code)]
17     tls: Option<()>,
18 }
19 
20 impl<C> Connector<C> {
21     #[cfg(not(feature = "tls"))]
new(inner: C) -> Self22     pub(crate) fn new(inner: C) -> Self {
23         Self { inner, tls: None }
24     }
25 
26     #[cfg(feature = "tls")]
new(inner: C, tls: Option<TlsConnector>) -> Self27     pub(crate) fn new(inner: C, tls: Option<TlsConnector>) -> Self {
28         Self { inner, tls }
29     }
30 
31     #[cfg(feature = "tls-roots-common")]
tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option<TlsConnector>32     fn tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option<TlsConnector> {
33         if self.tls.is_some() {
34             return self.tls.clone();
35         }
36 
37         let host = match (scheme, host) {
38             (Some("https"), Some(host)) => host,
39             _ => return None,
40         };
41 
42         TlsConnector::new(None, None, host).ok()
43     }
44 }
45 
46 impl<C> Service<Uri> for Connector<C>
47 where
48     C: MakeConnection<Uri>,
49     C::Connection: Unpin + Send + 'static,
50     C::Future: Send + 'static,
51     crate::Error: From<C::Error> + Send + 'static,
52 {
53     type Response = BoxedIo;
54     type Error = crate::Error;
55     type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
56 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>57     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58         MakeConnection::poll_ready(&mut self.inner, cx).map_err(Into::into)
59     }
60 
call(&mut self, uri: Uri) -> Self::Future61     fn call(&mut self, uri: Uri) -> Self::Future {
62         #[cfg(all(feature = "tls", not(feature = "tls-roots-common")))]
63         let tls = self.tls.clone();
64 
65         #[cfg(feature = "tls-roots-common")]
66         let tls = self.tls_or_default(uri.scheme_str(), uri.host());
67 
68         #[cfg(feature = "tls")]
69         let is_https = uri.scheme_str() == Some("https");
70         let connect = self.inner.make_connection(uri);
71 
72         Box::pin(async move {
73             let io = connect.await?;
74 
75             #[cfg(feature = "tls")]
76             {
77                 if let Some(tls) = tls {
78                     if is_https {
79                         let conn = tls.connect(io).await?;
80                         return Ok(BoxedIo::new(conn));
81                     } else {
82                         return Ok(BoxedIo::new(io));
83                     }
84                 } else if is_https {
85                     return Err(HttpsUriWithoutTlsSupport(()).into());
86                 }
87             }
88 
89             Ok(BoxedIo::new(io))
90         })
91     }
92 }
93 
94 /// Error returned when trying to connect to an HTTPS endpoint without TLS enabled.
95 #[derive(Debug)]
96 pub(crate) struct HttpsUriWithoutTlsSupport(());
97 
98 impl fmt::Display for HttpsUriWithoutTlsSupport {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result99     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100         write!(f, "Connecting to HTTPS without TLS enabled")
101     }
102 }
103 
104 // std::error::Error only requires a type to impl Debug and Display
105 impl std::error::Error for HttpsUriWithoutTlsSupport {}
106