• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::io;
6 use std::io::Read;
7 use std::os::unix::io::AsRawFd;
8 use std::os::unix::io::RawFd;
9 use std::os::unix::net::UnixStream;
10 use std::time::Duration;
11 
12 use libc::c_void;
13 use serde::Deserialize;
14 use serde::Serialize;
15 
16 use super::super::net::UnixSeqpacket;
17 use crate::descriptor::AsRawDescriptor;
18 use crate::IntoRawDescriptor;
19 use crate::RawDescriptor;
20 use crate::ReadNotifier;
21 use crate::Result;
22 
23 #[derive(Copy, Clone)]
24 pub enum FramingMode {
25     Message,
26     Byte,
27 }
28 
29 #[derive(Copy, Clone, PartialEq, Eq)]
30 pub enum BlockingMode {
31     Blocking,
32     Nonblocking,
33 }
34 
35 impl io::Read for StreamChannel {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>36     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
37         self.inner_read(buf)
38     }
39 }
40 
41 impl io::Read for &StreamChannel {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>42     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
43         self.inner_read(buf)
44     }
45 }
46 
47 impl AsRawDescriptor for StreamChannel {
as_raw_descriptor(&self) -> RawDescriptor48     fn as_raw_descriptor(&self) -> RawDescriptor {
49         (&self).as_raw_descriptor()
50     }
51 }
52 
53 #[derive(Debug, Deserialize, Serialize)]
54 enum SocketType {
55     Message(UnixSeqpacket),
56     #[serde(with = "crate::with_as_descriptor")]
57     Byte(UnixStream),
58 }
59 
60 /// An abstraction over named pipes and unix socketpairs. This abstraction can be used in a blocking
61 /// and non blocking mode.
62 ///
63 /// WARNING: partial reads of messages behave differently depending on the platform.
64 /// See sys::unix::StreamChannel::inner_read for details.
65 #[derive(Debug, Deserialize, Serialize)]
66 pub struct StreamChannel {
67     stream: SocketType,
68 }
69 
70 impl StreamChannel {
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>71     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
72         match &mut self.stream {
73             SocketType::Byte(sock) => sock.set_nonblocking(nonblocking),
74             SocketType::Message(sock) => sock.set_nonblocking(nonblocking),
75         }
76     }
77 
get_framing_mode(&self) -> FramingMode78     pub fn get_framing_mode(&self) -> FramingMode {
79         match &self.stream {
80             SocketType::Message(_) => FramingMode::Message,
81             SocketType::Byte(_) => FramingMode::Byte,
82         }
83     }
84 
inner_read(&self, buf: &mut [u8]) -> io::Result<usize>85     pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> {
86         match &self.stream {
87             SocketType::Byte(sock) => (&mut &*sock).read(buf),
88 
89             // On Windows, reading from SOCK_SEQPACKET with a buffer that is too small is an error,
90             // and the extra data will be preserved inside the named pipe.
91             //
92             // Linux though, will silently truncate unless MSG_TRUNC is passed. So we pass it, but
93             // even in that case, Linux will still throw away the extra data. This means there is a
94             // slight behavior difference between platforms from the consumer's perspective.
95             // In practice on Linux, intentional partial reads of messages are usually accomplished
96             // by also passing MSG_PEEK. While we could do this, and hide this rough edge from
97             // consumers, it would add complexity & turn every read into two read syscalls.
98             //
99             // So the compromise is this:
100             // * On Linux: a partial read of a message is an Err and loses data.
101             // * On Windows: a partial read of a message is Ok and does not lose data.
102             SocketType::Message(sock) => {
103                 // SAFETY:
104                 // Safe because buf is valid, we pass buf's size to recv to bound the return
105                 // length, and we check the return code.
106                 let retval = unsafe {
107                     // TODO(nkgold|b/152067913): Move this into the UnixSeqpacket struct as a
108                     // recv_with_flags method once that struct's tests are working.
109                     libc::recv(
110                         sock.as_raw_descriptor(),
111                         buf.as_mut_ptr() as *mut c_void,
112                         buf.len(),
113                         libc::MSG_TRUNC,
114                     )
115                 };
116                 let receive_len = if retval < 0 {
117                     Err(std::io::Error::last_os_error())
118                 } else {
119                     Ok(retval)
120                 }? as usize;
121 
122                 if receive_len > buf.len() {
123                     Err(std::io::Error::new(
124                         std::io::ErrorKind::Other,
125                         format!(
126                             "packet size {:?} encountered, but buffer was only of size {:?}",
127                             receive_len,
128                             buf.len()
129                         ),
130                     ))
131                 } else {
132                     Ok(receive_len)
133                 }
134             }
135         }
136     }
137 
138     /// Creates a cross platform stream pair.
pair( blocking_mode: BlockingMode, framing_mode: FramingMode, ) -> Result<(StreamChannel, StreamChannel)>139     pub fn pair(
140         blocking_mode: BlockingMode,
141         framing_mode: FramingMode,
142     ) -> Result<(StreamChannel, StreamChannel)> {
143         let (pipe_a, pipe_b) = match framing_mode {
144             FramingMode::Byte => {
145                 let (pipe_a, pipe_b) = UnixStream::pair()?;
146                 (SocketType::Byte(pipe_a), SocketType::Byte(pipe_b))
147             }
148             FramingMode::Message => {
149                 let (pipe_a, pipe_b) = UnixSeqpacket::pair()?;
150                 (SocketType::Message(pipe_a), SocketType::Message(pipe_b))
151             }
152         };
153         let mut stream_a = StreamChannel { stream: pipe_a };
154         let mut stream_b = StreamChannel { stream: pipe_b };
155         let is_non_blocking = blocking_mode == BlockingMode::Nonblocking;
156         stream_a.set_nonblocking(is_non_blocking)?;
157         stream_b.set_nonblocking(is_non_blocking)?;
158         Ok((stream_a, stream_b))
159     }
160 
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>161     pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
162         match &self.stream {
163             SocketType::Byte(sock) => sock.set_read_timeout(timeout),
164             SocketType::Message(sock) => sock.set_read_timeout(timeout),
165         }
166     }
167 
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>168     pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
169         match &self.stream {
170             SocketType::Byte(sock) => sock.set_write_timeout(timeout),
171             SocketType::Message(sock) => sock.set_write_timeout(timeout),
172         }
173     }
174 
175     // WARNING: Generally, multiple StreamChannel ends are not wanted. StreamChannel behavior with
176     // > 1 reader per end is not defined.
try_clone(&self) -> io::Result<Self>177     pub fn try_clone(&self) -> io::Result<Self> {
178         Ok(StreamChannel {
179             stream: match &self.stream {
180                 SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?),
181                 SocketType::Message(sock) => SocketType::Message(sock.try_clone()?),
182             },
183         })
184     }
185 }
186 
187 impl io::Write for StreamChannel {
write(&mut self, buf: &[u8]) -> io::Result<usize>188     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
189         match &mut self.stream {
190             SocketType::Byte(sock) => sock.write(buf),
191             SocketType::Message(sock) => sock.send(buf),
192         }
193     }
flush(&mut self) -> io::Result<()>194     fn flush(&mut self) -> io::Result<()> {
195         match &mut self.stream {
196             SocketType::Byte(sock) => sock.flush(),
197             SocketType::Message(_) => Ok(()),
198         }
199     }
200 }
201 
202 impl io::Write for &StreamChannel {
write(&mut self, buf: &[u8]) -> io::Result<usize>203     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
204         match &self.stream {
205             SocketType::Byte(sock) => (&mut &*sock).write(buf),
206             SocketType::Message(sock) => sock.send(buf),
207         }
208     }
flush(&mut self) -> io::Result<()>209     fn flush(&mut self) -> io::Result<()> {
210         match &self.stream {
211             SocketType::Byte(sock) => (&mut &*sock).flush(),
212             SocketType::Message(_) => Ok(()),
213         }
214     }
215 }
216 
217 impl AsRawFd for StreamChannel {
as_raw_fd(&self) -> RawFd218     fn as_raw_fd(&self) -> RawFd {
219         match &self.stream {
220             SocketType::Byte(sock) => sock.as_raw_descriptor(),
221             SocketType::Message(sock) => sock.as_raw_descriptor(),
222         }
223     }
224 }
225 
226 impl AsRawFd for &StreamChannel {
as_raw_fd(&self) -> RawFd227     fn as_raw_fd(&self) -> RawFd {
228         self.as_raw_descriptor()
229     }
230 }
231 
232 impl AsRawDescriptor for &StreamChannel {
as_raw_descriptor(&self) -> RawDescriptor233     fn as_raw_descriptor(&self) -> RawDescriptor {
234         match &self.stream {
235             SocketType::Byte(sock) => sock.as_raw_descriptor(),
236             SocketType::Message(sock) => sock.as_raw_descriptor(),
237         }
238     }
239 }
240 
241 impl IntoRawDescriptor for StreamChannel {
into_raw_descriptor(self) -> RawFd242     fn into_raw_descriptor(self) -> RawFd {
243         match self.stream {
244             SocketType::Byte(sock) => sock.into_raw_descriptor(),
245             SocketType::Message(sock) => sock.into_raw_descriptor(),
246         }
247     }
248 }
249 
250 impl ReadNotifier for StreamChannel {
251     /// Returns a RawDescriptor that can be polled for reads using PollContext.
get_read_notifier(&self) -> &dyn AsRawDescriptor252     fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
253         self
254     }
255 }
256 
257 #[cfg(test)]
258 mod test {
259     use std::io::Read;
260     use std::io::Write;
261 
262     use super::*;
263     use crate::EventContext;
264     use crate::EventToken;
265     use crate::ReadNotifier;
266 
267     #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
268     enum Token {
269         ReceivedData,
270     }
271 
272     #[test]
test_non_blocking_pair_byte()273     fn test_non_blocking_pair_byte() {
274         let (mut sender, mut receiver) =
275             StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
276 
277         sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
278 
279         // Wait for the data to arrive.
280         let event_ctx: EventContext<Token> =
281             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
282                 .unwrap();
283         let events = event_ctx.wait().unwrap();
284         let tokens: Vec<Token> = events
285             .iter()
286             .filter(|e| e.is_readable)
287             .map(|e| e.token)
288             .collect();
289         assert_eq!(tokens, vec! {Token::ReceivedData});
290 
291         // Smaller than what we sent so we get multiple chunks
292         let mut recv_buffer: [u8; 4] = [0; 4];
293 
294         let mut size = receiver.read(&mut recv_buffer).unwrap();
295         assert_eq!(size, 4);
296         assert_eq!(recv_buffer, [75, 77, 54, 82]);
297 
298         size = receiver.read(&mut recv_buffer).unwrap();
299         assert_eq!(size, 2);
300         assert_eq!(recv_buffer[0..2], [76, 65]);
301 
302         // Now that we've polled for & received all data, polling again should show no events.
303         assert_eq!(
304             event_ctx
305                 .wait_timeout(std::time::Duration::new(0, 0))
306                 .unwrap()
307                 .len(),
308             0
309         );
310     }
311 
312     #[test]
test_non_blocking_pair_message()313     fn test_non_blocking_pair_message() {
314         let (mut sender, mut receiver) =
315             StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap();
316 
317         sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
318 
319         // Wait for the data to arrive.
320         let event_ctx: EventContext<Token> =
321             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
322                 .unwrap();
323         let events = event_ctx.wait().unwrap();
324         let tokens: Vec<Token> = events
325             .iter()
326             .filter(|e| e.is_readable)
327             .map(|e| e.token)
328             .collect();
329         assert_eq!(tokens, vec! {Token::ReceivedData});
330 
331         // Unlike Byte format, Message mode panics if the buffer is smaller than the packet size;
332         // make the buffer the right size.
333         let mut recv_buffer: [u8; 6] = [0; 6];
334 
335         let size = receiver.read(&mut recv_buffer).unwrap();
336         assert_eq!(size, 6);
337         assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]);
338 
339         // Now that we've polled for & received all data, polling again should show no events.
340         assert_eq!(
341             event_ctx
342                 .wait_timeout(std::time::Duration::new(0, 0))
343                 .unwrap()
344                 .len(),
345             0
346         );
347     }
348 
349     #[test]
test_non_blocking_pair_error_no_data()350     fn test_non_blocking_pair_error_no_data() {
351         let (mut sender, mut receiver) =
352             StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
353         receiver
354             .set_nonblocking(true)
355             .expect("Failed to set receiver to nonblocking mode.");
356 
357         sender.write_all(&[75, 77]).unwrap();
358 
359         // Wait for the data to arrive.
360         let event_ctx: EventContext<Token> =
361             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
362                 .unwrap();
363         let events = event_ctx.wait().unwrap();
364         let tokens: Vec<Token> = events
365             .iter()
366             .filter(|e| e.is_readable)
367             .map(|e| e.token)
368             .collect();
369         assert_eq!(tokens, vec! {Token::ReceivedData});
370 
371         // We only read 2 bytes, even though we requested 4 bytes.
372         let mut recv_buffer: [u8; 4] = [0; 4];
373         let size = receiver.read(&mut recv_buffer).unwrap();
374         assert_eq!(size, 2);
375         assert_eq!(recv_buffer, [75, 77, 00, 00]);
376 
377         // Further reads should encounter an error since there is no available data and this is a
378         // non blocking pipe.
379         assert!(receiver.read(&mut recv_buffer).is_err());
380     }
381 }
382