1 //! WebSocket handshake machine. 2 3 use bytes::Buf; 4 use log::*; 5 use std::io::{Cursor, Read, Write}; 6 7 use crate::{ 8 error::{Error, ProtocolError, Result}, 9 util::NonBlockingResult, 10 ReadBuffer, 11 }; 12 13 /// A generic handshake state machine. 14 #[derive(Debug)] 15 pub struct HandshakeMachine<Stream> { 16 stream: Stream, 17 state: HandshakeState, 18 } 19 20 impl<Stream> HandshakeMachine<Stream> { 21 /// Start reading data from the peer. start_read(stream: Stream) -> Self22 pub fn start_read(stream: Stream) -> Self { 23 Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) } 24 } 25 /// Start writing data to the peer. start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self26 pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self { 27 HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) } 28 } 29 /// Returns a shared reference to the inner stream. get_ref(&self) -> &Stream30 pub fn get_ref(&self) -> &Stream { 31 &self.stream 32 } 33 /// Returns a mutable reference to the inner stream. get_mut(&mut self) -> &mut Stream34 pub fn get_mut(&mut self) -> &mut Stream { 35 &mut self.stream 36 } 37 } 38 39 impl<Stream: Read + Write> HandshakeMachine<Stream> { 40 /// Perform a single handshake round. single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>>41 pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> { 42 trace!("Doing handshake round."); 43 match self.state { 44 HandshakeState::Reading(mut buf, mut attack_check) => { 45 let read = buf.read_from(&mut self.stream).no_block()?; 46 match read { 47 Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)), 48 Some(count) => { 49 attack_check.check_incoming_packet_size(count)?; 50 // TODO: this is slow for big headers with too many small packets. 51 // The parser has to be reworked in order to work on streams instead 52 // of buffers. 53 Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { 54 buf.advance(size); 55 RoundResult::StageFinished(StageResult::DoneReading { 56 result: obj, 57 stream: self.stream, 58 tail: buf.into_vec(), 59 }) 60 } else { 61 RoundResult::Incomplete(HandshakeMachine { 62 state: HandshakeState::Reading(buf, attack_check), 63 ..self 64 }) 65 }) 66 } 67 None => Ok(RoundResult::WouldBlock(HandshakeMachine { 68 state: HandshakeState::Reading(buf, attack_check), 69 ..self 70 })), 71 } 72 } 73 HandshakeState::Writing(mut buf) => { 74 assert!(buf.has_remaining()); 75 if let Some(size) = self.stream.write(Buf::chunk(&buf)).no_block()? { 76 assert!(size > 0); 77 buf.advance(size); 78 Ok(if buf.has_remaining() { 79 RoundResult::Incomplete(HandshakeMachine { 80 state: HandshakeState::Writing(buf), 81 ..self 82 }) 83 } else { 84 RoundResult::Incomplete(HandshakeMachine { 85 state: HandshakeState::Flushing, 86 ..self 87 }) 88 }) 89 } else { 90 Ok(RoundResult::WouldBlock(HandshakeMachine { 91 state: HandshakeState::Writing(buf), 92 ..self 93 })) 94 } 95 } 96 HandshakeState::Flushing => Ok(match self.stream.flush().no_block()? { 97 Some(()) => RoundResult::StageFinished(StageResult::DoneWriting(self.stream)), 98 None => RoundResult::WouldBlock(HandshakeMachine { 99 state: HandshakeState::Flushing, 100 ..self 101 }), 102 }), 103 } 104 } 105 } 106 107 /// The result of the round. 108 #[derive(Debug)] 109 pub enum RoundResult<Obj, Stream> { 110 /// Round not done, I/O would block. 111 WouldBlock(HandshakeMachine<Stream>), 112 /// Round done, state unchanged. 113 Incomplete(HandshakeMachine<Stream>), 114 /// Stage complete. 115 StageFinished(StageResult<Obj, Stream>), 116 } 117 118 /// The result of the stage. 119 #[derive(Debug)] 120 pub enum StageResult<Obj, Stream> { 121 /// Reading round finished. 122 #[allow(missing_docs)] 123 DoneReading { result: Obj, stream: Stream, tail: Vec<u8> }, 124 /// Writing round finished. 125 DoneWriting(Stream), 126 } 127 128 /// The parseable object. 129 pub trait TryParse: Sized { 130 /// Return Ok(None) if incomplete, Err on syntax error. try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>131 fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>; 132 } 133 134 /// The handshake state. 135 #[derive(Debug)] 136 enum HandshakeState { 137 /// Reading data from the peer. 138 Reading(ReadBuffer, AttackCheck), 139 /// Sending data to the peer. 140 Writing(Cursor<Vec<u8>>), 141 /// Flushing data to ensure that all intermediately buffered contents reach their destination. 142 Flushing, 143 } 144 145 /// Attack mitigation. Contains counters needed to prevent DoS attacks 146 /// and reject valid but useless headers. 147 #[derive(Debug)] 148 pub(crate) struct AttackCheck { 149 /// Number of HTTP header successful reads (TCP packets). 150 number_of_packets: usize, 151 /// Total number of bytes in HTTP header. 152 number_of_bytes: usize, 153 } 154 155 impl AttackCheck { 156 /// Initialize attack checking for incoming buffer. new() -> Self157 fn new() -> Self { 158 Self { number_of_packets: 0, number_of_bytes: 0 } 159 } 160 161 /// Check the size of an incoming packet. To be called immediately after `read()` 162 /// passing its returned bytes count as `size`. check_incoming_packet_size(&mut self, size: usize) -> Result<()>163 fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> { 164 self.number_of_packets += 1; 165 self.number_of_bytes += size; 166 167 // TODO: these values are hardcoded. Instead of making them configurable, 168 // rework the way HTTP header is parsed to remove this check at all. 169 const MAX_BYTES: usize = 65536; 170 const MAX_PACKETS: usize = 512; 171 const MIN_PACKET_SIZE: usize = 128; 172 const MIN_PACKET_CHECK_THRESHOLD: usize = 64; 173 174 if self.number_of_bytes > MAX_BYTES { 175 return Err(Error::AttackAttempt); 176 } 177 178 if self.number_of_packets > MAX_PACKETS { 179 return Err(Error::AttackAttempt); 180 } 181 182 if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD 183 && self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes 184 { 185 return Err(Error::AttackAttempt); 186 } 187 188 Ok(()) 189 } 190 } 191