1 // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use std::fs::File; 5 use std::mem; 6 7 use base::AsRawDescriptor; 8 9 use crate::message::*; 10 use crate::BackendReq; 11 use crate::Connection; 12 use crate::Error; 13 use crate::HandlerResult; 14 use crate::Result; 15 use crate::SystemStream; 16 17 /// Trait for vhost-user frontends to respond to requests from the backend. 18 /// 19 /// Each method corresponds to a vhost-user protocol method. See the specification for details. 20 pub trait Frontend { 21 /// Handle device configuration change notifications. handle_config_change(&mut self) -> HandlerResult<u64>22 fn handle_config_change(&mut self) -> HandlerResult<u64> { 23 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 24 } 25 26 /// Handle shared memory region mapping requests. shmem_map( &mut self, _req: &VhostUserShmemMapMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>27 fn shmem_map( 28 &mut self, 29 _req: &VhostUserShmemMapMsg, 30 _fd: &dyn AsRawDescriptor, 31 ) -> HandlerResult<u64> { 32 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 33 } 34 35 /// Handle shared memory region unmapping requests. shmem_unmap(&mut self, _req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64>36 fn shmem_unmap(&mut self, _req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64> { 37 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 38 } 39 40 // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); 41 // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawDescriptor); 42 43 /// Handle GPU shared memory region mapping requests. gpu_map( &mut self, _req: &VhostUserGpuMapMsg, _descriptor: &dyn AsRawDescriptor, ) -> HandlerResult<u64>44 fn gpu_map( 45 &mut self, 46 _req: &VhostUserGpuMapMsg, 47 _descriptor: &dyn AsRawDescriptor, 48 ) -> HandlerResult<u64> { 49 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 50 } 51 52 /// Handle external memory region mapping requests. external_map(&mut self, _req: &VhostUserExternalMapMsg) -> HandlerResult<u64>53 fn external_map(&mut self, _req: &VhostUserExternalMapMsg) -> HandlerResult<u64> { 54 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 55 } 56 } 57 58 /// Handles requests from a vhost-user backend connection by dispatching them to [[Frontend]] 59 /// methods. 60 pub struct FrontendServer<S: Frontend> { 61 // underlying Unix domain socket for communication 62 pub(crate) sub_sock: Connection<BackendReq>, 63 // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. 64 reply_ack_negotiated: bool, 65 66 frontend: S, 67 } 68 69 impl<S: Frontend> FrontendServer<S> { 70 /// Create a server to handle requests from `stream`. new(frontend: S, stream: SystemStream) -> Result<Self>71 pub(crate) fn new(frontend: S, stream: SystemStream) -> Result<Self> { 72 Ok(FrontendServer { 73 sub_sock: Connection::from(stream), 74 reply_ack_negotiated: false, 75 frontend, 76 }) 77 } 78 79 /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. 80 /// 81 /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, 82 /// the "REPLY_ACK" flag will be set in the message header for every request message. set_reply_ack_flag(&mut self, enable: bool)83 pub fn set_reply_ack_flag(&mut self, enable: bool) { 84 self.reply_ack_negotiated = enable; 85 } 86 87 /// Get the underlying frontend frontend_mut(&mut self) -> &mut S88 pub fn frontend_mut(&mut self) -> &mut S { 89 &mut self.frontend 90 } 91 92 /// Process the next received request. 93 /// 94 /// The caller needs to: 95 /// - serialize calls to this function 96 /// - decide what to do when errer happens 97 /// - optional recover from failure handle_request(&mut self) -> Result<u64>98 pub fn handle_request(&mut self) -> Result<u64> { 99 // The underlying communication channel is a Unix domain socket in 100 // stream mode, and recvmsg() is a little tricky here. To successfully 101 // receive attached file descriptors, we need to receive messages and 102 // corresponding attached file descriptors in this way: 103 // . recv messsage header and optional attached file 104 // . validate message header 105 // . recv optional message body and payload according size field in 106 // message header 107 // . validate message body and optional payload 108 let (hdr, files) = self.sub_sock.recv_header()?; 109 self.check_attached_files(&hdr, &files)?; 110 let buf = self.sub_sock.recv_body_bytes(&hdr)?; 111 let size = buf.len(); 112 113 let res = match hdr.get_code() { 114 Ok(BackendReq::CONFIG_CHANGE_MSG) => { 115 self.check_msg_size(&hdr, size, 0)?; 116 self.frontend 117 .handle_config_change() 118 .map_err(Error::ReqHandlerError) 119 } 120 Ok(BackendReq::SHMEM_MAP) => { 121 let msg = self.extract_msg_body::<VhostUserShmemMapMsg>(&hdr, size, &buf)?; 122 // check_attached_files() has validated files 123 self.frontend 124 .shmem_map(&msg, &files[0]) 125 .map_err(Error::ReqHandlerError) 126 } 127 Ok(BackendReq::SHMEM_UNMAP) => { 128 let msg = self.extract_msg_body::<VhostUserShmemUnmapMsg>(&hdr, size, &buf)?; 129 self.frontend 130 .shmem_unmap(&msg) 131 .map_err(Error::ReqHandlerError) 132 } 133 Ok(BackendReq::GPU_MAP) => { 134 let msg = self.extract_msg_body::<VhostUserGpuMapMsg>(&hdr, size, &buf)?; 135 // check_attached_files() has validated files 136 self.frontend 137 .gpu_map(&msg, &files[0]) 138 .map_err(Error::ReqHandlerError) 139 } 140 Ok(BackendReq::EXTERNAL_MAP) => { 141 let msg = self.extract_msg_body::<VhostUserExternalMapMsg>(&hdr, size, &buf)?; 142 self.frontend 143 .external_map(&msg) 144 .map_err(Error::ReqHandlerError) 145 } 146 _ => Err(Error::InvalidMessage), 147 }; 148 149 self.send_reply(&hdr, &res)?; 150 151 res 152 } 153 check_msg_size( &self, hdr: &VhostUserMsgHeader<BackendReq>, size: usize, expected: usize, ) -> Result<()>154 fn check_msg_size( 155 &self, 156 hdr: &VhostUserMsgHeader<BackendReq>, 157 size: usize, 158 expected: usize, 159 ) -> Result<()> { 160 if hdr.get_size() as usize != expected 161 || hdr.is_reply() 162 || hdr.get_version() != 0x1 163 || size != expected 164 { 165 return Err(Error::InvalidMessage); 166 } 167 Ok(()) 168 } 169 check_attached_files( &self, hdr: &VhostUserMsgHeader<BackendReq>, files: &[File], ) -> Result<()>170 fn check_attached_files( 171 &self, 172 hdr: &VhostUserMsgHeader<BackendReq>, 173 files: &[File], 174 ) -> Result<()> { 175 let expected_num_files = match hdr.get_code().map_err(|_| Error::InvalidMessage)? { 176 // Expect a single file is passed. 177 BackendReq::SHMEM_MAP | BackendReq::GPU_MAP => 1, 178 _ => 0, 179 }; 180 181 if files.len() == expected_num_files { 182 Ok(()) 183 } else { 184 Err(Error::InvalidMessage) 185 } 186 } 187 extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<BackendReq>, size: usize, buf: &[u8], ) -> Result<T>188 fn extract_msg_body<T: Sized + VhostUserMsgValidator>( 189 &self, 190 hdr: &VhostUserMsgHeader<BackendReq>, 191 size: usize, 192 buf: &[u8], 193 ) -> Result<T> { 194 self.check_msg_size(hdr, size, mem::size_of::<T>())?; 195 // SAFETY: above check ensures that buf is `T` sized. 196 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; 197 if !msg.is_valid() { 198 return Err(Error::InvalidMessage); 199 } 200 Ok(msg) 201 } 202 new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<BackendReq>, ) -> Result<VhostUserMsgHeader<BackendReq>>203 fn new_reply_header<T: Sized>( 204 &self, 205 req: &VhostUserMsgHeader<BackendReq>, 206 ) -> Result<VhostUserMsgHeader<BackendReq>> { 207 Ok(VhostUserMsgHeader::new( 208 req.get_code().map_err(|_| Error::InvalidMessage)?, 209 VhostUserHeaderFlag::REPLY.bits(), 210 mem::size_of::<T>() as u32, 211 )) 212 } 213 send_reply( &mut self, req: &VhostUserMsgHeader<BackendReq>, res: &Result<u64>, ) -> Result<()>214 fn send_reply( 215 &mut self, 216 req: &VhostUserMsgHeader<BackendReq>, 217 res: &Result<u64>, 218 ) -> Result<()> { 219 let code = req.get_code().map_err(|_| Error::InvalidMessage)?; 220 if code == BackendReq::SHMEM_MAP 221 || code == BackendReq::SHMEM_UNMAP 222 || code == BackendReq::GPU_MAP 223 || code == BackendReq::EXTERNAL_MAP 224 || (self.reply_ack_negotiated && req.is_need_reply()) 225 { 226 let hdr = self.new_reply_header::<VhostUserU64>(req)?; 227 let def_err = libc::EINVAL; 228 let val = match res { 229 Ok(n) => *n, 230 Err(e) => match e { 231 Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() { 232 Some(rawerr) => -rawerr as u64, 233 None => -def_err as u64, 234 }, 235 _ => -def_err as u64, 236 }, 237 }; 238 let msg = VhostUserU64::new(val); 239 self.sub_sock.send_message(&hdr, &msg, None)?; 240 } 241 Ok(()) 242 } 243 } 244