• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::sync::mpsc::Receiver;
15 
16 #[cfg(feature = "async")]
17 mod async_utils;
18 
19 #[cfg(feature = "sync")]
20 mod sync_utils;
21 
22 pub struct TcpHandle {
23     pub addr: String,
24 
25     // This channel allows the server to notify the client when it has shut down.
26     pub server_shutdown: Receiver<()>,
27 }
28 
format_header_str(key: &str, value: &str) -> String29 pub fn format_header_str(key: &str, value: &str) -> String {
30     format!("{}:{}\r\n", key.to_ascii_lowercase(), value)
31 }
32 
33 #[macro_export]
34 macro_rules! start_tcp_server {
35     (
36         ASYNC;
37         ServerNum: $server_num: expr,
38         Handles: $handle_vec: expr,
39         $(Request: {
40             Method: $method: expr,
41             Path: $path: expr,
42             $(
43                 Header: $req_n: expr, $req_v: expr,
44             )*
45             Body: $req_body: expr,
46         },
47         Response: {
48             Status: $status: expr,
49             Version: $version: expr,
50             $(
51                 Header: $resp_n: expr, $resp_v: expr,
52             )*
53             Body: $resp_body: expr,
54         },)*
55 
56     ) => {{
57         use std::sync::mpsc::channel;
58         use ylong_runtime::net::TcpListener;
59         use ylong_runtime::io::{AsyncReadExt, AsyncWriteExt};
60 
61         for _i in 0..$server_num {
62             let (rx, tx) = channel();
63             let (rx2, tx2) = channel();
64 
65             ylong_runtime::spawn(async move {
66 
67                 let server = TcpListener::bind("127.0.0.1:0").await.expect("server is failed to bind a address !");
68                 let addr = server.local_addr().expect("failed to get server address !");
69                 let handle = TcpHandle {
70                     addr: addr.to_string(),
71                     server_shutdown: tx,
72                 };
73                 rx2.send(handle).expect("send TcpHandle out coroutine failed !");
74 
75                 let (mut stream, _client) = server.accept().await.expect("failed to build a tcp stream");
76 
77                 $(
78                 {
79                     let mut buf = [0u8; 4096];
80 
81                     let size = stream.read(&mut buf).await.expect("tcp stream read error !");
82                     let mut length = 0;
83                     let crlf = "\r\n";
84                     let request_str = String::from_utf8_lossy(&buf[..size]);
85 
86                     let request_line = format!("{} {} {}{}", $method, $path, "HTTP/1.1", crlf);
87                     assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
88                     length += request_line.len();
89 
90                     let accept = format_header_str("accept", "*/*");
91                     assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
92                     length += accept.len();
93 
94                     let host = format_header_str("host", addr.to_string().as_str());
95                     assert!(request_str.contains(host.as_str()), "Incorrect host header!");
96                     length += host.len();
97 
98                     $(
99                     let header_str = format_header_str($req_n, $req_v);
100                     assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
101                     length += header_str.len();
102                     )*
103 
104                     length += crlf.len();
105                     length += $req_body.len();
106 
107                     if length > size {
108                         let size2 = stream.read(&mut buf).await.expect("tcp stream read error2 !");
109                         assert_eq!(&buf[..size2], $req_body.as_bytes());
110                         assert_eq!(size + size2, length, "Incorrect total request bytes !");
111                     } else {
112                         assert_eq!(size, length, "Incorrect total request bytes !");
113                     }
114 
115                     let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
116                     $(
117                     let header = format_header_str($resp_n, $resp_v);
118                     resp_str.push_str(header.as_str());
119                     )*
120                     resp_str.push_str(crlf);
121                     resp_str.push_str($resp_body);
122 
123                     stream.write_all(resp_str.as_bytes()).await.expect("server write response failed");
124                 }
125                 )*
126                 rx.send(()).expect("server send order failed !");
127 
128             });
129 
130             let handle = tx2.recv().expect("recv server handle failed !");
131 
132             $handle_vec.push(handle);
133         }
134     }};
135 
136     (
137         SYNC;
138         ServerNum: $server_num: expr,
139         Handles: $handle_vec: expr,
140         $(Request: {
141             Method: $method: expr,
142             Path: $path: expr,
143             $(
144                 Header: $req_n: expr, $req_v: expr,
145             )*
146             Body: $req_body: expr,
147         },
148         Response: {
149             Status: $status: expr,
150             Version: $version: expr,
151             $(
152                 Header: $resp_n: expr, $resp_v: expr,
153             )*
154             Body: $resp_body: expr,
155         },)*
156 
157     ) => {{
158         use std::net::TcpListener;
159         use std::io::{Read, Write};
160         use std::sync::mpsc::channel;
161         use std::time::Duration;
162 
163         for _i in 0..$server_num {
164             let server = TcpListener::bind("127.0.0.1:0").expect("server is failed to bind a address !");
165             let addr = server.local_addr().expect("failed to get server address !");
166             let (rx, tx) = channel();
167 
168             std::thread::spawn( move || {
169 
170                 let (mut stream, _client) = server.accept().expect("failed to build a tcp stream");
171                 stream.set_read_timeout(Some(Duration::from_secs(10))).expect("tcp stream set read time out error !");
172                 stream.set_write_timeout(Some(Duration::from_secs(10))).expect("tcp stream set write time out error !");
173 
174                 $(
175                 {
176                     let mut buf = [0u8; 4096];
177 
178                     let size = stream.read(&mut buf).expect("tcp stream read error !");
179                     let mut length = 0;
180                     let crlf = "\r\n";
181                     let request_str = String::from_utf8_lossy(&buf[..size]);
182                     let request_line = format!("{} http://{}{} {}{}", $method, addr.to_string().as_str(), $path, "HTTP/1.1", crlf);
183                     assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
184 
185                     length += request_line.len();
186 
187                     let accept = format_header_str("accept", "*/*");
188                     assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
189                     length += accept.len();
190 
191                     let host = format_header_str("host", addr.to_string().as_str());
192                     assert!(request_str.contains(host.as_str()), "Incorrect host header!");
193                     length += host.len();
194 
195                     $(
196                     let header_str = format_header_str($req_n, $req_v);
197                     assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
198                     length += header_str.len();
199                     )*
200 
201                     length += crlf.len();
202                     length += $req_body.len();
203 
204                     if length > size {
205                         let size2 = stream.read(&mut buf).expect("tcp stream read error2 !");
206                         assert_eq!(&buf[..size2], $req_body.as_bytes());
207                         assert_eq!(size + size2, length, "Incorrect total request bytes !");
208                     } else {
209                         assert_eq!(size, length, "Incorrect total request bytes !");
210                     }
211 
212                     let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
213                     $(
214                     let header = format_header_str($resp_n, $resp_v);
215                     resp_str.push_str(header.as_str());
216                     )*
217                     resp_str.push_str(crlf);
218                     resp_str.push_str($resp_body);
219 
220                     stream.write_all(resp_str.as_bytes()).expect("server write response failed");
221                 }
222                 )*
223                 rx.send(()).expect("server send order failed !");
224 
225             });
226 
227             let handle = TcpHandle {
228                 addr: addr.to_string(),
229                 server_shutdown: tx,
230             };
231             $handle_vec.push(handle);
232         }
233 
234     }}
235 }
236 
237 /// Creates a `Request`.
238 #[macro_export]
239 macro_rules! build_client_request {
240     (
241         Request: {
242             Method: $method: expr,
243             Path: $path: expr,
244             Addr: $addr: expr,
245             $(
246                 Header: $req_n: expr, $req_v: expr,
247             )*
248             Body: $req_body: expr,
249         },
250     ) => {{
251         ylong_http::request::RequestBuilder::new()
252             .method($method)
253             .url(format!("http://{}{}",$addr, $path).as_str())
254             $(.header($req_n, $req_v))*
255             .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
256             .expect("Request build failed")
257     }};
258 }
259