• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #[cfg(feature = "async")]
15 mod async_utils;
16 
17 #[cfg(feature = "sync")]
18 mod sync_utils;
19 
20 use tokio::runtime::Runtime;
21 
22 #[macro_export]
23 macro_rules! define_service_handle {
24     (
25         HTTP;
26     ) => {
27         use tokio::sync::mpsc::{Receiver, Sender};
28 
29         pub struct HttpHandle {
30             pub port: u16,
31 
32             // This channel allows the server to notify the client when it is up and running.
33             pub server_start: Receiver<()>,
34 
35             // This channel allows the client to notify the server when it is ready to shut down.
36             pub client_shutdown: Sender<()>,
37 
38             // This channel allows the server to notify the client when it has shut down.
39             pub server_shutdown: Receiver<()>,
40         }
41     };
42     (
43         HTTPS;
44     ) => {
45         pub struct TlsHandle {
46             pub port: u16,
47         }
48     };
49 }
50 
51 #[macro_export]
52 macro_rules! start_server {
53     (
54         HTTPS;
55         ServerNum: $server_num: expr,
56         Runtime: $runtime: expr,
57         Handles: $handle_vec: expr,
58         ServeFnName: $service_fn: ident,
59     ) => {{
60         let key_path = std::path::PathBuf::from( "tests/file/key.pem");
61         let cert_path = std::path::PathBuf::from("tests/file/cert.pem");
62 
63         for _i in 0..$server_num {
64             let (tx, rx) = std::sync::mpsc::channel();
65             let key = key_path.clone();
66             let cert = cert_path.clone();
67             let server_handle = $runtime.spawn(async move {
68                 let handle = start_http_server!(
69                     HTTPS;
70                     $service_fn,
71                     key,
72                     cert
73                 );
74                 tx.send(handle)
75                     .expect("Failed to send the handle to the test thread.");
76             });
77             $runtime
78                 .block_on(server_handle)
79                 .expect("Runtime start server coroutine failed");
80             let handle = rx
81                 .recv()
82                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
83             $handle_vec.push(handle);
84         }
85     }};
86     (
87         HTTPS;
88         ServerNum: $server_num: expr,
89         Runtime: $runtime: expr,
90         Handles: $handle_vec: expr,
91         ServeFnName: $service_fn: ident,
92         ServeKeyPath: $server_key_path: ident,
93         ServeCrtPath: $server_crt_path: ident,
94     ) => {{
95         let key_path = std::path::PathBuf::from($server_key_path);
96         let cert_path = std::path::PathBuf::from($server_crt_path);
97         for _i in 0..$server_num {
98             let (tx, rx) = std::sync::mpsc::channel();
99             let key_path = key_path.clone();
100             let cert_path = cert_path.clone();
101             let server_handle = $runtime.spawn(async move {
102                 let handle = start_http_server!(
103                     HTTPS;
104                     $service_fn,
105                     key_path,
106                     cert_path
107                 );
108                 tx.send(handle)
109                     .expect("Failed to send the handle to the test thread.");
110             });
111             $runtime
112                 .block_on(server_handle)
113                 .expect("Runtime start server coroutine failed");
114             let handle = rx
115                 .recv()
116                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
117             $handle_vec.push(handle);
118         }
119     }};
120     (
121         HTTP;
122         ServerNum: $server_num: expr,
123         Runtime: $runtime: expr,
124         Handles: $handle_vec: expr,
125         ServeFnName: $service_fn: ident,
126     ) => {{
127         for _i in 0..$server_num {
128             let (tx, rx) = std::sync::mpsc::channel();
129             let server_handle = $runtime.spawn(async move {
130                 let mut handle = start_http_server!(
131                     HTTP;
132                     $service_fn
133                 );
134                 handle
135                     .server_start
136                     .recv()
137                     .await
138                     .expect("Start channel (Server-Half) be closed unexpectedly");
139                 tx.send(handle)
140                     .expect("Failed to send the handle to the test thread.");
141             });
142             $runtime
143                 .block_on(server_handle)
144                 .expect("Runtime start server coroutine failed");
145             let handle = rx
146                 .recv()
147                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
148             $handle_vec.push(handle);
149         }
150     }};
151 }
152 
153 #[macro_export]
154 macro_rules! start_http_server {
155     (
156         HTTP;
157         $server_fn: ident
158     ) => {{
159         use hyper::service::{make_service_fn, service_fn};
160         use std::convert::Infallible;
161         use tokio::sync::mpsc::channel;
162 
163         let (start_tx, start_rx) = channel::<()>(1);
164         let (client_tx, mut client_rx) = channel::<()>(1);
165         let (server_tx, server_rx) = channel::<()>(1);
166 
167         let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !");
168         let addr = tcp_listener.local_addr().expect("get server local address failed!");
169         let port = addr.port();
170         let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
171 
172         tokio::spawn(async move {
173             let make_svc =
174                 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
175             server
176                 .serve(make_svc)
177                 .with_graceful_shutdown(async {
178                     start_tx
179                         .send(())
180                         .await
181                         .expect("Start channel (Client-Half) be closed unexpectedly");
182                     client_rx
183                         .recv()
184                         .await
185                         .expect("Client channel (Client-Half) be closed unexpectedly");
186                 })
187                 .await
188                 .expect("Start server failed");
189             server_tx
190                 .send(())
191                 .await
192                 .expect("Server channel (Client-Half) be closed unexpectedly");
193         });
194 
195         HttpHandle {
196             port,
197             server_start: start_rx,
198             client_shutdown: client_tx,
199             server_shutdown: server_rx,
200         }
201     }};
202     (
203         HTTP;
204         Ipv6;
205         $server_fn: ident
206     ) => {{
207         use hyper::service::{make_service_fn, service_fn};
208         use std::convert::Infallible;
209         use tokio::sync::mpsc::channel;
210 
211         let (start_tx, start_rx) = channel::<()>(1);
212         let (client_tx, mut client_rx) = channel::<()>(1);
213         let (server_tx, server_rx) = channel::<()>(1);
214 
215         let tcp_listener = std::net::TcpListener::bind("::1:0").expect("server bind port failed !");
216         let addr = tcp_listener.local_addr().expect("get server local address failed!");
217         let port = addr.port();
218         let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
219 
220         tokio::spawn(async move {
221             let make_svc =
222                 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
223             server
224                 .serve(make_svc)
225                 .with_graceful_shutdown(async {
226                     start_tx
227                         .send(())
228                         .await
229                         .expect("Start channel (Client-Half) be closed unexpectedly");
230                     client_rx
231                         .recv()
232                         .await
233                         .expect("Client channel (Client-Half) be closed unexpectedly");
234                 })
235                 .await
236                 .expect("Start server failed");
237             server_tx
238                 .send(())
239                 .await
240                 .expect("Server channel (Client-Half) be closed unexpectedly");
241         });
242 
243         HttpHandle {
244             port,
245             server_start: start_rx,
246             client_shutdown: client_tx,
247             server_shutdown: server_rx,
248         }
249     }};
250     (
251         HTTPS;
252         $service_fn: ident,
253         $server_key_path: expr,
254         $server_cert_path: expr
255     ) => {{
256         let mut port = 10000;
257         let listener = loop {
258             let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
259             match tokio::net::TcpListener::bind(addr).await {
260                 Ok(listener) => break listener,
261                 Err(_) => {
262                     port += 1;
263                     if port == u16::MAX {
264                         port = 10000;
265                     }
266                     continue;
267                 }
268             }
269         };
270         let port = listener.local_addr().unwrap().port();
271 
272         tokio::spawn(async move {
273             let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls())
274                 .expect("SslAcceptorBuilder error");
275             acceptor
276                 .set_session_id_context(b"test")
277                 .expect("Set session id error");
278             acceptor
279                 .set_private_key_file($server_key_path, openssl::ssl::SslFiletype::PEM)
280                 .expect("Set private key error");
281             acceptor
282                 .set_certificate_chain_file($server_cert_path)
283                 .expect("Set cert error");
284             acceptor.set_alpn_protos(b"\x08http/1.1").unwrap();
285             acceptor.set_alpn_select_callback(|_, client| {
286                 openssl::ssl::select_next_proto(b"\x08http/1.1", client).ok_or(openssl::ssl::AlpnError::NOACK)
287             });
288 
289             let acceptor = acceptor.build();
290 
291             let (stream, _) = listener.accept().await.expect("TCP listener accept error");
292             let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error");
293             let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error");
294             core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully
295 
296             hyper::server::conn::Http::new()
297                 .http1_only(true)
298                 .http1_keep_alive(true)
299                 .serve_connection(stream, hyper::service::service_fn($service_fn))
300                 .await
301         });
302 
303         TlsHandle {
304             port,
305         }
306     }};
307 }
308 
309 /// Creates a `Request`.
310 #[macro_export]
311 #[cfg(feature = "sync")]
312 macro_rules! ylong_request {
313     (
314         Request: {
315             Method: $method: expr,
316             Host: $host: expr,
317             Port: $port: expr,
318             $(
319                 Header: $req_n: expr, $req_v: expr,
320             )*
321             Body: $req_body: expr,
322         },
323     ) => {
324         ylong_http::request::RequestBuilder::new()
325             .method($method)
326             .url(format!("{}:{}", $host, $port).as_str())
327             $(.header($req_n, $req_v))*
328             .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
329             .expect("Request build failed")
330     };
331 }
332 
333 /// Creates a `Request`.
334 #[macro_export]
335 #[cfg(feature = "async")]
336 macro_rules! ylong_request {
337     (
338         Request: {
339             Method: $method: expr,
340             Host: $host: expr,
341             Port: $port: expr,
342             $(
343                 Header: $req_n: expr, $req_v: expr,
344             )*
345             Body: $req_body: expr,
346         },
347     ) => {
348         ylong_http_client::async_impl::RequestBuilder::new()
349              .method($method)
350              .url(format!("{}:{}", $host, $port).as_str())
351              $(.header($req_n, $req_v))*
352              .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes()))
353              .expect("Request build failed")
354     };
355 }
356 
357 /// Sets server async function.
358 #[macro_export]
359 macro_rules! set_server_fn {
360     (
361         ASYNC;
362         $server_fn_name: ident,
363         $(Request: {
364             Method: $method: expr,
365             $(
366                 Header: $req_n: expr, $req_v: expr,
367             )*
368             Body: $req_body: expr,
369         },
370         Response: {
371             Status: $status: expr,
372             Version: $version: expr,
373             $(
374                 Header: $resp_n: expr, $resp_v: expr,
375             )*
376             Body: $resp_body: expr,
377         },)*
378     ) => {
379         async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
380             match request.method().as_str() {
381                 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
382                 $(
383                     $method => {
384                         assert_eq!($method, request.method().as_str(), "Assert request method failed");
385                         assert_eq!(
386                             "/",
387                             request.uri().to_string(),
388                             "Assert request host failed",
389                         );
390                         assert_eq!(
391                             $version,
392                             format!("{:?}", request.version()),
393                             "Assert request version failed",
394                         );
395                         $(assert_eq!(
396                             $req_v,
397                             request
398                                 .headers()
399                                 .get($req_n)
400                                 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
401                                 .to_str()
402                                 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
403                             "Assert request header {} failed", $req_n,
404                         );)*
405                         let body = hyper::body::to_bytes(request.into_body()).await
406                             .expect("Get request body failed");
407                         assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
408                         Ok(
409                             hyper::Response::builder()
410                                 .version(hyper::Version::HTTP_11)
411                                 .status($status)
412                                 $(.header($resp_n, $resp_v))*
413                                 .body($resp_body.into())
414                                 .expect("Build response failed")
415                         )
416                     },
417                 )*
418                 _ => {panic!("Unrecognized METHOD !");},
419             }
420         }
421 
422     };
423     (
424         SYNC;
425         $server_fn_name: ident,
426         $(Request: {
427             Method: $method: expr,
428             $(
429                 Header: $req_n: expr, $req_v: expr,
430             )*
431             Body: $req_body: expr,
432         },
433         Response: {
434             Status: $status: expr,
435             Version: $version: expr,
436             $(
437                 Header: $resp_n: expr, $resp_v: expr,
438             )*
439             Body: $resp_body: expr,
440         },)*
441     ) => {
442         async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
443             match request.method().as_str() {
444                 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
445                 $(
446                     $method => {
447                         assert_eq!($method, request.method().as_str(), "Assert request method failed");
448                         assert_eq!(
449                             "/",
450                             request.uri().to_string(),
451                             "Assert request uri failed",
452                         );
453                         assert_eq!(
454                             $version,
455                             format!("{:?}", request.version()),
456                             "Assert request version failed",
457                         );
458                         $(assert_eq!(
459                             $req_v,
460                             request
461                                 .headers()
462                                 .get($req_n)
463                                 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
464                                 .to_str()
465                                 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
466                             "Assert request header {} failed", $req_n,
467                         );)*
468                         let body = hyper::body::to_bytes(request.into_body()).await
469                             .expect("Get request body failed");
470                         assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
471                         Ok(
472                             hyper::Response::builder()
473                                 .version(hyper::Version::HTTP_11)
474                                 .status($status)
475                                 $(.header($resp_n, $resp_v))*
476                                 .body($resp_body.into())
477                                 .expect("Build response failed")
478                         )
479                     },
480                 )*
481                 _ => {panic!("Unrecognized METHOD !");},
482             }
483         }
484 
485     };
486 }
487 
488 #[macro_export]
489 macro_rules! ensure_server_shutdown {
490     (ServerHandle: $handle:expr) => {
491         $handle
492             .client_shutdown
493             .send(())
494             .await
495             .expect("Client channel (Server-Half) be closed unexpectedly");
496         $handle
497             .server_shutdown
498             .recv()
499             .await
500             .expect("Server channel (Server-Half) be closed unexpectedly");
501     };
502 }
503 
init_test_work_runtime(thread_num: usize) -> Runtime504 pub fn init_test_work_runtime(thread_num: usize) -> Runtime {
505     tokio::runtime::Builder::new_multi_thread()
506         .worker_threads(thread_num)
507         .enable_all()
508         .build()
509         .expect("Build runtime failed.")
510 }
511