1 // Copyright 2022 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use std::io; 6 use std::io::Read; 7 use std::os::unix::io::AsRawFd; 8 use std::os::unix::io::RawFd; 9 use std::os::unix::net::UnixStream; 10 use std::time::Duration; 11 12 use libc::c_void; 13 use serde::Deserialize; 14 use serde::Serialize; 15 16 use super::super::net::UnixSeqpacket; 17 use crate::descriptor::AsRawDescriptor; 18 use crate::IntoRawDescriptor; 19 use crate::RawDescriptor; 20 use crate::ReadNotifier; 21 use crate::Result; 22 23 #[derive(Copy, Clone)] 24 pub enum FramingMode { 25 Message, 26 Byte, 27 } 28 29 #[derive(Copy, Clone, PartialEq, Eq)] 30 pub enum BlockingMode { 31 Blocking, 32 Nonblocking, 33 } 34 35 impl io::Read for StreamChannel { read(&mut self, buf: &mut [u8]) -> io::Result<usize>36 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 37 self.inner_read(buf) 38 } 39 } 40 41 impl io::Read for &StreamChannel { read(&mut self, buf: &mut [u8]) -> io::Result<usize>42 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 43 self.inner_read(buf) 44 } 45 } 46 47 impl AsRawDescriptor for StreamChannel { as_raw_descriptor(&self) -> RawDescriptor48 fn as_raw_descriptor(&self) -> RawDescriptor { 49 (&self).as_raw_descriptor() 50 } 51 } 52 53 #[derive(Debug, Deserialize, Serialize)] 54 enum SocketType { 55 Message(UnixSeqpacket), 56 #[serde(with = "crate::with_as_descriptor")] 57 Byte(UnixStream), 58 } 59 60 /// An abstraction over named pipes and unix socketpairs. This abstraction can be used in a blocking 61 /// and non blocking mode. 62 /// 63 /// WARNING: partial reads of messages behave differently depending on the platform. 64 /// See sys::unix::StreamChannel::inner_read for details. 65 #[derive(Debug, Deserialize, Serialize)] 66 pub struct StreamChannel { 67 stream: SocketType, 68 } 69 70 impl StreamChannel { set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>71 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { 72 match &mut self.stream { 73 SocketType::Byte(sock) => sock.set_nonblocking(nonblocking), 74 SocketType::Message(sock) => sock.set_nonblocking(nonblocking), 75 } 76 } 77 get_framing_mode(&self) -> FramingMode78 pub fn get_framing_mode(&self) -> FramingMode { 79 match &self.stream { 80 SocketType::Message(_) => FramingMode::Message, 81 SocketType::Byte(_) => FramingMode::Byte, 82 } 83 } 84 inner_read(&self, buf: &mut [u8]) -> io::Result<usize>85 pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> { 86 match &self.stream { 87 SocketType::Byte(sock) => (&mut &*sock).read(buf), 88 89 // On Windows, reading from SOCK_SEQPACKET with a buffer that is too small is an error, 90 // and the extra data will be preserved inside the named pipe. 91 // 92 // Linux though, will silently truncate unless MSG_TRUNC is passed. So we pass it, but 93 // even in that case, Linux will still throw away the extra data. This means there is a 94 // slight behavior difference between platforms from the consumer's perspective. 95 // In practice on Linux, intentional partial reads of messages are usually accomplished 96 // by also passing MSG_PEEK. While we could do this, and hide this rough edge from 97 // consumers, it would add complexity & turn every read into two read syscalls. 98 // 99 // So the compromise is this: 100 // * On Linux: a partial read of a message is an Err and loses data. 101 // * On Windows: a partial read of a message is Ok and does not lose data. 102 SocketType::Message(sock) => { 103 // SAFETY: 104 // Safe because buf is valid, we pass buf's size to recv to bound the return 105 // length, and we check the return code. 106 let retval = unsafe { 107 // TODO(nkgold|b/152067913): Move this into the UnixSeqpacket struct as a 108 // recv_with_flags method once that struct's tests are working. 109 libc::recv( 110 sock.as_raw_descriptor(), 111 buf.as_mut_ptr() as *mut c_void, 112 buf.len(), 113 libc::MSG_TRUNC, 114 ) 115 }; 116 let receive_len = if retval < 0 { 117 Err(std::io::Error::last_os_error()) 118 } else { 119 Ok(retval) 120 }? as usize; 121 122 if receive_len > buf.len() { 123 Err(std::io::Error::new( 124 std::io::ErrorKind::Other, 125 format!( 126 "packet size {:?} encountered, but buffer was only of size {:?}", 127 receive_len, 128 buf.len() 129 ), 130 )) 131 } else { 132 Ok(receive_len) 133 } 134 } 135 } 136 } 137 138 /// Creates a cross platform stream pair. pair( blocking_mode: BlockingMode, framing_mode: FramingMode, ) -> Result<(StreamChannel, StreamChannel)>139 pub fn pair( 140 blocking_mode: BlockingMode, 141 framing_mode: FramingMode, 142 ) -> Result<(StreamChannel, StreamChannel)> { 143 let (pipe_a, pipe_b) = match framing_mode { 144 FramingMode::Byte => { 145 let (pipe_a, pipe_b) = UnixStream::pair()?; 146 (SocketType::Byte(pipe_a), SocketType::Byte(pipe_b)) 147 } 148 FramingMode::Message => { 149 let (pipe_a, pipe_b) = UnixSeqpacket::pair()?; 150 (SocketType::Message(pipe_a), SocketType::Message(pipe_b)) 151 } 152 }; 153 let mut stream_a = StreamChannel { stream: pipe_a }; 154 let mut stream_b = StreamChannel { stream: pipe_b }; 155 let is_non_blocking = blocking_mode == BlockingMode::Nonblocking; 156 stream_a.set_nonblocking(is_non_blocking)?; 157 stream_b.set_nonblocking(is_non_blocking)?; 158 Ok((stream_a, stream_b)) 159 } 160 set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>161 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> { 162 match &self.stream { 163 SocketType::Byte(sock) => sock.set_read_timeout(timeout), 164 SocketType::Message(sock) => sock.set_read_timeout(timeout), 165 } 166 } 167 set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>168 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> { 169 match &self.stream { 170 SocketType::Byte(sock) => sock.set_write_timeout(timeout), 171 SocketType::Message(sock) => sock.set_write_timeout(timeout), 172 } 173 } 174 175 // WARNING: Generally, multiple StreamChannel ends are not wanted. StreamChannel behavior with 176 // > 1 reader per end is not defined. try_clone(&self) -> io::Result<Self>177 pub fn try_clone(&self) -> io::Result<Self> { 178 Ok(StreamChannel { 179 stream: match &self.stream { 180 SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?), 181 SocketType::Message(sock) => SocketType::Message(sock.try_clone()?), 182 }, 183 }) 184 } 185 } 186 187 impl io::Write for StreamChannel { write(&mut self, buf: &[u8]) -> io::Result<usize>188 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 189 match &mut self.stream { 190 SocketType::Byte(sock) => sock.write(buf), 191 SocketType::Message(sock) => sock.send(buf), 192 } 193 } flush(&mut self) -> io::Result<()>194 fn flush(&mut self) -> io::Result<()> { 195 match &mut self.stream { 196 SocketType::Byte(sock) => sock.flush(), 197 SocketType::Message(_) => Ok(()), 198 } 199 } 200 } 201 202 impl io::Write for &StreamChannel { write(&mut self, buf: &[u8]) -> io::Result<usize>203 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 204 match &self.stream { 205 SocketType::Byte(sock) => (&mut &*sock).write(buf), 206 SocketType::Message(sock) => sock.send(buf), 207 } 208 } flush(&mut self) -> io::Result<()>209 fn flush(&mut self) -> io::Result<()> { 210 match &self.stream { 211 SocketType::Byte(sock) => (&mut &*sock).flush(), 212 SocketType::Message(_) => Ok(()), 213 } 214 } 215 } 216 217 impl AsRawFd for StreamChannel { as_raw_fd(&self) -> RawFd218 fn as_raw_fd(&self) -> RawFd { 219 match &self.stream { 220 SocketType::Byte(sock) => sock.as_raw_descriptor(), 221 SocketType::Message(sock) => sock.as_raw_descriptor(), 222 } 223 } 224 } 225 226 impl AsRawFd for &StreamChannel { as_raw_fd(&self) -> RawFd227 fn as_raw_fd(&self) -> RawFd { 228 self.as_raw_descriptor() 229 } 230 } 231 232 impl AsRawDescriptor for &StreamChannel { as_raw_descriptor(&self) -> RawDescriptor233 fn as_raw_descriptor(&self) -> RawDescriptor { 234 match &self.stream { 235 SocketType::Byte(sock) => sock.as_raw_descriptor(), 236 SocketType::Message(sock) => sock.as_raw_descriptor(), 237 } 238 } 239 } 240 241 impl IntoRawDescriptor for StreamChannel { into_raw_descriptor(self) -> RawFd242 fn into_raw_descriptor(self) -> RawFd { 243 match self.stream { 244 SocketType::Byte(sock) => sock.into_raw_descriptor(), 245 SocketType::Message(sock) => sock.into_raw_descriptor(), 246 } 247 } 248 } 249 250 impl ReadNotifier for StreamChannel { 251 /// Returns a RawDescriptor that can be polled for reads using PollContext. get_read_notifier(&self) -> &dyn AsRawDescriptor252 fn get_read_notifier(&self) -> &dyn AsRawDescriptor { 253 self 254 } 255 } 256 257 #[cfg(test)] 258 mod test { 259 use std::io::Read; 260 use std::io::Write; 261 262 use super::*; 263 use crate::EventContext; 264 use crate::EventToken; 265 use crate::ReadNotifier; 266 267 #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)] 268 enum Token { 269 ReceivedData, 270 } 271 272 #[test] test_non_blocking_pair_byte()273 fn test_non_blocking_pair_byte() { 274 let (mut sender, mut receiver) = 275 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap(); 276 277 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap(); 278 279 // Wait for the data to arrive. 280 let event_ctx: EventContext<Token> = 281 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 282 .unwrap(); 283 let events = event_ctx.wait().unwrap(); 284 let tokens: Vec<Token> = events 285 .iter() 286 .filter(|e| e.is_readable) 287 .map(|e| e.token) 288 .collect(); 289 assert_eq!(tokens, vec! {Token::ReceivedData}); 290 291 // Smaller than what we sent so we get multiple chunks 292 let mut recv_buffer: [u8; 4] = [0; 4]; 293 294 let mut size = receiver.read(&mut recv_buffer).unwrap(); 295 assert_eq!(size, 4); 296 assert_eq!(recv_buffer, [75, 77, 54, 82]); 297 298 size = receiver.read(&mut recv_buffer).unwrap(); 299 assert_eq!(size, 2); 300 assert_eq!(recv_buffer[0..2], [76, 65]); 301 302 // Now that we've polled for & received all data, polling again should show no events. 303 assert_eq!( 304 event_ctx 305 .wait_timeout(std::time::Duration::new(0, 0)) 306 .unwrap() 307 .len(), 308 0 309 ); 310 } 311 312 #[test] test_non_blocking_pair_message()313 fn test_non_blocking_pair_message() { 314 let (mut sender, mut receiver) = 315 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap(); 316 317 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap(); 318 319 // Wait for the data to arrive. 320 let event_ctx: EventContext<Token> = 321 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 322 .unwrap(); 323 let events = event_ctx.wait().unwrap(); 324 let tokens: Vec<Token> = events 325 .iter() 326 .filter(|e| e.is_readable) 327 .map(|e| e.token) 328 .collect(); 329 assert_eq!(tokens, vec! {Token::ReceivedData}); 330 331 // Unlike Byte format, Message mode panics if the buffer is smaller than the packet size; 332 // make the buffer the right size. 333 let mut recv_buffer: [u8; 6] = [0; 6]; 334 335 let size = receiver.read(&mut recv_buffer).unwrap(); 336 assert_eq!(size, 6); 337 assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]); 338 339 // Now that we've polled for & received all data, polling again should show no events. 340 assert_eq!( 341 event_ctx 342 .wait_timeout(std::time::Duration::new(0, 0)) 343 .unwrap() 344 .len(), 345 0 346 ); 347 } 348 349 #[test] test_non_blocking_pair_error_no_data()350 fn test_non_blocking_pair_error_no_data() { 351 let (mut sender, mut receiver) = 352 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap(); 353 receiver 354 .set_nonblocking(true) 355 .expect("Failed to set receiver to nonblocking mode."); 356 357 sender.write_all(&[75, 77]).unwrap(); 358 359 // Wait for the data to arrive. 360 let event_ctx: EventContext<Token> = 361 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 362 .unwrap(); 363 let events = event_ctx.wait().unwrap(); 364 let tokens: Vec<Token> = events 365 .iter() 366 .filter(|e| e.is_readable) 367 .map(|e| e.token) 368 .collect(); 369 assert_eq!(tokens, vec! {Token::ReceivedData}); 370 371 // We only read 2 bytes, even though we requested 4 bytes. 372 let mut recv_buffer: [u8; 4] = [0; 4]; 373 let size = receiver.read(&mut recv_buffer).unwrap(); 374 assert_eq!(size, 2); 375 assert_eq!(recv_buffer, [75, 77, 00, 00]); 376 377 // Further reads should encounter an error since there is no available data and this is a 378 // non blocking pipe. 379 assert!(receiver.read(&mut recv_buffer).is_err()); 380 } 381 } 382