• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use hyper::server::conn::AddrStream;
2 use std::net::SocketAddr;
3 use tokio::net::TcpStream;
4 
5 #[cfg(feature = "tls")]
6 use crate::transport::Certificate;
7 #[cfg(feature = "tls")]
8 use std::sync::Arc;
9 #[cfg(feature = "tls")]
10 use tokio_rustls::server::TlsStream;
11 
12 /// Trait that connected IO resources implement and use to produce info about the connection.
13 ///
14 /// The goal for this trait is to allow users to implement
15 /// custom IO types that can still provide the same connection
16 /// metadata.
17 ///
18 /// # Example
19 ///
20 /// The `ConnectInfo` returned will be accessible through [request extensions][ext]:
21 ///
22 /// ```
23 /// use tonic::{Request, transport::server::Connected};
24 ///
25 /// // A `Stream` that yields connections
26 /// struct MyConnector {}
27 ///
28 /// // Return metadata about the connection as `MyConnectInfo`
29 /// impl Connected for MyConnector {
30 ///     type ConnectInfo = MyConnectInfo;
31 ///
32 ///     fn connect_info(&self) -> Self::ConnectInfo {
33 ///         MyConnectInfo {}
34 ///     }
35 /// }
36 ///
37 /// #[derive(Clone)]
38 /// struct MyConnectInfo {
39 ///     // Metadata about your connection
40 /// }
41 ///
42 /// // The connect info can be accessed through request extensions:
43 /// # fn foo(request: Request<()>) {
44 /// let connect_info: &MyConnectInfo = request
45 ///     .extensions()
46 ///     .get::<MyConnectInfo>()
47 ///     .expect("bug in tonic");
48 /// # }
49 /// ```
50 ///
51 /// [ext]: crate::Request::extensions
52 pub trait Connected {
53     /// The connection info type the IO resources generates.
54     // all these bounds are necessary to set this as a request extension
55     type ConnectInfo: Clone + Send + Sync + 'static;
56 
57     /// Create type holding information about the connection.
connect_info(&self) -> Self::ConnectInfo58     fn connect_info(&self) -> Self::ConnectInfo;
59 }
60 
61 /// Connection info for standard TCP streams.
62 ///
63 /// This type will be accessible through [request extensions][ext] if you're using the default
64 /// non-TLS connector.
65 ///
66 /// See [`Connected`] for more details.
67 ///
68 /// [ext]: crate::Request::extensions
69 #[derive(Debug, Clone)]
70 pub struct TcpConnectInfo {
71     /// Returns the local address of this connection.
72     pub local_addr: Option<SocketAddr>,
73     /// Returns the remote (peer) address of this connection.
74     pub remote_addr: Option<SocketAddr>,
75 }
76 
77 impl TcpConnectInfo {
78     /// Return the local address the IO resource is connected.
local_addr(&self) -> Option<SocketAddr>79     pub fn local_addr(&self) -> Option<SocketAddr> {
80         self.local_addr
81     }
82 
83     /// Return the remote address the IO resource is connected too.
remote_addr(&self) -> Option<SocketAddr>84     pub fn remote_addr(&self) -> Option<SocketAddr> {
85         self.remote_addr
86     }
87 }
88 
89 impl Connected for AddrStream {
90     type ConnectInfo = TcpConnectInfo;
91 
connect_info(&self) -> Self::ConnectInfo92     fn connect_info(&self) -> Self::ConnectInfo {
93         TcpConnectInfo {
94             local_addr: Some(self.local_addr()),
95             remote_addr: Some(self.remote_addr()),
96         }
97     }
98 }
99 
100 impl Connected for TcpStream {
101     type ConnectInfo = TcpConnectInfo;
102 
connect_info(&self) -> Self::ConnectInfo103     fn connect_info(&self) -> Self::ConnectInfo {
104         TcpConnectInfo {
105             local_addr: self.local_addr().ok(),
106             remote_addr: self.peer_addr().ok(),
107         }
108     }
109 }
110 
111 impl Connected for tokio::io::DuplexStream {
112     type ConnectInfo = ();
113 
connect_info(&self) -> Self::ConnectInfo114     fn connect_info(&self) -> Self::ConnectInfo {}
115 }
116 
117 #[cfg(feature = "tls")]
118 impl<T> Connected for TlsStream<T>
119 where
120     T: Connected,
121 {
122     type ConnectInfo = TlsConnectInfo<T::ConnectInfo>;
123 
connect_info(&self) -> Self::ConnectInfo124     fn connect_info(&self) -> Self::ConnectInfo {
125         let (inner, session) = self.get_ref();
126         let inner = inner.connect_info();
127 
128         let certs = if let Some(certs) = session.peer_certificates() {
129             let certs = certs.iter().map(Certificate::from_pem).collect();
130             Some(Arc::new(certs))
131         } else {
132             None
133         };
134 
135         TlsConnectInfo { inner, certs }
136     }
137 }
138 
139 /// Connection info for TLS streams.
140 ///
141 /// This type will be accessible through [request extensions][ext] if you're using a TLS connector.
142 ///
143 /// See [`Connected`] for more details.
144 ///
145 /// [ext]: crate::Request::extensions
146 #[cfg(feature = "tls")]
147 #[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
148 #[derive(Debug, Clone)]
149 pub struct TlsConnectInfo<T> {
150     inner: T,
151     certs: Option<Arc<Vec<Certificate>>>,
152 }
153 
154 #[cfg(feature = "tls")]
155 #[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
156 impl<T> TlsConnectInfo<T> {
157     /// Get a reference to the underlying connection info.
get_ref(&self) -> &T158     pub fn get_ref(&self) -> &T {
159         &self.inner
160     }
161 
162     /// Get a mutable reference to the underlying connection info.
get_mut(&mut self) -> &mut T163     pub fn get_mut(&mut self) -> &mut T {
164         &mut self.inner
165     }
166 
167     /// Return the set of connected peer TLS certificates.
peer_certs(&self) -> Option<Arc<Vec<Certificate>>>168     pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
169         self.certs.clone()
170     }
171 }
172