1 use super::{BoxError, HttpBody}; 2 use bytes::Bytes; 3 use http::{ 4 header::{HeaderName, HeaderValue}, 5 Request, StatusCode, 6 }; 7 use hyper::{Body, Server}; 8 use std::net::{SocketAddr, TcpListener}; 9 use tower::make::Shared; 10 use tower_service::Service; 11 12 pub(crate) struct TestClient { 13 client: reqwest::Client, 14 addr: SocketAddr, 15 } 16 17 impl TestClient { new<S, ResBody>(svc: S) -> Self where S: Service<Request<Body>, Response = http::Response<ResBody>> + Clone + Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Data: Send, ResBody::Error: Into<BoxError>, S::Future: Send, S::Error: Into<BoxError>,18 pub(crate) fn new<S, ResBody>(svc: S) -> Self 19 where 20 S: Service<Request<Body>, Response = http::Response<ResBody>> + Clone + Send + 'static, 21 ResBody: HttpBody + Send + 'static, 22 ResBody::Data: Send, 23 ResBody::Error: Into<BoxError>, 24 S::Future: Send, 25 S::Error: Into<BoxError>, 26 { 27 let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket"); 28 let addr = listener.local_addr().unwrap(); 29 println!("Listening on {addr}"); 30 31 tokio::spawn(async move { 32 let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc)); 33 server.await.expect("server error"); 34 }); 35 36 let client = reqwest::Client::builder() 37 .redirect(reqwest::redirect::Policy::none()) 38 .build() 39 .unwrap(); 40 41 TestClient { client, addr } 42 } 43 get(&self, url: &str) -> RequestBuilder44 pub(crate) fn get(&self, url: &str) -> RequestBuilder { 45 RequestBuilder { 46 builder: self.client.get(format!("http://{}{}", self.addr, url)), 47 } 48 } 49 head(&self, url: &str) -> RequestBuilder50 pub(crate) fn head(&self, url: &str) -> RequestBuilder { 51 RequestBuilder { 52 builder: self.client.head(format!("http://{}{}", self.addr, url)), 53 } 54 } 55 post(&self, url: &str) -> RequestBuilder56 pub(crate) fn post(&self, url: &str) -> RequestBuilder { 57 RequestBuilder { 58 builder: self.client.post(format!("http://{}{}", self.addr, url)), 59 } 60 } 61 62 #[allow(dead_code)] put(&self, url: &str) -> RequestBuilder63 pub(crate) fn put(&self, url: &str) -> RequestBuilder { 64 RequestBuilder { 65 builder: self.client.put(format!("http://{}{}", self.addr, url)), 66 } 67 } 68 69 #[allow(dead_code)] patch(&self, url: &str) -> RequestBuilder70 pub(crate) fn patch(&self, url: &str) -> RequestBuilder { 71 RequestBuilder { 72 builder: self.client.patch(format!("http://{}{}", self.addr, url)), 73 } 74 } 75 } 76 77 pub(crate) struct RequestBuilder { 78 builder: reqwest::RequestBuilder, 79 } 80 81 impl RequestBuilder { send(self) -> TestResponse82 pub(crate) async fn send(self) -> TestResponse { 83 TestResponse { 84 response: self.builder.send().await.unwrap(), 85 } 86 } 87 body(mut self, body: impl Into<reqwest::Body>) -> Self88 pub(crate) fn body(mut self, body: impl Into<reqwest::Body>) -> Self { 89 self.builder = self.builder.body(body); 90 self 91 } 92 json<T>(mut self, json: &T) -> Self where T: serde::Serialize,93 pub(crate) fn json<T>(mut self, json: &T) -> Self 94 where 95 T: serde::Serialize, 96 { 97 self.builder = self.builder.json(json); 98 self 99 } 100 header<K, V>(mut self, key: K, value: V) -> Self where HeaderName: TryFrom<K>, <HeaderName as TryFrom<K>>::Error: Into<http::Error>, HeaderValue: TryFrom<V>, <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,101 pub(crate) fn header<K, V>(mut self, key: K, value: V) -> Self 102 where 103 HeaderName: TryFrom<K>, 104 <HeaderName as TryFrom<K>>::Error: Into<http::Error>, 105 HeaderValue: TryFrom<V>, 106 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>, 107 { 108 self.builder = self.builder.header(key, value); 109 self 110 } 111 112 #[allow(dead_code)] multipart(mut self, form: reqwest::multipart::Form) -> Self113 pub(crate) fn multipart(mut self, form: reqwest::multipart::Form) -> Self { 114 self.builder = self.builder.multipart(form); 115 self 116 } 117 } 118 119 #[derive(Debug)] 120 pub(crate) struct TestResponse { 121 response: reqwest::Response, 122 } 123 124 impl TestResponse { 125 #[allow(dead_code)] bytes(self) -> Bytes126 pub(crate) async fn bytes(self) -> Bytes { 127 self.response.bytes().await.unwrap() 128 } 129 text(self) -> String130 pub(crate) async fn text(self) -> String { 131 self.response.text().await.unwrap() 132 } 133 134 #[allow(dead_code)] json<T>(self) -> T where T: serde::de::DeserializeOwned,135 pub(crate) async fn json<T>(self) -> T 136 where 137 T: serde::de::DeserializeOwned, 138 { 139 self.response.json().await.unwrap() 140 } 141 status(&self) -> StatusCode142 pub(crate) fn status(&self) -> StatusCode { 143 self.response.status() 144 } 145 headers(&self) -> &http::HeaderMap146 pub(crate) fn headers(&self) -> &http::HeaderMap { 147 self.response.headers() 148 } 149 chunk(&mut self) -> Option<Bytes>150 pub(crate) async fn chunk(&mut self) -> Option<Bytes> { 151 self.response.chunk().await.unwrap() 152 } 153 chunk_text(&mut self) -> Option<String>154 pub(crate) async fn chunk_text(&mut self) -> Option<String> { 155 let chunk = self.chunk().await?; 156 Some(String::from_utf8(chunk.to_vec()).unwrap()) 157 } 158 } 159