• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::fs::File;
5 use std::mem;
6 
7 use base::AsRawDescriptor;
8 
9 use crate::message::*;
10 use crate::BackendReq;
11 use crate::Connection;
12 use crate::Error;
13 use crate::HandlerResult;
14 use crate::Result;
15 use crate::SystemStream;
16 
17 /// Trait for vhost-user frontends to respond to requests from the backend.
18 ///
19 /// Each method corresponds to a vhost-user protocol method. See the specification for details.
20 pub trait Frontend {
21     /// Handle device configuration change notifications.
handle_config_change(&mut self) -> HandlerResult<u64>22     fn handle_config_change(&mut self) -> HandlerResult<u64> {
23         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
24     }
25 
26     /// Handle shared memory region mapping requests.
shmem_map( &mut self, _req: &VhostUserShmemMapMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>27     fn shmem_map(
28         &mut self,
29         _req: &VhostUserShmemMapMsg,
30         _fd: &dyn AsRawDescriptor,
31     ) -> HandlerResult<u64> {
32         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
33     }
34 
35     /// Handle shared memory region unmapping requests.
shmem_unmap(&mut self, _req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64>36     fn shmem_unmap(&mut self, _req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64> {
37         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
38     }
39 
40     // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
41     // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawDescriptor);
42 
43     /// Handle GPU shared memory region mapping requests.
gpu_map( &mut self, _req: &VhostUserGpuMapMsg, _descriptor: &dyn AsRawDescriptor, ) -> HandlerResult<u64>44     fn gpu_map(
45         &mut self,
46         _req: &VhostUserGpuMapMsg,
47         _descriptor: &dyn AsRawDescriptor,
48     ) -> HandlerResult<u64> {
49         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
50     }
51 
52     /// Handle external memory region mapping requests.
external_map(&mut self, _req: &VhostUserExternalMapMsg) -> HandlerResult<u64>53     fn external_map(&mut self, _req: &VhostUserExternalMapMsg) -> HandlerResult<u64> {
54         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
55     }
56 }
57 
58 /// Handles requests from a vhost-user backend connection by dispatching them to [[Frontend]]
59 /// methods.
60 pub struct FrontendServer<S: Frontend> {
61     // underlying Unix domain socket for communication
62     pub(crate) sub_sock: Connection<BackendReq>,
63     // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
64     reply_ack_negotiated: bool,
65 
66     frontend: S,
67 }
68 
69 impl<S: Frontend> FrontendServer<S> {
70     /// Create a server to handle requests from `stream`.
new(frontend: S, stream: SystemStream) -> Result<Self>71     pub(crate) fn new(frontend: S, stream: SystemStream) -> Result<Self> {
72         Ok(FrontendServer {
73             sub_sock: Connection::from(stream),
74             reply_ack_negotiated: false,
75             frontend,
76         })
77     }
78 
79     /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
80     ///
81     /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
82     /// the "REPLY_ACK" flag will be set in the message header for every request message.
set_reply_ack_flag(&mut self, enable: bool)83     pub fn set_reply_ack_flag(&mut self, enable: bool) {
84         self.reply_ack_negotiated = enable;
85     }
86 
87     /// Get the underlying frontend
frontend_mut(&mut self) -> &mut S88     pub fn frontend_mut(&mut self) -> &mut S {
89         &mut self.frontend
90     }
91 
92     /// Process the next received request.
93     ///
94     /// The caller needs to:
95     /// - serialize calls to this function
96     /// - decide what to do when errer happens
97     /// - optional recover from failure
handle_request(&mut self) -> Result<u64>98     pub fn handle_request(&mut self) -> Result<u64> {
99         // The underlying communication channel is a Unix domain socket in
100         // stream mode, and recvmsg() is a little tricky here. To successfully
101         // receive attached file descriptors, we need to receive messages and
102         // corresponding attached file descriptors in this way:
103         // . recv messsage header and optional attached file
104         // . validate message header
105         // . recv optional message body and payload according size field in
106         //   message header
107         // . validate message body and optional payload
108         let (hdr, files) = self.sub_sock.recv_header()?;
109         self.check_attached_files(&hdr, &files)?;
110         let buf = self.sub_sock.recv_body_bytes(&hdr)?;
111         let size = buf.len();
112 
113         let res = match hdr.get_code() {
114             Ok(BackendReq::CONFIG_CHANGE_MSG) => {
115                 self.check_msg_size(&hdr, size, 0)?;
116                 self.frontend
117                     .handle_config_change()
118                     .map_err(Error::ReqHandlerError)
119             }
120             Ok(BackendReq::SHMEM_MAP) => {
121                 let msg = self.extract_msg_body::<VhostUserShmemMapMsg>(&hdr, size, &buf)?;
122                 // check_attached_files() has validated files
123                 self.frontend
124                     .shmem_map(&msg, &files[0])
125                     .map_err(Error::ReqHandlerError)
126             }
127             Ok(BackendReq::SHMEM_UNMAP) => {
128                 let msg = self.extract_msg_body::<VhostUserShmemUnmapMsg>(&hdr, size, &buf)?;
129                 self.frontend
130                     .shmem_unmap(&msg)
131                     .map_err(Error::ReqHandlerError)
132             }
133             Ok(BackendReq::GPU_MAP) => {
134                 let msg = self.extract_msg_body::<VhostUserGpuMapMsg>(&hdr, size, &buf)?;
135                 // check_attached_files() has validated files
136                 self.frontend
137                     .gpu_map(&msg, &files[0])
138                     .map_err(Error::ReqHandlerError)
139             }
140             Ok(BackendReq::EXTERNAL_MAP) => {
141                 let msg = self.extract_msg_body::<VhostUserExternalMapMsg>(&hdr, size, &buf)?;
142                 self.frontend
143                     .external_map(&msg)
144                     .map_err(Error::ReqHandlerError)
145             }
146             _ => Err(Error::InvalidMessage),
147         };
148 
149         self.send_reply(&hdr, &res)?;
150 
151         res
152     }
153 
check_msg_size( &self, hdr: &VhostUserMsgHeader<BackendReq>, size: usize, expected: usize, ) -> Result<()>154     fn check_msg_size(
155         &self,
156         hdr: &VhostUserMsgHeader<BackendReq>,
157         size: usize,
158         expected: usize,
159     ) -> Result<()> {
160         if hdr.get_size() as usize != expected
161             || hdr.is_reply()
162             || hdr.get_version() != 0x1
163             || size != expected
164         {
165             return Err(Error::InvalidMessage);
166         }
167         Ok(())
168     }
169 
check_attached_files( &self, hdr: &VhostUserMsgHeader<BackendReq>, files: &[File], ) -> Result<()>170     fn check_attached_files(
171         &self,
172         hdr: &VhostUserMsgHeader<BackendReq>,
173         files: &[File],
174     ) -> Result<()> {
175         let expected_num_files = match hdr.get_code().map_err(|_| Error::InvalidMessage)? {
176             // Expect a single file is passed.
177             BackendReq::SHMEM_MAP | BackendReq::GPU_MAP => 1,
178             _ => 0,
179         };
180 
181         if files.len() == expected_num_files {
182             Ok(())
183         } else {
184             Err(Error::InvalidMessage)
185         }
186     }
187 
extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<BackendReq>, size: usize, buf: &[u8], ) -> Result<T>188     fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
189         &self,
190         hdr: &VhostUserMsgHeader<BackendReq>,
191         size: usize,
192         buf: &[u8],
193     ) -> Result<T> {
194         self.check_msg_size(hdr, size, mem::size_of::<T>())?;
195         // SAFETY: above check ensures that buf is `T` sized.
196         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
197         if !msg.is_valid() {
198             return Err(Error::InvalidMessage);
199         }
200         Ok(msg)
201     }
202 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<BackendReq>, ) -> Result<VhostUserMsgHeader<BackendReq>>203     fn new_reply_header<T: Sized>(
204         &self,
205         req: &VhostUserMsgHeader<BackendReq>,
206     ) -> Result<VhostUserMsgHeader<BackendReq>> {
207         Ok(VhostUserMsgHeader::new(
208             req.get_code().map_err(|_| Error::InvalidMessage)?,
209             VhostUserHeaderFlag::REPLY.bits(),
210             mem::size_of::<T>() as u32,
211         ))
212     }
213 
send_reply( &mut self, req: &VhostUserMsgHeader<BackendReq>, res: &Result<u64>, ) -> Result<()>214     fn send_reply(
215         &mut self,
216         req: &VhostUserMsgHeader<BackendReq>,
217         res: &Result<u64>,
218     ) -> Result<()> {
219         let code = req.get_code().map_err(|_| Error::InvalidMessage)?;
220         if code == BackendReq::SHMEM_MAP
221             || code == BackendReq::SHMEM_UNMAP
222             || code == BackendReq::GPU_MAP
223             || code == BackendReq::EXTERNAL_MAP
224             || (self.reply_ack_negotiated && req.is_need_reply())
225         {
226             let hdr = self.new_reply_header::<VhostUserU64>(req)?;
227             let def_err = libc::EINVAL;
228             let val = match res {
229                 Ok(n) => *n,
230                 Err(e) => match e {
231                     Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() {
232                         Some(rawerr) => -rawerr as u64,
233                         None => -def_err as u64,
234                     },
235                     _ => -def_err as u64,
236                 },
237             };
238             let msg = VhostUserU64::new(val);
239             self.sub_sock.send_message(&hdr, &msg, None)?;
240         }
241         Ok(())
242     }
243 }
244