• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::sync::mpsc::Receiver;
15 
16 #[cfg(feature = "async")]
17 mod async_utils;
18 
19 #[cfg(feature = "sync")]
20 mod sync_utils;
21 
22 pub struct TcpHandle {
23     pub addr: String,
24 
25     // This channel allows the server to notify the client when it has shut down.
26     pub server_shutdown: Receiver<()>,
27 }
28 
format_header_str(key: &str, value: &str) -> String29 pub fn format_header_str(key: &str, value: &str) -> String {
30     format!("{}:{}\r\n", key.to_ascii_lowercase(), value)
31 }
32 
33 #[macro_export]
34 macro_rules! start_tcp_server {
35     (
36         ASYNC;
37         Proxy: $proxy: expr,
38         ServerNum: $server_num: expr,
39         Handles: $handle_vec: expr,
40         $(Request: {
41             Method: $method: expr,
42             Path: $path: expr,
43             $(
44                 Header: $req_n: expr, $req_v: expr,
45             )*
46             Body: $req_body: expr,
47         },
48         Response: {
49             Status: $status: expr,
50             Version: $version: expr,
51             $(
52                 Header: $resp_n: expr, $resp_v: expr,
53             )*
54             Body: $resp_body: expr,
55         },)*
56 
57     ) => {{
58         use std::sync::mpsc::channel;
59         use ylong_runtime::net::TcpListener;
60         use ylong_runtime::io::{AsyncReadExt, AsyncWriteExt};
61 
62         for _i in 0..$server_num {
63             let (rx, tx) = channel();
64             let (rx2, tx2) = channel();
65 
66             ylong_runtime::spawn(async move {
67 
68                 let server = TcpListener::bind("127.0.0.1:0").await.expect("server is failed to bind a address !");
69                 let addr = server.local_addr().expect("failed to get server address !");
70                 let handle = TcpHandle {
71                     addr: addr.to_string(),
72                     server_shutdown: tx,
73                 };
74                 rx2.send(handle).expect("send TcpHandle out coroutine failed !");
75 
76                 let (mut stream, _client) = server.accept().await.expect("failed to build a tcp stream");
77 
78                 $(
79                 {
80                     let mut buf = [0u8; 4096];
81 
82                     let size = stream.read(&mut buf).await.expect("tcp stream read error !");
83                     let mut length = 0;
84                     let crlf = "\r\n";
85                     let request_str = String::from_utf8_lossy(&buf[..size]);
86 
87                     let request_line = if $proxy {
88                         format!("{} http://{}{} {}{}", $method, addr.to_string().as_str(), $path, "HTTP/1.1", crlf)
89                     } else {
90                         format!("{} {} {}{}", $method, $path, "HTTP/1.1", crlf)
91                     };
92                     assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
93                     length += request_line.len();
94 
95                     let accept = format_header_str("accept", "*/*");
96                     assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
97                     length += accept.len();
98 
99                     let host = format_header_str("host", addr.to_string().as_str());
100                     assert!(request_str.contains(host.as_str()), "Incorrect host header!");
101                     length += host.len();
102 
103                     $(
104                     let header_str = format_header_str($req_n, $req_v);
105                     assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
106                     length += header_str.len();
107                     )*
108 
109                     length += crlf.len();
110                     length += $req_body.len();
111 
112                     if length > size {
113                         let size2 = stream.read(&mut buf).await.expect("tcp stream read error2 !");
114                         assert_eq!(&buf[..size2], $req_body.as_bytes());
115                         assert_eq!(size + size2, length, "Incorrect total request bytes !");
116                     } else {
117                         assert_eq!(size, length, "Incorrect total request bytes !");
118                     }
119 
120                     let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
121                     $(
122                     let header = format_header_str($resp_n, $resp_v);
123                     resp_str.push_str(header.as_str());
124                     )*
125                     resp_str.push_str(crlf);
126                     resp_str.push_str($resp_body);
127 
128                     stream.write_all(resp_str.as_bytes()).await.expect("server write response failed");
129                 }
130                 )*
131                 rx.send(()).expect("server send order failed !");
132 
133             });
134 
135             let handle = tx2.recv().expect("recv server handle failed !");
136 
137             $handle_vec.push(handle);
138         }
139     }};
140 
141     (
142         SYNC;
143         ServerNum: $server_num: expr,
144         Handles: $handle_vec: expr,
145         $(Request: {
146             Method: $method: expr,
147             Path: $path: expr,
148             $(
149                 Header: $req_n: expr, $req_v: expr,
150             )*
151             Body: $req_body: expr,
152         },
153         Response: {
154             Status: $status: expr,
155             Version: $version: expr,
156             $(
157                 Header: $resp_n: expr, $resp_v: expr,
158             )*
159             Body: $resp_body: expr,
160         },)*
161 
162     ) => {{
163         use std::net::TcpListener;
164         use std::io::{Read, Write};
165         use std::sync::mpsc::channel;
166         use std::time::Duration;
167 
168         for _i in 0..$server_num {
169             let server = TcpListener::bind("127.0.0.1:0").expect("server is failed to bind a address !");
170             let addr = server.local_addr().expect("failed to get server address !");
171             let (rx, tx) = channel();
172 
173             std::thread::spawn( move || {
174 
175                 let (mut stream, _client) = server.accept().expect("failed to build a tcp stream");
176                 stream.set_read_timeout(Some(Duration::from_secs(10))).expect("tcp stream set read time out error !");
177                 stream.set_write_timeout(Some(Duration::from_secs(10))).expect("tcp stream set write time out error !");
178 
179                 $(
180                 {
181                     let mut buf = [0u8; 4096];
182 
183                     let size = stream.read(&mut buf).expect("tcp stream read error !");
184                     let mut length = 0;
185                     let crlf = "\r\n";
186                     let request_str = String::from_utf8_lossy(&buf[..size]);
187                     let request_line = format!("{} http://{}{} {}{}", $method, addr.to_string().as_str(), $path, "HTTP/1.1", crlf);
188                     assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
189 
190                     length += request_line.len();
191 
192                     let accept = format_header_str("accept", "*/*");
193                     assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
194                     length += accept.len();
195 
196                     let host = format_header_str("host", addr.to_string().as_str());
197                     assert!(request_str.contains(host.as_str()), "Incorrect host header!");
198                     length += host.len();
199 
200                     $(
201                     let header_str = format_header_str($req_n, $req_v);
202                     assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
203                     length += header_str.len();
204                     )*
205 
206                     length += crlf.len();
207                     length += $req_body.len();
208 
209                     if length > size {
210                         let size2 = stream.read(&mut buf).expect("tcp stream read error2 !");
211                         assert_eq!(&buf[..size2], $req_body.as_bytes());
212                         assert_eq!(size + size2, length, "Incorrect total request bytes !");
213                     } else {
214                         assert_eq!(size, length, "Incorrect total request bytes !");
215                     }
216 
217                     let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
218                     $(
219                     let header = format_header_str($resp_n, $resp_v);
220                     resp_str.push_str(header.as_str());
221                     )*
222                     resp_str.push_str(crlf);
223                     resp_str.push_str($resp_body);
224 
225                     stream.write_all(resp_str.as_bytes()).expect("server write response failed");
226                 }
227                 )*
228                 rx.send(()).expect("server send order failed !");
229 
230             });
231 
232             let handle = TcpHandle {
233                 addr: addr.to_string(),
234                 server_shutdown: tx,
235             };
236             $handle_vec.push(handle);
237         }
238 
239     }}
240 }
241 
242 /// Creates a `Request`.
243 #[macro_export]
244 macro_rules! build_client_request {
245     (
246         Request: {
247             Method: $method: expr,
248             Path: $path: expr,
249             Addr: $addr: expr,
250             $(
251                 Header: $req_n: expr, $req_v: expr,
252             )*
253             Body: $req_body: expr,
254         },
255     ) => {{
256         ylong_http::request::RequestBuilder::new()
257             .method($method)
258             .url(format!("http://{}{}",$addr, $path).as_str())
259             $(.header($req_n, $req_v))*
260             .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
261             .expect("Request build failed")
262     }};
263 }
264