1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::sync::mpsc::Receiver;
15
16 #[cfg(feature = "async")]
17 mod async_utils;
18
19 #[cfg(feature = "sync")]
20 mod sync_utils;
21
22 pub struct TcpHandle {
23 pub addr: String,
24
25 // This channel allows the server to notify the client when it has shut down.
26 pub server_shutdown: Receiver<()>,
27 }
28
format_header_str(key: &str, value: &str) -> String29 pub fn format_header_str(key: &str, value: &str) -> String {
30 format!("{}:{}\r\n", key.to_ascii_lowercase(), value)
31 }
32
33 #[macro_export]
34 macro_rules! start_tcp_server {
35 (
36 ASYNC;
37 Proxy: $proxy: expr,
38 ServerNum: $server_num: expr,
39 Handles: $handle_vec: expr,
40 $(Request: {
41 Method: $method: expr,
42 Path: $path: expr,
43 $(
44 Header: $req_n: expr, $req_v: expr,
45 )*
46 Body: $req_body: expr,
47 },
48 Response: {
49 Status: $status: expr,
50 Version: $version: expr,
51 $(
52 Header: $resp_n: expr, $resp_v: expr,
53 )*
54 Body: $resp_body: expr,
55 },)*
56
57 ) => {{
58 use std::sync::mpsc::channel;
59 use ylong_runtime::net::TcpListener;
60 use ylong_runtime::io::{AsyncReadExt, AsyncWriteExt};
61
62 for _i in 0..$server_num {
63 let (rx, tx) = channel();
64 let (rx2, tx2) = channel();
65
66 ylong_runtime::spawn(async move {
67
68 let server = TcpListener::bind("127.0.0.1:0").await.expect("server is failed to bind a address !");
69 let addr = server.local_addr().expect("failed to get server address !");
70 let handle = TcpHandle {
71 addr: addr.to_string(),
72 server_shutdown: tx,
73 };
74 rx2.send(handle).expect("send TcpHandle out coroutine failed !");
75
76 let (mut stream, _client) = server.accept().await.expect("failed to build a tcp stream");
77
78 $(
79 {
80 let mut buf = [0u8; 4096];
81
82 let size = stream.read(&mut buf).await.expect("tcp stream read error !");
83 let mut length = 0;
84 let crlf = "\r\n";
85 let request_str = String::from_utf8_lossy(&buf[..size]);
86
87 let request_line = if $proxy {
88 format!("{} http://{}{} {}{}", $method, addr.to_string().as_str(), $path, "HTTP/1.1", crlf)
89 } else {
90 format!("{} {} {}{}", $method, $path, "HTTP/1.1", crlf)
91 };
92 assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
93 length += request_line.len();
94
95 let accept = format_header_str("accept", "*/*");
96 assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
97 length += accept.len();
98
99 let host = format_header_str("host", addr.to_string().as_str());
100 assert!(request_str.contains(host.as_str()), "Incorrect host header!");
101 length += host.len();
102
103 $(
104 let header_str = format_header_str($req_n, $req_v);
105 assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
106 length += header_str.len();
107 )*
108
109 length += crlf.len();
110 length += $req_body.len();
111
112 if length > size {
113 let size2 = stream.read(&mut buf).await.expect("tcp stream read error2 !");
114 assert_eq!(&buf[..size2], $req_body.as_bytes());
115 assert_eq!(size + size2, length, "Incorrect total request bytes !");
116 } else {
117 assert_eq!(size, length, "Incorrect total request bytes !");
118 }
119
120 let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
121 $(
122 let header = format_header_str($resp_n, $resp_v);
123 resp_str.push_str(header.as_str());
124 )*
125 resp_str.push_str(crlf);
126 resp_str.push_str($resp_body);
127
128 stream.write_all(resp_str.as_bytes()).await.expect("server write response failed");
129 }
130 )*
131 rx.send(()).expect("server send order failed !");
132
133 });
134
135 let handle = tx2.recv().expect("recv server handle failed !");
136
137 $handle_vec.push(handle);
138 }
139 }};
140
141 (
142 SYNC;
143 ServerNum: $server_num: expr,
144 Handles: $handle_vec: expr,
145 $(Request: {
146 Method: $method: expr,
147 Path: $path: expr,
148 $(
149 Header: $req_n: expr, $req_v: expr,
150 )*
151 Body: $req_body: expr,
152 },
153 Response: {
154 Status: $status: expr,
155 Version: $version: expr,
156 $(
157 Header: $resp_n: expr, $resp_v: expr,
158 )*
159 Body: $resp_body: expr,
160 },)*
161
162 ) => {{
163 use std::net::TcpListener;
164 use std::io::{Read, Write};
165 use std::sync::mpsc::channel;
166 use std::time::Duration;
167
168 for _i in 0..$server_num {
169 let server = TcpListener::bind("127.0.0.1:0").expect("server is failed to bind a address !");
170 let addr = server.local_addr().expect("failed to get server address !");
171 let (rx, tx) = channel();
172
173 std::thread::spawn( move || {
174
175 let (mut stream, _client) = server.accept().expect("failed to build a tcp stream");
176 stream.set_read_timeout(Some(Duration::from_secs(10))).expect("tcp stream set read time out error !");
177 stream.set_write_timeout(Some(Duration::from_secs(10))).expect("tcp stream set write time out error !");
178
179 $(
180 {
181 let mut buf = [0u8; 4096];
182
183 let size = stream.read(&mut buf).expect("tcp stream read error !");
184 let mut length = 0;
185 let crlf = "\r\n";
186 let request_str = String::from_utf8_lossy(&buf[..size]);
187 let request_line = format!("{} http://{}{} {}{}", $method, addr.to_string().as_str(), $path, "HTTP/1.1", crlf);
188 assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
189
190 length += request_line.len();
191
192 let accept = format_header_str("accept", "*/*");
193 assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
194 length += accept.len();
195
196 let host = format_header_str("host", addr.to_string().as_str());
197 assert!(request_str.contains(host.as_str()), "Incorrect host header!");
198 length += host.len();
199
200 $(
201 let header_str = format_header_str($req_n, $req_v);
202 assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
203 length += header_str.len();
204 )*
205
206 length += crlf.len();
207 length += $req_body.len();
208
209 if length > size {
210 let size2 = stream.read(&mut buf).expect("tcp stream read error2 !");
211 assert_eq!(&buf[..size2], $req_body.as_bytes());
212 assert_eq!(size + size2, length, "Incorrect total request bytes !");
213 } else {
214 assert_eq!(size, length, "Incorrect total request bytes !");
215 }
216
217 let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
218 $(
219 let header = format_header_str($resp_n, $resp_v);
220 resp_str.push_str(header.as_str());
221 )*
222 resp_str.push_str(crlf);
223 resp_str.push_str($resp_body);
224
225 stream.write_all(resp_str.as_bytes()).expect("server write response failed");
226 }
227 )*
228 rx.send(()).expect("server send order failed !");
229
230 });
231
232 let handle = TcpHandle {
233 addr: addr.to_string(),
234 server_shutdown: tx,
235 };
236 $handle_vec.push(handle);
237 }
238
239 }}
240 }
241
242 /// Creates a `Request`.
243 #[macro_export]
244 macro_rules! build_client_request {
245 (
246 Request: {
247 Method: $method: expr,
248 Path: $path: expr,
249 Addr: $addr: expr,
250 $(
251 Header: $req_n: expr, $req_v: expr,
252 )*
253 Body: $req_body: expr,
254 },
255 ) => {{
256 ylong_http::request::RequestBuilder::new()
257 .method($method)
258 .url(format!("http://{}{}",$addr, $path).as_str())
259 $(.header($req_n, $req_v))*
260 .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
261 .expect("Request build failed")
262 }};
263 }
264