1 // Copyright 2021 The Chromium OS Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use remain::sorted; 6 use std::io; 7 8 use thiserror::Error as ThisError; 9 10 #[cfg_attr(windows, path = "windows/tube.rs")] 11 #[cfg_attr(not(windows), path = "unix/tube.rs")] 12 mod tube; 13 use serde::{de::DeserializeOwned, Deserialize, Serialize}; 14 use std::time::Duration; 15 pub use tube::*; 16 17 impl Tube { 18 /// Creates a Send/Recv pair of Tubes. directional_pair() -> Result<(SendTube, RecvTube)>19 pub fn directional_pair() -> Result<(SendTube, RecvTube)> { 20 let (t1, t2) = Self::pair()?; 21 Ok((SendTube(t1), RecvTube(t2))) 22 } 23 } 24 25 #[derive(Serialize, Deserialize)] 26 #[serde(transparent)] 27 /// A Tube end which can only send messages. Cloneable. 28 pub struct SendTube(Tube); 29 30 #[allow(dead_code)] 31 impl SendTube { 32 /// TODO(b/145998747, b/184398671): this method should be removed. set_send_timeout(&self, _timeout: Option<Duration>) -> Result<()>33 pub fn set_send_timeout(&self, _timeout: Option<Duration>) -> Result<()> { 34 unimplemented!("To be removed/refactored upstream."); 35 } 36 send<T: Serialize>(&self, msg: &T) -> Result<()>37 pub fn send<T: Serialize>(&self, msg: &T) -> Result<()> { 38 self.0.send(msg) 39 } 40 try_clone(&self) -> Result<Self>41 pub fn try_clone(&self) -> Result<Self> { 42 Ok(SendTube( 43 #[allow(deprecated)] 44 self.0.try_clone()?, 45 )) 46 } 47 48 /// Never call this function, it is for use by cros_async to provide 49 /// directional wrapper types only. Using it in any other context may 50 /// violate concurrency assumptions. (Type splitting across crates has put 51 /// us in a situation where we can't use Rust privacy to enforce this.) 52 #[deprecated] into_tube(self) -> Tube53 pub fn into_tube(self) -> Tube { 54 self.0 55 } 56 } 57 58 #[derive(Serialize, Deserialize)] 59 #[serde(transparent)] 60 /// A Tube end which can only recv messages. 61 pub struct RecvTube(Tube); 62 63 #[allow(dead_code)] 64 impl RecvTube { recv<T: DeserializeOwned>(&self) -> Result<T>65 pub fn recv<T: DeserializeOwned>(&self) -> Result<T> { 66 self.0.recv() 67 } 68 69 /// TODO(b/145998747, b/184398671): this method should be removed. set_recv_timeout(&self, _timeout: Option<Duration>) -> Result<()>70 pub fn set_recv_timeout(&self, _timeout: Option<Duration>) -> Result<()> { 71 unimplemented!("To be removed/refactored upstream."); 72 } 73 74 /// Never call this function, it is for use by cros_async to provide 75 /// directional wrapper types only. Using it in any other context may 76 /// violate concurrency assumptions. (Type splitting across crates has put 77 /// us in a situation where we can't use Rust privacy to enforce this.) 78 #[deprecated] into_tube(self) -> Tube79 pub fn into_tube(self) -> Tube { 80 self.0 81 } 82 } 83 84 #[sorted] 85 #[derive(ThisError, Debug)] 86 pub enum Error { 87 #[cfg(windows)] 88 #[error("attempt to duplicate descriptor via broker failed")] 89 BrokerDupDescriptor, 90 #[error("failed to clone transport: {0}")] 91 Clone(io::Error), 92 #[error("tube was disconnected")] 93 Disconnected, 94 #[error("failed to duplicate descriptor: {0}")] 95 DupDescriptor(io::Error), 96 #[cfg(windows)] 97 #[error("failed to flush named pipe: {0}")] 98 Flush(io::Error), 99 #[error("failed to serialize/deserialize json from packet: {0}")] 100 Json(serde_json::Error), 101 #[error("cancelled a queued async operation")] 102 OperationCancelled, 103 #[error("failed to crate tube pair: {0}")] 104 Pair(io::Error), 105 #[error("failed to receive packet: {0}")] 106 Recv(io::Error), 107 #[error("Received a message with a zero sized body. This should not happen.")] 108 RecvUnexpectedEmptyBody, 109 #[error("failed to send packet: {0}")] 110 Send(crate::platform::Error), 111 #[error("failed to send packet: {0}")] 112 SendIo(io::Error), 113 #[error("failed to write packet to intermediate buffer: {0}")] 114 SendIoBuf(io::Error), 115 #[error("failed to set recv timeout: {0}")] 116 SetRecvTimeout(io::Error), 117 #[error("failed to set send timeout: {0}")] 118 SetSendTimeout(io::Error), 119 } 120 121 pub type Result<T> = std::result::Result<T, Error>; 122 123 #[cfg(test)] 124 mod tests { 125 use super::*; 126 use crate::Event; 127 128 use std::{collections::HashMap, time::Duration}; 129 130 use serde::{Deserialize, Serialize}; 131 use std::{ 132 sync::{Arc, Barrier}, 133 thread, 134 }; 135 136 #[derive(Serialize, Deserialize)] 137 struct DataStruct { 138 x: u32, 139 } 140 141 // Magics to identify which producer sent a message (& detect corruption). 142 const PRODUCER_ID_1: u32 = 801279273; 143 const PRODUCER_ID_2: u32 = 345234861; 144 145 #[track_caller] test_event_pair(send: Event, recv: Event)146 fn test_event_pair(send: Event, recv: Event) { 147 send.write(1).unwrap(); 148 recv.read_timeout(Duration::from_secs(1)).unwrap(); 149 } 150 151 #[test] send_recv_no_fd()152 fn send_recv_no_fd() { 153 let (s1, s2) = Tube::pair().unwrap(); 154 155 let test_msg = "hello world"; 156 s1.send(&test_msg).unwrap(); 157 let recv_msg: String = s2.recv().unwrap(); 158 159 assert_eq!(test_msg, recv_msg); 160 } 161 162 #[test] send_recv_one_fd()163 fn send_recv_one_fd() { 164 #[derive(Serialize, Deserialize)] 165 struct EventStruct { 166 x: u32, 167 b: Event, 168 } 169 170 let (s1, s2) = Tube::pair().unwrap(); 171 172 let test_msg = EventStruct { 173 x: 100, 174 b: Event::new().unwrap(), 175 }; 176 s1.send(&test_msg).unwrap(); 177 let recv_msg: EventStruct = s2.recv().unwrap(); 178 179 assert_eq!(test_msg.x, recv_msg.x); 180 181 test_event_pair(test_msg.b, recv_msg.b); 182 } 183 184 /// Send messages to a Tube with the given identifier (see `consume_messages`; we use this to 185 /// track different message producers). 186 #[track_caller] produce_messages(tube: SendTube, data: u32, barrier: Arc<Barrier>) -> SendTube187 fn produce_messages(tube: SendTube, data: u32, barrier: Arc<Barrier>) -> SendTube { 188 let data = DataStruct { x: data }; 189 barrier.wait(); 190 for _ in 0..100 { 191 tube.send(&data).unwrap(); 192 } 193 tube 194 } 195 196 /// Consumes the given number of messages from a Tube, returning the number messages read with 197 /// each producer ID. 198 #[track_caller] consume_messages( tube: RecvTube, count: usize, barrier: Arc<Barrier>, ) -> (RecvTube, usize, usize)199 fn consume_messages( 200 tube: RecvTube, 201 count: usize, 202 barrier: Arc<Barrier>, 203 ) -> (RecvTube, usize, usize) { 204 barrier.wait(); 205 206 let mut id1_count = 0usize; 207 let mut id2_count = 0usize; 208 209 for _ in 0..count { 210 let msg = tube.recv::<DataStruct>().unwrap(); 211 match msg.x { 212 PRODUCER_ID_1 => id1_count += 1, 213 PRODUCER_ID_2 => id2_count += 1, 214 _ => panic!( 215 "want message with ID {} or {}; got message w/ ID {}.", 216 PRODUCER_ID_1, PRODUCER_ID_2, msg.x 217 ), 218 } 219 } 220 (tube, id1_count, id2_count) 221 } 222 223 #[test] send_recv_mpsc()224 fn send_recv_mpsc() { 225 let (p1, consumer) = Tube::directional_pair().unwrap(); 226 let p2 = p1.try_clone().unwrap(); 227 let start_block_p1 = Arc::new(Barrier::new(3)); 228 let start_block_p2 = start_block_p1.clone(); 229 let start_block_consumer = start_block_p1.clone(); 230 231 let p1_thread = thread::spawn(move || produce_messages(p1, PRODUCER_ID_1, start_block_p1)); 232 let p2_thread = thread::spawn(move || produce_messages(p2, PRODUCER_ID_2, start_block_p2)); 233 234 let (_tube, id1_count, id2_count) = consume_messages(consumer, 200, start_block_consumer); 235 assert_eq!(id1_count, 100); 236 assert_eq!(id2_count, 100); 237 238 p1_thread.join().unwrap(); 239 p2_thread.join().unwrap(); 240 } 241 242 #[test] send_recv_hash_map()243 fn send_recv_hash_map() { 244 let (s1, s2) = Tube::pair().unwrap(); 245 246 let mut test_msg = HashMap::new(); 247 test_msg.insert("Red".to_owned(), Event::new().unwrap()); 248 test_msg.insert("White".to_owned(), Event::new().unwrap()); 249 test_msg.insert("Blue".to_owned(), Event::new().unwrap()); 250 test_msg.insert("Orange".to_owned(), Event::new().unwrap()); 251 test_msg.insert("Green".to_owned(), Event::new().unwrap()); 252 s1.send(&test_msg).unwrap(); 253 let mut recv_msg: HashMap<String, Event> = s2.recv().unwrap(); 254 255 let mut test_msg_keys: Vec<_> = test_msg.keys().collect(); 256 test_msg_keys.sort(); 257 let mut recv_msg_keys: Vec<_> = recv_msg.keys().collect(); 258 recv_msg_keys.sort(); 259 assert_eq!(test_msg_keys, recv_msg_keys); 260 261 for (key, test_event) in test_msg { 262 let recv_event = recv_msg.remove(&key).unwrap(); 263 test_event_pair(test_event, recv_event); 264 } 265 } 266 } 267