• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #[cfg(feature = "async")]
15 mod async_utils;
16 
17 #[cfg(feature = "sync")]
18 mod sync_utils;
19 
20 use tokio::runtime::Runtime;
21 
22 macro_rules! define_service_handle {
23     (
24         HTTP;
25     ) => {
26         use tokio::sync::mpsc::{Receiver, Sender};
27 
28         pub struct HttpHandle {
29             pub port: u16,
30 
31             // This channel allows the server to notify the client when it is up and running.
32             pub server_start: Receiver<()>,
33 
34             // This channel allows the client to notify the server when it is ready to shut down.
35             pub client_shutdown: Sender<()>,
36 
37             // This channel allows the server to notify the client when it has shut down.
38             pub server_shutdown: Receiver<()>,
39         }
40     };
41     (
42         HTTPS;
43     ) => {
44         pub struct TlsHandle {
45             pub port: u16,
46         }
47     };
48 }
49 
50 #[macro_export]
51 macro_rules! start_server {
52     (
53         HTTPS;
54         ServerNum: $server_num: expr,
55         Runtime: $runtime: expr,
56         Handles: $handle_vec: expr,
57         ServeFnName: $service_fn: ident,
58     ) => {{
59         for _i in 0..$server_num {
60             let (tx, rx) = std::sync::mpsc::channel();
61             let server_handle = $runtime.spawn(async move {
62                 let handle = start_http_server!(
63                     HTTPS;
64                     $service_fn
65                 );
66                 tx.send(handle)
67                     .expect("Failed to send the handle to the test thread.");
68             });
69             $runtime
70                 .block_on(server_handle)
71                 .expect("Runtime start server coroutine failed");
72             let handle = rx
73                 .recv()
74                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
75             $handle_vec.push(handle);
76         }
77     }};
78     (
79         HTTP;
80         ServerNum: $server_num: expr,
81         Runtime: $runtime: expr,
82         Handles: $handle_vec: expr,
83         ServeFnName: $service_fn: ident,
84     ) => {{
85         for _i in 0..$server_num {
86             let (tx, rx) = std::sync::mpsc::channel();
87             let server_handle = $runtime.spawn(async move {
88                 let mut handle = start_http_server!(
89                     HTTP;
90                     $service_fn
91                 );
92                 handle
93                     .server_start
94                     .recv()
95                     .await
96                     .expect("Start channel (Server-Half) be closed unexpectedly");
97                 tx.send(handle)
98                     .expect("Failed to send the handle to the test thread.");
99             });
100             $runtime
101                 .block_on(server_handle)
102                 .expect("Runtime start server coroutine failed");
103             let handle = rx
104                 .recv()
105                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
106             $handle_vec.push(handle);
107         }
108     }};
109 }
110 
111 #[macro_export]
112 macro_rules! start_http_server {
113     (
114         HTTP;
115         $server_fn: ident
116     ) => {{
117         use hyper::service::{make_service_fn, service_fn};
118         use std::convert::Infallible;
119         use tokio::sync::mpsc::channel;
120 
121         let (start_tx, start_rx) = channel::<()>(1);
122         let (client_tx, mut client_rx) = channel::<()>(1);
123         let (server_tx, server_rx) = channel::<()>(1);
124 
125         let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !");
126         let addr = tcp_listener.local_addr().expect("get server local address failed!");
127         let port = addr.port();
128         let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
129 
130         tokio::spawn(async move {
131             let make_svc =
132                 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
133             server
134                 .serve(make_svc)
135                 .with_graceful_shutdown(async {
136                     start_tx
137                         .send(())
138                         .await
139                         .expect("Start channel (Client-Half) be closed unexpectedly");
140                     client_rx
141                         .recv()
142                         .await
143                         .expect("Client channel (Client-Half) be closed unexpectedly");
144                 })
145                 .await
146                 .expect("Start server failed");
147             server_tx
148                 .send(())
149                 .await
150                 .expect("Server channel (Client-Half) be closed unexpectedly");
151         });
152 
153         HttpHandle {
154             port,
155             server_start: start_rx,
156             client_shutdown: client_tx,
157             server_shutdown: server_rx,
158         }
159     }};
160     (
161         HTTPS;
162         $service_fn: ident
163     ) => {{
164         let mut port = 10000;
165         let listener = loop {
166             let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
167             match tokio::net::TcpListener::bind(addr).await {
168                 Ok(listener) => break listener,
169                 Err(_) => {
170                     port += 1;
171                     if port == u16::MAX {
172                         port = 10000;
173                     }
174                     continue;
175                 }
176             }
177         };
178         let port = listener.local_addr().unwrap().port();
179 
180         tokio::spawn(async move {
181             let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls())
182                 .expect("SslAcceptorBuilder error");
183             acceptor
184                 .set_session_id_context(b"test")
185                 .expect("Set session id error");
186             acceptor
187                 .set_private_key_file("tests/file/key.pem", openssl::ssl::SslFiletype::PEM)
188                 .expect("Set private key error");
189             acceptor
190                 .set_certificate_chain_file("tests/file/cert.pem")
191                 .expect("Set cert error");
192             let acceptor = acceptor.build();
193 
194             let (stream, _) = listener.accept().await.expect("TCP listener accpet error");
195             let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error");
196             let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error");
197             core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully
198 
199             hyper::server::conn::Http::new()
200                 .http1_only(true)
201                 .http1_keep_alive(true)
202                 .serve_connection(stream, hyper::service::service_fn($service_fn))
203                 .await
204         });
205 
206         TlsHandle {
207             port,
208         }
209     }};
210 }
211 
212 /// Creates a `Request`.
213 #[macro_export]
214 macro_rules! ylong_request {
215     (
216         Request: {
217             Method: $method: expr,
218             Host: $host: expr,
219             Port: $port: expr,
220             $(
221                 Header: $req_n: expr, $req_v: expr,
222             )*
223             Body: $req_body: expr,
224         },
225     ) => {
226         ylong_http::request::RequestBuilder::new()
227             .method($method)
228             .url(format!("{}:{}", $host, $port).as_str())
229             $(.header($req_n, $req_v))*
230             .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
231             .expect("Request build failed")
232     };
233 }
234 
235 /// Sets server async function.
236 #[macro_export]
237 macro_rules! set_server_fn {
238     (
239         ASYNC;
240         $server_fn_name: ident,
241         $(Request: {
242             Method: $method: expr,
243             $(
244                 Header: $req_n: expr, $req_v: expr,
245             )*
246             Body: $req_body: expr,
247         },
248         Response: {
249             Status: $status: expr,
250             Version: $version: expr,
251             $(
252                 Header: $resp_n: expr, $resp_v: expr,
253             )*
254             Body: $resp_body: expr,
255         },)*
256     ) => {
257         async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
258             match request.method().as_str() {
259                 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
260                 $(
261                     $method => {
262                         assert_eq!($method, request.method().as_str(), "Assert request method failed");
263 
264                         assert_eq!(
265                             "/",
266                             request.uri().to_string(),
267                             "Assert request host failed",
268                         );
269                         assert_eq!(
270                             $version,
271                             format!("{:?}", request.version()),
272                             "Assert request version failed",
273                         );
274                         $(assert_eq!(
275                             $req_v,
276                             request
277                                 .headers()
278                                 .get($req_n)
279                                 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
280                                 .to_str()
281                                 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
282                             "Assert request header {} failed", $req_n,
283                         );)*
284                         let body = hyper::body::to_bytes(request.into_body()).await
285                             .expect("Get request body failed");
286                         assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
287                         Ok(
288                             hyper::Response::builder()
289                                 .version(hyper::Version::HTTP_11)
290                                 .status($status)
291                                 $(.header($resp_n, $resp_v))*
292                                 .body($resp_body.into())
293                                 .expect("Build response failed")
294                         )
295                     },
296                 )*
297                 _ => {panic!("Unrecognized METHOD !");},
298             }
299         }
300 
301     };
302     (
303         SYNC;
304         $server_fn_name: ident,
305         $(Request: {
306             Method: $method: expr,
307             Host: $host: expr,
308             $(
309                 Header: $req_n: expr, $req_v: expr,
310             )*
311             Body: $req_body: expr,
312         },
313         Response: {
314             Status: $status: expr,
315             Version: $version: expr,
316             $(
317                 Header: $resp_n: expr, $resp_v: expr,
318             )*
319             Body: $resp_body: expr,
320         },)*
321     ) => {
322         async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
323             match request.method().as_str() {
324                 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
325                 $(
326                     $method => {
327                         assert_eq!($method, request.method().as_str(), "Assert request method failed");
328 
329                         assert_eq!(
330                             $host,
331                             format!("{}://{}", request.uri().scheme().expect("assert uri scheme failed !").as_str(), request.uri().host().expect("assert uri host failed !")),
332                             "Assert request host failed",
333                         );
334                         assert_eq!(
335                             $version,
336                             format!("{:?}", request.version()),
337                             "Assert request version failed",
338                         );
339                         $(assert_eq!(
340                             $req_v,
341                             request
342                                 .headers()
343                                 .get($req_n)
344                                 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
345                                 .to_str()
346                                 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
347                             "Assert request header {} failed", $req_n,
348                         );)*
349                         let body = hyper::body::to_bytes(request.into_body()).await
350                             .expect("Get request body failed");
351                         assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
352                         Ok(
353                             hyper::Response::builder()
354                                 .version(hyper::Version::HTTP_11)
355                                 .status($status)
356                                 $(.header($resp_n, $resp_v))*
357                                 .body($resp_body.into())
358                                 .expect("Build response failed")
359                         )
360                     },
361                 )*
362                 _ => {panic!("Unrecognized METHOD !");},
363             }
364         }
365 
366     };
367 }
368 
369 #[macro_export]
370 macro_rules! ensure_server_shutdown {
371     (ServerHandle: $handle:expr) => {
372         $handle
373             .client_shutdown
374             .send(())
375             .await
376             .expect("Client channel (Server-Half) be closed unexpectedly");
377         $handle
378             .server_shutdown
379             .recv()
380             .await
381             .expect("Server channel (Server-Half) be closed unexpectedly");
382     };
383 }
384 
init_test_work_runtime(thread_num: usize) -> Runtime385 pub fn init_test_work_runtime(thread_num: usize) -> Runtime {
386     tokio::runtime::Builder::new_multi_thread()
387         .worker_threads(thread_num)
388         .enable_all()
389         .build()
390         .expect("Build runtime failed.")
391 }
392