1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 #[cfg(feature = "async")]
15 mod async_utils;
16
17 #[cfg(feature = "sync")]
18 mod sync_utils;
19
20 use tokio::runtime::Runtime;
21
22 macro_rules! define_service_handle {
23 (
24 HTTP;
25 ) => {
26 use tokio::sync::mpsc::{Receiver, Sender};
27
28 pub struct HttpHandle {
29 pub port: u16,
30
31 // This channel allows the server to notify the client when it is up and running.
32 pub server_start: Receiver<()>,
33
34 // This channel allows the client to notify the server when it is ready to shut down.
35 pub client_shutdown: Sender<()>,
36
37 // This channel allows the server to notify the client when it has shut down.
38 pub server_shutdown: Receiver<()>,
39 }
40 };
41 (
42 HTTPS;
43 ) => {
44 pub struct TlsHandle {
45 pub port: u16,
46 }
47 };
48 }
49
50 #[macro_export]
51 macro_rules! start_server {
52 (
53 HTTPS;
54 ServerNum: $server_num: expr,
55 Runtime: $runtime: expr,
56 Handles: $handle_vec: expr,
57 ServeFnName: $service_fn: ident,
58 ) => {{
59 for _i in 0..$server_num {
60 let (tx, rx) = std::sync::mpsc::channel();
61 let server_handle = $runtime.spawn(async move {
62 let handle = start_http_server!(
63 HTTPS;
64 $service_fn
65 );
66 tx.send(handle)
67 .expect("Failed to send the handle to the test thread.");
68 });
69 $runtime
70 .block_on(server_handle)
71 .expect("Runtime start server coroutine failed");
72 let handle = rx
73 .recv()
74 .expect("Handle send channel (Server-Half) be closed unexpectedly");
75 $handle_vec.push(handle);
76 }
77 }};
78 (
79 HTTP;
80 ServerNum: $server_num: expr,
81 Runtime: $runtime: expr,
82 Handles: $handle_vec: expr,
83 ServeFnName: $service_fn: ident,
84 ) => {{
85 for _i in 0..$server_num {
86 let (tx, rx) = std::sync::mpsc::channel();
87 let server_handle = $runtime.spawn(async move {
88 let mut handle = start_http_server!(
89 HTTP;
90 $service_fn
91 );
92 handle
93 .server_start
94 .recv()
95 .await
96 .expect("Start channel (Server-Half) be closed unexpectedly");
97 tx.send(handle)
98 .expect("Failed to send the handle to the test thread.");
99 });
100 $runtime
101 .block_on(server_handle)
102 .expect("Runtime start server coroutine failed");
103 let handle = rx
104 .recv()
105 .expect("Handle send channel (Server-Half) be closed unexpectedly");
106 $handle_vec.push(handle);
107 }
108 }};
109 }
110
111 #[macro_export]
112 macro_rules! start_http_server {
113 (
114 HTTP;
115 $server_fn: ident
116 ) => {{
117 use hyper::service::{make_service_fn, service_fn};
118 use std::convert::Infallible;
119 use tokio::sync::mpsc::channel;
120
121 let (start_tx, start_rx) = channel::<()>(1);
122 let (client_tx, mut client_rx) = channel::<()>(1);
123 let (server_tx, server_rx) = channel::<()>(1);
124
125 let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !");
126 let addr = tcp_listener.local_addr().expect("get server local address failed!");
127 let port = addr.port();
128 let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
129
130 tokio::spawn(async move {
131 let make_svc =
132 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
133 server
134 .serve(make_svc)
135 .with_graceful_shutdown(async {
136 start_tx
137 .send(())
138 .await
139 .expect("Start channel (Client-Half) be closed unexpectedly");
140 client_rx
141 .recv()
142 .await
143 .expect("Client channel (Client-Half) be closed unexpectedly");
144 })
145 .await
146 .expect("Start server failed");
147 server_tx
148 .send(())
149 .await
150 .expect("Server channel (Client-Half) be closed unexpectedly");
151 });
152
153 HttpHandle {
154 port,
155 server_start: start_rx,
156 client_shutdown: client_tx,
157 server_shutdown: server_rx,
158 }
159 }};
160 (
161 HTTPS;
162 $service_fn: ident
163 ) => {{
164 let mut port = 10000;
165 let listener = loop {
166 let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
167 match tokio::net::TcpListener::bind(addr).await {
168 Ok(listener) => break listener,
169 Err(_) => {
170 port += 1;
171 if port == u16::MAX {
172 port = 10000;
173 }
174 continue;
175 }
176 }
177 };
178 let port = listener.local_addr().unwrap().port();
179
180 tokio::spawn(async move {
181 let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls())
182 .expect("SslAcceptorBuilder error");
183 acceptor
184 .set_session_id_context(b"test")
185 .expect("Set session id error");
186 acceptor
187 .set_private_key_file("tests/file/key.pem", openssl::ssl::SslFiletype::PEM)
188 .expect("Set private key error");
189 acceptor
190 .set_certificate_chain_file("tests/file/cert.pem")
191 .expect("Set cert error");
192 let acceptor = acceptor.build();
193
194 let (stream, _) = listener.accept().await.expect("TCP listener accpet error");
195 let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error");
196 let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error");
197 core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully
198
199 hyper::server::conn::Http::new()
200 .http1_only(true)
201 .http1_keep_alive(true)
202 .serve_connection(stream, hyper::service::service_fn($service_fn))
203 .await
204 });
205
206 TlsHandle {
207 port,
208 }
209 }};
210 }
211
212 /// Creates a `Request`.
213 #[macro_export]
214 macro_rules! ylong_request {
215 (
216 Request: {
217 Method: $method: expr,
218 Host: $host: expr,
219 Port: $port: expr,
220 $(
221 Header: $req_n: expr, $req_v: expr,
222 )*
223 Body: $req_body: expr,
224 },
225 ) => {
226 ylong_http::request::RequestBuilder::new()
227 .method($method)
228 .url(format!("{}:{}", $host, $port).as_str())
229 $(.header($req_n, $req_v))*
230 .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
231 .expect("Request build failed")
232 };
233 }
234
235 /// Sets server async function.
236 #[macro_export]
237 macro_rules! set_server_fn {
238 (
239 ASYNC;
240 $server_fn_name: ident,
241 $(Request: {
242 Method: $method: expr,
243 $(
244 Header: $req_n: expr, $req_v: expr,
245 )*
246 Body: $req_body: expr,
247 },
248 Response: {
249 Status: $status: expr,
250 Version: $version: expr,
251 $(
252 Header: $resp_n: expr, $resp_v: expr,
253 )*
254 Body: $resp_body: expr,
255 },)*
256 ) => {
257 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
258 match request.method().as_str() {
259 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
260 $(
261 $method => {
262 assert_eq!($method, request.method().as_str(), "Assert request method failed");
263
264 assert_eq!(
265 "/",
266 request.uri().to_string(),
267 "Assert request host failed",
268 );
269 assert_eq!(
270 $version,
271 format!("{:?}", request.version()),
272 "Assert request version failed",
273 );
274 $(assert_eq!(
275 $req_v,
276 request
277 .headers()
278 .get($req_n)
279 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
280 .to_str()
281 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
282 "Assert request header {} failed", $req_n,
283 );)*
284 let body = hyper::body::to_bytes(request.into_body()).await
285 .expect("Get request body failed");
286 assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
287 Ok(
288 hyper::Response::builder()
289 .version(hyper::Version::HTTP_11)
290 .status($status)
291 $(.header($resp_n, $resp_v))*
292 .body($resp_body.into())
293 .expect("Build response failed")
294 )
295 },
296 )*
297 _ => {panic!("Unrecognized METHOD !");},
298 }
299 }
300
301 };
302 (
303 SYNC;
304 $server_fn_name: ident,
305 $(Request: {
306 Method: $method: expr,
307 Host: $host: expr,
308 $(
309 Header: $req_n: expr, $req_v: expr,
310 )*
311 Body: $req_body: expr,
312 },
313 Response: {
314 Status: $status: expr,
315 Version: $version: expr,
316 $(
317 Header: $resp_n: expr, $resp_v: expr,
318 )*
319 Body: $resp_body: expr,
320 },)*
321 ) => {
322 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
323 match request.method().as_str() {
324 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
325 $(
326 $method => {
327 assert_eq!($method, request.method().as_str(), "Assert request method failed");
328
329 assert_eq!(
330 $host,
331 format!("{}://{}", request.uri().scheme().expect("assert uri scheme failed !").as_str(), request.uri().host().expect("assert uri host failed !")),
332 "Assert request host failed",
333 );
334 assert_eq!(
335 $version,
336 format!("{:?}", request.version()),
337 "Assert request version failed",
338 );
339 $(assert_eq!(
340 $req_v,
341 request
342 .headers()
343 .get($req_n)
344 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
345 .to_str()
346 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
347 "Assert request header {} failed", $req_n,
348 );)*
349 let body = hyper::body::to_bytes(request.into_body()).await
350 .expect("Get request body failed");
351 assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
352 Ok(
353 hyper::Response::builder()
354 .version(hyper::Version::HTTP_11)
355 .status($status)
356 $(.header($resp_n, $resp_v))*
357 .body($resp_body.into())
358 .expect("Build response failed")
359 )
360 },
361 )*
362 _ => {panic!("Unrecognized METHOD !");},
363 }
364 }
365
366 };
367 }
368
369 #[macro_export]
370 macro_rules! ensure_server_shutdown {
371 (ServerHandle: $handle:expr) => {
372 $handle
373 .client_shutdown
374 .send(())
375 .await
376 .expect("Client channel (Server-Half) be closed unexpectedly");
377 $handle
378 .server_shutdown
379 .recv()
380 .await
381 .expect("Server channel (Server-Half) be closed unexpectedly");
382 };
383 }
384
init_test_work_runtime(thread_num: usize) -> Runtime385 pub fn init_test_work_runtime(thread_num: usize) -> Runtime {
386 tokio::runtime::Builder::new_multi_thread()
387 .worker_threads(thread_num)
388 .enable_all()
389 .build()
390 .expect("Build runtime failed.")
391 }
392