• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::collections::BTreeMap;
6 
7 use base::error;
8 use base::Error as SysError;
9 use base::Event;
10 use base::EventToken;
11 use base::Tube;
12 use base::WaitContext;
13 use libc::EIO;
14 use serde::Deserialize;
15 use serde::Serialize;
16 use vhost::Vhost;
17 use vm_memory::GuestMemory;
18 
19 use super::control_socket::VhostDevRequest;
20 use super::control_socket::VhostDevResponse;
21 use super::Error;
22 use super::Result;
23 use crate::virtio::Interrupt;
24 use crate::virtio::Queue;
25 use crate::virtio::VIRTIO_F_ACCESS_PLATFORM;
26 
27 #[derive(Clone, Serialize, Deserialize)]
28 pub struct VringBase {
29     pub index: usize,
30     pub base: u16,
31 }
32 
33 /// Worker that takes care of running the vhost device.
34 pub struct Worker<T: Vhost> {
35     interrupt: Interrupt,
36     pub queues: BTreeMap<usize, Queue>,
37     pub vhost_handle: T,
38     pub vhost_interrupt: Vec<Event>,
39     acked_features: u64,
40     pub response_tube: Option<Tube>,
41 }
42 
43 impl<T: Vhost> Worker<T> {
new( queues: BTreeMap<usize, Queue>, vhost_handle: T, vhost_interrupt: Vec<Event>, interrupt: Interrupt, acked_features: u64, response_tube: Option<Tube>, ) -> Worker<T>44     pub fn new(
45         queues: BTreeMap<usize, Queue>,
46         vhost_handle: T,
47         vhost_interrupt: Vec<Event>,
48         interrupt: Interrupt,
49         acked_features: u64,
50         response_tube: Option<Tube>,
51     ) -> Worker<T> {
52         Worker {
53             interrupt,
54             queues,
55             vhost_handle,
56             vhost_interrupt,
57             acked_features,
58             response_tube,
59         }
60     }
61 
init<F1>( &mut self, mem: GuestMemory, queue_sizes: &[u16], activate_vqs: F1, queue_vrings_base: Option<Vec<VringBase>>, ) -> Result<()> where F1: FnOnce(&T) -> Result<()>,62     pub fn init<F1>(
63         &mut self,
64         mem: GuestMemory,
65         queue_sizes: &[u16],
66         activate_vqs: F1,
67         queue_vrings_base: Option<Vec<VringBase>>,
68     ) -> Result<()>
69     where
70         F1: FnOnce(&T) -> Result<()>,
71     {
72         let avail_features = self
73             .vhost_handle
74             .get_features()
75             .map_err(Error::VhostGetFeatures)?;
76 
77         let mut features = self.acked_features & avail_features;
78         if self.acked_features & (1u64 << VIRTIO_F_ACCESS_PLATFORM) != 0 {
79             // The vhost API is a bit poorly named, this flag in the context of vhost
80             // means that it will do address translation via its IOTLB APIs. If the
81             // underlying virtio device doesn't use viommu, it doesn't need vhost
82             // translation.
83             features &= !(1u64 << VIRTIO_F_ACCESS_PLATFORM);
84         }
85 
86         self.vhost_handle
87             .set_features(features)
88             .map_err(Error::VhostSetFeatures)?;
89 
90         self.vhost_handle
91             .set_mem_table(&mem)
92             .map_err(Error::VhostSetMemTable)?;
93 
94         for (&queue_index, queue) in self.queues.iter() {
95             self.vhost_handle
96                 .set_vring_num(queue_index, queue.size())
97                 .map_err(Error::VhostSetVringNum)?;
98 
99             self.vhost_handle
100                 .set_vring_addr(
101                     &mem,
102                     queue_sizes[queue_index],
103                     queue.size(),
104                     queue_index,
105                     0,
106                     queue.desc_table(),
107                     queue.used_ring(),
108                     queue.avail_ring(),
109                     None,
110                 )
111                 .map_err(Error::VhostSetVringAddr)?;
112             if let Some(vrings_base) = &queue_vrings_base {
113                 let base = if let Some(vring_base) = vrings_base
114                     .iter()
115                     .find(|vring_base| vring_base.index == queue_index)
116                 {
117                     vring_base.base
118                 } else {
119                     return Err(Error::VringBaseMissing);
120                 };
121                 self.vhost_handle
122                     .set_vring_base(queue_index, base)
123                     .map_err(Error::VhostSetVringBase)?;
124             } else {
125                 self.vhost_handle
126                     .set_vring_base(queue_index, 0)
127                     .map_err(Error::VhostSetVringBase)?;
128             }
129             self.set_vring_call_for_entry(queue_index, queue.vector() as usize)?;
130             self.vhost_handle
131                 .set_vring_kick(queue_index, queue.event())
132                 .map_err(Error::VhostSetVringKick)?;
133         }
134 
135         activate_vqs(&self.vhost_handle)?;
136         Ok(())
137     }
138 
run<F1>(&mut self, cleanup_vqs: F1, kill_evt: Event) -> Result<()> where F1: FnOnce(&T) -> Result<()>,139     pub fn run<F1>(&mut self, cleanup_vqs: F1, kill_evt: Event) -> Result<()>
140     where
141         F1: FnOnce(&T) -> Result<()>,
142     {
143         #[derive(EventToken)]
144         enum Token {
145             VhostIrqi { index: usize },
146             Kill,
147             ControlNotify,
148         }
149 
150         let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[(&kill_evt, Token::Kill)])
151             .map_err(Error::CreateWaitContext)?;
152 
153         for (index, vhost_int) in self.vhost_interrupt.iter().enumerate() {
154             wait_ctx
155                 .add(vhost_int, Token::VhostIrqi { index })
156                 .map_err(Error::CreateWaitContext)?;
157         }
158         if let Some(socket) = &self.response_tube {
159             wait_ctx
160                 .add(socket, Token::ControlNotify)
161                 .map_err(Error::CreateWaitContext)?;
162         }
163 
164         'wait: loop {
165             let events = wait_ctx.wait().map_err(Error::WaitError)?;
166 
167             for event in events.iter().filter(|e| e.is_readable) {
168                 match event.token {
169                     Token::VhostIrqi { index } => {
170                         self.vhost_interrupt[index]
171                             .wait()
172                             .map_err(Error::VhostIrqRead)?;
173                         self.interrupt
174                             .signal_used_queue(self.queues[&index].vector());
175                     }
176                     Token::Kill => {
177                         let _ = kill_evt.wait();
178                         break 'wait;
179                     }
180                     Token::ControlNotify => {
181                         if let Some(socket) = &self.response_tube {
182                             match socket.recv() {
183                                 Ok(VhostDevRequest::MsixEntryChanged(index)) => {
184                                     let mut qindex = 0;
185                                     for (&queue_index, queue) in self.queues.iter() {
186                                         if queue.vector() == index as u16 {
187                                             qindex = queue_index;
188                                             break;
189                                         }
190                                     }
191                                     let response =
192                                         match self.set_vring_call_for_entry(qindex, index) {
193                                             Ok(()) => VhostDevResponse::Ok,
194                                             Err(e) => {
195                                                 error!(
196                                                 "Set vring call failed for masked entry {}: {:?}",
197                                                 index, e
198                                             );
199                                                 VhostDevResponse::Err(SysError::new(EIO))
200                                             }
201                                         };
202                                     if let Err(e) = socket.send(&response) {
203                                         error!("Vhost failed to send VhostMsixEntryMasked Response for entry {}: {:?}", index, e);
204                                     }
205                                 }
206                                 Ok(VhostDevRequest::MsixChanged) => {
207                                     let response = match self.set_vring_calls() {
208                                         Ok(()) => VhostDevResponse::Ok,
209                                         Err(e) => {
210                                             error!("Set vring calls failed: {:?}", e);
211                                             VhostDevResponse::Err(SysError::new(EIO))
212                                         }
213                                     };
214                                     if let Err(e) = socket.send(&response) {
215                                         error!(
216                                             "Vhost failed to send VhostMsixMasked Response: {:?}",
217                                             e
218                                         );
219                                     }
220                                 }
221                                 Err(e) => {
222                                     error!("Vhost failed to receive Control request: {:?}", e);
223                                 }
224                             }
225                         }
226                     }
227                 }
228             }
229         }
230         cleanup_vqs(&self.vhost_handle)?;
231         Ok(())
232     }
233 
set_vring_call_for_entry(&self, queue_index: usize, vector: usize) -> Result<()>234     fn set_vring_call_for_entry(&self, queue_index: usize, vector: usize) -> Result<()> {
235         // No response_socket means it doesn't have any control related
236         // with the msix. Due to this, cannot use the direct irq fd but
237         // should fall back to indirect irq fd.
238         if self.response_tube.is_some() {
239             if let Some(msix_config) = self.interrupt.get_msix_config() {
240                 let msix_config = msix_config.lock();
241                 let msix_masked = msix_config.masked();
242                 if msix_masked {
243                     return Ok(());
244                 }
245                 if !msix_config.table_masked(vector) {
246                     if let Some(irqfd) = msix_config.get_irqfd(vector) {
247                         self.vhost_handle
248                             .set_vring_call(queue_index, irqfd)
249                             .map_err(Error::VhostSetVringCall)?;
250                     } else {
251                         self.vhost_handle
252                             .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
253                             .map_err(Error::VhostSetVringCall)?;
254                     }
255                     return Ok(());
256                 }
257             }
258         }
259 
260         self.vhost_handle
261             .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
262             .map_err(Error::VhostSetVringCall)?;
263         Ok(())
264     }
265 
set_vring_calls(&self) -> Result<()>266     fn set_vring_calls(&self) -> Result<()> {
267         if let Some(msix_config) = self.interrupt.get_msix_config() {
268             let msix_config = msix_config.lock();
269             if msix_config.masked() {
270                 for (&queue_index, _) in self.queues.iter() {
271                     self.vhost_handle
272                         .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
273                         .map_err(Error::VhostSetVringCall)?;
274                 }
275             } else {
276                 for (&queue_index, queue) in self.queues.iter() {
277                     let vector = queue.vector() as usize;
278                     if !msix_config.table_masked(vector) {
279                         if let Some(irqfd) = msix_config.get_irqfd(vector) {
280                             self.vhost_handle
281                                 .set_vring_call(queue_index, irqfd)
282                                 .map_err(Error::VhostSetVringCall)?;
283                         } else {
284                             self.vhost_handle
285                                 .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
286                                 .map_err(Error::VhostSetVringCall)?;
287                         }
288                     } else {
289                         self.vhost_handle
290                             .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
291                             .map_err(Error::VhostSetVringCall)?;
292                     }
293                 }
294             }
295         }
296         Ok(())
297     }
298 }
299