1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::sync::mpsc::Receiver;
15
16 #[cfg(feature = "async")]
17 mod async_utils;
18
19 #[cfg(feature = "sync")]
20 mod sync_utils;
21
22 pub struct TcpHandle {
23 pub addr: String,
24
25 // This channel allows the server to notify the client when it has shut down.
26 pub server_shutdown: Receiver<()>,
27 }
28
format_header_str(key: &str, value: &str) -> String29 pub fn format_header_str(key: &str, value: &str) -> String {
30 format!("{}:{}\r\n", key.to_ascii_lowercase(), value)
31 }
32
33 #[macro_export]
34 macro_rules! start_tcp_server {
35 (
36 ASYNC;
37 ServerNum: $server_num: expr,
38 Handles: $handle_vec: expr,
39 $(Request: {
40 Method: $method: expr,
41 Path: $path: expr,
42 $(
43 Header: $req_n: expr, $req_v: expr,
44 )*
45 Body: $req_body: expr,
46 },
47 Response: {
48 Status: $status: expr,
49 Version: $version: expr,
50 $(
51 Header: $resp_n: expr, $resp_v: expr,
52 )*
53 Body: $resp_body: expr,
54 },)*
55
56 ) => {{
57 use std::sync::mpsc::channel;
58 use ylong_runtime::net::TcpListener;
59 use ylong_runtime::io::{AsyncReadExt, AsyncWriteExt};
60
61 for _i in 0..$server_num {
62 let (rx, tx) = channel();
63 let (rx2, tx2) = channel();
64
65 ylong_runtime::spawn(async move {
66
67 let server = TcpListener::bind("127.0.0.1:0").await.expect("server is failed to bind a address !");
68 let addr = server.local_addr().expect("failed to get server address !");
69 let handle = TcpHandle {
70 addr: addr.to_string(),
71 server_shutdown: tx,
72 };
73 rx2.send(handle).expect("send TcpHandle out coroutine failed !");
74
75 let (mut stream, _client) = server.accept().await.expect("failed to build a tcp stream");
76
77 $(
78 {
79 let mut buf = [0u8; 4096];
80
81 let size = stream.read(&mut buf).await.expect("tcp stream read error !");
82 let mut length = 0;
83 let crlf = "\r\n";
84 let request_str = String::from_utf8_lossy(&buf[..size]);
85
86 let request_line = format!("{} {} {}{}", $method, $path, "HTTP/1.1", crlf);
87 assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
88 length += request_line.len();
89
90 let accept = format_header_str("accept", "*/*");
91 assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
92 length += accept.len();
93
94 let host = format_header_str("host", addr.to_string().as_str());
95 assert!(request_str.contains(host.as_str()), "Incorrect host header!");
96 length += host.len();
97
98 $(
99 let header_str = format_header_str($req_n, $req_v);
100 assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
101 length += header_str.len();
102 )*
103
104 length += crlf.len();
105 length += $req_body.len();
106
107 if length > size {
108 let size2 = stream.read(&mut buf).await.expect("tcp stream read error2 !");
109 assert_eq!(&buf[..size2], $req_body.as_bytes());
110 assert_eq!(size + size2, length, "Incorrect total request bytes !");
111 } else {
112 assert_eq!(size, length, "Incorrect total request bytes !");
113 }
114
115 let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
116 $(
117 let header = format_header_str($resp_n, $resp_v);
118 resp_str.push_str(header.as_str());
119 )*
120 resp_str.push_str(crlf);
121 resp_str.push_str($resp_body);
122
123 stream.write_all(resp_str.as_bytes()).await.expect("server write response failed");
124 }
125 )*
126 rx.send(()).expect("server send order failed !");
127
128 });
129
130 let handle = tx2.recv().expect("recv server handle failed !");
131
132 $handle_vec.push(handle);
133 }
134 }};
135
136 (
137 SYNC;
138 ServerNum: $server_num: expr,
139 Handles: $handle_vec: expr,
140 $(Request: {
141 Method: $method: expr,
142 Path: $path: expr,
143 $(
144 Header: $req_n: expr, $req_v: expr,
145 )*
146 Body: $req_body: expr,
147 },
148 Response: {
149 Status: $status: expr,
150 Version: $version: expr,
151 $(
152 Header: $resp_n: expr, $resp_v: expr,
153 )*
154 Body: $resp_body: expr,
155 },)*
156
157 ) => {{
158 use std::net::TcpListener;
159 use std::io::{Read, Write};
160 use std::sync::mpsc::channel;
161 use std::time::Duration;
162
163 for _i in 0..$server_num {
164 let server = TcpListener::bind("127.0.0.1:0").expect("server is failed to bind a address !");
165 let addr = server.local_addr().expect("failed to get server address !");
166 let (rx, tx) = channel();
167
168 std::thread::spawn( move || {
169
170 let (mut stream, _client) = server.accept().expect("failed to build a tcp stream");
171 stream.set_read_timeout(Some(Duration::from_secs(10))).expect("tcp stream set read time out error !");
172 stream.set_write_timeout(Some(Duration::from_secs(10))).expect("tcp stream set write time out error !");
173
174 $(
175 {
176 let mut buf = [0u8; 4096];
177
178 let size = stream.read(&mut buf).expect("tcp stream read error !");
179 let mut length = 0;
180 let crlf = "\r\n";
181 let request_str = String::from_utf8_lossy(&buf[..size]);
182 let request_line = format!("{} http://{}{} {}{}", $method, addr.to_string().as_str(), $path, "HTTP/1.1", crlf);
183 assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
184
185 length += request_line.len();
186
187 let accept = format_header_str("accept", "*/*");
188 assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
189 length += accept.len();
190
191 let host = format_header_str("host", addr.to_string().as_str());
192 assert!(request_str.contains(host.as_str()), "Incorrect host header!");
193 length += host.len();
194
195 $(
196 let header_str = format_header_str($req_n, $req_v);
197 assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
198 length += header_str.len();
199 )*
200
201 length += crlf.len();
202 length += $req_body.len();
203
204 if length > size {
205 let size2 = stream.read(&mut buf).expect("tcp stream read error2 !");
206 assert_eq!(&buf[..size2], $req_body.as_bytes());
207 assert_eq!(size + size2, length, "Incorrect total request bytes !");
208 } else {
209 assert_eq!(size, length, "Incorrect total request bytes !");
210 }
211
212 let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
213 $(
214 let header = format_header_str($resp_n, $resp_v);
215 resp_str.push_str(header.as_str());
216 )*
217 resp_str.push_str(crlf);
218 resp_str.push_str($resp_body);
219
220 stream.write_all(resp_str.as_bytes()).expect("server write response failed");
221 }
222 )*
223 rx.send(()).expect("server send order failed !");
224
225 });
226
227 let handle = TcpHandle {
228 addr: addr.to_string(),
229 server_shutdown: tx,
230 };
231 $handle_vec.push(handle);
232 }
233
234 }}
235 }
236
237 /// Creates a `Request`.
238 #[macro_export]
239 macro_rules! build_client_request {
240 (
241 Request: {
242 Method: $method: expr,
243 Path: $path: expr,
244 Addr: $addr: expr,
245 $(
246 Header: $req_n: expr, $req_v: expr,
247 )*
248 Body: $req_body: expr,
249 },
250 ) => {{
251 ylong_http::request::RequestBuilder::new()
252 .method($method)
253 .url(format!("http://{}{}",$addr, $path).as_str())
254 $(.header($req_n, $req_v))*
255 .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
256 .expect("Request build failed")
257 }};
258 }
259