• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #[cfg(feature = "async")]
15 mod async_utils;
16 
17 #[cfg(feature = "sync")]
18 mod sync_utils;
19 
20 use tokio::runtime::Runtime;
21 
22 macro_rules! define_service_handle {
23     (
24         HTTP;
25     ) => {
26         use tokio::sync::mpsc::{Receiver, Sender};
27 
28         pub struct HttpHandle {
29             pub port: u16,
30 
31             // This channel allows the server to notify the client when it is up and running.
32             pub server_start: Receiver<()>,
33 
34             // This channel allows the client to notify the server when it is ready to shut down.
35             pub client_shutdown: Sender<()>,
36 
37             // This channel allows the server to notify the client when it has shut down.
38             pub server_shutdown: Receiver<()>,
39         }
40     };
41     (
42         HTTPS;
43     ) => {
44         pub struct TlsHandle {
45             pub port: u16,
46         }
47     };
48 }
49 
50 #[macro_export]
51 macro_rules! start_server {
52     (
53         HTTPS;
54         ServerNum: $server_num: expr,
55         Runtime: $runtime: expr,
56         Handles: $handle_vec: expr,
57         ServeFnName: $service_fn: ident,
58     ) => {{
59         let key_path = std::path::PathBuf::from( "tests/file/key.pem");
60         let cert_path = std::path::PathBuf::from("tests/file/cert.pem");
61 
62         for _i in 0..$server_num {
63             let (tx, rx) = std::sync::mpsc::channel();
64             let key = key_path.clone();
65             let cert = cert_path.clone();
66             let server_handle = $runtime.spawn(async move {
67                 let handle = start_http_server!(
68                     HTTPS;
69                     $service_fn,
70                     key,
71                     cert
72                 );
73                 tx.send(handle)
74                     .expect("Failed to send the handle to the test thread.");
75             });
76             $runtime
77                 .block_on(server_handle)
78                 .expect("Runtime start server coroutine failed");
79             let handle = rx
80                 .recv()
81                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
82             $handle_vec.push(handle);
83         }
84     }};
85     (
86         HTTPS;
87         ServerNum: $server_num: expr,
88         Runtime: $runtime: expr,
89         Handles: $handle_vec: expr,
90         ServeFnName: $service_fn: ident,
91         ServeKeyPath: $server_key_path: ident,
92         ServeCrtPath: $server_crt_path: ident,
93     ) => {{
94         let key_path = std::path::PathBuf::from($server_key_path);
95         let cert_path = std::path::PathBuf::from($server_crt_path);
96         for _i in 0..$server_num {
97             let (tx, rx) = std::sync::mpsc::channel();
98             let key_path = key_path.clone();
99             let cert_path = cert_path.clone();
100             let server_handle = $runtime.spawn(async move {
101                 let handle = start_http_server!(
102                     HTTPS;
103                     $service_fn,
104                     key_path,
105                     cert_path
106                 );
107                 tx.send(handle)
108                     .expect("Failed to send the handle to the test thread.");
109             });
110             $runtime
111                 .block_on(server_handle)
112                 .expect("Runtime start server coroutine failed");
113             let handle = rx
114                 .recv()
115                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
116             $handle_vec.push(handle);
117         }
118     }};
119     (
120         HTTP;
121         ServerNum: $server_num: expr,
122         Runtime: $runtime: expr,
123         Handles: $handle_vec: expr,
124         ServeFnName: $service_fn: ident,
125     ) => {{
126         for _i in 0..$server_num {
127             let (tx, rx) = std::sync::mpsc::channel();
128             let server_handle = $runtime.spawn(async move {
129                 let mut handle = start_http_server!(
130                     HTTP;
131                     $service_fn
132                 );
133                 handle
134                     .server_start
135                     .recv()
136                     .await
137                     .expect("Start channel (Server-Half) be closed unexpectedly");
138                 tx.send(handle)
139                     .expect("Failed to send the handle to the test thread.");
140             });
141             $runtime
142                 .block_on(server_handle)
143                 .expect("Runtime start server coroutine failed");
144             let handle = rx
145                 .recv()
146                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
147             $handle_vec.push(handle);
148         }
149     }};
150 }
151 
152 #[macro_export]
153 macro_rules! start_http_server {
154     (
155         HTTP;
156         $server_fn: ident
157     ) => {{
158         use hyper::service::{make_service_fn, service_fn};
159         use std::convert::Infallible;
160         use tokio::sync::mpsc::channel;
161 
162         let (start_tx, start_rx) = channel::<()>(1);
163         let (client_tx, mut client_rx) = channel::<()>(1);
164         let (server_tx, server_rx) = channel::<()>(1);
165 
166         let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !");
167         let addr = tcp_listener.local_addr().expect("get server local address failed!");
168         let port = addr.port();
169         let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
170 
171         tokio::spawn(async move {
172             let make_svc =
173                 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
174             server
175                 .serve(make_svc)
176                 .with_graceful_shutdown(async {
177                     start_tx
178                         .send(())
179                         .await
180                         .expect("Start channel (Client-Half) be closed unexpectedly");
181                     client_rx
182                         .recv()
183                         .await
184                         .expect("Client channel (Client-Half) be closed unexpectedly");
185                 })
186                 .await
187                 .expect("Start server failed");
188             server_tx
189                 .send(())
190                 .await
191                 .expect("Server channel (Client-Half) be closed unexpectedly");
192         });
193 
194         HttpHandle {
195             port,
196             server_start: start_rx,
197             client_shutdown: client_tx,
198             server_shutdown: server_rx,
199         }
200     }};
201     (
202         HTTPS;
203         $service_fn: ident,
204         $server_key_path: expr,
205         $server_cert_path: expr
206     ) => {{
207         let mut port = 10000;
208         let listener = loop {
209             let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
210             match tokio::net::TcpListener::bind(addr).await {
211                 Ok(listener) => break listener,
212                 Err(_) => {
213                     port += 1;
214                     if port == u16::MAX {
215                         port = 10000;
216                     }
217                     continue;
218                 }
219             }
220         };
221         let port = listener.local_addr().unwrap().port();
222 
223         tokio::spawn(async move {
224             let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls())
225                 .expect("SslAcceptorBuilder error");
226             acceptor
227                 .set_session_id_context(b"test")
228                 .expect("Set session id error");
229             acceptor
230                 .set_private_key_file($server_key_path, openssl::ssl::SslFiletype::PEM)
231                 .expect("Set private key error");
232             acceptor
233                 .set_certificate_chain_file($server_cert_path)
234                 .expect("Set cert error");
235             acceptor.set_alpn_protos(b"\x08http/1.1").unwrap();
236             acceptor.set_alpn_select_callback(|_, client| {
237                 openssl::ssl::select_next_proto(b"\x08http/1.1", client).ok_or(openssl::ssl::AlpnError::NOACK)
238             });
239 
240             let acceptor = acceptor.build();
241 
242             let (stream, _) = listener.accept().await.expect("TCP listener accept error");
243             let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error");
244             let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error");
245             core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully
246 
247             hyper::server::conn::Http::new()
248                 .http1_only(true)
249                 .http1_keep_alive(true)
250                 .serve_connection(stream, hyper::service::service_fn($service_fn))
251                 .await
252         });
253 
254         TlsHandle {
255             port,
256         }
257     }};
258 }
259 
260 /// Creates a `Request`.
261 #[macro_export]
262 #[cfg(feature = "sync")]
263 macro_rules! ylong_request {
264     (
265         Request: {
266             Method: $method: expr,
267             Host: $host: expr,
268             Port: $port: expr,
269             $(
270                 Header: $req_n: expr, $req_v: expr,
271             )*
272             Body: $req_body: expr,
273         },
274     ) => {
275         ylong_http::request::RequestBuilder::new()
276             .method($method)
277             .url(format!("{}:{}", $host, $port).as_str())
278             $(.header($req_n, $req_v))*
279             .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
280             .expect("Request build failed")
281     };
282 }
283 
284 /// Creates a `Request`.
285 #[macro_export]
286 #[cfg(feature = "async")]
287 macro_rules! ylong_request {
288     (
289         Request: {
290             Method: $method: expr,
291             Host: $host: expr,
292             Port: $port: expr,
293             $(
294                 Header: $req_n: expr, $req_v: expr,
295             )*
296             Body: $req_body: expr,
297         },
298     ) => {
299         ylong_http_client::async_impl::RequestBuilder::new()
300              .method($method)
301              .url(format!("{}:{}", $host, $port).as_str())
302              $(.header($req_n, $req_v))*
303              .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes()))
304              .expect("Request build failed")
305     };
306 }
307 
308 /// Sets server async function.
309 #[macro_export]
310 macro_rules! set_server_fn {
311     (
312         ASYNC;
313         $server_fn_name: ident,
314         $(Request: {
315             Method: $method: expr,
316             $(
317                 Header: $req_n: expr, $req_v: expr,
318             )*
319             Body: $req_body: expr,
320         },
321         Response: {
322             Status: $status: expr,
323             Version: $version: expr,
324             $(
325                 Header: $resp_n: expr, $resp_v: expr,
326             )*
327             Body: $resp_body: expr,
328         },)*
329     ) => {
330         async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
331             match request.method().as_str() {
332                 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
333                 $(
334                     $method => {
335                         assert_eq!($method, request.method().as_str(), "Assert request method failed");
336                         assert_eq!(
337                             "/",
338                             request.uri().to_string(),
339                             "Assert request host failed",
340                         );
341                         assert_eq!(
342                             $version,
343                             format!("{:?}", request.version()),
344                             "Assert request version failed",
345                         );
346                         $(assert_eq!(
347                             $req_v,
348                             request
349                                 .headers()
350                                 .get($req_n)
351                                 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
352                                 .to_str()
353                                 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
354                             "Assert request header {} failed", $req_n,
355                         );)*
356                         let body = hyper::body::to_bytes(request.into_body()).await
357                             .expect("Get request body failed");
358                         assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
359                         Ok(
360                             hyper::Response::builder()
361                                 .version(hyper::Version::HTTP_11)
362                                 .status($status)
363                                 $(.header($resp_n, $resp_v))*
364                                 .body($resp_body.into())
365                                 .expect("Build response failed")
366                         )
367                     },
368                 )*
369                 _ => {panic!("Unrecognized METHOD !");},
370             }
371         }
372 
373     };
374     (
375         SYNC;
376         $server_fn_name: ident,
377         $(Request: {
378             Method: $method: expr,
379             $(
380                 Header: $req_n: expr, $req_v: expr,
381             )*
382             Body: $req_body: expr,
383         },
384         Response: {
385             Status: $status: expr,
386             Version: $version: expr,
387             $(
388                 Header: $resp_n: expr, $resp_v: expr,
389             )*
390             Body: $resp_body: expr,
391         },)*
392     ) => {
393         async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
394             match request.method().as_str() {
395                 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
396                 $(
397                     $method => {
398                         assert_eq!($method, request.method().as_str(), "Assert request method failed");
399                         assert_eq!(
400                             "/",
401                             request.uri().to_string(),
402                             "Assert request uri failed",
403                         );
404                         assert_eq!(
405                             $version,
406                             format!("{:?}", request.version()),
407                             "Assert request version failed",
408                         );
409                         $(assert_eq!(
410                             $req_v,
411                             request
412                                 .headers()
413                                 .get($req_n)
414                                 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
415                                 .to_str()
416                                 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
417                             "Assert request header {} failed", $req_n,
418                         );)*
419                         let body = hyper::body::to_bytes(request.into_body()).await
420                             .expect("Get request body failed");
421                         assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
422                         Ok(
423                             hyper::Response::builder()
424                                 .version(hyper::Version::HTTP_11)
425                                 .status($status)
426                                 $(.header($resp_n, $resp_v))*
427                                 .body($resp_body.into())
428                                 .expect("Build response failed")
429                         )
430                     },
431                 )*
432                 _ => {panic!("Unrecognized METHOD !");},
433             }
434         }
435 
436     };
437 }
438 
439 #[macro_export]
440 macro_rules! ensure_server_shutdown {
441     (ServerHandle: $handle:expr) => {
442         $handle
443             .client_shutdown
444             .send(())
445             .await
446             .expect("Client channel (Server-Half) be closed unexpectedly");
447         $handle
448             .server_shutdown
449             .recv()
450             .await
451             .expect("Server channel (Server-Half) be closed unexpectedly");
452     };
453 }
454 
init_test_work_runtime(thread_num: usize) -> Runtime455 pub fn init_test_work_runtime(thread_num: usize) -> Runtime {
456     tokio::runtime::Builder::new_multi_thread()
457         .worker_threads(thread_num)
458         .enable_all()
459         .build()
460         .expect("Build runtime failed.")
461 }
462