• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Client handshake machine.
2 
3 use std::{
4     io::{Read, Write},
5     marker::PhantomData,
6 };
7 
8 use http::{
9     header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
10 };
11 use httparse::Status;
12 use log::*;
13 
14 use super::{
15     derive_accept_key,
16     headers::{FromHttparse, MAX_HEADERS},
17     machine::{HandshakeMachine, StageResult, TryParse},
18     HandshakeRole, MidHandshake, ProcessingResult,
19 };
20 use crate::{
21     error::{Error, ProtocolError, Result, SubProtocolError, UrlError},
22     protocol::{Role, WebSocket, WebSocketConfig},
23 };
24 
25 /// Client request type.
26 pub type Request = HttpRequest<()>;
27 
28 /// Client response type.
29 pub type Response = HttpResponse<Option<Vec<u8>>>;
30 
31 /// Client handshake role.
32 #[derive(Debug)]
33 pub struct ClientHandshake<S> {
34     verify_data: VerifyData,
35     config: Option<WebSocketConfig>,
36     _marker: PhantomData<S>,
37 }
38 
39 impl<S: Read + Write> ClientHandshake<S> {
40     /// Initiate a client handshake.
start( stream: S, request: Request, config: Option<WebSocketConfig>, ) -> Result<MidHandshake<Self>>41     pub fn start(
42         stream: S,
43         request: Request,
44         config: Option<WebSocketConfig>,
45     ) -> Result<MidHandshake<Self>> {
46         if request.method() != http::Method::GET {
47             return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
48         }
49 
50         if request.version() < http::Version::HTTP_11 {
51             return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
52         }
53 
54         // Check the URI scheme: only ws or wss are supported
55         let _ = crate::client::uri_mode(request.uri())?;
56 
57         let subprotocols = extract_subprotocols_from_request(&request)?;
58 
59         // Convert and verify the `http::Request` and turn it into the request as per RFC.
60         // Also extract the key from it (it must be present in a correct request).
61         let (request, key) = generate_request(request)?;
62 
63         let machine = HandshakeMachine::start_write(stream, request);
64 
65         let client = {
66             let accept_key = derive_accept_key(key.as_ref());
67             ClientHandshake {
68                 verify_data: VerifyData { accept_key, subprotocols },
69                 config,
70                 _marker: PhantomData,
71             }
72         };
73 
74         trace!("Client handshake initiated.");
75         Ok(MidHandshake { role: client, machine })
76     }
77 }
78 
79 impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
80     type IncomingData = Response;
81     type InternalStream = S;
82     type FinalResult = (WebSocket<S>, Response);
stage_finished( &mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>, ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>83     fn stage_finished(
84         &mut self,
85         finish: StageResult<Self::IncomingData, Self::InternalStream>,
86     ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
87         Ok(match finish {
88             StageResult::DoneWriting(stream) => {
89                 ProcessingResult::Continue(HandshakeMachine::start_read(stream))
90             }
91             StageResult::DoneReading { stream, result, tail } => {
92                 let result = match self.verify_data.verify_response(result) {
93                     Ok(r) => r,
94                     Err(Error::Http(mut e)) => {
95                         *e.body_mut() = Some(tail);
96                         return Err(Error::Http(e));
97                     }
98                     Err(e) => return Err(e),
99                 };
100 
101                 debug!("Client handshake done.");
102                 let websocket =
103                     WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
104                 ProcessingResult::Done((websocket, result))
105             }
106         })
107     }
108 }
109 
110 /// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
generate_request(mut request: Request) -> Result<(Vec<u8>, String)>111 pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
112     let mut req = Vec::new();
113     write!(
114         req,
115         "GET {path} {version:?}\r\n",
116         path = request.uri().path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(),
117         version = request.version()
118     )
119     .unwrap();
120 
121     // Headers that must be present in a correct request.
122     const KEY_HEADERNAME: &str = "Sec-WebSocket-Key";
123     const WEBSOCKET_HEADERS: [&str; 5] =
124         ["Host", "Connection", "Upgrade", "Sec-WebSocket-Version", KEY_HEADERNAME];
125 
126     // We must extract a WebSocket key from a properly formed request or fail if it's not present.
127     let key = request
128         .headers()
129         .get(KEY_HEADERNAME)
130         .ok_or_else(|| {
131             Error::Protocol(ProtocolError::InvalidHeader(
132                 HeaderName::from_bytes(KEY_HEADERNAME.as_bytes()).unwrap(),
133             ))
134         })?
135         .to_str()?
136         .to_owned();
137 
138     // We must check that all necessary headers for a valid request are present. Note that we have to
139     // deal with the fact that some apps seem to have a case-sensitive check for headers which is not
140     // correct and should not considered the correct behavior, but it seems like some apps ignore it.
141     // `http` by default writes all headers in lower-case which is fine (and does not violate the RFC)
142     // but some servers seem to be poorely written and ignore RFC.
143     //
144     // See similar problem in `hyper`: https://github.com/hyperium/hyper/issues/1492
145     let headers = request.headers_mut();
146     for &header in &WEBSOCKET_HEADERS {
147         let value = headers.remove(header).ok_or_else(|| {
148             Error::Protocol(ProtocolError::InvalidHeader(
149                 HeaderName::from_bytes(header.as_bytes()).unwrap(),
150             ))
151         })?;
152         write!(req, "{header}: {value}\r\n", header = header, value = value.to_str()?).unwrap();
153     }
154 
155     // Now we must ensure that the headers that we've written once are not anymore present in the map.
156     // If they do, then the request is invalid (some headers are duplicated there for some reason).
157     let insensitive: Vec<String> =
158         WEBSOCKET_HEADERS.iter().map(|h| h.to_ascii_lowercase()).collect();
159     for (k, v) in headers {
160         let mut name = k.as_str();
161 
162         // We have already written the necessary headers once (above) and removed them from the map.
163         // If we encounter them again, then the request is considered invalid and error is returned.
164         // Note that we can't use `.contains()`, since `&str` does not coerce to `&String` in Rust.
165         if insensitive.iter().any(|x| x == name) {
166             return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone())));
167         }
168 
169         // Relates to the issue of some servers treating headers in a case-sensitive way, please see:
170         // https://github.com/snapview/tungstenite-rs/pull/119 (original fix of the problem)
171         if name == "sec-websocket-protocol" {
172             name = "Sec-WebSocket-Protocol";
173         }
174 
175         if name == "origin" {
176             name = "Origin";
177         }
178 
179         writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
180     }
181 
182     writeln!(req, "\r").unwrap();
183     trace!("Request: {:?}", String::from_utf8_lossy(&req));
184     Ok((req, key))
185 }
186 
extract_subprotocols_from_request(request: &Request) -> Result<Option<Vec<String>>>187 fn extract_subprotocols_from_request(request: &Request) -> Result<Option<Vec<String>>> {
188     if let Some(subprotocols) = request.headers().get("Sec-WebSocket-Protocol") {
189         Ok(Some(subprotocols.to_str()?.split(",").map(|s| s.to_string()).collect()))
190     } else {
191         Ok(None)
192     }
193 }
194 
195 /// Information for handshake verification.
196 #[derive(Debug)]
197 struct VerifyData {
198     /// Accepted server key.
199     accept_key: String,
200 
201     /// Accepted subprotocols
202     subprotocols: Option<Vec<String>>,
203 }
204 
205 impl VerifyData {
verify_response(&self, response: Response) -> Result<Response>206     pub fn verify_response(&self, response: Response) -> Result<Response> {
207         // 1. If the status code received from the server is not 101, the
208         // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
209         if response.status() != StatusCode::SWITCHING_PROTOCOLS {
210             return Err(Error::Http(response));
211         }
212 
213         let headers = response.headers();
214 
215         // 2. If the response lacks an |Upgrade| header field or the |Upgrade|
216         // header field contains a value that is not an ASCII case-
217         // insensitive match for the value "websocket", the client MUST
218         // _Fail the WebSocket Connection_. (RFC 6455)
219         if !headers
220             .get("Upgrade")
221             .and_then(|h| h.to_str().ok())
222             .map(|h| h.eq_ignore_ascii_case("websocket"))
223             .unwrap_or(false)
224         {
225             return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
226         }
227         // 3.  If the response lacks a |Connection| header field or the
228         // |Connection| header field doesn't contain a token that is an
229         // ASCII case-insensitive match for the value "Upgrade", the client
230         // MUST _Fail the WebSocket Connection_. (RFC 6455)
231         if !headers
232             .get("Connection")
233             .and_then(|h| h.to_str().ok())
234             .map(|h| h.eq_ignore_ascii_case("Upgrade"))
235             .unwrap_or(false)
236         {
237             return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
238         }
239         // 4.  If the response lacks a |Sec-WebSocket-Accept| header field or
240         // the |Sec-WebSocket-Accept| contains a value other than the
241         // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
242         // Connection_. (RFC 6455)
243         if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
244             return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
245         }
246         // 5.  If the response includes a |Sec-WebSocket-Extensions| header
247         // field and this header field indicates the use of an extension
248         // that was not present in the client's handshake (the server has
249         // indicated an extension not requested by the client), the client
250         // MUST _Fail the WebSocket Connection_. (RFC 6455)
251         // TODO
252 
253         // 6.  If the response includes a |Sec-WebSocket-Protocol| header field
254         // and this header field indicates the use of a subprotocol that was
255         // not present in the client's handshake (the server has indicated a
256         // subprotocol not requested by the client), the client MUST _Fail
257         // the WebSocket Connection_. (RFC 6455)
258         if headers.get("Sec-WebSocket-Protocol").is_none() && self.subprotocols.is_some() {
259             return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
260                 SubProtocolError::NoSubProtocol,
261             )));
262         }
263 
264         if headers.get("Sec-WebSocket-Protocol").is_some() && self.subprotocols.is_none() {
265             return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
266                 SubProtocolError::ServerSentSubProtocolNoneRequested,
267             )));
268         }
269 
270         if let Some(returned_subprotocol) = headers.get("Sec-WebSocket-Protocol") {
271             if let Some(accepted_subprotocols) = &self.subprotocols {
272                 if !accepted_subprotocols.contains(&returned_subprotocol.to_str()?.to_string()) {
273                     return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
274                         SubProtocolError::InvalidSubProtocol,
275                     )));
276                 }
277             }
278         }
279 
280         Ok(response)
281     }
282 }
283 
284 impl TryParse for Response {
try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>>285     fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
286         let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
287         let mut req = httparse::Response::new(&mut hbuffer);
288         Ok(match req.parse(buf)? {
289             Status::Partial => None,
290             Status::Complete(size) => Some((size, Response::from_httparse(req)?)),
291         })
292     }
293 }
294 
295 impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self>296     fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
297         if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
298             return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
299         }
300 
301         let headers = HeaderMap::from_httparse(raw.headers)?;
302 
303         let mut response = Response::new(None);
304         *response.status_mut() = StatusCode::from_u16(raw.code.expect("Bug: no HTTP status code"))?;
305         *response.headers_mut() = headers;
306         // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
307         // so the only valid value we could get in the response would be 1.1.
308         *response.version_mut() = http::Version::HTTP_11;
309 
310         Ok(response)
311     }
312 }
313 
314 /// Generate a random key for the `Sec-WebSocket-Key` header.
generate_key() -> String315 pub fn generate_key() -> String {
316     // a base64-encoded (see Section 4 of [RFC4648]) value that,
317     // when decoded, is 16 bytes in length (RFC 6455)
318     let r: [u8; 16] = rand::random();
319     data_encoding::BASE64.encode(&r)
320 }
321 
322 #[cfg(test)]
323 mod tests {
324     use super::{super::machine::TryParse, generate_key, generate_request, Response};
325     use crate::client::IntoClientRequest;
326 
327     #[test]
random_keys()328     fn random_keys() {
329         let k1 = generate_key();
330         println!("Generated random key 1: {k1}");
331         let k2 = generate_key();
332         println!("Generated random key 2: {k2}");
333         assert_ne!(k1, k2);
334         assert_eq!(k1.len(), k2.len());
335         assert_eq!(k1.len(), 24);
336         assert_eq!(k2.len(), 24);
337         assert!(k1.ends_with("=="));
338         assert!(k2.ends_with("=="));
339         assert!(k1[..22].find('=').is_none());
340         assert!(k2[..22].find('=').is_none());
341     }
342 
construct_expected(host: &str, key: &str) -> Vec<u8>343     fn construct_expected(host: &str, key: &str) -> Vec<u8> {
344         format!(
345             "\
346             GET /getCaseCount HTTP/1.1\r\n\
347             Host: {host}\r\n\
348             Connection: Upgrade\r\n\
349             Upgrade: websocket\r\n\
350             Sec-WebSocket-Version: 13\r\n\
351             Sec-WebSocket-Key: {key}\r\n\
352             \r\n"
353         )
354         .into_bytes()
355     }
356 
357     #[test]
request_formatting()358     fn request_formatting() {
359         let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
360         let (request, key) = generate_request(request).unwrap();
361         let correct = construct_expected("localhost", &key);
362         assert_eq!(&request[..], &correct[..]);
363     }
364 
365     #[test]
request_formatting_with_host()366     fn request_formatting_with_host() {
367         let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
368         let (request, key) = generate_request(request).unwrap();
369         let correct = construct_expected("localhost:9001", &key);
370         assert_eq!(&request[..], &correct[..]);
371     }
372 
373     #[test]
request_formatting_with_at()374     fn request_formatting_with_at() {
375         let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
376         let (request, key) = generate_request(request).unwrap();
377         let correct = construct_expected("localhost:9001", &key);
378         assert_eq!(&request[..], &correct[..]);
379     }
380 
381     #[test]
response_parsing()382     fn response_parsing() {
383         const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
384         let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
385         assert_eq!(resp.status(), http::StatusCode::OK);
386         assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],);
387     }
388 
389     #[test]
invalid_custom_request()390     fn invalid_custom_request() {
391         let request = http::Request::builder().method("GET").body(()).unwrap();
392         assert!(generate_request(request).is_err());
393     }
394 }
395