1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 #[cfg(feature = "async")]
15 mod async_utils;
16
17 #[cfg(feature = "sync")]
18 mod sync_utils;
19
20 use tokio::runtime::Runtime;
21
22 #[macro_export]
23 macro_rules! define_service_handle {
24 (
25 HTTP;
26 ) => {
27 use tokio::sync::mpsc::{Receiver, Sender};
28
29 pub struct HttpHandle {
30 pub port: u16,
31
32 // This channel allows the server to notify the client when it is up and running.
33 pub server_start: Receiver<()>,
34
35 // This channel allows the client to notify the server when it is ready to shut down.
36 pub client_shutdown: Sender<()>,
37
38 // This channel allows the server to notify the client when it has shut down.
39 pub server_shutdown: Receiver<()>,
40 }
41 };
42 (
43 HTTPS;
44 ) => {
45 pub struct TlsHandle {
46 pub port: u16,
47 }
48 };
49 }
50
51 #[macro_export]
52 macro_rules! start_server {
53 (
54 HTTPS;
55 ServerNum: $server_num: expr,
56 Runtime: $runtime: expr,
57 Handles: $handle_vec: expr,
58 ServeFnName: $service_fn: ident,
59 ) => {{
60 let key_path = std::path::PathBuf::from( "tests/file/key.pem");
61 let cert_path = std::path::PathBuf::from("tests/file/cert.pem");
62
63 for _i in 0..$server_num {
64 let (tx, rx) = std::sync::mpsc::channel();
65 let key = key_path.clone();
66 let cert = cert_path.clone();
67 let server_handle = $runtime.spawn(async move {
68 let handle = start_http_server!(
69 HTTPS;
70 $service_fn,
71 key,
72 cert
73 );
74 tx.send(handle)
75 .expect("Failed to send the handle to the test thread.");
76 });
77 $runtime
78 .block_on(server_handle)
79 .expect("Runtime start server coroutine failed");
80 let handle = rx
81 .recv()
82 .expect("Handle send channel (Server-Half) be closed unexpectedly");
83 $handle_vec.push(handle);
84 }
85 }};
86 (
87 HTTPS;
88 ServerNum: $server_num: expr,
89 Runtime: $runtime: expr,
90 Handles: $handle_vec: expr,
91 ServeFnName: $service_fn: ident,
92 ServeKeyPath: $server_key_path: ident,
93 ServeCrtPath: $server_crt_path: ident,
94 ) => {{
95 let key_path = std::path::PathBuf::from($server_key_path);
96 let cert_path = std::path::PathBuf::from($server_crt_path);
97 for _i in 0..$server_num {
98 let (tx, rx) = std::sync::mpsc::channel();
99 let key_path = key_path.clone();
100 let cert_path = cert_path.clone();
101 let server_handle = $runtime.spawn(async move {
102 let handle = start_http_server!(
103 HTTPS;
104 $service_fn,
105 key_path,
106 cert_path
107 );
108 tx.send(handle)
109 .expect("Failed to send the handle to the test thread.");
110 });
111 $runtime
112 .block_on(server_handle)
113 .expect("Runtime start server coroutine failed");
114 let handle = rx
115 .recv()
116 .expect("Handle send channel (Server-Half) be closed unexpectedly");
117 $handle_vec.push(handle);
118 }
119 }};
120 (
121 HTTP;
122 ServerNum: $server_num: expr,
123 Runtime: $runtime: expr,
124 Handles: $handle_vec: expr,
125 ServeFnName: $service_fn: ident,
126 ) => {{
127 for _i in 0..$server_num {
128 let (tx, rx) = std::sync::mpsc::channel();
129 let server_handle = $runtime.spawn(async move {
130 let mut handle = start_http_server!(
131 HTTP;
132 $service_fn
133 );
134 handle
135 .server_start
136 .recv()
137 .await
138 .expect("Start channel (Server-Half) be closed unexpectedly");
139 tx.send(handle)
140 .expect("Failed to send the handle to the test thread.");
141 });
142 $runtime
143 .block_on(server_handle)
144 .expect("Runtime start server coroutine failed");
145 let handle = rx
146 .recv()
147 .expect("Handle send channel (Server-Half) be closed unexpectedly");
148 $handle_vec.push(handle);
149 }
150 }};
151 }
152
153 #[macro_export]
154 macro_rules! start_http_server {
155 (
156 HTTP;
157 $server_fn: ident
158 ) => {{
159 use hyper::service::{make_service_fn, service_fn};
160 use std::convert::Infallible;
161 use tokio::sync::mpsc::channel;
162
163 let (start_tx, start_rx) = channel::<()>(1);
164 let (client_tx, mut client_rx) = channel::<()>(1);
165 let (server_tx, server_rx) = channel::<()>(1);
166
167 let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !");
168 let addr = tcp_listener.local_addr().expect("get server local address failed!");
169 let port = addr.port();
170 let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
171
172 tokio::spawn(async move {
173 let make_svc =
174 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
175 server
176 .serve(make_svc)
177 .with_graceful_shutdown(async {
178 start_tx
179 .send(())
180 .await
181 .expect("Start channel (Client-Half) be closed unexpectedly");
182 client_rx
183 .recv()
184 .await
185 .expect("Client channel (Client-Half) be closed unexpectedly");
186 })
187 .await
188 .expect("Start server failed");
189 server_tx
190 .send(())
191 .await
192 .expect("Server channel (Client-Half) be closed unexpectedly");
193 });
194
195 HttpHandle {
196 port,
197 server_start: start_rx,
198 client_shutdown: client_tx,
199 server_shutdown: server_rx,
200 }
201 }};
202 (
203 HTTP;
204 Ipv6;
205 $server_fn: ident
206 ) => {{
207 use hyper::service::{make_service_fn, service_fn};
208 use std::convert::Infallible;
209 use tokio::sync::mpsc::channel;
210
211 let (start_tx, start_rx) = channel::<()>(1);
212 let (client_tx, mut client_rx) = channel::<()>(1);
213 let (server_tx, server_rx) = channel::<()>(1);
214
215 let tcp_listener = std::net::TcpListener::bind("::1:0").expect("server bind port failed !");
216 let addr = tcp_listener.local_addr().expect("get server local address failed!");
217 let port = addr.port();
218 let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
219
220 tokio::spawn(async move {
221 let make_svc =
222 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
223 server
224 .serve(make_svc)
225 .with_graceful_shutdown(async {
226 start_tx
227 .send(())
228 .await
229 .expect("Start channel (Client-Half) be closed unexpectedly");
230 client_rx
231 .recv()
232 .await
233 .expect("Client channel (Client-Half) be closed unexpectedly");
234 })
235 .await
236 .expect("Start server failed");
237 server_tx
238 .send(())
239 .await
240 .expect("Server channel (Client-Half) be closed unexpectedly");
241 });
242
243 HttpHandle {
244 port,
245 server_start: start_rx,
246 client_shutdown: client_tx,
247 server_shutdown: server_rx,
248 }
249 }};
250 (
251 HTTPS;
252 $service_fn: ident,
253 $server_key_path: expr,
254 $server_cert_path: expr
255 ) => {{
256 let mut port = 10000;
257 let listener = loop {
258 let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
259 match tokio::net::TcpListener::bind(addr).await {
260 Ok(listener) => break listener,
261 Err(_) => {
262 port += 1;
263 if port == u16::MAX {
264 port = 10000;
265 }
266 continue;
267 }
268 }
269 };
270 let port = listener.local_addr().unwrap().port();
271
272 tokio::spawn(async move {
273 let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls())
274 .expect("SslAcceptorBuilder error");
275 acceptor
276 .set_session_id_context(b"test")
277 .expect("Set session id error");
278 acceptor
279 .set_private_key_file($server_key_path, openssl::ssl::SslFiletype::PEM)
280 .expect("Set private key error");
281 acceptor
282 .set_certificate_chain_file($server_cert_path)
283 .expect("Set cert error");
284 acceptor.set_alpn_protos(b"\x08http/1.1").unwrap();
285 acceptor.set_alpn_select_callback(|_, client| {
286 openssl::ssl::select_next_proto(b"\x08http/1.1", client).ok_or(openssl::ssl::AlpnError::NOACK)
287 });
288
289 let acceptor = acceptor.build();
290
291 let (stream, _) = listener.accept().await.expect("TCP listener accept error");
292 let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error");
293 let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error");
294 core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully
295
296 hyper::server::conn::Http::new()
297 .http1_only(true)
298 .http1_keep_alive(true)
299 .serve_connection(stream, hyper::service::service_fn($service_fn))
300 .await
301 });
302
303 TlsHandle {
304 port,
305 }
306 }};
307 }
308
309 /// Creates a `Request`.
310 #[macro_export]
311 #[cfg(feature = "sync")]
312 macro_rules! ylong_request {
313 (
314 Request: {
315 Method: $method: expr,
316 Host: $host: expr,
317 Port: $port: expr,
318 $(
319 Header: $req_n: expr, $req_v: expr,
320 )*
321 Body: $req_body: expr,
322 },
323 ) => {
324 ylong_http::request::RequestBuilder::new()
325 .method($method)
326 .url(format!("{}:{}", $host, $port).as_str())
327 $(.header($req_n, $req_v))*
328 .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
329 .expect("Request build failed")
330 };
331 }
332
333 /// Creates a `Request`.
334 #[macro_export]
335 #[cfg(feature = "async")]
336 macro_rules! ylong_request {
337 (
338 Request: {
339 Method: $method: expr,
340 Host: $host: expr,
341 Port: $port: expr,
342 $(
343 Header: $req_n: expr, $req_v: expr,
344 )*
345 Body: $req_body: expr,
346 },
347 ) => {
348 ylong_http_client::async_impl::RequestBuilder::new()
349 .method($method)
350 .url(format!("{}:{}", $host, $port).as_str())
351 $(.header($req_n, $req_v))*
352 .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes()))
353 .expect("Request build failed")
354 };
355 }
356
357 /// Sets server async function.
358 #[macro_export]
359 macro_rules! set_server_fn {
360 (
361 ASYNC;
362 $server_fn_name: ident,
363 $(Request: {
364 Method: $method: expr,
365 $(
366 Header: $req_n: expr, $req_v: expr,
367 )*
368 Body: $req_body: expr,
369 },
370 Response: {
371 Status: $status: expr,
372 Version: $version: expr,
373 $(
374 Header: $resp_n: expr, $resp_v: expr,
375 )*
376 Body: $resp_body: expr,
377 },)*
378 ) => {
379 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
380 match request.method().as_str() {
381 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
382 $(
383 $method => {
384 assert_eq!($method, request.method().as_str(), "Assert request method failed");
385 assert_eq!(
386 "/",
387 request.uri().to_string(),
388 "Assert request host failed",
389 );
390 assert_eq!(
391 $version,
392 format!("{:?}", request.version()),
393 "Assert request version failed",
394 );
395 $(assert_eq!(
396 $req_v,
397 request
398 .headers()
399 .get($req_n)
400 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
401 .to_str()
402 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
403 "Assert request header {} failed", $req_n,
404 );)*
405 let body = hyper::body::to_bytes(request.into_body()).await
406 .expect("Get request body failed");
407 assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
408 Ok(
409 hyper::Response::builder()
410 .version(hyper::Version::HTTP_11)
411 .status($status)
412 $(.header($resp_n, $resp_v))*
413 .body($resp_body.into())
414 .expect("Build response failed")
415 )
416 },
417 )*
418 _ => {panic!("Unrecognized METHOD !");},
419 }
420 }
421
422 };
423 (
424 SYNC;
425 $server_fn_name: ident,
426 $(Request: {
427 Method: $method: expr,
428 $(
429 Header: $req_n: expr, $req_v: expr,
430 )*
431 Body: $req_body: expr,
432 },
433 Response: {
434 Status: $status: expr,
435 Version: $version: expr,
436 $(
437 Header: $resp_n: expr, $resp_v: expr,
438 )*
439 Body: $resp_body: expr,
440 },)*
441 ) => {
442 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
443 match request.method().as_str() {
444 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
445 $(
446 $method => {
447 assert_eq!($method, request.method().as_str(), "Assert request method failed");
448 assert_eq!(
449 "/",
450 request.uri().to_string(),
451 "Assert request uri failed",
452 );
453 assert_eq!(
454 $version,
455 format!("{:?}", request.version()),
456 "Assert request version failed",
457 );
458 $(assert_eq!(
459 $req_v,
460 request
461 .headers()
462 .get($req_n)
463 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
464 .to_str()
465 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
466 "Assert request header {} failed", $req_n,
467 );)*
468 let body = hyper::body::to_bytes(request.into_body()).await
469 .expect("Get request body failed");
470 assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
471 Ok(
472 hyper::Response::builder()
473 .version(hyper::Version::HTTP_11)
474 .status($status)
475 $(.header($resp_n, $resp_v))*
476 .body($resp_body.into())
477 .expect("Build response failed")
478 )
479 },
480 )*
481 _ => {panic!("Unrecognized METHOD !");},
482 }
483 }
484
485 };
486 }
487
488 #[macro_export]
489 macro_rules! ensure_server_shutdown {
490 (ServerHandle: $handle:expr) => {
491 $handle
492 .client_shutdown
493 .send(())
494 .await
495 .expect("Client channel (Server-Half) be closed unexpectedly");
496 $handle
497 .server_shutdown
498 .recv()
499 .await
500 .expect("Server channel (Server-Half) be closed unexpectedly");
501 };
502 }
503
init_test_work_runtime(thread_num: usize) -> Runtime504 pub fn init_test_work_runtime(thread_num: usize) -> Runtime {
505 tokio::runtime::Builder::new_multi_thread()
506 .worker_threads(thread_num)
507 .enable_all()
508 .build()
509 .expect("Build runtime failed.")
510 }
511