1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 #[cfg(feature = "async")]
15 mod async_utils;
16
17 #[cfg(feature = "sync")]
18 mod sync_utils;
19
20 use tokio::runtime::Runtime;
21
22 macro_rules! define_service_handle {
23 (
24 HTTP;
25 ) => {
26 use tokio::sync::mpsc::{Receiver, Sender};
27
28 pub struct HttpHandle {
29 pub port: u16,
30
31 // This channel allows the server to notify the client when it is up and running.
32 pub server_start: Receiver<()>,
33
34 // This channel allows the client to notify the server when it is ready to shut down.
35 pub client_shutdown: Sender<()>,
36
37 // This channel allows the server to notify the client when it has shut down.
38 pub server_shutdown: Receiver<()>,
39 }
40 };
41 (
42 HTTPS;
43 ) => {
44 pub struct TlsHandle {
45 pub port: u16,
46 }
47 };
48 }
49
50 #[macro_export]
51 macro_rules! start_server {
52 (
53 HTTPS;
54 ServerNum: $server_num: expr,
55 Runtime: $runtime: expr,
56 Handles: $handle_vec: expr,
57 ServeFnName: $service_fn: ident,
58 ) => {{
59 let key_path = std::path::PathBuf::from( "tests/file/key.pem");
60 let cert_path = std::path::PathBuf::from("tests/file/cert.pem");
61
62 for _i in 0..$server_num {
63 let (tx, rx) = std::sync::mpsc::channel();
64 let key = key_path.clone();
65 let cert = cert_path.clone();
66 let server_handle = $runtime.spawn(async move {
67 let handle = start_http_server!(
68 HTTPS;
69 $service_fn,
70 key,
71 cert
72 );
73 tx.send(handle)
74 .expect("Failed to send the handle to the test thread.");
75 });
76 $runtime
77 .block_on(server_handle)
78 .expect("Runtime start server coroutine failed");
79 let handle = rx
80 .recv()
81 .expect("Handle send channel (Server-Half) be closed unexpectedly");
82 $handle_vec.push(handle);
83 }
84 }};
85 (
86 HTTPS;
87 ServerNum: $server_num: expr,
88 Runtime: $runtime: expr,
89 Handles: $handle_vec: expr,
90 ServeFnName: $service_fn: ident,
91 ServeKeyPath: $server_key_path: ident,
92 ServeCrtPath: $server_crt_path: ident,
93 ) => {{
94 let key_path = std::path::PathBuf::from($server_key_path);
95 let cert_path = std::path::PathBuf::from($server_crt_path);
96 for _i in 0..$server_num {
97 let (tx, rx) = std::sync::mpsc::channel();
98 let key_path = key_path.clone();
99 let cert_path = cert_path.clone();
100 let server_handle = $runtime.spawn(async move {
101 let handle = start_http_server!(
102 HTTPS;
103 $service_fn,
104 key_path,
105 cert_path
106 );
107 tx.send(handle)
108 .expect("Failed to send the handle to the test thread.");
109 });
110 $runtime
111 .block_on(server_handle)
112 .expect("Runtime start server coroutine failed");
113 let handle = rx
114 .recv()
115 .expect("Handle send channel (Server-Half) be closed unexpectedly");
116 $handle_vec.push(handle);
117 }
118 }};
119 (
120 HTTP;
121 ServerNum: $server_num: expr,
122 Runtime: $runtime: expr,
123 Handles: $handle_vec: expr,
124 ServeFnName: $service_fn: ident,
125 ) => {{
126 for _i in 0..$server_num {
127 let (tx, rx) = std::sync::mpsc::channel();
128 let server_handle = $runtime.spawn(async move {
129 let mut handle = start_http_server!(
130 HTTP;
131 $service_fn
132 );
133 handle
134 .server_start
135 .recv()
136 .await
137 .expect("Start channel (Server-Half) be closed unexpectedly");
138 tx.send(handle)
139 .expect("Failed to send the handle to the test thread.");
140 });
141 $runtime
142 .block_on(server_handle)
143 .expect("Runtime start server coroutine failed");
144 let handle = rx
145 .recv()
146 .expect("Handle send channel (Server-Half) be closed unexpectedly");
147 $handle_vec.push(handle);
148 }
149 }};
150 }
151
152 #[macro_export]
153 macro_rules! start_http_server {
154 (
155 HTTP;
156 $server_fn: ident
157 ) => {{
158 use hyper::service::{make_service_fn, service_fn};
159 use std::convert::Infallible;
160 use tokio::sync::mpsc::channel;
161
162 let (start_tx, start_rx) = channel::<()>(1);
163 let (client_tx, mut client_rx) = channel::<()>(1);
164 let (server_tx, server_rx) = channel::<()>(1);
165
166 let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !");
167 let addr = tcp_listener.local_addr().expect("get server local address failed!");
168 let port = addr.port();
169 let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
170
171 tokio::spawn(async move {
172 let make_svc =
173 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
174 server
175 .serve(make_svc)
176 .with_graceful_shutdown(async {
177 start_tx
178 .send(())
179 .await
180 .expect("Start channel (Client-Half) be closed unexpectedly");
181 client_rx
182 .recv()
183 .await
184 .expect("Client channel (Client-Half) be closed unexpectedly");
185 })
186 .await
187 .expect("Start server failed");
188 server_tx
189 .send(())
190 .await
191 .expect("Server channel (Client-Half) be closed unexpectedly");
192 });
193
194 HttpHandle {
195 port,
196 server_start: start_rx,
197 client_shutdown: client_tx,
198 server_shutdown: server_rx,
199 }
200 }};
201 (
202 HTTPS;
203 $service_fn: ident,
204 $server_key_path: expr,
205 $server_cert_path: expr
206 ) => {{
207 let mut port = 10000;
208 let listener = loop {
209 let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
210 match tokio::net::TcpListener::bind(addr).await {
211 Ok(listener) => break listener,
212 Err(_) => {
213 port += 1;
214 if port == u16::MAX {
215 port = 10000;
216 }
217 continue;
218 }
219 }
220 };
221 let port = listener.local_addr().unwrap().port();
222
223 tokio::spawn(async move {
224 let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls())
225 .expect("SslAcceptorBuilder error");
226 acceptor
227 .set_session_id_context(b"test")
228 .expect("Set session id error");
229 acceptor
230 .set_private_key_file($server_key_path, openssl::ssl::SslFiletype::PEM)
231 .expect("Set private key error");
232 acceptor
233 .set_certificate_chain_file($server_cert_path)
234 .expect("Set cert error");
235 acceptor.set_alpn_protos(b"\x08http/1.1").unwrap();
236 acceptor.set_alpn_select_callback(|_, client| {
237 openssl::ssl::select_next_proto(b"\x08http/1.1", client).ok_or(openssl::ssl::AlpnError::NOACK)
238 });
239
240 let acceptor = acceptor.build();
241
242 let (stream, _) = listener.accept().await.expect("TCP listener accept error");
243 let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error");
244 let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error");
245 core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully
246
247 hyper::server::conn::Http::new()
248 .http1_only(true)
249 .http1_keep_alive(true)
250 .serve_connection(stream, hyper::service::service_fn($service_fn))
251 .await
252 });
253
254 TlsHandle {
255 port,
256 }
257 }};
258 }
259
260 /// Creates a `Request`.
261 #[macro_export]
262 #[cfg(feature = "sync")]
263 macro_rules! ylong_request {
264 (
265 Request: {
266 Method: $method: expr,
267 Host: $host: expr,
268 Port: $port: expr,
269 $(
270 Header: $req_n: expr, $req_v: expr,
271 )*
272 Body: $req_body: expr,
273 },
274 ) => {
275 ylong_http::request::RequestBuilder::new()
276 .method($method)
277 .url(format!("{}:{}", $host, $port).as_str())
278 $(.header($req_n, $req_v))*
279 .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
280 .expect("Request build failed")
281 };
282 }
283
284 /// Creates a `Request`.
285 #[macro_export]
286 #[cfg(feature = "async")]
287 macro_rules! ylong_request {
288 (
289 Request: {
290 Method: $method: expr,
291 Host: $host: expr,
292 Port: $port: expr,
293 $(
294 Header: $req_n: expr, $req_v: expr,
295 )*
296 Body: $req_body: expr,
297 },
298 ) => {
299 ylong_http_client::async_impl::RequestBuilder::new()
300 .method($method)
301 .url(format!("{}:{}", $host, $port).as_str())
302 $(.header($req_n, $req_v))*
303 .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes()))
304 .expect("Request build failed")
305 };
306 }
307
308 /// Sets server async function.
309 #[macro_export]
310 macro_rules! set_server_fn {
311 (
312 ASYNC;
313 $server_fn_name: ident,
314 $(Request: {
315 Method: $method: expr,
316 $(
317 Header: $req_n: expr, $req_v: expr,
318 )*
319 Body: $req_body: expr,
320 },
321 Response: {
322 Status: $status: expr,
323 Version: $version: expr,
324 $(
325 Header: $resp_n: expr, $resp_v: expr,
326 )*
327 Body: $resp_body: expr,
328 },)*
329 ) => {
330 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
331 match request.method().as_str() {
332 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
333 $(
334 $method => {
335 assert_eq!($method, request.method().as_str(), "Assert request method failed");
336 assert_eq!(
337 "/",
338 request.uri().to_string(),
339 "Assert request host failed",
340 );
341 assert_eq!(
342 $version,
343 format!("{:?}", request.version()),
344 "Assert request version failed",
345 );
346 $(assert_eq!(
347 $req_v,
348 request
349 .headers()
350 .get($req_n)
351 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
352 .to_str()
353 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
354 "Assert request header {} failed", $req_n,
355 );)*
356 let body = hyper::body::to_bytes(request.into_body()).await
357 .expect("Get request body failed");
358 assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
359 Ok(
360 hyper::Response::builder()
361 .version(hyper::Version::HTTP_11)
362 .status($status)
363 $(.header($resp_n, $resp_v))*
364 .body($resp_body.into())
365 .expect("Build response failed")
366 )
367 },
368 )*
369 _ => {panic!("Unrecognized METHOD !");},
370 }
371 }
372
373 };
374 (
375 SYNC;
376 $server_fn_name: ident,
377 $(Request: {
378 Method: $method: expr,
379 $(
380 Header: $req_n: expr, $req_v: expr,
381 )*
382 Body: $req_body: expr,
383 },
384 Response: {
385 Status: $status: expr,
386 Version: $version: expr,
387 $(
388 Header: $resp_n: expr, $resp_v: expr,
389 )*
390 Body: $resp_body: expr,
391 },)*
392 ) => {
393 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
394 match request.method().as_str() {
395 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
396 $(
397 $method => {
398 assert_eq!($method, request.method().as_str(), "Assert request method failed");
399 assert_eq!(
400 "/",
401 request.uri().to_string(),
402 "Assert request uri failed",
403 );
404 assert_eq!(
405 $version,
406 format!("{:?}", request.version()),
407 "Assert request version failed",
408 );
409 $(assert_eq!(
410 $req_v,
411 request
412 .headers()
413 .get($req_n)
414 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
415 .to_str()
416 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
417 "Assert request header {} failed", $req_n,
418 );)*
419 let body = hyper::body::to_bytes(request.into_body()).await
420 .expect("Get request body failed");
421 assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
422 Ok(
423 hyper::Response::builder()
424 .version(hyper::Version::HTTP_11)
425 .status($status)
426 $(.header($resp_n, $resp_v))*
427 .body($resp_body.into())
428 .expect("Build response failed")
429 )
430 },
431 )*
432 _ => {panic!("Unrecognized METHOD !");},
433 }
434 }
435
436 };
437 }
438
439 #[macro_export]
440 macro_rules! ensure_server_shutdown {
441 (ServerHandle: $handle:expr) => {
442 $handle
443 .client_shutdown
444 .send(())
445 .await
446 .expect("Client channel (Server-Half) be closed unexpectedly");
447 $handle
448 .server_shutdown
449 .recv()
450 .await
451 .expect("Server channel (Server-Half) be closed unexpectedly");
452 };
453 }
454
init_test_work_runtime(thread_num: usize) -> Runtime455 pub fn init_test_work_runtime(thread_num: usize) -> Runtime {
456 tokio::runtime::Builder::new_multi_thread()
457 .worker_threads(thread_num)
458 .enable_all()
459 .build()
460 .expect("Build runtime failed.")
461 }
462