• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! WebSocket handshake machine.
2 
3 use bytes::Buf;
4 use log::*;
5 use std::io::{Cursor, Read, Write};
6 
7 use crate::{
8     error::{Error, ProtocolError, Result},
9     util::NonBlockingResult,
10     ReadBuffer,
11 };
12 
13 /// A generic handshake state machine.
14 #[derive(Debug)]
15 pub struct HandshakeMachine<Stream> {
16     stream: Stream,
17     state: HandshakeState,
18 }
19 
20 impl<Stream> HandshakeMachine<Stream> {
21     /// Start reading data from the peer.
start_read(stream: Stream) -> Self22     pub fn start_read(stream: Stream) -> Self {
23         Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
24     }
25     /// Start writing data to the peer.
start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self26     pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
27         HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
28     }
29     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &Stream30     pub fn get_ref(&self) -> &Stream {
31         &self.stream
32     }
33     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut Stream34     pub fn get_mut(&mut self) -> &mut Stream {
35         &mut self.stream
36     }
37 }
38 
39 impl<Stream: Read + Write> HandshakeMachine<Stream> {
40     /// Perform a single handshake round.
single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>>41     pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
42         trace!("Doing handshake round.");
43         match self.state {
44             HandshakeState::Reading(mut buf, mut attack_check) => {
45                 let read = buf.read_from(&mut self.stream).no_block()?;
46                 match read {
47                     Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
48                     Some(count) => {
49                         attack_check.check_incoming_packet_size(count)?;
50                         // TODO: this is slow for big headers with too many small packets.
51                         // The parser has to be reworked in order to work on streams instead
52                         // of buffers.
53                         Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
54                             buf.advance(size);
55                             RoundResult::StageFinished(StageResult::DoneReading {
56                                 result: obj,
57                                 stream: self.stream,
58                                 tail: buf.into_vec(),
59                             })
60                         } else {
61                             RoundResult::Incomplete(HandshakeMachine {
62                                 state: HandshakeState::Reading(buf, attack_check),
63                                 ..self
64                             })
65                         })
66                     }
67                     None => Ok(RoundResult::WouldBlock(HandshakeMachine {
68                         state: HandshakeState::Reading(buf, attack_check),
69                         ..self
70                     })),
71                 }
72             }
73             HandshakeState::Writing(mut buf) => {
74                 assert!(buf.has_remaining());
75                 if let Some(size) = self.stream.write(Buf::chunk(&buf)).no_block()? {
76                     assert!(size > 0);
77                     buf.advance(size);
78                     Ok(if buf.has_remaining() {
79                         RoundResult::Incomplete(HandshakeMachine {
80                             state: HandshakeState::Writing(buf),
81                             ..self
82                         })
83                     } else {
84                         RoundResult::Incomplete(HandshakeMachine {
85                             state: HandshakeState::Flushing,
86                             ..self
87                         })
88                     })
89                 } else {
90                     Ok(RoundResult::WouldBlock(HandshakeMachine {
91                         state: HandshakeState::Writing(buf),
92                         ..self
93                     }))
94                 }
95             }
96             HandshakeState::Flushing => Ok(match self.stream.flush().no_block()? {
97                 Some(()) => RoundResult::StageFinished(StageResult::DoneWriting(self.stream)),
98                 None => RoundResult::WouldBlock(HandshakeMachine {
99                     state: HandshakeState::Flushing,
100                     ..self
101                 }),
102             }),
103         }
104     }
105 }
106 
107 /// The result of the round.
108 #[derive(Debug)]
109 pub enum RoundResult<Obj, Stream> {
110     /// Round not done, I/O would block.
111     WouldBlock(HandshakeMachine<Stream>),
112     /// Round done, state unchanged.
113     Incomplete(HandshakeMachine<Stream>),
114     /// Stage complete.
115     StageFinished(StageResult<Obj, Stream>),
116 }
117 
118 /// The result of the stage.
119 #[derive(Debug)]
120 pub enum StageResult<Obj, Stream> {
121     /// Reading round finished.
122     #[allow(missing_docs)]
123     DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
124     /// Writing round finished.
125     DoneWriting(Stream),
126 }
127 
128 /// The parseable object.
129 pub trait TryParse: Sized {
130     /// Return Ok(None) if incomplete, Err on syntax error.
try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>131     fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>;
132 }
133 
134 /// The handshake state.
135 #[derive(Debug)]
136 enum HandshakeState {
137     /// Reading data from the peer.
138     Reading(ReadBuffer, AttackCheck),
139     /// Sending data to the peer.
140     Writing(Cursor<Vec<u8>>),
141     /// Flushing data to ensure that all intermediately buffered contents reach their destination.
142     Flushing,
143 }
144 
145 /// Attack mitigation. Contains counters needed to prevent DoS attacks
146 /// and reject valid but useless headers.
147 #[derive(Debug)]
148 pub(crate) struct AttackCheck {
149     /// Number of HTTP header successful reads (TCP packets).
150     number_of_packets: usize,
151     /// Total number of bytes in HTTP header.
152     number_of_bytes: usize,
153 }
154 
155 impl AttackCheck {
156     /// Initialize attack checking for incoming buffer.
new() -> Self157     fn new() -> Self {
158         Self { number_of_packets: 0, number_of_bytes: 0 }
159     }
160 
161     /// Check the size of an incoming packet. To be called immediately after `read()`
162     /// passing its returned bytes count as `size`.
check_incoming_packet_size(&mut self, size: usize) -> Result<()>163     fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
164         self.number_of_packets += 1;
165         self.number_of_bytes += size;
166 
167         // TODO: these values are hardcoded. Instead of making them configurable,
168         // rework the way HTTP header is parsed to remove this check at all.
169         const MAX_BYTES: usize = 65536;
170         const MAX_PACKETS: usize = 512;
171         const MIN_PACKET_SIZE: usize = 128;
172         const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
173 
174         if self.number_of_bytes > MAX_BYTES {
175             return Err(Error::AttackAttempt);
176         }
177 
178         if self.number_of_packets > MAX_PACKETS {
179             return Err(Error::AttackAttempt);
180         }
181 
182         if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD
183             && self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes
184         {
185             return Err(Error::AttackAttempt);
186         }
187 
188         Ok(())
189     }
190 }
191