• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use super::{BoxError, HttpBody};
2 use bytes::Bytes;
3 use http::{
4     header::{HeaderName, HeaderValue},
5     Request, StatusCode,
6 };
7 use hyper::{Body, Server};
8 use std::net::{SocketAddr, TcpListener};
9 use tower::make::Shared;
10 use tower_service::Service;
11 
12 pub(crate) struct TestClient {
13     client: reqwest::Client,
14     addr: SocketAddr,
15 }
16 
17 impl TestClient {
new<S, ResBody>(svc: S) -> Self where S: Service<Request<Body>, Response = http::Response<ResBody>> + Clone + Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Data: Send, ResBody::Error: Into<BoxError>, S::Future: Send, S::Error: Into<BoxError>,18     pub(crate) fn new<S, ResBody>(svc: S) -> Self
19     where
20         S: Service<Request<Body>, Response = http::Response<ResBody>> + Clone + Send + 'static,
21         ResBody: HttpBody + Send + 'static,
22         ResBody::Data: Send,
23         ResBody::Error: Into<BoxError>,
24         S::Future: Send,
25         S::Error: Into<BoxError>,
26     {
27         let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket");
28         let addr = listener.local_addr().unwrap();
29         println!("Listening on {addr}");
30 
31         tokio::spawn(async move {
32             let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc));
33             server.await.expect("server error");
34         });
35 
36         let client = reqwest::Client::builder()
37             .redirect(reqwest::redirect::Policy::none())
38             .build()
39             .unwrap();
40 
41         TestClient { client, addr }
42     }
43 
get(&self, url: &str) -> RequestBuilder44     pub(crate) fn get(&self, url: &str) -> RequestBuilder {
45         RequestBuilder {
46             builder: self.client.get(format!("http://{}{}", self.addr, url)),
47         }
48     }
49 
head(&self, url: &str) -> RequestBuilder50     pub(crate) fn head(&self, url: &str) -> RequestBuilder {
51         RequestBuilder {
52             builder: self.client.head(format!("http://{}{}", self.addr, url)),
53         }
54     }
55 
post(&self, url: &str) -> RequestBuilder56     pub(crate) fn post(&self, url: &str) -> RequestBuilder {
57         RequestBuilder {
58             builder: self.client.post(format!("http://{}{}", self.addr, url)),
59         }
60     }
61 
62     #[allow(dead_code)]
put(&self, url: &str) -> RequestBuilder63     pub(crate) fn put(&self, url: &str) -> RequestBuilder {
64         RequestBuilder {
65             builder: self.client.put(format!("http://{}{}", self.addr, url)),
66         }
67     }
68 
69     #[allow(dead_code)]
patch(&self, url: &str) -> RequestBuilder70     pub(crate) fn patch(&self, url: &str) -> RequestBuilder {
71         RequestBuilder {
72             builder: self.client.patch(format!("http://{}{}", self.addr, url)),
73         }
74     }
75 }
76 
77 pub(crate) struct RequestBuilder {
78     builder: reqwest::RequestBuilder,
79 }
80 
81 impl RequestBuilder {
send(self) -> TestResponse82     pub(crate) async fn send(self) -> TestResponse {
83         TestResponse {
84             response: self.builder.send().await.unwrap(),
85         }
86     }
87 
body(mut self, body: impl Into<reqwest::Body>) -> Self88     pub(crate) fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
89         self.builder = self.builder.body(body);
90         self
91     }
92 
json<T>(mut self, json: &T) -> Self where T: serde::Serialize,93     pub(crate) fn json<T>(mut self, json: &T) -> Self
94     where
95         T: serde::Serialize,
96     {
97         self.builder = self.builder.json(json);
98         self
99     }
100 
header<K, V>(mut self, key: K, value: V) -> Self where HeaderName: TryFrom<K>, <HeaderName as TryFrom<K>>::Error: Into<http::Error>, HeaderValue: TryFrom<V>, <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,101     pub(crate) fn header<K, V>(mut self, key: K, value: V) -> Self
102     where
103         HeaderName: TryFrom<K>,
104         <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
105         HeaderValue: TryFrom<V>,
106         <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
107     {
108         self.builder = self.builder.header(key, value);
109         self
110     }
111 
112     #[allow(dead_code)]
multipart(mut self, form: reqwest::multipart::Form) -> Self113     pub(crate) fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
114         self.builder = self.builder.multipart(form);
115         self
116     }
117 }
118 
119 #[derive(Debug)]
120 pub(crate) struct TestResponse {
121     response: reqwest::Response,
122 }
123 
124 impl TestResponse {
125     #[allow(dead_code)]
bytes(self) -> Bytes126     pub(crate) async fn bytes(self) -> Bytes {
127         self.response.bytes().await.unwrap()
128     }
129 
text(self) -> String130     pub(crate) async fn text(self) -> String {
131         self.response.text().await.unwrap()
132     }
133 
134     #[allow(dead_code)]
json<T>(self) -> T where T: serde::de::DeserializeOwned,135     pub(crate) async fn json<T>(self) -> T
136     where
137         T: serde::de::DeserializeOwned,
138     {
139         self.response.json().await.unwrap()
140     }
141 
status(&self) -> StatusCode142     pub(crate) fn status(&self) -> StatusCode {
143         self.response.status()
144     }
145 
headers(&self) -> &http::HeaderMap146     pub(crate) fn headers(&self) -> &http::HeaderMap {
147         self.response.headers()
148     }
149 
chunk(&mut self) -> Option<Bytes>150     pub(crate) async fn chunk(&mut self) -> Option<Bytes> {
151         self.response.chunk().await.unwrap()
152     }
153 
chunk_text(&mut self) -> Option<String>154     pub(crate) async fn chunk_text(&mut self) -> Option<String> {
155         let chunk = self.chunk().await?;
156         Some(String::from_utf8(chunk.to_vec()).unwrap())
157     }
158 }
159