1 // Copyright 2022 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use std::io; 6 use std::io::Read; 7 use std::os::unix::io::AsRawFd; 8 use std::os::unix::io::RawFd; 9 use std::os::unix::net::UnixStream; 10 use std::time::Duration; 11 12 use libc::c_void; 13 use serde::Deserialize; 14 use serde::Serialize; 15 16 use super::super::net::UnixSeqpacket; 17 use super::super::Result; 18 use super::RawDescriptor; 19 use crate::descriptor::AsRawDescriptor; 20 use crate::ReadNotifier; 21 22 #[derive(Copy, Clone)] 23 pub enum FramingMode { 24 Message, 25 Byte, 26 } 27 28 #[derive(Copy, Clone, PartialEq, Eq)] 29 pub enum BlockingMode { 30 Blocking, 31 Nonblocking, 32 } 33 34 impl io::Read for StreamChannel { read(&mut self, buf: &mut [u8]) -> io::Result<usize>35 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 36 self.inner_read(buf) 37 } 38 } 39 40 impl io::Read for &StreamChannel { read(&mut self, buf: &mut [u8]) -> io::Result<usize>41 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 42 self.inner_read(buf) 43 } 44 } 45 46 impl AsRawDescriptor for StreamChannel { as_raw_descriptor(&self) -> RawDescriptor47 fn as_raw_descriptor(&self) -> RawDescriptor { 48 (&self).as_raw_descriptor() 49 } 50 } 51 52 #[derive(Debug, Deserialize, Serialize)] 53 enum SocketType { 54 Message(UnixSeqpacket), 55 #[serde(with = "super::with_as_descriptor")] 56 Byte(UnixStream), 57 } 58 59 /// An abstraction over named pipes and unix socketpairs. This abstraction can be used in a blocking 60 /// and non blocking mode. 61 #[derive(Debug, Deserialize, Serialize)] 62 pub struct StreamChannel { 63 stream: SocketType, 64 } 65 66 impl StreamChannel { set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>67 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { 68 match &mut self.stream { 69 SocketType::Byte(sock) => sock.set_nonblocking(nonblocking), 70 SocketType::Message(sock) => sock.set_nonblocking(nonblocking), 71 } 72 } 73 get_framing_mode(&self) -> FramingMode74 pub fn get_framing_mode(&self) -> FramingMode { 75 match &self.stream { 76 SocketType::Message(_) => FramingMode::Message, 77 SocketType::Byte(_) => FramingMode::Byte, 78 } 79 } 80 inner_read(&self, buf: &mut [u8]) -> io::Result<usize>81 pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> { 82 match &self.stream { 83 SocketType::Byte(sock) => (&mut &*sock).read(buf), 84 85 // On Windows, reading from SOCK_SEQPACKET with a buffer that is too small is an error, 86 // but on Linux will silently truncate unless MSG_TRUNC is passed. Here, we emulate 87 // Windows behavior on POSIX. 88 // 89 // Note that Rust translates ERROR_MORE_DATA into io::ErrorKind::Other 90 // (see sys::decode_error_kind) on Windows, so we preserve this behavior on POSIX even 91 // though one could argue ErrorKind::UnexpectedEof is a closer match to the true error. 92 SocketType::Message(sock) => { 93 // Safe because buf is valid, we pass buf's size to recv to bound the return 94 // length, and we check the return code. 95 let retval = unsafe { 96 // TODO(nkgold|b/152067913): Move this into the UnixSeqpacket struct as a 97 // recv_with_flags method once that struct's tests are working. 98 libc::recv( 99 sock.as_raw_descriptor(), 100 buf.as_mut_ptr() as *mut c_void, 101 buf.len(), 102 libc::MSG_TRUNC, 103 ) 104 }; 105 let receive_len = if retval < 0 { 106 Err(std::io::Error::last_os_error()) 107 } else { 108 Ok(retval) 109 }? as usize; 110 111 if receive_len > buf.len() { 112 Err(std::io::Error::new( 113 std::io::ErrorKind::Other, 114 format!( 115 "packet size {:?} encountered, but buffer was only of size {:?}", 116 receive_len, 117 buf.len() 118 ), 119 )) 120 } else { 121 Ok(receive_len) 122 } 123 } 124 } 125 } 126 127 /// Creates a cross platform stream pair. pair( blocking_mode: BlockingMode, framing_mode: FramingMode, ) -> Result<(StreamChannel, StreamChannel)>128 pub fn pair( 129 blocking_mode: BlockingMode, 130 framing_mode: FramingMode, 131 ) -> Result<(StreamChannel, StreamChannel)> { 132 let (pipe_a, pipe_b) = match framing_mode { 133 FramingMode::Byte => { 134 let (pipe_a, pipe_b) = UnixStream::pair()?; 135 (SocketType::Byte(pipe_a), SocketType::Byte(pipe_b)) 136 } 137 FramingMode::Message => { 138 let (pipe_a, pipe_b) = UnixSeqpacket::pair()?; 139 (SocketType::Message(pipe_a), SocketType::Message(pipe_b)) 140 } 141 }; 142 let mut stream_a = StreamChannel { stream: pipe_a }; 143 let mut stream_b = StreamChannel { stream: pipe_b }; 144 let is_non_blocking = blocking_mode == BlockingMode::Nonblocking; 145 stream_a.set_nonblocking(is_non_blocking)?; 146 stream_b.set_nonblocking(is_non_blocking)?; 147 Ok((stream_a, stream_b)) 148 } 149 from_unix_seqpacket(sock: UnixSeqpacket) -> StreamChannel150 pub fn from_unix_seqpacket(sock: UnixSeqpacket) -> StreamChannel { 151 StreamChannel { 152 stream: SocketType::Message(sock), 153 } 154 } 155 peek_size(&self) -> io::Result<usize>156 pub fn peek_size(&self) -> io::Result<usize> { 157 match &self.stream { 158 SocketType::Byte(_) => Err(std::io::Error::new( 159 std::io::ErrorKind::Other, 160 "Cannot check the size of streamed data", 161 )), 162 SocketType::Message(sock) => Ok(sock.next_packet_size()?), 163 } 164 } 165 set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>166 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> { 167 match &self.stream { 168 SocketType::Byte(sock) => sock.set_read_timeout(timeout), 169 SocketType::Message(sock) => sock.set_read_timeout(timeout), 170 } 171 } 172 set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>173 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> { 174 match &self.stream { 175 SocketType::Byte(sock) => sock.set_write_timeout(timeout), 176 SocketType::Message(sock) => sock.set_write_timeout(timeout), 177 } 178 } 179 180 // WARNING: Generally, multiple StreamChannel ends are not wanted. StreamChannel behavior with 181 // > 1 reader per end is not defined. try_clone(&self) -> io::Result<Self>182 pub fn try_clone(&self) -> io::Result<Self> { 183 Ok(StreamChannel { 184 stream: match &self.stream { 185 SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?), 186 SocketType::Message(sock) => SocketType::Message(sock.try_clone()?), 187 }, 188 }) 189 } 190 } 191 192 impl io::Write for StreamChannel { write(&mut self, buf: &[u8]) -> io::Result<usize>193 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 194 match &mut self.stream { 195 SocketType::Byte(sock) => sock.write(buf), 196 SocketType::Message(sock) => sock.send(buf), 197 } 198 } flush(&mut self) -> io::Result<()>199 fn flush(&mut self) -> io::Result<()> { 200 match &mut self.stream { 201 SocketType::Byte(sock) => sock.flush(), 202 SocketType::Message(_) => Ok(()), 203 } 204 } 205 } 206 207 impl io::Write for &StreamChannel { write(&mut self, buf: &[u8]) -> io::Result<usize>208 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 209 match &self.stream { 210 SocketType::Byte(sock) => (&mut &*sock).write(buf), 211 SocketType::Message(sock) => sock.send(buf), 212 } 213 } flush(&mut self) -> io::Result<()>214 fn flush(&mut self) -> io::Result<()> { 215 match &self.stream { 216 SocketType::Byte(sock) => (&mut &*sock).flush(), 217 SocketType::Message(_) => Ok(()), 218 } 219 } 220 } 221 222 impl AsRawFd for StreamChannel { as_raw_fd(&self) -> RawFd223 fn as_raw_fd(&self) -> RawFd { 224 match &self.stream { 225 SocketType::Byte(sock) => sock.as_raw_descriptor(), 226 SocketType::Message(sock) => sock.as_raw_descriptor(), 227 } 228 } 229 } 230 231 impl AsRawFd for &StreamChannel { as_raw_fd(&self) -> RawFd232 fn as_raw_fd(&self) -> RawFd { 233 self.as_raw_descriptor() 234 } 235 } 236 237 impl AsRawDescriptor for &StreamChannel { as_raw_descriptor(&self) -> RawDescriptor238 fn as_raw_descriptor(&self) -> RawDescriptor { 239 match &self.stream { 240 SocketType::Byte(sock) => sock.as_raw_descriptor(), 241 SocketType::Message(sock) => sock.as_raw_descriptor(), 242 } 243 } 244 } 245 246 impl ReadNotifier for StreamChannel { 247 /// Returns a RawDescriptor that can be polled for reads using PollContext. get_read_notifier(&self) -> &dyn AsRawDescriptor248 fn get_read_notifier(&self) -> &dyn AsRawDescriptor { 249 self 250 } 251 } 252 253 #[cfg(test)] 254 mod test { 255 use std::io::Read; 256 use std::io::Write; 257 258 use super::*; 259 use crate::EventContext; 260 use crate::EventToken; 261 use crate::ReadNotifier; 262 263 #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)] 264 enum Token { 265 ReceivedData, 266 } 267 268 #[test] test_non_blocking_pair_byte()269 fn test_non_blocking_pair_byte() { 270 let (mut sender, mut receiver) = 271 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap(); 272 273 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap(); 274 275 // Wait for the data to arrive. 276 let event_ctx: EventContext<Token> = 277 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 278 .unwrap(); 279 let events = event_ctx.wait().unwrap(); 280 let tokens: Vec<Token> = events 281 .iter() 282 .filter(|e| e.is_readable) 283 .map(|e| e.token) 284 .collect(); 285 assert_eq!(tokens, vec! {Token::ReceivedData}); 286 287 // Smaller than what we sent so we get multiple chunks 288 let mut recv_buffer: [u8; 4] = [0; 4]; 289 290 let mut size = receiver.read(&mut recv_buffer).unwrap(); 291 assert_eq!(size, 4); 292 assert_eq!(recv_buffer, [75, 77, 54, 82]); 293 294 size = receiver.read(&mut recv_buffer).unwrap(); 295 assert_eq!(size, 2); 296 assert_eq!(recv_buffer[0..2], [76, 65]); 297 298 // Now that we've polled for & received all data, polling again should show no events. 299 assert_eq!( 300 event_ctx 301 .wait_timeout(std::time::Duration::new(0, 0)) 302 .unwrap() 303 .len(), 304 0 305 ); 306 } 307 308 #[test] test_non_blocking_pair_message()309 fn test_non_blocking_pair_message() { 310 let (mut sender, mut receiver) = 311 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap(); 312 313 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap(); 314 315 // Wait for the data to arrive. 316 let event_ctx: EventContext<Token> = 317 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 318 .unwrap(); 319 let events = event_ctx.wait().unwrap(); 320 let tokens: Vec<Token> = events 321 .iter() 322 .filter(|e| e.is_readable) 323 .map(|e| e.token) 324 .collect(); 325 assert_eq!(tokens, vec! {Token::ReceivedData}); 326 327 // Unlike Byte format, Message mode panics if the buffer is smaller than the packet size; 328 // make the buffer the right size. 329 let mut recv_buffer: [u8; 6] = [0; 6]; 330 331 let size = receiver.read(&mut recv_buffer).unwrap(); 332 assert_eq!(size, 6); 333 assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]); 334 335 // Now that we've polled for & received all data, polling again should show no events. 336 assert_eq!( 337 event_ctx 338 .wait_timeout(std::time::Duration::new(0, 0)) 339 .unwrap() 340 .len(), 341 0 342 ); 343 } 344 345 #[test] test_non_blocking_pair_error_no_data()346 fn test_non_blocking_pair_error_no_data() { 347 let (mut sender, mut receiver) = 348 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap(); 349 receiver 350 .set_nonblocking(true) 351 .expect("Failed to set receiver to nonblocking mode."); 352 353 sender.write_all(&[75, 77]).unwrap(); 354 355 // Wait for the data to arrive. 356 let event_ctx: EventContext<Token> = 357 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 358 .unwrap(); 359 let events = event_ctx.wait().unwrap(); 360 let tokens: Vec<Token> = events 361 .iter() 362 .filter(|e| e.is_readable) 363 .map(|e| e.token) 364 .collect(); 365 assert_eq!(tokens, vec! {Token::ReceivedData}); 366 367 // We only read 2 bytes, even though we requested 4 bytes. 368 let mut recv_buffer: [u8; 4] = [0; 4]; 369 let size = receiver.read(&mut recv_buffer).unwrap(); 370 assert_eq!(size, 2); 371 assert_eq!(recv_buffer, [75, 77, 00, 00]); 372 373 // Further reads should encounter an error since there is no available data and this is a 374 // non blocking pipe. 375 assert!(receiver.read(&mut recv_buffer).is_err()); 376 } 377 378 #[test] test_from_unix_seqpacket()379 fn test_from_unix_seqpacket() { 380 let (sock_sender, sock_receiver) = UnixSeqpacket::pair().unwrap(); 381 let mut sender = StreamChannel::from_unix_seqpacket(sock_sender); 382 let mut receiver = StreamChannel::from_unix_seqpacket(sock_receiver); 383 384 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap(); 385 386 // Wait for the data to arrive. 387 let event_ctx: EventContext<Token> = 388 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 389 .unwrap(); 390 let events = event_ctx.wait().unwrap(); 391 let tokens: Vec<Token> = events 392 .iter() 393 .filter(|e| e.is_readable) 394 .map(|e| e.token) 395 .collect(); 396 assert_eq!(tokens, vec! {Token::ReceivedData}); 397 398 let mut recv_buffer: [u8; 6] = [0; 6]; 399 400 let size = receiver.read(&mut recv_buffer).unwrap(); 401 assert_eq!(size, 6); 402 assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]); 403 404 // Now that we've polled for & received all data, polling again should show no events. 405 assert_eq!( 406 event_ctx 407 .wait_timeout(std::time::Duration::new(0, 0)) 408 .unwrap() 409 .len(), 410 0 411 ); 412 } 413 } 414