• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::fs::File;
5 use std::mem;
6 
7 use base::error;
8 use base::AsRawDescriptor;
9 use base::RawDescriptor;
10 use zerocopy::AsBytes;
11 use zerocopy::FromBytes;
12 use zerocopy::Ref;
13 
14 use crate::into_single_file;
15 use crate::message::*;
16 use crate::to_system_stream;
17 use crate::BackendReq;
18 use crate::Connection;
19 use crate::Error;
20 use crate::FrontendReq;
21 use crate::Result;
22 use crate::SystemStream;
23 
24 /// Trait for vhost-user backends.
25 ///
26 /// Each method corresponds to a vhost-user protocol method. See the specification for details.
27 #[allow(missing_docs)]
28 pub trait Backend {
set_owner(&mut self) -> Result<()>29     fn set_owner(&mut self) -> Result<()>;
reset_owner(&mut self) -> Result<()>30     fn reset_owner(&mut self) -> Result<()>;
get_features(&mut self) -> Result<u64>31     fn get_features(&mut self) -> Result<u64>;
set_features(&mut self, features: u64) -> Result<()>32     fn set_features(&mut self, features: u64) -> Result<()>;
set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>33     fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
set_vring_num(&mut self, index: u32, num: u32) -> Result<()>34     fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>35     fn set_vring_addr(
36         &mut self,
37         index: u32,
38         flags: VhostUserVringAddrFlags,
39         descriptor: u64,
40         used: u64,
41         available: u64,
42         log: u64,
43     ) -> Result<()>;
44     // TODO: b/331466964 - Argument type is wrong for packed queues.
set_vring_base(&mut self, index: u32, base: u32) -> Result<()>45     fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
46     // TODO: b/331466964 - Return type is wrong for packed queues.
get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>47     fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>48     fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>49     fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>50     fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
51 
get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>52     fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
set_protocol_features(&mut self, features: u64) -> Result<()>53     fn set_protocol_features(&mut self, features: u64) -> Result<()>;
get_queue_num(&mut self) -> Result<u64>54     fn get_queue_num(&mut self) -> Result<u64>;
set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>55     fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>56     fn get_config(
57         &mut self,
58         offset: u32,
59         size: u32,
60         flags: VhostUserConfigFlags,
61     ) -> Result<Vec<u8>>;
set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>62     fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>)63     fn set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>) {}
get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>64     fn get_inflight_fd(
65         &mut self,
66         inflight: &VhostUserInflight,
67     ) -> Result<(VhostUserInflight, File)>;
set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>68     fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
get_max_mem_slots(&mut self) -> Result<u64>69     fn get_max_mem_slots(&mut self) -> Result<u64>;
add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>70     fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>71     fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>72     fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>;
73     /// Request the device to sleep by stopping their workers. This should NOT be called if the
74     /// device is already asleep.
sleep(&mut self) -> Result<()>75     fn sleep(&mut self) -> Result<()>;
76     /// Request the device to wake up by starting up their workers. This should NOT be called if the
77     /// device is already awake.
wake(&mut self) -> Result<()>78     fn wake(&mut self) -> Result<()>;
snapshot(&mut self) -> Result<Vec<u8>>79     fn snapshot(&mut self) -> Result<Vec<u8>>;
restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> Result<()>80     fn restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> Result<()>;
81 }
82 
83 impl<T> Backend for T
84 where
85     T: AsMut<dyn Backend>,
86 {
set_owner(&mut self) -> Result<()>87     fn set_owner(&mut self) -> Result<()> {
88         self.as_mut().set_owner()
89     }
90 
reset_owner(&mut self) -> Result<()>91     fn reset_owner(&mut self) -> Result<()> {
92         self.as_mut().reset_owner()
93     }
94 
get_features(&mut self) -> Result<u64>95     fn get_features(&mut self) -> Result<u64> {
96         self.as_mut().get_features()
97     }
98 
set_features(&mut self, features: u64) -> Result<()>99     fn set_features(&mut self, features: u64) -> Result<()> {
100         self.as_mut().set_features(features)
101     }
102 
set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>103     fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
104         self.as_mut().set_mem_table(ctx, files)
105     }
106 
set_vring_num(&mut self, index: u32, num: u32) -> Result<()>107     fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
108         self.as_mut().set_vring_num(index, num)
109     }
110 
set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>111     fn set_vring_addr(
112         &mut self,
113         index: u32,
114         flags: VhostUserVringAddrFlags,
115         descriptor: u64,
116         used: u64,
117         available: u64,
118         log: u64,
119     ) -> Result<()> {
120         self.as_mut()
121             .set_vring_addr(index, flags, descriptor, used, available, log)
122     }
123 
set_vring_base(&mut self, index: u32, base: u32) -> Result<()>124     fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
125         self.as_mut().set_vring_base(index, base)
126     }
127 
get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>128     fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
129         self.as_mut().get_vring_base(index)
130     }
131 
set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>132     fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
133         self.as_mut().set_vring_kick(index, fd)
134     }
135 
set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>136     fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
137         self.as_mut().set_vring_call(index, fd)
138     }
139 
set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>140     fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
141         self.as_mut().set_vring_err(index, fd)
142     }
143 
get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>144     fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
145         self.as_mut().get_protocol_features()
146     }
147 
set_protocol_features(&mut self, features: u64) -> Result<()>148     fn set_protocol_features(&mut self, features: u64) -> Result<()> {
149         self.as_mut().set_protocol_features(features)
150     }
151 
get_queue_num(&mut self) -> Result<u64>152     fn get_queue_num(&mut self) -> Result<u64> {
153         self.as_mut().get_queue_num()
154     }
155 
set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>156     fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
157         self.as_mut().set_vring_enable(index, enable)
158     }
159 
get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>160     fn get_config(
161         &mut self,
162         offset: u32,
163         size: u32,
164         flags: VhostUserConfigFlags,
165     ) -> Result<Vec<u8>> {
166         self.as_mut().get_config(offset, size, flags)
167     }
168 
set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>169     fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
170         self.as_mut().set_config(offset, buf, flags)
171     }
172 
set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>)173     fn set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>) {
174         self.as_mut().set_backend_req_fd(vu_req)
175     }
176 
get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>177     fn get_inflight_fd(
178         &mut self,
179         inflight: &VhostUserInflight,
180     ) -> Result<(VhostUserInflight, File)> {
181         self.as_mut().get_inflight_fd(inflight)
182     }
183 
set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>184     fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> {
185         self.as_mut().set_inflight_fd(inflight, file)
186     }
187 
get_max_mem_slots(&mut self) -> Result<u64>188     fn get_max_mem_slots(&mut self) -> Result<u64> {
189         self.as_mut().get_max_mem_slots()
190     }
191 
add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>192     fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
193         self.as_mut().add_mem_region(region, fd)
194     }
195 
remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>196     fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
197         self.as_mut().remove_mem_region(region)
198     }
199 
get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>200     fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>> {
201         self.as_mut().get_shared_memory_regions()
202     }
203 
sleep(&mut self) -> Result<()>204     fn sleep(&mut self) -> Result<()> {
205         self.as_mut().sleep()
206     }
207 
wake(&mut self) -> Result<()>208     fn wake(&mut self) -> Result<()> {
209         self.as_mut().wake()
210     }
211 
snapshot(&mut self) -> Result<Vec<u8>>212     fn snapshot(&mut self) -> Result<Vec<u8>> {
213         self.as_mut().snapshot()
214     }
215 
restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> Result<()>216     fn restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> Result<()> {
217         self.as_mut().restore(data_bytes, queue_evts)
218     }
219 }
220 
221 /// Handles requests from a vhost-user connection by dispatching them to [[Backend]] methods.
222 pub struct BackendServer<S: Backend> {
223     /// Underlying connection for communication.
224     connection: Connection<FrontendReq>,
225     // the vhost-user backend device object
226     backend: S,
227 
228     virtio_features: u64,
229     acked_virtio_features: u64,
230     protocol_features: VhostUserProtocolFeatures,
231     acked_protocol_features: u64,
232 
233     /// Sending ack for messages without payload.
234     reply_ack_enabled: bool,
235 }
236 
237 impl<S: Backend> BackendServer<S> {
238     /// Create a backend server from a connected socket.
from_stream(socket: SystemStream, backend: S) -> Self239     pub fn from_stream(socket: SystemStream, backend: S) -> Self {
240         Self::new(Connection::from(socket), backend)
241     }
242 }
243 
244 impl<S: Backend> AsRef<S> for BackendServer<S> {
as_ref(&self) -> &S245     fn as_ref(&self) -> &S {
246         &self.backend
247     }
248 }
249 
250 impl<S: Backend> BackendServer<S> {
new(connection: Connection<FrontendReq>, backend: S) -> Self251     pub fn new(connection: Connection<FrontendReq>, backend: S) -> Self {
252         BackendServer {
253             connection,
254             backend,
255             virtio_features: 0,
256             acked_virtio_features: 0,
257             protocol_features: VhostUserProtocolFeatures::empty(),
258             acked_protocol_features: 0,
259             reply_ack_enabled: false,
260         }
261     }
262 
263     /// Receives and validates a vhost-user message header and optional files.
264     ///
265     /// Since the length of vhost-user messages are different among message types, regular
266     /// vhost-user messages are sent via an underlying communication channel in stream mode.
267     /// (e.g. `SOCK_STREAM` in UNIX)
268     /// So, the logic of receiving and handling a message consists of the following steps:
269     ///
270     /// 1. Receives a message header and optional attached file.
271     /// 2. Validates the message header.
272     /// 3. Check if optional payloads is expected.
273     /// 4. Wait for the optional payloads.
274     /// 5. Receives optional payloads.
275     /// 6. Processes the message.
276     ///
277     /// This method [`BackendServer::recv_header()`] is in charge of the step (1) and (2),
278     /// [`BackendServer::needs_wait_for_payload()`] is (3), and
279     /// [`BackendServer::process_message()`] is (5) and (6). We need to have the three method
280     /// separately for multi-platform supports; [`BackendServer::recv_header()`] and
281     /// [`BackendServer::process_message()`] need to be separated because the way of waiting for
282     /// incoming messages differs between Unix and Windows so it's the caller's responsibility to
283     /// wait before [`BackendServer::process_message()`].
284     ///
285     /// Note that some vhost-user protocol variant such as VVU doesn't assume stream mode. In this
286     /// case, a message header and its body are sent together so the step (4) is skipped. We handle
287     /// this case in [`BackendServer::needs_wait_for_payload()`].
288     ///
289     /// The following pseudo code describes how a caller should process incoming vhost-user
290     /// messages:
291     /// ```ignore
292     /// loop {
293     ///   // block until a message header comes.
294     ///   // The actual code differs, depending on platforms.
295     ///   connection.wait_readable().unwrap();
296     ///
297     ///   // (1) and (2)
298     ///   let (hdr, files) = backend_server.recv_header();
299     ///
300     ///   // (3)
301     ///   if backend_server.needs_wait_for_payload(&hdr) {
302     ///     // (4) block until a payload comes if needed.
303     ///     connection.wait_readable().unwrap();
304     ///   }
305     ///
306     ///   // (5) and (6)
307     ///   backend_server.process_message(&hdr, &files).unwrap();
308     /// }
309     /// ```
recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)>310     pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)> {
311         // The underlying communication channel is a Unix domain socket in
312         // stream mode, and recvmsg() is a little tricky here. To successfully
313         // receive attached file descriptors, we need to receive messages and
314         // corresponding attached file descriptors in this way:
315         // . recv messsage header and optional attached file
316         // . validate message header
317         // . recv optional message body and payload according size field in
318         //   message header
319         // . validate message body and optional payload
320         let (hdr, files) = match self.connection.recv_header() {
321             Ok((hdr, files)) => (hdr, files),
322             Err(Error::Disconnect) => {
323                 // If the client closed the connection before sending a header, this should be
324                 // handled as a legal exit.
325                 return Err(Error::ClientExit);
326             }
327             Err(e) => {
328                 return Err(e);
329             }
330         };
331 
332         self.check_attached_files(&hdr, &files)?;
333 
334         Ok((hdr, files))
335     }
336 
337     /// Returns whether the caller needs to wait for the incoming message before calling
338     /// [`BackendServer::process_message`].
339     ///
340     /// See [`BackendServer::recv_header`]'s doc comment for the usage.
needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool341     pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool {
342         // Since the vhost-user protocol uses stream mode, we need to wait until an additional
343         // payload is available if exists.
344         hdr.get_size() != 0
345     }
346 
347     /// Main entrance to request from the communication channel.
348     ///
349     /// Receive and handle one incoming request message from the frontend.
350     /// See [`BackendServer::recv_header`]'s doc comment for the usage.
351     ///
352     /// # Return:
353     /// * - `Ok(())`: one request was successfully handled.
354     /// * - `Err(ClientExit)`: the frontend closed the connection properly. This isn't an actual
355     ///   failure.
356     /// * - `Err(Disconnect)`: the connection was closed unexpectedly.
357     /// * - `Err(InvalidMessage)`: the vmm sent a illegal message.
358     /// * - other errors: failed to handle a request.
process_message( &mut self, hdr: VhostUserMsgHeader<FrontendReq>, files: Vec<File>, ) -> Result<()>359     pub fn process_message(
360         &mut self,
361         hdr: VhostUserMsgHeader<FrontendReq>,
362         files: Vec<File>,
363     ) -> Result<()> {
364         let buf = self.connection.recv_body_bytes(&hdr)?;
365         let size = buf.len();
366 
367         match hdr.get_code() {
368             Ok(FrontendReq::SET_OWNER) => {
369                 self.check_request_size(&hdr, size, 0)?;
370                 let res = self.backend.set_owner();
371                 self.send_ack_message(&hdr, res.is_ok())?;
372                 res?;
373             }
374             Ok(FrontendReq::RESET_OWNER) => {
375                 self.check_request_size(&hdr, size, 0)?;
376                 let res = self.backend.reset_owner();
377                 self.send_ack_message(&hdr, res.is_ok())?;
378                 res?;
379             }
380             Ok(FrontendReq::GET_FEATURES) => {
381                 self.check_request_size(&hdr, size, 0)?;
382                 let mut features = self.backend.get_features()?;
383 
384                 // Don't advertise packed queues even if the device does. We don't handle them
385                 // properly yet at the protocol layer.
386                 // TODO: b/331466964 - Remove once support is added.
387                 features &= !(1 << VIRTIO_F_RING_PACKED);
388 
389                 let msg = VhostUserU64::new(features);
390                 self.send_reply_message(&hdr, &msg)?;
391                 self.virtio_features = features;
392                 self.update_reply_ack_flag();
393             }
394             Ok(FrontendReq::SET_FEATURES) => {
395                 let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
396 
397                 // Don't allow packed queues even if the device does. We don't handle them
398                 // properly yet at the protocol layer.
399                 // TODO: b/331466964 - Remove once support is added.
400                 msg.value &= !(1 << VIRTIO_F_RING_PACKED);
401 
402                 let res = self.backend.set_features(msg.value);
403                 self.acked_virtio_features = msg.value;
404                 self.update_reply_ack_flag();
405                 self.send_ack_message(&hdr, res.is_ok())?;
406                 res?;
407             }
408             Ok(FrontendReq::SET_MEM_TABLE) => {
409                 let res = self.set_mem_table(&hdr, size, &buf, files);
410                 self.send_ack_message(&hdr, res.is_ok())?;
411                 res?;
412             }
413             Ok(FrontendReq::SET_VRING_NUM) => {
414                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
415                 let res = self.backend.set_vring_num(msg.index, msg.num);
416                 self.send_ack_message(&hdr, res.is_ok())?;
417                 res?;
418             }
419             Ok(FrontendReq::SET_VRING_ADDR) => {
420                 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
421                 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
422                     Some(val) => val,
423                     None => return Err(Error::InvalidMessage),
424                 };
425                 let res = self.backend.set_vring_addr(
426                     msg.index,
427                     flags,
428                     msg.descriptor,
429                     msg.used,
430                     msg.available,
431                     msg.log,
432                 );
433                 self.send_ack_message(&hdr, res.is_ok())?;
434                 res?;
435             }
436             Ok(FrontendReq::SET_VRING_BASE) => {
437                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
438                 let res = self.backend.set_vring_base(msg.index, msg.num);
439                 self.send_ack_message(&hdr, res.is_ok())?;
440                 res?;
441             }
442             Ok(FrontendReq::GET_VRING_BASE) => {
443                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
444                 let reply = self.backend.get_vring_base(msg.index)?;
445                 self.send_reply_message(&hdr, &reply)?;
446             }
447             Ok(FrontendReq::SET_VRING_CALL) => {
448                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
449                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
450                 let res = self.backend.set_vring_call(index, file);
451                 self.send_ack_message(&hdr, res.is_ok())?;
452                 res?;
453             }
454             Ok(FrontendReq::SET_VRING_KICK) => {
455                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
456                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
457                 let res = self.backend.set_vring_kick(index, file);
458                 self.send_ack_message(&hdr, res.is_ok())?;
459                 res?;
460             }
461             Ok(FrontendReq::SET_VRING_ERR) => {
462                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
463                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
464                 let res = self.backend.set_vring_err(index, file);
465                 self.send_ack_message(&hdr, res.is_ok())?;
466                 res?;
467             }
468             Ok(FrontendReq::GET_PROTOCOL_FEATURES) => {
469                 self.check_request_size(&hdr, size, 0)?;
470                 let features = self.backend.get_protocol_features()?;
471                 let msg = VhostUserU64::new(features.bits());
472                 self.send_reply_message(&hdr, &msg)?;
473                 self.protocol_features = features;
474                 self.update_reply_ack_flag();
475             }
476             Ok(FrontendReq::SET_PROTOCOL_FEATURES) => {
477                 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
478                 let res = self.backend.set_protocol_features(msg.value);
479                 self.acked_protocol_features = msg.value;
480                 self.update_reply_ack_flag();
481                 self.send_ack_message(&hdr, res.is_ok())?;
482                 res?;
483             }
484             Ok(FrontendReq::GET_QUEUE_NUM) => {
485                 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
486                     return Err(Error::InvalidOperation);
487                 }
488                 self.check_request_size(&hdr, size, 0)?;
489                 let num = self.backend.get_queue_num()?;
490                 let msg = VhostUserU64::new(num);
491                 self.send_reply_message(&hdr, &msg)?;
492             }
493             Ok(FrontendReq::SET_VRING_ENABLE) => {
494                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
495                 if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
496                     return Err(Error::InvalidOperation);
497                 }
498                 let enable = match msg.num {
499                     1 => true,
500                     0 => false,
501                     _ => return Err(Error::InvalidParam),
502                 };
503 
504                 let res = self.backend.set_vring_enable(msg.index, enable);
505                 self.send_ack_message(&hdr, res.is_ok())?;
506                 res?;
507             }
508             Ok(FrontendReq::GET_CONFIG) => {
509                 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
510                     return Err(Error::InvalidOperation);
511                 }
512                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
513                 self.get_config(&hdr, &buf)?;
514             }
515             Ok(FrontendReq::SET_CONFIG) => {
516                 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
517                     return Err(Error::InvalidOperation);
518                 }
519                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
520                 let res = self.set_config(&buf);
521                 self.send_ack_message(&hdr, res.is_ok())?;
522                 res?;
523             }
524             Ok(FrontendReq::SET_BACKEND_REQ_FD) => {
525                 if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0
526                 {
527                     return Err(Error::InvalidOperation);
528                 }
529                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
530                 let res = self.set_backend_req_fd(files);
531                 self.send_ack_message(&hdr, res.is_ok())?;
532                 res?;
533             }
534             Ok(FrontendReq::GET_INFLIGHT_FD) => {
535                 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
536                     == 0
537                 {
538                     return Err(Error::InvalidOperation);
539                 }
540 
541                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
542                 let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
543                 let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
544                 self.connection.send_message(
545                     &reply_hdr,
546                     &inflight,
547                     Some(&[file.as_raw_descriptor()]),
548                 )?;
549             }
550             Ok(FrontendReq::SET_INFLIGHT_FD) => {
551                 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
552                     == 0
553                 {
554                     return Err(Error::InvalidOperation);
555                 }
556                 let file = into_single_file(files).ok_or(Error::IncorrectFds)?;
557                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
558                 let res = self.backend.set_inflight_fd(&msg, file);
559                 self.send_ack_message(&hdr, res.is_ok())?;
560                 res?;
561             }
562             Ok(FrontendReq::GET_MAX_MEM_SLOTS) => {
563                 if self.acked_protocol_features
564                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
565                     == 0
566                 {
567                     return Err(Error::InvalidOperation);
568                 }
569                 self.check_request_size(&hdr, size, 0)?;
570                 let num = self.backend.get_max_mem_slots()?;
571                 let msg = VhostUserU64::new(num);
572                 self.send_reply_message(&hdr, &msg)?;
573             }
574             Ok(FrontendReq::ADD_MEM_REG) => {
575                 if self.acked_protocol_features
576                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
577                     == 0
578                 {
579                     return Err(Error::InvalidOperation);
580                 }
581                 let file = into_single_file(files).ok_or(Error::InvalidParam)?;
582                 let msg =
583                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
584                 let res = self.backend.add_mem_region(&msg, file);
585                 self.send_ack_message(&hdr, res.is_ok())?;
586                 res?;
587             }
588             Ok(FrontendReq::REM_MEM_REG) => {
589                 if self.acked_protocol_features
590                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
591                     == 0
592                 {
593                     return Err(Error::InvalidOperation);
594                 }
595 
596                 let msg =
597                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
598                 let res = self.backend.remove_mem_region(&msg);
599                 self.send_ack_message(&hdr, res.is_ok())?;
600                 res?;
601             }
602             Ok(FrontendReq::GET_SHARED_MEMORY_REGIONS) => {
603                 let regions = self.backend.get_shared_memory_regions()?;
604                 let mut buf = Vec::new();
605                 let msg = VhostUserU64::new(regions.len() as u64);
606                 for r in regions {
607                     buf.extend_from_slice(r.as_bytes())
608                 }
609                 self.send_reply_with_payload(&hdr, &msg, buf.as_slice())?;
610             }
611             Ok(FrontendReq::SLEEP) => {
612                 let res = self.backend.sleep();
613                 let msg = VhostUserSuccess::new(res.is_ok());
614                 self.send_reply_message(&hdr, &msg)?;
615             }
616             Ok(FrontendReq::WAKE) => {
617                 let res = self.backend.wake();
618                 let msg = VhostUserSuccess::new(res.is_ok());
619                 self.send_reply_message(&hdr, &msg)?;
620             }
621             Ok(FrontendReq::SNAPSHOT) => {
622                 let (success_msg, payload) = match self.backend.snapshot() {
623                     Ok(snapshot_payload) => (VhostUserSuccess::new(true), snapshot_payload),
624                     Err(e) => {
625                         error!("Failed to snapshot: {}", e);
626                         (VhostUserSuccess::new(false), Vec::new())
627                     }
628                 };
629                 self.send_reply_with_payload(&hdr, &success_msg, payload.as_slice())?;
630             }
631             Ok(FrontendReq::RESTORE) => {
632                 let res = self.backend.restore(buf.as_slice(), files);
633                 let msg = VhostUserSuccess::new(res.is_ok());
634                 self.send_reply_message(&hdr, &msg)?;
635             }
636             _ => {
637                 return Err(Error::InvalidMessage);
638             }
639         }
640         Ok(())
641     }
642 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<FrontendReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<FrontendReq>>643     fn new_reply_header<T: Sized>(
644         &self,
645         req: &VhostUserMsgHeader<FrontendReq>,
646         payload_size: usize,
647     ) -> Result<VhostUserMsgHeader<FrontendReq>> {
648         Ok(VhostUserMsgHeader::new(
649             req.get_code().map_err(|_| Error::InvalidMessage)?,
650             VhostUserHeaderFlag::REPLY.bits(),
651             (mem::size_of::<T>()
652                 .checked_add(payload_size)
653                 .ok_or(Error::OversizedMsg)?)
654             .try_into()
655             .map_err(Error::InvalidCastToInt)?,
656         ))
657     }
658 
659     /// Sends reply back to Vhost frontend in response to a message.
send_ack_message( &mut self, req: &VhostUserMsgHeader<FrontendReq>, success: bool, ) -> Result<()>660     fn send_ack_message(
661         &mut self,
662         req: &VhostUserMsgHeader<FrontendReq>,
663         success: bool,
664     ) -> Result<()> {
665         if self.reply_ack_enabled && req.is_need_reply() {
666             let hdr: VhostUserMsgHeader<FrontendReq> =
667                 self.new_reply_header::<VhostUserU64>(req, 0)?;
668             let val = if success { 0 } else { 1 };
669             let msg = VhostUserU64::new(val);
670             self.connection.send_message(&hdr, &msg, None)?;
671         }
672         Ok(())
673     }
674 
send_reply_message<T: Sized + AsBytes>( &mut self, req: &VhostUserMsgHeader<FrontendReq>, msg: &T, ) -> Result<()>675     fn send_reply_message<T: Sized + AsBytes>(
676         &mut self,
677         req: &VhostUserMsgHeader<FrontendReq>,
678         msg: &T,
679     ) -> Result<()> {
680         let hdr = self.new_reply_header::<T>(req, 0)?;
681         self.connection.send_message(&hdr, msg, None)?;
682         Ok(())
683     }
684 
send_reply_with_payload<T: Sized + AsBytes>( &mut self, req: &VhostUserMsgHeader<FrontendReq>, msg: &T, payload: &[u8], ) -> Result<()>685     fn send_reply_with_payload<T: Sized + AsBytes>(
686         &mut self,
687         req: &VhostUserMsgHeader<FrontendReq>,
688         msg: &T,
689         payload: &[u8],
690     ) -> Result<()> {
691         let hdr = self.new_reply_header::<T>(req, payload.len())?;
692         self.connection
693             .send_message_with_payload(&hdr, msg, payload, None)?;
694         Ok(())
695     }
696 
set_mem_table( &mut self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, buf: &[u8], files: Vec<File>, ) -> Result<()>697     fn set_mem_table(
698         &mut self,
699         hdr: &VhostUserMsgHeader<FrontendReq>,
700         size: usize,
701         buf: &[u8],
702         files: Vec<File>,
703     ) -> Result<()> {
704         self.check_request_size(hdr, size, hdr.get_size() as usize)?;
705 
706         let (msg, regions) =
707             Ref::<_, VhostUserMemory>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?;
708         if !msg.is_valid() {
709             return Err(Error::InvalidMessage);
710         }
711 
712         // validate number of fds matching number of memory regions
713         if files.len() != msg.num_regions as usize {
714             return Err(Error::InvalidMessage);
715         }
716 
717         let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::new_slice_from_prefix(
718             regions,
719             msg.num_regions as usize,
720         )
721         .ok_or(Error::InvalidMessage)?;
722         if !excess.is_empty() {
723             return Err(Error::InvalidMessage);
724         }
725 
726         // Validate memory regions
727         for region in regions.iter() {
728             if !region.is_valid() {
729                 return Err(Error::InvalidMessage);
730             }
731         }
732 
733         self.backend.set_mem_table(&regions, files)
734     }
735 
get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()>736     fn get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()> {
737         let (msg, payload) =
738             Ref::<_, VhostUserConfig>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?;
739         if !msg.is_valid() {
740             return Err(Error::InvalidMessage);
741         }
742         if payload.len() != msg.size as usize {
743             return Err(Error::InvalidMessage);
744         }
745         let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
746             Some(val) => val,
747             None => return Err(Error::InvalidMessage),
748         };
749         let res = self.backend.get_config(msg.offset, msg.size, flags);
750 
751         // The response payload size MUST match the request payload size on success. A zero length
752         // response is used to indicate an error.
753         match res {
754             Ok(ref buf) if buf.len() == msg.size as usize => {
755                 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
756                 self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
757             }
758             Ok(_) => {
759                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
760                 self.send_reply_message(hdr, &reply)?;
761             }
762             Err(_) => {
763                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
764                 self.send_reply_message(hdr, &reply)?;
765             }
766         }
767         Ok(())
768     }
769 
set_config(&mut self, buf: &[u8]) -> Result<()>770     fn set_config(&mut self, buf: &[u8]) -> Result<()> {
771         let (msg, payload) =
772             Ref::<_, VhostUserConfig>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?;
773         if !msg.is_valid() {
774             return Err(Error::InvalidMessage);
775         }
776         if payload.len() != msg.size as usize {
777             return Err(Error::InvalidMessage);
778         }
779         let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
780             Some(val) => val,
781             None => return Err(Error::InvalidMessage),
782         };
783 
784         self.backend.set_config(msg.offset, payload, flags)
785     }
786 
set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()>787     fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> {
788         let file = into_single_file(files).ok_or(Error::InvalidMessage)?;
789         let fd = file.into();
790         // SAFETY: Safe because the protocol promises the file represents the appropriate file type
791         // for the platform.
792         let stream = unsafe { to_system_stream(fd) }?;
793         self.backend.set_backend_req_fd(Connection::from(stream));
794         Ok(())
795     }
796 
797     /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a
798     /// Vring number and an fd.
handle_vring_fd_request( &mut self, buf: &[u8], files: Vec<File>, ) -> Result<(u8, Option<File>)>799     fn handle_vring_fd_request(
800         &mut self,
801         buf: &[u8],
802         files: Vec<File>,
803     ) -> Result<(u8, Option<File>)> {
804         let msg = VhostUserU64::read_from_prefix(buf).ok_or(Error::InvalidMessage)?;
805         if !msg.is_valid() {
806             return Err(Error::InvalidMessage);
807         }
808 
809         // Bits (0-7) of the payload contain the vring index. Bit 8 is the
810         // invalid FD flag (VHOST_USER_VRING_NOFD_MASK).
811         // This bit is set when there is no file descriptor
812         // in the ancillary data. This signals that polling will be used
813         // instead of waiting for the call.
814         // If Bit 8 is unset, the data must contain a file descriptor.
815         let has_fd = (msg.value & 0x100u64) == 0;
816 
817         let file = into_single_file(files);
818 
819         if has_fd && file.is_none() || !has_fd && file.is_some() {
820             return Err(Error::InvalidMessage);
821         }
822 
823         Ok((msg.value as u8, file))
824     }
825 
check_request_size( &self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, expected: usize, ) -> Result<()>826     fn check_request_size(
827         &self,
828         hdr: &VhostUserMsgHeader<FrontendReq>,
829         size: usize,
830         expected: usize,
831     ) -> Result<()> {
832         if hdr.get_size() as usize != expected
833             || hdr.is_reply()
834             || hdr.get_version() != 0x1
835             || size != expected
836         {
837             return Err(Error::InvalidMessage);
838         }
839         Ok(())
840     }
841 
check_attached_files( &self, hdr: &VhostUserMsgHeader<FrontendReq>, files: &[File], ) -> Result<()>842     fn check_attached_files(
843         &self,
844         hdr: &VhostUserMsgHeader<FrontendReq>,
845         files: &[File],
846     ) -> Result<()> {
847         match hdr.get_code() {
848             Ok(FrontendReq::SET_MEM_TABLE)
849             | Ok(FrontendReq::SET_VRING_CALL)
850             | Ok(FrontendReq::SET_VRING_KICK)
851             | Ok(FrontendReq::SET_VRING_ERR)
852             | Ok(FrontendReq::SET_LOG_BASE)
853             | Ok(FrontendReq::SET_LOG_FD)
854             | Ok(FrontendReq::SET_BACKEND_REQ_FD)
855             | Ok(FrontendReq::SET_INFLIGHT_FD)
856             | Ok(FrontendReq::RESTORE)
857             | Ok(FrontendReq::ADD_MEM_REG) => Ok(()),
858             Err(_) => Err(Error::InvalidMessage),
859             _ if !files.is_empty() => Err(Error::InvalidMessage),
860             _ => Ok(()),
861         }
862     }
863 
extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, buf: &[u8], ) -> Result<T>864     fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>(
865         &self,
866         hdr: &VhostUserMsgHeader<FrontendReq>,
867         size: usize,
868         buf: &[u8],
869     ) -> Result<T> {
870         self.check_request_size(hdr, size, mem::size_of::<T>())?;
871         T::read_from_prefix(buf)
872             .filter(T::is_valid)
873             .map_or(Err(Error::InvalidMessage), Ok)
874     }
875 
update_reply_ack_flag(&mut self)876     fn update_reply_ack_flag(&mut self) {
877         let pflag = VhostUserProtocolFeatures::REPLY_ACK;
878         self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0
879             && self.protocol_features.contains(pflag)
880             && (self.acked_protocol_features & pflag.bits()) != 0;
881     }
882 }
883 
884 impl<S: Backend> AsRawDescriptor for BackendServer<S> {
as_raw_descriptor(&self) -> RawDescriptor885     fn as_raw_descriptor(&self) -> RawDescriptor {
886         // TODO(b/221882601): figure out if this used for polling.
887         self.connection.as_raw_descriptor()
888     }
889 }
890 
891 #[cfg(test)]
892 mod tests {
893     use base::INVALID_DESCRIPTOR;
894 
895     use super::*;
896     use crate::test_backend::TestBackend;
897     use crate::Connection;
898     use crate::SystemStream;
899 
900     #[test]
test_backend_server_new()901     fn test_backend_server_new() {
902         let (p1, _p2) = SystemStream::pair().unwrap();
903         let connection = Connection::from(p1);
904         let backend = TestBackend::new();
905         let handler = BackendServer::new(connection, backend);
906 
907         assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
908     }
909 }
910