1 use crate::codec::{Decoder, Encoder}; 2 3 use futures_core::Stream; 4 use tokio::{io::ReadBuf, net::UdpSocket}; 5 6 use bytes::{BufMut, BytesMut}; 7 use futures_core::ready; 8 use futures_sink::Sink; 9 use std::pin::Pin; 10 use std::task::{Context, Poll}; 11 use std::{ 12 borrow::Borrow, 13 net::{Ipv4Addr, SocketAddr, SocketAddrV4}, 14 }; 15 use std::{io, mem::MaybeUninit}; 16 17 /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using 18 /// the `Encoder` and `Decoder` traits to encode and decode frames. 19 /// 20 /// Raw UDP sockets work with datagrams, but higher-level code usually wants to 21 /// batch these into meaningful chunks, called "frames". This method layers 22 /// framing on top of this socket by using the `Encoder` and `Decoder` traits to 23 /// handle encoding and decoding of messages frames. Note that the incoming and 24 /// outgoing frame types may be distinct. 25 /// 26 /// This function returns a *single* object that is both [`Stream`] and [`Sink`]; 27 /// grouping this into a single object is often useful for layering things which 28 /// require both read and write access to the underlying object. 29 /// 30 /// If you want to work more directly with the streams and sink, consider 31 /// calling [`split`] on the `UdpFramed` returned by this method, which will break 32 /// them into separate objects, allowing them to interact more easily. 33 /// 34 /// [`Stream`]: futures_core::Stream 35 /// [`Sink`]: futures_sink::Sink 36 /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split 37 #[must_use = "sinks do nothing unless polled"] 38 #[derive(Debug)] 39 pub struct UdpFramed<C, T = UdpSocket> { 40 socket: T, 41 codec: C, 42 rd: BytesMut, 43 wr: BytesMut, 44 out_addr: SocketAddr, 45 flushed: bool, 46 is_readable: bool, 47 current_addr: Option<SocketAddr>, 48 } 49 50 const INITIAL_RD_CAPACITY: usize = 64 * 1024; 51 const INITIAL_WR_CAPACITY: usize = 8 * 1024; 52 53 impl<C, T> Unpin for UdpFramed<C, T> {} 54 55 impl<C, T> Stream for UdpFramed<C, T> 56 where 57 T: Borrow<UdpSocket>, 58 C: Decoder, 59 { 60 type Item = Result<(C::Item, SocketAddr), C::Error>; 61 poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>62 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 63 let pin = self.get_mut(); 64 65 pin.rd.reserve(INITIAL_RD_CAPACITY); 66 67 loop { 68 // Are there still bytes left in the read buffer to decode? 69 if pin.is_readable { 70 if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { 71 let current_addr = pin 72 .current_addr 73 .expect("will always be set before this line is called"); 74 75 return Poll::Ready(Some(Ok((frame, current_addr)))); 76 } 77 78 // if this line has been reached then decode has returned `None`. 79 pin.is_readable = false; 80 pin.rd.clear(); 81 } 82 83 // We're out of data. Try and fetch more data to decode 84 let addr = { 85 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a 86 // transparent wrapper around `[MaybeUninit<u8>]`. 87 let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) }; 88 let mut read = ReadBuf::uninit(buf); 89 let ptr = read.filled().as_ptr(); 90 let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); 91 92 assert_eq!(ptr, read.filled().as_ptr()); 93 let addr = res?; 94 95 // Safety: This is guaranteed to be the number of initialized (and read) bytes due 96 // to the invariants provided by `ReadBuf::filled`. 97 unsafe { pin.rd.advance_mut(read.filled().len()) }; 98 99 addr 100 }; 101 102 pin.current_addr = Some(addr); 103 pin.is_readable = true; 104 } 105 } 106 } 107 108 impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T> 109 where 110 T: Borrow<UdpSocket>, 111 C: Encoder<I>, 112 { 113 type Error = C::Error; 114 poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>115 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 116 if !self.flushed { 117 match self.poll_flush(cx)? { 118 Poll::Ready(()) => {} 119 Poll::Pending => return Poll::Pending, 120 } 121 } 122 123 Poll::Ready(Ok(())) 124 } 125 start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error>126 fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { 127 let (frame, out_addr) = item; 128 129 let pin = self.get_mut(); 130 131 pin.codec.encode(frame, &mut pin.wr)?; 132 pin.out_addr = out_addr; 133 pin.flushed = false; 134 135 Ok(()) 136 } 137 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>138 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 139 if self.flushed { 140 return Poll::Ready(Ok(())); 141 } 142 143 let Self { 144 ref socket, 145 ref mut out_addr, 146 ref mut wr, 147 .. 148 } = *self; 149 150 let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; 151 152 let wrote_all = n == self.wr.len(); 153 self.wr.clear(); 154 self.flushed = true; 155 156 let res = if wrote_all { 157 Ok(()) 158 } else { 159 Err(io::Error::new( 160 io::ErrorKind::Other, 161 "failed to write entire datagram to socket", 162 ) 163 .into()) 164 }; 165 166 Poll::Ready(res) 167 } 168 poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>169 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 170 ready!(self.poll_flush(cx))?; 171 Poll::Ready(Ok(())) 172 } 173 } 174 175 impl<C, T> UdpFramed<C, T> 176 where 177 T: Borrow<UdpSocket>, 178 { 179 /// Create a new `UdpFramed` backed by the given socket and codec. 180 /// 181 /// See struct level documentation for more details. new(socket: T, codec: C) -> UdpFramed<C, T>182 pub fn new(socket: T, codec: C) -> UdpFramed<C, T> { 183 Self { 184 socket, 185 codec, 186 out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), 187 rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), 188 wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), 189 flushed: true, 190 is_readable: false, 191 current_addr: None, 192 } 193 } 194 195 /// Returns a reference to the underlying I/O stream wrapped by `Framed`. 196 /// 197 /// # Note 198 /// 199 /// Care should be taken to not tamper with the underlying stream of data 200 /// coming in as it may corrupt the stream of frames otherwise being worked 201 /// with. get_ref(&self) -> &T202 pub fn get_ref(&self) -> &T { 203 &self.socket 204 } 205 206 /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. 207 /// 208 /// # Note 209 /// 210 /// Care should be taken to not tamper with the underlying stream of data 211 /// coming in as it may corrupt the stream of frames otherwise being worked 212 /// with. get_mut(&mut self) -> &mut T213 pub fn get_mut(&mut self) -> &mut T { 214 &mut self.socket 215 } 216 217 /// Returns a reference to the underlying codec wrapped by 218 /// `Framed`. 219 /// 220 /// Note that care should be taken to not tamper with the underlying codec 221 /// as it may corrupt the stream of frames otherwise being worked with. codec(&self) -> &C222 pub fn codec(&self) -> &C { 223 &self.codec 224 } 225 226 /// Returns a mutable reference to the underlying codec wrapped by 227 /// `UdpFramed`. 228 /// 229 /// Note that care should be taken to not tamper with the underlying codec 230 /// as it may corrupt the stream of frames otherwise being worked with. codec_mut(&mut self) -> &mut C231 pub fn codec_mut(&mut self) -> &mut C { 232 &mut self.codec 233 } 234 235 /// Returns a reference to the read buffer. read_buffer(&self) -> &BytesMut236 pub fn read_buffer(&self) -> &BytesMut { 237 &self.rd 238 } 239 240 /// Returns a mutable reference to the read buffer. read_buffer_mut(&mut self) -> &mut BytesMut241 pub fn read_buffer_mut(&mut self) -> &mut BytesMut { 242 &mut self.rd 243 } 244 245 /// Consumes the `Framed`, returning its underlying I/O stream. into_inner(self) -> T246 pub fn into_inner(self) -> T { 247 self.socket 248 } 249 } 250