• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::codec::{Decoder, Encoder};
2 
3 use futures_core::Stream;
4 use tokio::{io::ReadBuf, net::UdpSocket};
5 
6 use bytes::{BufMut, BytesMut};
7 use futures_core::ready;
8 use futures_sink::Sink;
9 use std::pin::Pin;
10 use std::task::{Context, Poll};
11 use std::{
12     borrow::Borrow,
13     net::{Ipv4Addr, SocketAddr, SocketAddrV4},
14 };
15 use std::{io, mem::MaybeUninit};
16 
17 /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
18 /// the `Encoder` and `Decoder` traits to encode and decode frames.
19 ///
20 /// Raw UDP sockets work with datagrams, but higher-level code usually wants to
21 /// batch these into meaningful chunks, called "frames". This method layers
22 /// framing on top of this socket by using the `Encoder` and `Decoder` traits to
23 /// handle encoding and decoding of messages frames. Note that the incoming and
24 /// outgoing frame types may be distinct.
25 ///
26 /// This function returns a *single* object that is both [`Stream`] and [`Sink`];
27 /// grouping this into a single object is often useful for layering things which
28 /// require both read and write access to the underlying object.
29 ///
30 /// If you want to work more directly with the streams and sink, consider
31 /// calling [`split`] on the `UdpFramed` returned by this method, which will break
32 /// them into separate objects, allowing them to interact more easily.
33 ///
34 /// [`Stream`]: futures_core::Stream
35 /// [`Sink`]: futures_sink::Sink
36 /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
37 #[must_use = "sinks do nothing unless polled"]
38 #[derive(Debug)]
39 pub struct UdpFramed<C, T = UdpSocket> {
40     socket: T,
41     codec: C,
42     rd: BytesMut,
43     wr: BytesMut,
44     out_addr: SocketAddr,
45     flushed: bool,
46     is_readable: bool,
47     current_addr: Option<SocketAddr>,
48 }
49 
50 const INITIAL_RD_CAPACITY: usize = 64 * 1024;
51 const INITIAL_WR_CAPACITY: usize = 8 * 1024;
52 
53 impl<C, T> Unpin for UdpFramed<C, T> {}
54 
55 impl<C, T> Stream for UdpFramed<C, T>
56 where
57     T: Borrow<UdpSocket>,
58     C: Decoder,
59 {
60     type Item = Result<(C::Item, SocketAddr), C::Error>;
61 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>62     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63         let pin = self.get_mut();
64 
65         pin.rd.reserve(INITIAL_RD_CAPACITY);
66 
67         loop {
68             // Are there still bytes left in the read buffer to decode?
69             if pin.is_readable {
70                 if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
71                     let current_addr = pin
72                         .current_addr
73                         .expect("will always be set before this line is called");
74 
75                     return Poll::Ready(Some(Ok((frame, current_addr))));
76                 }
77 
78                 // if this line has been reached then decode has returned `None`.
79                 pin.is_readable = false;
80                 pin.rd.clear();
81             }
82 
83             // We're out of data. Try and fetch more data to decode
84             let addr = {
85                 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
86                 // transparent wrapper around `[MaybeUninit<u8>]`.
87                 let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) };
88                 let mut read = ReadBuf::uninit(buf);
89                 let ptr = read.filled().as_ptr();
90                 let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
91 
92                 assert_eq!(ptr, read.filled().as_ptr());
93                 let addr = res?;
94 
95                 // Safety: This is guaranteed to be the number of initialized (and read) bytes due
96                 // to the invariants provided by `ReadBuf::filled`.
97                 unsafe { pin.rd.advance_mut(read.filled().len()) };
98 
99                 addr
100             };
101 
102             pin.current_addr = Some(addr);
103             pin.is_readable = true;
104         }
105     }
106 }
107 
108 impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
109 where
110     T: Borrow<UdpSocket>,
111     C: Encoder<I>,
112 {
113     type Error = C::Error;
114 
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>115     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116         if !self.flushed {
117             match self.poll_flush(cx)? {
118                 Poll::Ready(()) => {}
119                 Poll::Pending => return Poll::Pending,
120             }
121         }
122 
123         Poll::Ready(Ok(()))
124     }
125 
start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error>126     fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
127         let (frame, out_addr) = item;
128 
129         let pin = self.get_mut();
130 
131         pin.codec.encode(frame, &mut pin.wr)?;
132         pin.out_addr = out_addr;
133         pin.flushed = false;
134 
135         Ok(())
136     }
137 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>138     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139         if self.flushed {
140             return Poll::Ready(Ok(()));
141         }
142 
143         let Self {
144             ref socket,
145             ref mut out_addr,
146             ref mut wr,
147             ..
148         } = *self;
149 
150         let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?;
151 
152         let wrote_all = n == self.wr.len();
153         self.wr.clear();
154         self.flushed = true;
155 
156         let res = if wrote_all {
157             Ok(())
158         } else {
159             Err(io::Error::new(
160                 io::ErrorKind::Other,
161                 "failed to write entire datagram to socket",
162             )
163             .into())
164         };
165 
166         Poll::Ready(res)
167     }
168 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>169     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170         ready!(self.poll_flush(cx))?;
171         Poll::Ready(Ok(()))
172     }
173 }
174 
175 impl<C, T> UdpFramed<C, T>
176 where
177     T: Borrow<UdpSocket>,
178 {
179     /// Create a new `UdpFramed` backed by the given socket and codec.
180     ///
181     /// See struct level documentation for more details.
new(socket: T, codec: C) -> UdpFramed<C, T>182     pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
183         Self {
184             socket,
185             codec,
186             out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
187             rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
188             wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
189             flushed: true,
190             is_readable: false,
191             current_addr: None,
192         }
193     }
194 
195     /// Returns a reference to the underlying I/O stream wrapped by `Framed`.
196     ///
197     /// # Note
198     ///
199     /// Care should be taken to not tamper with the underlying stream of data
200     /// coming in as it may corrupt the stream of frames otherwise being worked
201     /// with.
get_ref(&self) -> &T202     pub fn get_ref(&self) -> &T {
203         &self.socket
204     }
205 
206     /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
207     ///
208     /// # Note
209     ///
210     /// Care should be taken to not tamper with the underlying stream of data
211     /// coming in as it may corrupt the stream of frames otherwise being worked
212     /// with.
get_mut(&mut self) -> &mut T213     pub fn get_mut(&mut self) -> &mut T {
214         &mut self.socket
215     }
216 
217     /// Returns a reference to the underlying codec wrapped by
218     /// `Framed`.
219     ///
220     /// Note that care should be taken to not tamper with the underlying codec
221     /// as it may corrupt the stream of frames otherwise being worked with.
codec(&self) -> &C222     pub fn codec(&self) -> &C {
223         &self.codec
224     }
225 
226     /// Returns a mutable reference to the underlying codec wrapped by
227     /// `UdpFramed`.
228     ///
229     /// Note that care should be taken to not tamper with the underlying codec
230     /// as it may corrupt the stream of frames otherwise being worked with.
codec_mut(&mut self) -> &mut C231     pub fn codec_mut(&mut self) -> &mut C {
232         &mut self.codec
233     }
234 
235     /// Returns a reference to the read buffer.
read_buffer(&self) -> &BytesMut236     pub fn read_buffer(&self) -> &BytesMut {
237         &self.rd
238     }
239 
240     /// Returns a mutable reference to the read buffer.
read_buffer_mut(&mut self) -> &mut BytesMut241     pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
242         &mut self.rd
243     }
244 
245     /// Consumes the `Framed`, returning its underlying I/O stream.
into_inner(self) -> T246     pub fn into_inner(self) -> T {
247         self.socket
248     }
249 }
250