1 // Copyright 2021 The Chromium OS Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use std::io::{self, IoSlice}; 6 use std::marker::PhantomData; 7 use std::ops::Deref; 8 use std::os::unix::prelude::{AsRawFd, RawFd}; 9 use std::time::Duration; 10 11 use crate::{net::UnixSeqpacket, FromRawDescriptor, SafeDescriptor, ScmSocket, UnsyncMarker}; 12 13 use cros_async::{Executor, IntoAsync, IoSourceExt}; 14 use serde::{de::DeserializeOwned, Serialize}; 15 use sys_util::{ 16 deserialize_with_descriptors, AsRawDescriptor, RawDescriptor, SerializeDescriptors, 17 }; 18 use thiserror::Error as ThisError; 19 20 #[derive(ThisError, Debug)] 21 pub enum Error { 22 #[error("failed to serialize/deserialize json from packet: {0}")] 23 Json(serde_json::Error), 24 #[error("failed to send packet: {0}")] 25 Send(sys_util::Error), 26 #[error("failed to receive packet: {0}")] 27 Recv(io::Error), 28 #[error("tube was disconnected")] 29 Disconnected, 30 #[error("failed to crate tube pair: {0}")] 31 Pair(io::Error), 32 #[error("failed to set send timeout: {0}")] 33 SetSendTimeout(io::Error), 34 #[error("failed to set recv timeout: {0}")] 35 SetRecvTimeout(io::Error), 36 #[error("failed to create async tube: {0}")] 37 CreateAsync(cros_async::AsyncError), 38 } 39 40 pub type Result<T> = std::result::Result<T, Error>; 41 42 /// Bidirectional tube that support both send and recv. 43 pub struct Tube { 44 socket: UnixSeqpacket, 45 _unsync_marker: UnsyncMarker, 46 } 47 48 impl Tube { 49 /// Create a pair of connected tubes. Request is send in one direction while response is in the 50 /// other direction. pair() -> Result<(Tube, Tube)>51 pub fn pair() -> Result<(Tube, Tube)> { 52 let (socket1, socket2) = UnixSeqpacket::pair().map_err(Error::Pair)?; 53 let tube1 = Tube::new(socket1); 54 let tube2 = Tube::new(socket2); 55 Ok((tube1, tube2)) 56 } 57 58 // Create a new `Tube`. new(socket: UnixSeqpacket) -> Tube59 pub fn new(socket: UnixSeqpacket) -> Tube { 60 Tube { 61 socket, 62 _unsync_marker: PhantomData, 63 } 64 } 65 into_async_tube(self, ex: &Executor) -> Result<AsyncTube>66 pub fn into_async_tube(self, ex: &Executor) -> Result<AsyncTube> { 67 let inner = ex.async_from(self).map_err(Error::CreateAsync)?; 68 Ok(AsyncTube { inner }) 69 } 70 send<T: Serialize>(&self, msg: &T) -> Result<()>71 pub fn send<T: Serialize>(&self, msg: &T) -> Result<()> { 72 let msg_serialize = SerializeDescriptors::new(&msg); 73 let msg_json = serde_json::to_vec(&msg_serialize).map_err(Error::Json)?; 74 let msg_descriptors = msg_serialize.into_descriptors(); 75 76 self.socket 77 .send_with_fds(&[IoSlice::new(&msg_json)], &msg_descriptors) 78 .map_err(Error::Send)?; 79 Ok(()) 80 } 81 recv<T: DeserializeOwned>(&self) -> Result<T>82 pub fn recv<T: DeserializeOwned>(&self) -> Result<T> { 83 let (msg_json, msg_descriptors) = 84 self.socket.recv_as_vec_with_fds().map_err(Error::Recv)?; 85 86 if msg_json.is_empty() { 87 return Err(Error::Disconnected); 88 } 89 90 let mut msg_descriptors_safe = msg_descriptors 91 .into_iter() 92 .map(|v| { 93 Some(unsafe { 94 // Safe because the socket returns new fds that are owned locally by this scope. 95 SafeDescriptor::from_raw_descriptor(v) 96 }) 97 }) 98 .collect(); 99 100 deserialize_with_descriptors( 101 || serde_json::from_slice(&msg_json), 102 &mut msg_descriptors_safe, 103 ) 104 .map_err(Error::Json) 105 } 106 107 /// Returns true if there is a packet ready to `recv` without blocking. 108 /// 109 /// If there is an error trying to determine if there is a packet ready, this returns false. is_packet_ready(&self) -> bool110 pub fn is_packet_ready(&self) -> bool { 111 self.socket.get_readable_bytes().unwrap_or(0) > 0 112 } 113 set_send_timeout(&self, timeout: Option<Duration>) -> Result<()>114 pub fn set_send_timeout(&self, timeout: Option<Duration>) -> Result<()> { 115 self.socket 116 .set_write_timeout(timeout) 117 .map_err(Error::SetSendTimeout) 118 } 119 set_recv_timeout(&self, timeout: Option<Duration>) -> Result<()>120 pub fn set_recv_timeout(&self, timeout: Option<Duration>) -> Result<()> { 121 self.socket 122 .set_read_timeout(timeout) 123 .map_err(Error::SetRecvTimeout) 124 } 125 } 126 127 impl AsRawDescriptor for Tube { as_raw_descriptor(&self) -> RawDescriptor128 fn as_raw_descriptor(&self) -> RawDescriptor { 129 self.socket.as_raw_descriptor() 130 } 131 } 132 133 impl AsRawFd for Tube { as_raw_fd(&self) -> RawFd134 fn as_raw_fd(&self) -> RawFd { 135 self.socket.as_raw_fd() 136 } 137 } 138 139 impl IntoAsync for Tube {} 140 141 pub struct AsyncTube { 142 inner: Box<dyn IoSourceExt<Tube>>, 143 } 144 145 impl AsyncTube { next<T: DeserializeOwned>(&self) -> Result<T>146 pub async fn next<T: DeserializeOwned>(&self) -> Result<T> { 147 self.inner.wait_readable().await.unwrap(); 148 self.inner.as_source().recv() 149 } 150 } 151 152 impl Deref for AsyncTube { 153 type Target = Tube; 154 deref(&self) -> &Self::Target155 fn deref(&self) -> &Self::Target { 156 self.inner.as_source() 157 } 158 } 159 160 impl Into<Tube> for AsyncTube { into(self) -> Tube161 fn into(self) -> Tube { 162 self.inner.into_source() 163 } 164 } 165 166 #[cfg(test)] 167 mod tests { 168 use super::*; 169 use crate::Event; 170 171 use std::collections::HashMap; 172 use std::time::Duration; 173 174 use serde::{Deserialize, Serialize}; 175 176 #[track_caller] test_event_pair(send: Event, mut recv: Event)177 fn test_event_pair(send: Event, mut recv: Event) { 178 send.write(1).unwrap(); 179 recv.read_timeout(Duration::from_secs(1)).unwrap(); 180 } 181 182 #[test] send_recv_no_fd()183 fn send_recv_no_fd() { 184 let (s1, s2) = Tube::pair().unwrap(); 185 186 let test_msg = "hello world"; 187 s1.send(&test_msg).unwrap(); 188 let recv_msg: String = s2.recv().unwrap(); 189 190 assert_eq!(test_msg, recv_msg); 191 } 192 193 #[test] send_recv_one_fd()194 fn send_recv_one_fd() { 195 #[derive(Serialize, Deserialize)] 196 struct EventStruct { 197 x: u32, 198 b: Event, 199 } 200 201 let (s1, s2) = Tube::pair().unwrap(); 202 203 let test_msg = EventStruct { 204 x: 100, 205 b: Event::new().unwrap(), 206 }; 207 s1.send(&test_msg).unwrap(); 208 let recv_msg: EventStruct = s2.recv().unwrap(); 209 210 assert_eq!(test_msg.x, recv_msg.x); 211 212 test_event_pair(test_msg.b, recv_msg.b); 213 } 214 215 #[test] send_recv_hash_map()216 fn send_recv_hash_map() { 217 let (s1, s2) = Tube::pair().unwrap(); 218 219 let mut test_msg = HashMap::new(); 220 test_msg.insert("Red".to_owned(), Event::new().unwrap()); 221 test_msg.insert("White".to_owned(), Event::new().unwrap()); 222 test_msg.insert("Blue".to_owned(), Event::new().unwrap()); 223 test_msg.insert("Orange".to_owned(), Event::new().unwrap()); 224 test_msg.insert("Green".to_owned(), Event::new().unwrap()); 225 s1.send(&test_msg).unwrap(); 226 let mut recv_msg: HashMap<String, Event> = s2.recv().unwrap(); 227 228 let mut test_msg_keys: Vec<_> = test_msg.keys().collect(); 229 test_msg_keys.sort(); 230 let mut recv_msg_keys: Vec<_> = recv_msg.keys().collect(); 231 recv_msg_keys.sort(); 232 assert_eq!(test_msg_keys, recv_msg_keys); 233 234 for (key, test_event) in test_msg { 235 let recv_event = recv_msg.remove(&key).unwrap(); 236 test_event_pair(test_event, recv_event); 237 } 238 } 239 } 240