• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::io;
6 use std::io::Read;
7 use std::os::unix::io::AsRawFd;
8 use std::os::unix::io::RawFd;
9 use std::os::unix::net::UnixStream;
10 use std::time::Duration;
11 
12 use libc::c_void;
13 use serde::Deserialize;
14 use serde::Serialize;
15 
16 use super::super::net::UnixSeqpacket;
17 use super::super::Result;
18 use super::RawDescriptor;
19 use crate::descriptor::AsRawDescriptor;
20 use crate::ReadNotifier;
21 
22 #[derive(Copy, Clone)]
23 pub enum FramingMode {
24     Message,
25     Byte,
26 }
27 
28 #[derive(Copy, Clone, PartialEq, Eq)]
29 pub enum BlockingMode {
30     Blocking,
31     Nonblocking,
32 }
33 
34 impl io::Read for StreamChannel {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>35     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
36         self.inner_read(buf)
37     }
38 }
39 
40 impl io::Read for &StreamChannel {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>41     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
42         self.inner_read(buf)
43     }
44 }
45 
46 impl AsRawDescriptor for StreamChannel {
as_raw_descriptor(&self) -> RawDescriptor47     fn as_raw_descriptor(&self) -> RawDescriptor {
48         (&self).as_raw_descriptor()
49     }
50 }
51 
52 #[derive(Debug, Deserialize, Serialize)]
53 enum SocketType {
54     Message(UnixSeqpacket),
55     #[serde(with = "super::with_as_descriptor")]
56     Byte(UnixStream),
57 }
58 
59 /// An abstraction over named pipes and unix socketpairs. This abstraction can be used in a blocking
60 /// and non blocking mode.
61 #[derive(Debug, Deserialize, Serialize)]
62 pub struct StreamChannel {
63     stream: SocketType,
64 }
65 
66 impl StreamChannel {
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>67     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
68         match &mut self.stream {
69             SocketType::Byte(sock) => sock.set_nonblocking(nonblocking),
70             SocketType::Message(sock) => sock.set_nonblocking(nonblocking),
71         }
72     }
73 
get_framing_mode(&self) -> FramingMode74     pub fn get_framing_mode(&self) -> FramingMode {
75         match &self.stream {
76             SocketType::Message(_) => FramingMode::Message,
77             SocketType::Byte(_) => FramingMode::Byte,
78         }
79     }
80 
inner_read(&self, buf: &mut [u8]) -> io::Result<usize>81     pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> {
82         match &self.stream {
83             SocketType::Byte(sock) => (&mut &*sock).read(buf),
84 
85             // On Windows, reading from SOCK_SEQPACKET with a buffer that is too small is an error,
86             // but on Linux will silently truncate unless MSG_TRUNC is passed. Here, we emulate
87             // Windows behavior on POSIX.
88             //
89             // Note that Rust translates ERROR_MORE_DATA into io::ErrorKind::Other
90             // (see sys::decode_error_kind) on Windows, so we preserve this behavior on POSIX even
91             // though one could argue ErrorKind::UnexpectedEof is a closer match to the true error.
92             SocketType::Message(sock) => {
93                 // Safe because buf is valid, we pass buf's size to recv to bound the return
94                 // length, and we check the return code.
95                 let retval = unsafe {
96                     // TODO(nkgold|b/152067913): Move this into the UnixSeqpacket struct as a
97                     // recv_with_flags method once that struct's tests are working.
98                     libc::recv(
99                         sock.as_raw_descriptor(),
100                         buf.as_mut_ptr() as *mut c_void,
101                         buf.len(),
102                         libc::MSG_TRUNC,
103                     )
104                 };
105                 let receive_len = if retval < 0 {
106                     Err(std::io::Error::last_os_error())
107                 } else {
108                     Ok(retval)
109                 }? as usize;
110 
111                 if receive_len > buf.len() {
112                     Err(std::io::Error::new(
113                         std::io::ErrorKind::Other,
114                         format!(
115                             "packet size {:?} encountered, but buffer was only of size {:?}",
116                             receive_len,
117                             buf.len()
118                         ),
119                     ))
120                 } else {
121                     Ok(receive_len)
122                 }
123             }
124         }
125     }
126 
127     /// Creates a cross platform stream pair.
pair( blocking_mode: BlockingMode, framing_mode: FramingMode, ) -> Result<(StreamChannel, StreamChannel)>128     pub fn pair(
129         blocking_mode: BlockingMode,
130         framing_mode: FramingMode,
131     ) -> Result<(StreamChannel, StreamChannel)> {
132         let (pipe_a, pipe_b) = match framing_mode {
133             FramingMode::Byte => {
134                 let (pipe_a, pipe_b) = UnixStream::pair()?;
135                 (SocketType::Byte(pipe_a), SocketType::Byte(pipe_b))
136             }
137             FramingMode::Message => {
138                 let (pipe_a, pipe_b) = UnixSeqpacket::pair()?;
139                 (SocketType::Message(pipe_a), SocketType::Message(pipe_b))
140             }
141         };
142         let mut stream_a = StreamChannel { stream: pipe_a };
143         let mut stream_b = StreamChannel { stream: pipe_b };
144         let is_non_blocking = blocking_mode == BlockingMode::Nonblocking;
145         stream_a.set_nonblocking(is_non_blocking)?;
146         stream_b.set_nonblocking(is_non_blocking)?;
147         Ok((stream_a, stream_b))
148     }
149 
from_unix_seqpacket(sock: UnixSeqpacket) -> StreamChannel150     pub fn from_unix_seqpacket(sock: UnixSeqpacket) -> StreamChannel {
151         StreamChannel {
152             stream: SocketType::Message(sock),
153         }
154     }
155 
peek_size(&self) -> io::Result<usize>156     pub fn peek_size(&self) -> io::Result<usize> {
157         match &self.stream {
158             SocketType::Byte(_) => Err(std::io::Error::new(
159                 std::io::ErrorKind::Other,
160                 "Cannot check the size of streamed data",
161             )),
162             SocketType::Message(sock) => Ok(sock.next_packet_size()?),
163         }
164     }
165 
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>166     pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
167         match &self.stream {
168             SocketType::Byte(sock) => sock.set_read_timeout(timeout),
169             SocketType::Message(sock) => sock.set_read_timeout(timeout),
170         }
171     }
172 
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>173     pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
174         match &self.stream {
175             SocketType::Byte(sock) => sock.set_write_timeout(timeout),
176             SocketType::Message(sock) => sock.set_write_timeout(timeout),
177         }
178     }
179 
180     // WARNING: Generally, multiple StreamChannel ends are not wanted. StreamChannel behavior with
181     // > 1 reader per end is not defined.
try_clone(&self) -> io::Result<Self>182     pub fn try_clone(&self) -> io::Result<Self> {
183         Ok(StreamChannel {
184             stream: match &self.stream {
185                 SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?),
186                 SocketType::Message(sock) => SocketType::Message(sock.try_clone()?),
187             },
188         })
189     }
190 }
191 
192 impl io::Write for StreamChannel {
write(&mut self, buf: &[u8]) -> io::Result<usize>193     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
194         match &mut self.stream {
195             SocketType::Byte(sock) => sock.write(buf),
196             SocketType::Message(sock) => sock.send(buf),
197         }
198     }
flush(&mut self) -> io::Result<()>199     fn flush(&mut self) -> io::Result<()> {
200         match &mut self.stream {
201             SocketType::Byte(sock) => sock.flush(),
202             SocketType::Message(_) => Ok(()),
203         }
204     }
205 }
206 
207 impl io::Write for &StreamChannel {
write(&mut self, buf: &[u8]) -> io::Result<usize>208     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
209         match &self.stream {
210             SocketType::Byte(sock) => (&mut &*sock).write(buf),
211             SocketType::Message(sock) => sock.send(buf),
212         }
213     }
flush(&mut self) -> io::Result<()>214     fn flush(&mut self) -> io::Result<()> {
215         match &self.stream {
216             SocketType::Byte(sock) => (&mut &*sock).flush(),
217             SocketType::Message(_) => Ok(()),
218         }
219     }
220 }
221 
222 impl AsRawFd for StreamChannel {
as_raw_fd(&self) -> RawFd223     fn as_raw_fd(&self) -> RawFd {
224         match &self.stream {
225             SocketType::Byte(sock) => sock.as_raw_descriptor(),
226             SocketType::Message(sock) => sock.as_raw_descriptor(),
227         }
228     }
229 }
230 
231 impl AsRawFd for &StreamChannel {
as_raw_fd(&self) -> RawFd232     fn as_raw_fd(&self) -> RawFd {
233         self.as_raw_descriptor()
234     }
235 }
236 
237 impl AsRawDescriptor for &StreamChannel {
as_raw_descriptor(&self) -> RawDescriptor238     fn as_raw_descriptor(&self) -> RawDescriptor {
239         match &self.stream {
240             SocketType::Byte(sock) => sock.as_raw_descriptor(),
241             SocketType::Message(sock) => sock.as_raw_descriptor(),
242         }
243     }
244 }
245 
246 impl ReadNotifier for StreamChannel {
247     /// Returns a RawDescriptor that can be polled for reads using PollContext.
get_read_notifier(&self) -> &dyn AsRawDescriptor248     fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
249         self
250     }
251 }
252 
253 #[cfg(test)]
254 mod test {
255     use std::io::Read;
256     use std::io::Write;
257 
258     use super::*;
259     use crate::EventContext;
260     use crate::EventToken;
261     use crate::ReadNotifier;
262 
263     #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
264     enum Token {
265         ReceivedData,
266     }
267 
268     #[test]
test_non_blocking_pair_byte()269     fn test_non_blocking_pair_byte() {
270         let (mut sender, mut receiver) =
271             StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
272 
273         sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
274 
275         // Wait for the data to arrive.
276         let event_ctx: EventContext<Token> =
277             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
278                 .unwrap();
279         let events = event_ctx.wait().unwrap();
280         let tokens: Vec<Token> = events
281             .iter()
282             .filter(|e| e.is_readable)
283             .map(|e| e.token)
284             .collect();
285         assert_eq!(tokens, vec! {Token::ReceivedData});
286 
287         // Smaller than what we sent so we get multiple chunks
288         let mut recv_buffer: [u8; 4] = [0; 4];
289 
290         let mut size = receiver.read(&mut recv_buffer).unwrap();
291         assert_eq!(size, 4);
292         assert_eq!(recv_buffer, [75, 77, 54, 82]);
293 
294         size = receiver.read(&mut recv_buffer).unwrap();
295         assert_eq!(size, 2);
296         assert_eq!(recv_buffer[0..2], [76, 65]);
297 
298         // Now that we've polled for & received all data, polling again should show no events.
299         assert_eq!(
300             event_ctx
301                 .wait_timeout(std::time::Duration::new(0, 0))
302                 .unwrap()
303                 .len(),
304             0
305         );
306     }
307 
308     #[test]
test_non_blocking_pair_message()309     fn test_non_blocking_pair_message() {
310         let (mut sender, mut receiver) =
311             StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap();
312 
313         sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
314 
315         // Wait for the data to arrive.
316         let event_ctx: EventContext<Token> =
317             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
318                 .unwrap();
319         let events = event_ctx.wait().unwrap();
320         let tokens: Vec<Token> = events
321             .iter()
322             .filter(|e| e.is_readable)
323             .map(|e| e.token)
324             .collect();
325         assert_eq!(tokens, vec! {Token::ReceivedData});
326 
327         // Unlike Byte format, Message mode panics if the buffer is smaller than the packet size;
328         // make the buffer the right size.
329         let mut recv_buffer: [u8; 6] = [0; 6];
330 
331         let size = receiver.read(&mut recv_buffer).unwrap();
332         assert_eq!(size, 6);
333         assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]);
334 
335         // Now that we've polled for & received all data, polling again should show no events.
336         assert_eq!(
337             event_ctx
338                 .wait_timeout(std::time::Duration::new(0, 0))
339                 .unwrap()
340                 .len(),
341             0
342         );
343     }
344 
345     #[test]
test_non_blocking_pair_error_no_data()346     fn test_non_blocking_pair_error_no_data() {
347         let (mut sender, mut receiver) =
348             StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
349         receiver
350             .set_nonblocking(true)
351             .expect("Failed to set receiver to nonblocking mode.");
352 
353         sender.write_all(&[75, 77]).unwrap();
354 
355         // Wait for the data to arrive.
356         let event_ctx: EventContext<Token> =
357             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
358                 .unwrap();
359         let events = event_ctx.wait().unwrap();
360         let tokens: Vec<Token> = events
361             .iter()
362             .filter(|e| e.is_readable)
363             .map(|e| e.token)
364             .collect();
365         assert_eq!(tokens, vec! {Token::ReceivedData});
366 
367         // We only read 2 bytes, even though we requested 4 bytes.
368         let mut recv_buffer: [u8; 4] = [0; 4];
369         let size = receiver.read(&mut recv_buffer).unwrap();
370         assert_eq!(size, 2);
371         assert_eq!(recv_buffer, [75, 77, 00, 00]);
372 
373         // Further reads should encounter an error since there is no available data and this is a
374         // non blocking pipe.
375         assert!(receiver.read(&mut recv_buffer).is_err());
376     }
377 
378     #[test]
test_from_unix_seqpacket()379     fn test_from_unix_seqpacket() {
380         let (sock_sender, sock_receiver) = UnixSeqpacket::pair().unwrap();
381         let mut sender = StreamChannel::from_unix_seqpacket(sock_sender);
382         let mut receiver = StreamChannel::from_unix_seqpacket(sock_receiver);
383 
384         sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
385 
386         // Wait for the data to arrive.
387         let event_ctx: EventContext<Token> =
388             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
389                 .unwrap();
390         let events = event_ctx.wait().unwrap();
391         let tokens: Vec<Token> = events
392             .iter()
393             .filter(|e| e.is_readable)
394             .map(|e| e.token)
395             .collect();
396         assert_eq!(tokens, vec! {Token::ReceivedData});
397 
398         let mut recv_buffer: [u8; 6] = [0; 6];
399 
400         let size = receiver.read(&mut recv_buffer).unwrap();
401         assert_eq!(size, 6);
402         assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]);
403 
404         // Now that we've polled for & received all data, polling again should show no events.
405         assert_eq!(
406             event_ctx
407                 .wait_timeout(std::time::Duration::new(0, 0))
408                 .unwrap()
409                 .len(),
410             0
411         );
412     }
413 }
414