use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::timeout; use tokio_io_timeout::TimeoutStream; use hyper::client::connect::{Connected, Connection}; use hyper::{service::Service, Uri}; mod stream; use stream::TimeoutConnectorStream; type BoxError = Box; /// A connector that enforces as connection timeout #[derive(Debug, Clone)] pub struct TimeoutConnector { /// A connector implementing the `Connect` trait connector: T, /// Amount of time to wait connecting connect_timeout: Option, /// Amount of time to wait reading response read_timeout: Option, /// Amount of time to wait writing request write_timeout: Option, } impl TimeoutConnector where T: Service + Send, T::Response: AsyncRead + AsyncWrite + Send + Unpin, T::Future: Send + 'static, T::Error: Into, { /// Construct a new TimeoutConnector with a given connector implementing the `Connect` trait pub fn new(connector: T) -> Self { TimeoutConnector { connector, connect_timeout: None, read_timeout: None, write_timeout: None, } } } impl Service for TimeoutConnector where T: Service + Send, T::Response: AsyncRead + AsyncWrite + Connection + Send + Unpin, T::Future: Send + 'static, T::Error: Into, { type Response = Pin>>; type Error = BoxError; #[allow(clippy::type_complexity)] type Future = Pin> + Send>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.connector.poll_ready(cx).map_err(Into::into) } fn call(&mut self, dst: Uri) -> Self::Future { let connect_timeout = self.connect_timeout; let read_timeout = self.read_timeout; let write_timeout = self.write_timeout; let connecting = self.connector.call(dst); let fut = async move { let stream = match connect_timeout { None => { let io = connecting.await.map_err(Into::into)?; TimeoutStream::new(io) } Some(connect_timeout) => { let timeout = timeout(connect_timeout, connecting); let connecting = timeout .await .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?; let io = connecting.map_err(Into::into)?; TimeoutStream::new(io) } }; let mut tm = TimeoutConnectorStream::new(stream); tm.set_read_timeout(read_timeout); tm.set_write_timeout(write_timeout); Ok(Box::pin(tm)) }; Box::pin(fut) } } impl TimeoutConnector { /// Set the timeout for connecting to a URL. /// /// Default is no timeout. #[inline] pub fn set_connect_timeout(&mut self, val: Option) { self.connect_timeout = val; } /// Set the timeout for the response. /// /// Default is no timeout. #[inline] pub fn set_read_timeout(&mut self, val: Option) { self.read_timeout = val; } /// Set the timeout for the request. /// /// Default is no timeout. #[inline] pub fn set_write_timeout(&mut self, val: Option) { self.write_timeout = val; } } impl Connection for TimeoutConnector where T: AsyncRead + AsyncWrite + Connection + Service + Send + Unpin, T::Response: AsyncRead + AsyncWrite + Send + Unpin, T::Future: Send + 'static, T::Error: Into, { fn connected(&self) -> Connected { self.connector.connected() } } #[cfg(test)] mod tests { use std::error::Error; use std::io; use std::time::Duration; use hyper::client::HttpConnector; use hyper::Client; use super::TimeoutConnector; #[tokio::test] async fn test_timeout_connector() { // 10.255.255.1 is a not a routable IP address let url = "http://10.255.255.1".parse().unwrap(); let http = HttpConnector::new(); let mut connector = TimeoutConnector::new(http); connector.set_connect_timeout(Some(Duration::from_millis(1))); let client = Client::builder().build::<_, hyper::Body>(connector); let res = client.get(url).await; match res { Ok(_) => panic!("Expected a timeout"), Err(e) => { if let Some(io_e) = e.source().unwrap().downcast_ref::() { assert_eq!(io_e.kind(), io::ErrorKind::TimedOut); } else { panic!("Expected timeout error"); } } } } #[tokio::test] async fn test_read_timeout() { let url = "http://example.com".parse().unwrap(); let http = HttpConnector::new(); let mut connector = TimeoutConnector::new(http); // A 1 ms read timeout should be so short that we trigger a timeout error connector.set_read_timeout(Some(Duration::from_millis(1))); let client = Client::builder().build::<_, hyper::Body>(connector); let res = client.get(url).await; match res { Ok(_) => panic!("Expected a timeout"), Err(e) => { if let Some(io_e) = e.source().unwrap().downcast_ref::() { assert_eq!(io_e.kind(), io::ErrorKind::TimedOut); } else { panic!("Expected timeout error"); } } } } }