• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Methods to connect to a WebSocket as a client.
2 
3 use std::{
4     convert::TryFrom,
5     io::{Read, Write},
6     net::{SocketAddr, TcpStream, ToSocketAddrs},
7     result::Result as StdResult,
8 };
9 
10 use http::{request::Parts, HeaderName, Uri};
11 use log::*;
12 
13 use crate::{
14     handshake::client::{generate_key, Request, Response},
15     protocol::WebSocketConfig,
16     stream::MaybeTlsStream,
17 };
18 
19 use crate::{
20     error::{Error, Result, UrlError},
21     handshake::{client::ClientHandshake, HandshakeError},
22     protocol::WebSocket,
23     stream::{Mode, NoDelay},
24 };
25 
26 /// Connect to the given WebSocket in blocking mode.
27 ///
28 /// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is
29 /// equal to calling `connect()` function.
30 ///
31 /// The URL may be either ws:// or wss://.
32 /// To support wss:// URLs, you must activate the TLS feature on the crate level. Please refer to the
33 /// project's [README][readme] for more information on available features.
34 ///
35 /// This function "just works" for those who wants a simple blocking solution
36 /// similar to `std::net::TcpStream`. If you want a non-blocking or other
37 /// custom stream, call `client` instead.
38 ///
39 /// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
40 /// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
41 /// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
42 ///
43 /// [readme]: https://github.com/snapview/tungstenite-rs/#features
connect_with_config<Req: IntoClientRequest>( request: Req, config: Option<WebSocketConfig>, max_redirects: u8, ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)>44 pub fn connect_with_config<Req: IntoClientRequest>(
45     request: Req,
46     config: Option<WebSocketConfig>,
47     max_redirects: u8,
48 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
49     fn try_client_handshake(
50         request: Request,
51         config: Option<WebSocketConfig>,
52     ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
53         let uri = request.uri();
54         let mode = uri_mode(uri)?;
55 
56         #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
57         if let Mode::Tls = mode {
58             return Err(Error::Url(UrlError::TlsFeatureNotEnabled));
59         }
60 
61         let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
62         let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
63         let port = uri.port_u16().unwrap_or(match mode {
64             Mode::Plain => 80,
65             Mode::Tls => 443,
66         });
67         let addrs = (host, port).to_socket_addrs()?;
68         let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
69         NoDelay::set_nodelay(&mut stream, true)?;
70 
71         #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
72         let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
73         #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
74         let client = crate::tls::client_tls_with_config(request, stream, config, None);
75 
76         client.map_err(|e| match e {
77             HandshakeError::Failure(f) => f,
78             HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
79         })
80     }
81 
82     fn create_request(parts: &Parts, uri: &Uri) -> Request {
83         let mut builder =
84             Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
85         *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
86         builder.body(()).expect("Failed to create `Request`")
87     }
88 
89     let (parts, _) = request.into_client_request()?.into_parts();
90     let mut uri = parts.uri.clone();
91 
92     for attempt in 0..(max_redirects + 1) {
93         let request = create_request(&parts, &uri);
94 
95         match try_client_handshake(request, config) {
96             Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
97                 if let Some(location) = res.headers().get("Location") {
98                     uri = location.to_str()?.parse::<Uri>()?;
99                     debug!("Redirecting to {:?}", uri);
100                     continue;
101                 } else {
102                     warn!("No `Location` found in redirect");
103                     return Err(Error::Http(res));
104                 }
105             }
106             other => return other,
107         }
108     }
109 
110     unreachable!("Bug in a redirect handling logic")
111 }
112 
113 /// Connect to the given WebSocket in blocking mode.
114 ///
115 /// The URL may be either ws:// or wss://.
116 /// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on.
117 ///
118 /// This function "just works" for those who wants a simple blocking solution
119 /// similar to `std::net::TcpStream`. If you want a non-blocking or other
120 /// custom stream, call `client` instead.
121 ///
122 /// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
123 /// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
124 /// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
connect<Req: IntoClientRequest>( request: Req, ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)>125 pub fn connect<Req: IntoClientRequest>(
126     request: Req,
127 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
128     connect_with_config(request, None, 3)
129 }
130 
connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream>131 fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
132     for addr in addrs {
133         debug!("Trying to contact {} at {}...", uri, addr);
134         if let Ok(stream) = TcpStream::connect(addr) {
135             return Ok(stream);
136         }
137     }
138     Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
139 }
140 
141 /// Get the mode of the given URL.
142 ///
143 /// This function may be used to ease the creation of custom TLS streams
144 /// in non-blocking algorithms or for use with TLS libraries other than `native_tls` or `rustls`.
uri_mode(uri: &Uri) -> Result<Mode>145 pub fn uri_mode(uri: &Uri) -> Result<Mode> {
146     match uri.scheme_str() {
147         Some("ws") => Ok(Mode::Plain),
148         Some("wss") => Ok(Mode::Tls),
149         _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
150     }
151 }
152 
153 /// Do the client handshake over the given stream given a web socket configuration. Passing `None`
154 /// as configuration is equal to calling `client()` function.
155 ///
156 /// Use this function if you need a nonblocking handshake support or if you
157 /// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
158 /// Any stream supporting `Read + Write` will do.
client_with_config<Stream, Req>( request: Req, stream: Stream, config: Option<WebSocketConfig>, ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> where Stream: Read + Write, Req: IntoClientRequest,159 pub fn client_with_config<Stream, Req>(
160     request: Req,
161     stream: Stream,
162     config: Option<WebSocketConfig>,
163 ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
164 where
165     Stream: Read + Write,
166     Req: IntoClientRequest,
167 {
168     ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
169 }
170 
171 /// Do the client handshake over the given stream.
172 ///
173 /// Use this function if you need a nonblocking handshake support or if you
174 /// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
175 /// Any stream supporting `Read + Write` will do.
client<Stream, Req>( request: Req, stream: Stream, ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> where Stream: Read + Write, Req: IntoClientRequest,176 pub fn client<Stream, Req>(
177     request: Req,
178     stream: Stream,
179 ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
180 where
181     Stream: Read + Write,
182     Req: IntoClientRequest,
183 {
184     client_with_config(request, stream, None)
185 }
186 
187 /// Trait for converting various types into HTTP requests used for a client connection.
188 ///
189 /// This trait is implemented by default for string slices, strings, `http::Uri` and
190 /// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
191 /// simply take your request and pass it as is further without altering any headers or URLs, so
192 /// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
193 /// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
194 /// the proper `http::Request<()>` for you.
195 pub trait IntoClientRequest {
196     /// Convert into a `Request` that can be used for a client connection.
into_client_request(self) -> Result<Request>197     fn into_client_request(self) -> Result<Request>;
198 }
199 
200 impl<'a> IntoClientRequest for &'a str {
into_client_request(self) -> Result<Request>201     fn into_client_request(self) -> Result<Request> {
202         self.parse::<Uri>()?.into_client_request()
203     }
204 }
205 
206 impl<'a> IntoClientRequest for &'a String {
into_client_request(self) -> Result<Request>207     fn into_client_request(self) -> Result<Request> {
208         <&str as IntoClientRequest>::into_client_request(self)
209     }
210 }
211 
212 impl IntoClientRequest for String {
into_client_request(self) -> Result<Request>213     fn into_client_request(self) -> Result<Request> {
214         <&str as IntoClientRequest>::into_client_request(&self)
215     }
216 }
217 
218 impl<'a> IntoClientRequest for &'a Uri {
into_client_request(self) -> Result<Request>219     fn into_client_request(self) -> Result<Request> {
220         self.clone().into_client_request()
221     }
222 }
223 
224 impl IntoClientRequest for Uri {
into_client_request(self) -> Result<Request>225     fn into_client_request(self) -> Result<Request> {
226         let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
227         let host = authority
228             .find('@')
229             .map(|idx| authority.split_at(idx + 1).1)
230             .unwrap_or_else(|| authority);
231 
232         if host.is_empty() {
233             return Err(Error::Url(UrlError::EmptyHostName));
234         }
235 
236         let req = Request::builder()
237             .method("GET")
238             .header("Host", host)
239             .header("Connection", "Upgrade")
240             .header("Upgrade", "websocket")
241             .header("Sec-WebSocket-Version", "13")
242             .header("Sec-WebSocket-Key", generate_key())
243             .uri(self)
244             .body(())?;
245         Ok(req)
246     }
247 }
248 
249 #[cfg(feature = "url")]
250 impl<'a> IntoClientRequest for &'a url::Url {
into_client_request(self) -> Result<Request>251     fn into_client_request(self) -> Result<Request> {
252         self.as_str().into_client_request()
253     }
254 }
255 
256 #[cfg(feature = "url")]
257 impl IntoClientRequest for url::Url {
into_client_request(self) -> Result<Request>258     fn into_client_request(self) -> Result<Request> {
259         self.as_str().into_client_request()
260     }
261 }
262 
263 impl IntoClientRequest for Request {
into_client_request(self) -> Result<Request>264     fn into_client_request(self) -> Result<Request> {
265         Ok(self)
266     }
267 }
268 
269 impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
into_client_request(self) -> Result<Request>270     fn into_client_request(self) -> Result<Request> {
271         use crate::handshake::headers::FromHttparse;
272         Request::from_httparse(self)
273     }
274 }
275 
276 /// Builder for a custom [`IntoClientRequest`] with options to add
277 /// custom additional headers and sub protocols.
278 ///
279 /// # Example
280 ///
281 /// ```rust no_run
282 /// # use crate::*;
283 /// use http::Uri;
284 /// use tungstenite::{connect, ClientRequestBuilder};
285 ///
286 /// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap();
287 /// let token = "my_jwt_token";
288 /// let builder = ClientRequestBuilder::new(uri)
289 ///     .with_header("Authorization", format!("Bearer {token}"))
290 ///     .with_sub_protocol("my_sub_protocol");
291 /// let socket = connect(builder).unwrap();
292 /// ```
293 #[derive(Debug, Clone)]
294 pub struct ClientRequestBuilder {
295     uri: Uri,
296     /// Additional [`Request`] handshake headers
297     additional_headers: Vec<(String, String)>,
298     /// Handsake subprotocols
299     subprotocols: Vec<String>,
300 }
301 
302 impl ClientRequestBuilder {
303     /// Initializes an empty request builder
304     #[must_use]
new(uri: Uri) -> Self305     pub const fn new(uri: Uri) -> Self {
306         Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
307     }
308 
309     /// Adds (`key`, `value`) as an additional header to the handshake request
with_header<K, V>(mut self, key: K, value: V) -> Self where K: Into<String>, V: Into<String>,310     pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
311     where
312         K: Into<String>,
313         V: Into<String>,
314     {
315         self.additional_headers.push((key.into(), value.into()));
316         self
317     }
318 
319     /// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`)
with_sub_protocol<P>(mut self, protocol: P) -> Self where P: Into<String>,320     pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
321     where
322         P: Into<String>,
323     {
324         self.subprotocols.push(protocol.into());
325         self
326     }
327 }
328 
329 impl IntoClientRequest for ClientRequestBuilder {
into_client_request(self) -> Result<Request>330     fn into_client_request(self) -> Result<Request> {
331         let mut request = self.uri.into_client_request()?;
332         let headers = request.headers_mut();
333         for (k, v) in self.additional_headers {
334             let key = HeaderName::try_from(k)?;
335             let value = v.parse()?;
336             headers.append(key, value);
337         }
338         if !self.subprotocols.is_empty() {
339             let protocols = self.subprotocols.join(", ").parse()?;
340             headers.append("Sec-WebSocket-Protocol", protocols);
341         }
342         Ok(request)
343     }
344 }
345