1 // Copyright 2021 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use std::os::unix::prelude::AsRawFd; 6 use std::os::unix::prelude::RawFd; 7 use std::time::Duration; 8 9 use serde::de::DeserializeOwned; 10 use serde::Deserialize; 11 use serde::Serialize; 12 13 use crate::descriptor::AsRawDescriptor; 14 use crate::descriptor_reflection::deserialize_with_descriptors; 15 use crate::descriptor_reflection::SerializeDescriptors; 16 use crate::handle_eintr; 17 use crate::tube::Error; 18 use crate::tube::RecvTube; 19 use crate::tube::Result; 20 use crate::tube::SendTube; 21 use crate::BlockingMode; 22 use crate::FramingMode; 23 use crate::RawDescriptor; 24 use crate::ReadNotifier; 25 use crate::ScmSocket; 26 use crate::StreamChannel; 27 use crate::UnixSeqpacket; 28 29 // This size matches the inline buffer size of CmsgBuffer. 30 const TUBE_MAX_FDS: usize = 32; 31 32 /// Bidirectional tube that support both send and recv. 33 #[derive(Serialize, Deserialize)] 34 pub struct Tube { 35 socket: ScmSocket<StreamChannel>, 36 } 37 38 impl Tube { 39 /// Create a pair of connected tubes. Request is sent in one direction while response is in the 40 /// other direction. pair() -> Result<(Tube, Tube)>41 pub fn pair() -> Result<(Tube, Tube)> { 42 let (socket1, socket2) = StreamChannel::pair(BlockingMode::Blocking, FramingMode::Message) 43 .map_err(|errno| Error::Pair(std::io::Error::from(errno)))?; 44 let tube1 = Tube::new(socket1)?; 45 let tube2 = Tube::new(socket2)?; 46 Ok((tube1, tube2)) 47 } 48 49 /// Create a new `Tube` from a `StreamChannel`. 50 /// The StreamChannel must use FramingMode::Message (meaning, must use a SOCK_SEQPACKET as the 51 /// underlying socket type), otherwise, this method returns an error. new(socket: StreamChannel) -> Result<Tube>52 pub fn new(socket: StreamChannel) -> Result<Tube> { 53 match socket.get_framing_mode() { 54 FramingMode::Message => Ok(Tube { 55 socket: socket.try_into().map_err(Error::DupDescriptor)?, 56 }), 57 FramingMode::Byte => Err(Error::InvalidFramingMode), 58 } 59 } 60 61 /// Create a new `Tube` from a UnixSeqpacket. The StreamChannel is implicitly constructed to 62 /// have the right FramingMode by being constructed from a UnixSeqpacket. new_from_unix_seqpacket(sock: UnixSeqpacket) -> Result<Tube>63 pub fn new_from_unix_seqpacket(sock: UnixSeqpacket) -> Result<Tube> { 64 Ok(Tube { 65 socket: StreamChannel::from_unix_seqpacket(sock) 66 .try_into() 67 .map_err(Error::DupDescriptor)?, 68 }) 69 } 70 71 /// DO NOT USE this method directly as it will become private soon (b/221484449). Use a 72 /// directional Tube pair instead. 73 #[deprecated] try_clone(&self) -> Result<Self>74 pub fn try_clone(&self) -> Result<Self> { 75 self.socket 76 .inner() 77 .try_clone() 78 .map(Tube::new) 79 .map_err(Error::Clone)? 80 } 81 send<T: Serialize>(&self, msg: &T) -> Result<()>82 pub fn send<T: Serialize>(&self, msg: &T) -> Result<()> { 83 let msg_serialize = SerializeDescriptors::new(&msg); 84 let msg_json = serde_json::to_vec(&msg_serialize).map_err(Error::Json)?; 85 let msg_descriptors = msg_serialize.into_descriptors(); 86 87 if msg_descriptors.len() > TUBE_MAX_FDS { 88 return Err(Error::SendTooManyFds); 89 } 90 91 handle_eintr!(self.socket.send_with_fds(&msg_json, &msg_descriptors)) 92 .map_err(Error::Send)?; 93 Ok(()) 94 } 95 recv<T: DeserializeOwned>(&self) -> Result<T>96 pub fn recv<T: DeserializeOwned>(&self) -> Result<T> { 97 let msg_size = handle_eintr!(self.socket.inner().peek_size()).map_err(Error::Recv)?; 98 // This buffer is the right size, as the size received in peek_size() represents the size 99 // of only the message itself and not the file descriptors. The descriptors are stored 100 // separately in msghdr::msg_control. 101 let mut msg_json = vec![0u8; msg_size]; 102 103 let (msg_json_size, msg_descriptors) = 104 handle_eintr!(self.socket.recv_with_fds(&mut msg_json, TUBE_MAX_FDS)) 105 .map_err(Error::Recv)?; 106 107 if msg_json_size == 0 { 108 return Err(Error::Disconnected); 109 } 110 111 deserialize_with_descriptors( 112 || serde_json::from_slice(&msg_json[0..msg_json_size]), 113 msg_descriptors, 114 ) 115 .map_err(Error::Json) 116 } 117 set_send_timeout(&self, timeout: Option<Duration>) -> Result<()>118 pub fn set_send_timeout(&self, timeout: Option<Duration>) -> Result<()> { 119 self.socket 120 .inner() 121 .set_write_timeout(timeout) 122 .map_err(Error::SetSendTimeout) 123 } 124 set_recv_timeout(&self, timeout: Option<Duration>) -> Result<()>125 pub fn set_recv_timeout(&self, timeout: Option<Duration>) -> Result<()> { 126 self.socket 127 .inner() 128 .set_read_timeout(timeout) 129 .map_err(Error::SetRecvTimeout) 130 } 131 132 #[cfg(feature = "proto_tube")] send_proto<M: protobuf::Message>(&self, msg: &M) -> Result<()>133 fn send_proto<M: protobuf::Message>(&self, msg: &M) -> Result<()> { 134 let bytes = msg.write_to_bytes().map_err(Error::Proto)?; 135 let no_fds: [RawFd; 0] = []; 136 137 handle_eintr!(self.socket.send_with_fds(&bytes, &no_fds)).map_err(Error::Send)?; 138 139 Ok(()) 140 } 141 142 #[cfg(feature = "proto_tube")] recv_proto<M: protobuf::Message>(&self) -> Result<M>143 fn recv_proto<M: protobuf::Message>(&self) -> Result<M> { 144 let msg_size = handle_eintr!(self.socket.inner().peek_size()).map_err(Error::Recv)?; 145 let mut msg_bytes = vec![0u8; msg_size]; 146 147 let (msg_bytes_size, _) = 148 handle_eintr!(self.socket.recv_with_fds(&mut msg_bytes, TUBE_MAX_FDS)) 149 .map_err(Error::Recv)?; 150 151 if msg_bytes_size == 0 { 152 return Err(Error::Disconnected); 153 } 154 155 protobuf::Message::parse_from_bytes(&msg_bytes).map_err(Error::Proto) 156 } 157 } 158 159 impl AsRawDescriptor for Tube { as_raw_descriptor(&self) -> RawDescriptor160 fn as_raw_descriptor(&self) -> RawDescriptor { 161 self.socket.as_raw_descriptor() 162 } 163 } 164 165 impl AsRawFd for Tube { as_raw_fd(&self) -> RawFd166 fn as_raw_fd(&self) -> RawFd { 167 self.socket.inner().as_raw_fd() 168 } 169 } 170 171 impl ReadNotifier for Tube { get_read_notifier(&self) -> &dyn AsRawDescriptor172 fn get_read_notifier(&self) -> &dyn AsRawDescriptor { 173 &self.socket 174 } 175 } 176 177 impl AsRawDescriptor for SendTube { as_raw_descriptor(&self) -> RawDescriptor178 fn as_raw_descriptor(&self) -> RawDescriptor { 179 self.0.as_raw_descriptor() 180 } 181 } 182 183 impl AsRawDescriptor for RecvTube { as_raw_descriptor(&self) -> RawDescriptor184 fn as_raw_descriptor(&self) -> RawDescriptor { 185 self.0.as_raw_descriptor() 186 } 187 } 188 189 /// Wrapper for Tube used for sending and receiving protos - avoids extra overhead of serialization 190 /// via serde_json. Since protos should be standalone objects we do not support sending of file 191 /// descriptors as a normal Tube would. 192 #[cfg(feature = "proto_tube")] 193 pub struct ProtoTube(Tube); 194 195 #[cfg(feature = "proto_tube")] 196 impl ProtoTube { pair() -> Result<(ProtoTube, ProtoTube)>197 pub fn pair() -> Result<(ProtoTube, ProtoTube)> { 198 Tube::pair().map(|(t1, t2)| (ProtoTube(t1), ProtoTube(t2))) 199 } 200 send_proto<M: protobuf::Message>(&self, msg: &M) -> Result<()>201 pub fn send_proto<M: protobuf::Message>(&self, msg: &M) -> Result<()> { 202 self.0.send_proto(msg) 203 } 204 recv_proto<M: protobuf::Message>(&self) -> Result<M>205 pub fn recv_proto<M: protobuf::Message>(&self) -> Result<M> { 206 self.0.recv_proto() 207 } 208 new_from_unix_seqpacket(sock: UnixSeqpacket) -> Result<ProtoTube>209 pub fn new_from_unix_seqpacket(sock: UnixSeqpacket) -> Result<ProtoTube> { 210 Ok(ProtoTube(Tube::new_from_unix_seqpacket(sock)?)) 211 } 212 } 213 214 #[cfg(all(feature = "proto_tube", test))] 215 #[allow(unused_variables)] 216 mod tests { 217 // not testing this proto specifically, just need an existing one to test the ProtoTube. 218 use protos::cdisk_spec::ComponentDisk; 219 220 use super::*; 221 222 #[test] tube_serializes_and_deserializes()223 fn tube_serializes_and_deserializes() { 224 let (pt1, pt2) = ProtoTube::pair().unwrap(); 225 let proto = ComponentDisk { 226 file_path: "/some/cool/path".to_string(), 227 offset: 99, 228 ..ComponentDisk::new() 229 }; 230 231 pt1.send_proto(&proto).unwrap(); 232 233 let recv_proto = pt2.recv_proto().unwrap(); 234 235 assert!(proto.eq(&recv_proto)); 236 } 237 } 238