• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 mod sys;
6 
7 use std::collections::BTreeMap;
8 use std::collections::VecDeque;
9 use std::io::Write;
10 use std::sync::Arc;
11 
12 use anyhow::anyhow;
13 use anyhow::Context;
14 use balloon_control::BalloonStats;
15 use balloon_control::BalloonTubeCommand;
16 use balloon_control::BalloonTubeResult;
17 use balloon_control::BalloonWS;
18 use balloon_control::WSBucket;
19 use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
20 use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
21 use base::debug;
22 use base::error;
23 use base::warn;
24 use base::AsRawDescriptor;
25 use base::Event;
26 use base::RawDescriptor;
27 #[cfg(feature = "registered_events")]
28 use base::SendTube;
29 use base::Tube;
30 use base::WorkerThread;
31 use cros_async::block_on;
32 use cros_async::sync::RwLock as AsyncRwLock;
33 use cros_async::AsyncTube;
34 use cros_async::EventAsync;
35 use cros_async::Executor;
36 #[cfg(feature = "registered_events")]
37 use cros_async::SendTubeAsync;
38 use data_model::Le16;
39 use data_model::Le32;
40 use data_model::Le64;
41 use futures::channel::mpsc;
42 use futures::channel::oneshot;
43 use futures::pin_mut;
44 use futures::select;
45 use futures::select_biased;
46 use futures::FutureExt;
47 use futures::StreamExt;
48 use remain::sorted;
49 use serde::Deserialize;
50 use serde::Serialize;
51 use thiserror::Error as ThisError;
52 use vm_control::api::VmMemoryClient;
53 #[cfg(feature = "registered_events")]
54 use vm_control::RegisteredEventWithData;
55 use vm_memory::GuestAddress;
56 use vm_memory::GuestMemory;
57 use zerocopy::AsBytes;
58 use zerocopy::FromBytes;
59 use zerocopy::FromZeroes;
60 
61 use super::async_utils;
62 use super::copy_config;
63 use super::create_stop_oneshot;
64 use super::DescriptorChain;
65 use super::DeviceType;
66 use super::Interrupt;
67 use super::Queue;
68 use super::Reader;
69 use super::StoppedWorker;
70 use super::VirtioDevice;
71 use crate::UnpinRequest;
72 use crate::UnpinResponse;
73 
74 #[sorted]
75 #[derive(ThisError, Debug)]
76 pub enum BalloonError {
77     /// Failed an async await
78     #[error("failed async await: {0}")]
79     AsyncAwait(cros_async::AsyncError),
80     /// Failed an async await
81     #[error("failed async await: {0}")]
82     AsyncAwaitAnyhow(anyhow::Error),
83     /// Failed to create event.
84     #[error("failed to create event: {0}")]
85     CreatingEvent(base::Error),
86     /// Failed to create async message receiver.
87     #[error("failed to create async message receiver: {0}")]
88     CreatingMessageReceiver(base::TubeError),
89     /// Failed to receive command message.
90     #[error("failed to receive command message: {0}")]
91     ReceivingCommand(base::TubeError),
92     /// Failed to send command response.
93     #[error("failed to send command response: {0}")]
94     SendResponse(base::TubeError),
95     /// Error while writing to virtqueue
96     #[error("failed to write to virtqueue: {0}")]
97     WriteQueue(std::io::Error),
98     /// Failed to write config event.
99     #[error("failed to write config event: {0}")]
100     WritingConfigEvent(base::Error),
101 }
102 pub type Result<T> = std::result::Result<T, BalloonError>;
103 
104 // Balloon implements six virt IO queues: Inflate, Deflate, Stats, Event, WsData, WsCmd.
105 const QUEUE_SIZE: u16 = 128;
106 const QUEUE_SIZES: &[u16] = &[
107     QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE,
108 ];
109 
110 const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
111 const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
112 
113 // The feature bitmap for virtio balloon
114 const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; // Tell before reclaiming pages
115 const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; // Stats reporting enabled
116 const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; // Deflate balloon on OOM
117 const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; // Page reporting virtqueue
118                                                 // TODO(b/273973298): this should maybe be bit 6? to be changed later
119 const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; // Working Set Reporting virtqueues
120 
121 #[derive(Copy, Clone)]
122 #[repr(u32)]
123 // Balloon virtqueues
124 pub enum BalloonFeatures {
125     // Page Reporting enabled
126     PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
127     // WS Reporting enabled
128     WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
129 }
130 
131 // These feature bits are part of the proposal:
132 //  https://lists.oasis-open.org/archives/virtio-comment/202201/msg00139.html
133 const VIRTIO_BALLOON_F_RESPONSIVE_DEVICE: u32 = 6; // Device actively watching guest memory
134 const VIRTIO_BALLOON_F_EVENTS_VQ: u32 = 7; // Event vq is enabled
135 
136 // virtio_balloon_config is the balloon device configuration space defined by the virtio spec.
137 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
138 #[repr(C)]
139 struct virtio_balloon_config {
140     num_pages: Le32,
141     actual: Le32,
142     free_page_hint_cmd_id: Le32,
143     poison_val: Le32,
144     // WS field is part of proposed spec extension (b/273973298).
145     ws_num_bins: u8,
146     _reserved: [u8; 3],
147 }
148 
149 // BalloonState is shared by the worker and device thread.
150 #[derive(Clone, Default, Serialize, Deserialize)]
151 struct BalloonState {
152     num_pages: u32,
153     actual_pages: u32,
154     expecting_ws: bool,
155     // Flag indicating that the balloon is in the process of a failable update. This
156     // is set by an Adjust command that has allow_failure set, and is cleared when the
157     // Adjusted success/failure response is sent.
158     failable_update: bool,
159     pending_adjusted_responses: VecDeque<u32>,
160 }
161 
162 // The constants defining stats types in virtio_baloon_stat
163 const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
164 const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
165 const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
166 const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
167 const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
168 const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
169 const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
170 const VIRTIO_BALLOON_S_CACHES: u16 = 7;
171 const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
172 const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
173 const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
174 const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
175 
176 // BalloonStat is used to deserialize stats from the stats_queue.
177 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
178 #[repr(C, packed)]
179 struct BalloonStat {
180     tag: Le16,
181     val: Le64,
182 }
183 
184 impl BalloonStat {
update_stats(&self, stats: &mut BalloonStats)185     fn update_stats(&self, stats: &mut BalloonStats) {
186         let val = Some(self.val.to_native());
187         match self.tag.to_native() {
188             VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
189             VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
190             VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
191             VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
192             VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
193             VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
194             VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
195             VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
196             VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
197             VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
198             VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
199             VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
200             _ => (),
201         }
202     }
203 }
204 
205 const VIRTIO_BALLOON_EVENT_PRESSURE: u32 = 1;
206 const VIRTIO_BALLOON_EVENT_PUFF_FAILURE: u32 = 2;
207 
208 #[repr(C)]
209 #[derive(Copy, Clone, Default, AsBytes, FromZeroes, FromBytes)]
210 struct virtio_balloon_event_header {
211     evt_type: Le32,
212 }
213 
214 // virtio_balloon_ws is used to deserialize from the ws data vq.
215 #[repr(C)]
216 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
217 struct virtio_balloon_ws {
218     tag: Le16,
219     node_id: Le16,
220     // virtio prefers field members to align on a word boundary so we must pad. see:
221     // https://crsrc.org/o/src/third_party/kernel/v5.15/include/uapi/linux/virtio_balloon.h;l=105
222     _reserved: [u8; 4],
223     idle_age_ms: Le64,
224     // TODO(b/273973298): these should become separate fields - bytes for ANON and FILE
225     memory_size_bytes: [Le64; 2],
226 }
227 
228 impl virtio_balloon_ws {
update_ws(&self, ws: &mut BalloonWS)229     fn update_ws(&self, ws: &mut BalloonWS) {
230         let bucket = WSBucket {
231             age: self.idle_age_ms.to_native(),
232             bytes: [
233                 self.memory_size_bytes[0].to_native(),
234                 self.memory_size_bytes[1].to_native(),
235             ],
236         };
237         ws.ws.push(bucket);
238     }
239 }
240 
241 const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
242 const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
243 const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
244 const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
245 
246 // virtio_balloon_op is used to serialize to the ws cmd vq.
247 #[repr(C, packed)]
248 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
249 struct virtio_balloon_op {
250     type_: Le16,
251 }
252 
invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F) where F: FnMut(GuestAddress, u64),253 fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
254 where
255     F: FnMut(GuestAddress, u64),
256 {
257     for range in ranges {
258         desc_handler(GuestAddress(range.0), range.1);
259     }
260 }
261 
262 // Release a list of guest memory ranges back to the host system.
263 // Unpin requests for each inflate range will be sent via `release_memory_tube`
264 // if provided, and then `desc_handler` will be called for each inflate range.
release_ranges<F>( release_memory_tube: Option<&Tube>, inflate_ranges: Vec<(u64, u64)>, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),265 fn release_ranges<F>(
266     release_memory_tube: Option<&Tube>,
267     inflate_ranges: Vec<(u64, u64)>,
268     desc_handler: &mut F,
269 ) -> anyhow::Result<()>
270 where
271     F: FnMut(GuestAddress, u64),
272 {
273     if let Some(tube) = release_memory_tube {
274         let unpin_ranges = inflate_ranges
275             .iter()
276             .map(|v| {
277                 (
278                     v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
279                     v.1 / VIRTIO_BALLOON_PF_SIZE,
280                 )
281             })
282             .collect();
283         let req = UnpinRequest {
284             ranges: unpin_ranges,
285         };
286         if let Err(e) = tube.send(&req) {
287             error!("failed to send unpin request: {}", e);
288         } else {
289             match tube.recv() {
290                 Ok(resp) => match resp {
291                     UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
292                     UnpinResponse::Failed => error!("failed to handle unpin request"),
293                 },
294                 Err(e) => error!("failed to handle get unpin response: {}", e),
295             }
296         }
297     } else {
298         invoke_desc_handler(inflate_ranges, desc_handler);
299     }
300 
301     Ok(())
302 }
303 
304 // Processes one message's list of addresses.
handle_address_chain<F>( release_memory_tube: Option<&Tube>, avail_desc: &mut DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),305 fn handle_address_chain<F>(
306     release_memory_tube: Option<&Tube>,
307     avail_desc: &mut DescriptorChain,
308     desc_handler: &mut F,
309 ) -> anyhow::Result<()>
310 where
311     F: FnMut(GuestAddress, u64),
312 {
313     // In a long-running system, there is no reason to expect that
314     // a significant number of freed pages are consecutive. However,
315     // batching is relatively simple and can result in significant
316     // gains in a newly booted system, so it's worth attempting.
317     let mut range_start = 0;
318     let mut range_size = 0;
319     let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
320     for res in avail_desc.reader.iter::<Le32>() {
321         let pfn = match res {
322             Ok(pfn) => pfn,
323             Err(e) => {
324                 error!("error while reading unused pages: {}", e);
325                 break;
326             }
327         };
328         let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
329         if range_start + range_size == guest_address {
330             range_size += VIRTIO_BALLOON_PF_SIZE;
331         } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
332             range_start = guest_address;
333             range_size += VIRTIO_BALLOON_PF_SIZE;
334         } else {
335             // Discontinuity, so flush the previous range. Note range_size
336             // will be 0 on the first iteration, so skip that.
337             if range_size != 0 {
338                 inflate_ranges.push((range_start, range_size));
339             }
340             range_start = guest_address;
341             range_size = VIRTIO_BALLOON_PF_SIZE;
342         }
343     }
344     if range_size != 0 {
345         inflate_ranges.push((range_start, range_size));
346     }
347 
348     release_ranges(release_memory_tube, inflate_ranges, desc_handler)
349 }
350 
351 // Async task that handles the main balloon inflate and deflate queues.
handle_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, interrupt: Interrupt, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(GuestAddress, u64),352 async fn handle_queue<F>(
353     mut queue: Queue,
354     mut queue_event: EventAsync,
355     release_memory_tube: Option<&Tube>,
356     interrupt: Interrupt,
357     mut desc_handler: F,
358     mut stop_rx: oneshot::Receiver<()>,
359 ) -> Queue
360 where
361     F: FnMut(GuestAddress, u64),
362 {
363     loop {
364         let mut avail_desc = match queue
365             .next_async_interruptable(&mut queue_event, &mut stop_rx)
366             .await
367         {
368             Ok(Some(res)) => res,
369             Ok(None) => return queue,
370             Err(e) => {
371                 error!("Failed to read descriptor {}", e);
372                 return queue;
373             }
374         };
375         if let Err(e) =
376             handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
377         {
378             error!("balloon: failed to process inflate addresses: {}", e);
379         }
380         queue.add_used(avail_desc, 0);
381         queue.trigger_interrupt(&interrupt);
382     }
383 }
384 
385 // Processes one page-reporting descriptor.
handle_reported_buffer<F>( release_memory_tube: Option<&Tube>, avail_desc: &DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),386 fn handle_reported_buffer<F>(
387     release_memory_tube: Option<&Tube>,
388     avail_desc: &DescriptorChain,
389     desc_handler: &mut F,
390 ) -> anyhow::Result<()>
391 where
392     F: FnMut(GuestAddress, u64),
393 {
394     let reported_ranges: Vec<(u64, u64)> = avail_desc
395         .reader
396         .get_remaining_regions()
397         .chain(avail_desc.writer.get_remaining_regions())
398         .map(|r| (r.offset, r.len as u64))
399         .collect();
400 
401     release_ranges(release_memory_tube, reported_ranges, desc_handler)
402 }
403 
404 // Async task that handles the page reporting queue.
handle_reporting_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, interrupt: Interrupt, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(GuestAddress, u64),405 async fn handle_reporting_queue<F>(
406     mut queue: Queue,
407     mut queue_event: EventAsync,
408     release_memory_tube: Option<&Tube>,
409     interrupt: Interrupt,
410     mut desc_handler: F,
411     mut stop_rx: oneshot::Receiver<()>,
412 ) -> Queue
413 where
414     F: FnMut(GuestAddress, u64),
415 {
416     loop {
417         let avail_desc = match queue
418             .next_async_interruptable(&mut queue_event, &mut stop_rx)
419             .await
420         {
421             Ok(Some(res)) => res,
422             Ok(None) => return queue,
423             Err(e) => {
424                 error!("Failed to read descriptor {}", e);
425                 return queue;
426             }
427         };
428         if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
429         {
430             error!("balloon: failed to process reported buffer: {}", e);
431         }
432         queue.add_used(avail_desc, 0);
433         queue.trigger_interrupt(&interrupt);
434     }
435 }
436 
parse_balloon_stats(reader: &mut Reader) -> BalloonStats437 fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
438     let mut stats: BalloonStats = Default::default();
439     for res in reader.iter::<BalloonStat>() {
440         match res {
441             Ok(stat) => stat.update_stats(&mut stats),
442             Err(e) => {
443                 error!("error while reading stats: {}", e);
444                 break;
445             }
446         };
447     }
448     stats
449 }
450 
451 // Async task that handles the stats queue. Note that the cadence of this is driven by requests for
452 // balloon stats from the control pipe.
453 // The guests queues an initial buffer on boot, which is read and then this future will block until
454 // signaled from the command socket that stats should be collected again.
handle_stats_queue( mut queue: Queue, mut queue_event: EventAsync, mut stats_rx: mpsc::Receiver<()>, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, mut stop_rx: oneshot::Receiver<()>, ) -> Queue455 async fn handle_stats_queue(
456     mut queue: Queue,
457     mut queue_event: EventAsync,
458     mut stats_rx: mpsc::Receiver<()>,
459     command_tube: &AsyncTube,
460     #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
461     state: Arc<AsyncRwLock<BalloonState>>,
462     interrupt: Interrupt,
463     mut stop_rx: oneshot::Receiver<()>,
464 ) -> Queue {
465     let mut avail_desc = match queue
466         .next_async_interruptable(&mut queue_event, &mut stop_rx)
467         .await
468     {
469         // Consume the first stats buffer sent from the guest at startup. It was not
470         // requested by anyone, and the stats are stale.
471         Ok(Some(res)) => res,
472         Ok(None) => return queue,
473         Err(e) => {
474             error!("Failed to read descriptor {}", e);
475             return queue;
476         }
477     };
478 
479     loop {
480         select_biased! {
481             msg = stats_rx.next() => {
482                 // Wait for a request to read the stats.
483                 match msg {
484                     Some(()) => (),
485                     None => {
486                         error!("stats signal channel was closed");
487                         return queue;
488                     }
489                 }
490             }
491             _ = stop_rx => return queue,
492         };
493 
494         // Request a new stats_desc to the guest.
495         queue.add_used(avail_desc, 0);
496         queue.trigger_interrupt(&interrupt);
497 
498         avail_desc = match queue.next_async(&mut queue_event).await {
499             Err(e) => {
500                 error!("Failed to read descriptor {}", e);
501                 return queue;
502             }
503             Ok(d) => d,
504         };
505         let stats = parse_balloon_stats(&mut avail_desc.reader);
506 
507         let actual_pages = state.lock().await.actual_pages as u64;
508         let result = BalloonTubeResult::Stats {
509             balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
510             stats,
511         };
512         let send_result = command_tube.send(result).await;
513         if let Err(e) = send_result {
514             error!("failed to send stats result: {}", e);
515         }
516 
517         #[cfg(feature = "registered_events")]
518         if let Some(registered_evt_q) = registered_evt_q {
519             if let Err(e) = registered_evt_q
520                 .send(&RegisteredEventWithData::VirtioBalloonResize)
521                 .await
522             {
523                 error!("failed to send VirtioBalloonResize event: {}", e);
524             }
525         }
526     }
527 }
528 
send_adjusted_response( tube: &AsyncTube, num_pages: u32, ) -> std::result::Result<(), base::TubeError>529 async fn send_adjusted_response(
530     tube: &AsyncTube,
531     num_pages: u32,
532 ) -> std::result::Result<(), base::TubeError> {
533     let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
534     let result = BalloonTubeResult::Adjusted { num_bytes };
535     tube.send(result).await
536 }
537 
handle_event( state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, r: &mut Reader, command_tube: &AsyncTube, ) -> Result<()>538 async fn handle_event(
539     state: Arc<AsyncRwLock<BalloonState>>,
540     interrupt: Interrupt,
541     r: &mut Reader,
542     command_tube: &AsyncTube,
543 ) -> Result<()> {
544     match r.read_obj::<virtio_balloon_event_header>() {
545         Ok(hdr) => match hdr.evt_type.to_native() {
546             VIRTIO_BALLOON_EVENT_PRESSURE => {
547                 // TODO(b/213962590): See how this can be integrated this into memory rebalancing
548             }
549             VIRTIO_BALLOON_EVENT_PUFF_FAILURE => {
550                 let mut state = state.lock().await;
551                 if state.failable_update {
552                     state.num_pages = state.actual_pages;
553                     interrupt.signal_config_changed();
554 
555                     state.failable_update = false;
556                     send_adjusted_response(command_tube, state.actual_pages)
557                         .await
558                         .map_err(BalloonError::SendResponse)?;
559                 }
560             }
561             _ => {
562                 warn!("Unknown event {}", hdr.evt_type.to_native());
563             }
564         },
565         Err(e) => error!("Failed to parse event header {:?}", e),
566     }
567     Ok(())
568 }
569 
570 // Async task that handles the events queue.
handle_events_queue( mut queue: Queue, mut queue_event: EventAsync, state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, command_tube: &AsyncTube, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>571 async fn handle_events_queue(
572     mut queue: Queue,
573     mut queue_event: EventAsync,
574     state: Arc<AsyncRwLock<BalloonState>>,
575     interrupt: Interrupt,
576     command_tube: &AsyncTube,
577     mut stop_rx: oneshot::Receiver<()>,
578 ) -> Result<Queue> {
579     while let Some(mut avail_desc) = queue
580         .next_async_interruptable(&mut queue_event, &mut stop_rx)
581         .await
582         .map_err(BalloonError::AsyncAwait)?
583     {
584         handle_event(
585             state.clone(),
586             interrupt.clone(),
587             &mut avail_desc.reader,
588             command_tube,
589         )
590         .await?;
591 
592         queue.add_used(avail_desc, 0);
593         queue.trigger_interrupt(&interrupt);
594     }
595     Ok(queue)
596 }
597 
598 enum WSOp {
599     WSReport,
600     WSConfig {
601         bins: Vec<u32>,
602         refresh_threshold: u32,
603         report_threshold: u32,
604     },
605 }
606 
handle_ws_op_queue( mut queue: Queue, mut queue_event: EventAsync, mut ws_op_rx: mpsc::Receiver<WSOp>, state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>607 async fn handle_ws_op_queue(
608     mut queue: Queue,
609     mut queue_event: EventAsync,
610     mut ws_op_rx: mpsc::Receiver<WSOp>,
611     state: Arc<AsyncRwLock<BalloonState>>,
612     interrupt: Interrupt,
613     mut stop_rx: oneshot::Receiver<()>,
614 ) -> Result<Queue> {
615     loop {
616         let op = select_biased! {
617             next_op = ws_op_rx.next().fuse() => {
618                 match next_op {
619                     Some(op) => op,
620                     None => {
621                         error!("ws op tube was closed");
622                         break;
623                     }
624                 }
625             }
626             _ = stop_rx => {
627                 break;
628             }
629         };
630         let mut avail_desc = queue
631             .next_async(&mut queue_event)
632             .await
633             .map_err(BalloonError::AsyncAwait)?;
634         let writer = &mut avail_desc.writer;
635 
636         match op {
637             WSOp::WSReport => {
638                 {
639                     let mut state = state.lock().await;
640                     state.expecting_ws = true;
641                 }
642 
643                 let ws_r = virtio_balloon_op {
644                     type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
645                 };
646 
647                 writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
648             }
649             WSOp::WSConfig {
650                 bins,
651                 refresh_threshold,
652                 report_threshold,
653             } => {
654                 let cmd = virtio_balloon_op {
655                     type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
656                 };
657 
658                 writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
659                 writer
660                     .write_all(bins.as_bytes())
661                     .map_err(BalloonError::WriteQueue)?;
662                 writer
663                     .write_obj(refresh_threshold)
664                     .map_err(BalloonError::WriteQueue)?;
665                 writer
666                     .write_obj(report_threshold)
667                     .map_err(BalloonError::WriteQueue)?;
668             }
669         }
670 
671         let len = writer.bytes_written() as u32;
672         queue.add_used(avail_desc, len);
673         queue.trigger_interrupt(&interrupt);
674     }
675 
676     Ok(queue)
677 }
678 
parse_balloon_ws(reader: &mut Reader) -> BalloonWS679 fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
680     let mut ws = BalloonWS::new();
681     for res in reader.iter::<virtio_balloon_ws>() {
682         match res {
683             Ok(ws_msg) => {
684                 ws_msg.update_ws(&mut ws);
685             }
686             Err(e) => {
687                 error!("error while reading ws: {}", e);
688                 break;
689             }
690         }
691     }
692     if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
693     {
694         error!("unexpected number of WS buckets: {}", ws.ws.len());
695     }
696     ws
697 }
698 
699 // Async task that handles the stats queue. Note that the arrival of events on
700 // the WS vq may be the result of either a WS request (WS-R) command having
701 // been sent to the guest, or an unprompted send due to memory pressue in the
702 // guest. If the data was requested, we should also send that back on the
703 // command tube.
handle_ws_data_queue( mut queue: Queue, mut queue_event: EventAsync, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>704 async fn handle_ws_data_queue(
705     mut queue: Queue,
706     mut queue_event: EventAsync,
707     command_tube: &AsyncTube,
708     #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
709     state: Arc<AsyncRwLock<BalloonState>>,
710     interrupt: Interrupt,
711     mut stop_rx: oneshot::Receiver<()>,
712 ) -> Result<Queue> {
713     loop {
714         let mut avail_desc = match queue
715             .next_async_interruptable(&mut queue_event, &mut stop_rx)
716             .await
717             .map_err(BalloonError::AsyncAwait)?
718         {
719             Some(res) => res,
720             None => return Ok(queue),
721         };
722 
723         let ws = parse_balloon_ws(&mut avail_desc.reader);
724 
725         let mut state = state.lock().await;
726 
727         // update ws report with balloon pages now that we have a lock on state
728         let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
729 
730         if state.expecting_ws {
731             let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
732             let send_result = command_tube.send(result).await;
733             if let Err(e) = send_result {
734                 error!("failed to send ws result: {}", e);
735             }
736 
737             state.expecting_ws = false;
738         } else {
739             #[cfg(feature = "registered_events")]
740             if let Some(registered_evt_q) = registered_evt_q {
741                 if let Err(e) = registered_evt_q
742                     .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
743                     .await
744                 {
745                     error!("failed to send VirtioBalloonWSReport event: {}", e);
746                 }
747             }
748         }
749 
750         queue.add_used(avail_desc, 0);
751         queue.trigger_interrupt(&interrupt);
752     }
753 }
754 
755 // Async task that handles the command socket. The command socket handles messages from the host
756 // requesting that the guest balloon be adjusted or to report guest memory statistics.
handle_command_tube( command_tube: &AsyncTube, interrupt: Interrupt, state: Arc<AsyncRwLock<BalloonState>>, mut stats_tx: mpsc::Sender<()>, mut ws_op_tx: mpsc::Sender<WSOp>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<()>757 async fn handle_command_tube(
758     command_tube: &AsyncTube,
759     interrupt: Interrupt,
760     state: Arc<AsyncRwLock<BalloonState>>,
761     mut stats_tx: mpsc::Sender<()>,
762     mut ws_op_tx: mpsc::Sender<WSOp>,
763     mut stop_rx: oneshot::Receiver<()>,
764 ) -> Result<()> {
765     loop {
766         let cmd_res = select_biased! {
767             res = command_tube.next().fuse() => res,
768             _ = stop_rx => return Ok(())
769         };
770         match cmd_res {
771             Ok(command) => match command {
772                 BalloonTubeCommand::Adjust {
773                     num_bytes,
774                     allow_failure,
775                 } => {
776                     let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
777                     let mut state = state.lock().await;
778 
779                     state.num_pages = num_pages;
780                     interrupt.signal_config_changed();
781 
782                     if allow_failure {
783                         if num_pages == state.actual_pages {
784                             send_adjusted_response(command_tube, num_pages)
785                                 .await
786                                 .map_err(BalloonError::SendResponse)?;
787                         } else {
788                             state.failable_update = true;
789                         }
790                     }
791                 }
792                 BalloonTubeCommand::WorkingSetConfig {
793                     bins,
794                     refresh_threshold,
795                     report_threshold,
796                 } => {
797                     if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
798                         bins,
799                         refresh_threshold,
800                         report_threshold,
801                     }) {
802                         error!("failed to send config to ws handler: {}", e);
803                     }
804                 }
805                 BalloonTubeCommand::Stats => {
806                     if let Err(e) = stats_tx.try_send(()) {
807                         error!("failed to signal the stat handler: {}", e);
808                     }
809                 }
810                 BalloonTubeCommand::WorkingSet => {
811                     if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
812                         error!("failed to send report request to ws handler: {}", e);
813                     }
814                 }
815             },
816             #[cfg(windows)]
817             Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
818                 // On Windows, async IO tasks like the next/recv above are cancelled as the VM is
819                 // shutting down. For the sake of consistency with unix, we can't *just* return
820                 // here; instead, we wait for the stop request to arrive, *and then* return.
821                 //
822                 // The real fix is to get rid of the global unblock pool, since then we won't
823                 // cancel the tasks early (b/196911556).
824                 let _ = stop_rx.await;
825                 return Ok(());
826             }
827             Err(e) => {
828                 return Err(BalloonError::ReceivingCommand(e));
829             }
830         }
831     }
832 }
833 
handle_pending_adjusted_responses( pending_adjusted_response_event: EventAsync, command_tube: &AsyncTube, state: Arc<AsyncRwLock<BalloonState>>, ) -> Result<()>834 async fn handle_pending_adjusted_responses(
835     pending_adjusted_response_event: EventAsync,
836     command_tube: &AsyncTube,
837     state: Arc<AsyncRwLock<BalloonState>>,
838 ) -> Result<()> {
839     loop {
840         pending_adjusted_response_event
841             .next_val()
842             .await
843             .map_err(BalloonError::AsyncAwait)?;
844         while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
845             send_adjusted_response(command_tube, num_pages)
846                 .await
847                 .map_err(BalloonError::SendResponse)?;
848         }
849     }
850 }
851 
852 /// Represents queues & events for the balloon device.
853 struct BalloonQueues {
854     inflate: Queue,
855     deflate: Queue,
856     stats: Option<Queue>,
857     reporting: Option<Queue>,
858     events: Option<Queue>,
859     ws: (Option<Queue>, Option<Queue>),
860 }
861 
862 impl BalloonQueues {
new(inflate: Queue, deflate: Queue) -> Self863     fn new(inflate: Queue, deflate: Queue) -> Self {
864         BalloonQueues {
865             inflate,
866             deflate,
867             stats: None,
868             reporting: None,
869             events: None,
870             ws: (None, None),
871         }
872     }
873 }
874 
875 /// When the worker is stopped, the queues are preserved here.
876 struct PausedQueues {
877     inflate: Queue,
878     deflate: Queue,
879     stats: Option<Queue>,
880     reporting: Option<Queue>,
881     events: Option<Queue>,
882     ws: (Option<Queue>, Option<Queue>),
883 }
884 
885 impl PausedQueues {
new(inflate: Queue, deflate: Queue) -> Self886     fn new(inflate: Queue, deflate: Queue) -> Self {
887         PausedQueues {
888             inflate,
889             deflate,
890             stats: None,
891             reporting: None,
892             events: None,
893             ws: (None, None),
894         }
895     }
896 }
897 
apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F) where F: FnMut(Queue) -> R,898 fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
899 where
900     F: FnMut(Queue) -> R,
901 {
902     if let Some(queue) = queue_opt {
903         func(queue);
904     }
905 }
906 
907 impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue>908     fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
909         let mut ret = Vec::new();
910         ret.push(queues.inflate);
911         ret.push(queues.deflate);
912         apply_if_some(queues.stats, |stats| ret.push(stats));
913         apply_if_some(queues.reporting, |reporting| ret.push(reporting));
914         apply_if_some(queues.events, |events| ret.push(events));
915         apply_if_some(queues.ws.0, |ws_data| ret.push(ws_data));
916         apply_if_some(queues.ws.1, |ws_op| ret.push(ws_op));
917         // WARNING: We don't use the indices from the virito spec on purpose, see comment in
918         // get_queues_from_map for the rationale.
919         ret.into_iter().enumerate().collect()
920     }
921 }
922 
923 /// Stores data from the worker when it stops so that data can be re-used when
924 /// the worker is restarted.
925 struct WorkerReturn {
926     release_memory_tube: Option<Tube>,
927     command_tube: Tube,
928     #[cfg(feature = "registered_events")]
929     registered_evt_q: Option<SendTube>,
930     paused_queues: Option<PausedQueues>,
931     vm_memory_client: VmMemoryClient,
932 }
933 
934 // The main worker thread. Initialized the asynchronous worker tasks and passes them to the executor
935 // to be processed.
run_worker( inflate_queue: Queue, deflate_queue: Queue, stats_queue: Option<Queue>, reporting_queue: Option<Queue>, events_queue: Option<Queue>, ws_queues: (Option<Queue>, Option<Queue>), command_tube: Tube, vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, interrupt: Interrupt, kill_evt: Event, target_reached_evt: Event, pending_adjusted_response_event: Event, mem: GuestMemory, state: Arc<AsyncRwLock<BalloonState>>, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ) -> WorkerReturn936 fn run_worker(
937     inflate_queue: Queue,
938     deflate_queue: Queue,
939     stats_queue: Option<Queue>,
940     reporting_queue: Option<Queue>,
941     events_queue: Option<Queue>,
942     ws_queues: (Option<Queue>, Option<Queue>),
943     command_tube: Tube,
944     vm_memory_client: VmMemoryClient,
945     release_memory_tube: Option<Tube>,
946     interrupt: Interrupt,
947     kill_evt: Event,
948     target_reached_evt: Event,
949     pending_adjusted_response_event: Event,
950     mem: GuestMemory,
951     state: Arc<AsyncRwLock<BalloonState>>,
952     #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
953 ) -> WorkerReturn {
954     let ex = Executor::new().unwrap();
955     let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
956     #[cfg(feature = "registered_events")]
957     let registered_evt_q_async = registered_evt_q
958         .as_ref()
959         .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
960 
961     let mut stop_queue_oneshots = Vec::new();
962 
963     // We need a block to release all references to command_tube at the end before returning it.
964     let paused_queues = {
965         // The first queue is used for inflate messages
966         let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
967         let inflate_queue_evt = inflate_queue
968             .event()
969             .try_clone()
970             .expect("failed to clone queue event");
971         let inflate = handle_queue(
972             inflate_queue,
973             EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
974             release_memory_tube.as_ref(),
975             interrupt.clone(),
976             |guest_address, len| {
977                 sys::free_memory(
978                     &guest_address,
979                     len,
980                     &vm_memory_client,
981                 )
982             },
983             stop_rx,
984         );
985         let inflate = inflate.fuse();
986         pin_mut!(inflate);
987 
988         // The second queue is used for deflate messages
989         let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
990         let deflate_queue_evt = deflate_queue
991             .event()
992             .try_clone()
993             .expect("failed to clone queue event");
994         let deflate = handle_queue(
995             deflate_queue,
996             EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
997             None,
998             interrupt.clone(),
999             |guest_address, len| {
1000                 sys::reclaim_memory(
1001                     &guest_address,
1002                     len,
1003                     &vm_memory_client,
1004                 )
1005             },
1006             stop_rx,
1007         );
1008         let deflate = deflate.fuse();
1009         pin_mut!(deflate);
1010 
1011         // The next queue is used for stats messages if VIRTIO_BALLOON_F_STATS_VQ is negotiated.
1012         let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
1013         let has_stats_queue = stats_queue.is_some();
1014         let stats = if let Some(stats_queue) = stats_queue {
1015             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1016             let stats_queue_evt = stats_queue
1017                 .event()
1018                 .try_clone()
1019                 .expect("failed to clone queue event");
1020             handle_stats_queue(
1021                 stats_queue,
1022                 EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
1023                 stats_rx,
1024                 &command_tube,
1025                 #[cfg(feature = "registered_events")]
1026                 registered_evt_q_async.as_ref(),
1027                 state.clone(),
1028                 interrupt.clone(),
1029                 stop_rx,
1030             )
1031             .left_future()
1032         } else {
1033             std::future::pending().right_future()
1034         };
1035         let stats = stats.fuse();
1036         pin_mut!(stats);
1037 
1038         // The next queue is used for reporting messages
1039         let has_reporting_queue = reporting_queue.is_some();
1040         let reporting = if let Some(reporting_queue) = reporting_queue {
1041             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1042             let reporting_queue_evt = reporting_queue
1043                 .event()
1044                 .try_clone()
1045                 .expect("failed to clone queue event");
1046             handle_reporting_queue(
1047                 reporting_queue,
1048                 EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
1049                 release_memory_tube.as_ref(),
1050                 interrupt.clone(),
1051                 |guest_address, len| {
1052                     sys::free_memory(
1053                         &guest_address,
1054                         len,
1055                         &vm_memory_client,
1056                     )
1057                 },
1058                 stop_rx,
1059             )
1060             .left_future()
1061         } else {
1062             std::future::pending().right_future()
1063         };
1064         let reporting = reporting.fuse();
1065         pin_mut!(reporting);
1066 
1067         // If VIRTIO_BALLOON_F_WS_REPORTING is set 2 queues must handled - one for WS data and one
1068         // for WS notifications.
1069         let has_ws_data_queue = ws_queues.0.is_some();
1070         let ws_data = if let Some(ws_data_queue) = ws_queues.0 {
1071             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1072             let ws_data_queue_evt = ws_data_queue
1073                 .event()
1074                 .try_clone()
1075                 .expect("failed to clone queue event");
1076             handle_ws_data_queue(
1077                 ws_data_queue,
1078                 EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1079                 &command_tube,
1080                 #[cfg(feature = "registered_events")]
1081                 registered_evt_q_async.as_ref(),
1082                 state.clone(),
1083                 interrupt.clone(),
1084                 stop_rx,
1085             )
1086             .left_future()
1087         } else {
1088             std::future::pending().right_future()
1089         };
1090         let ws_data = ws_data.fuse();
1091         pin_mut!(ws_data);
1092 
1093         let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1094         let has_ws_op_queue = ws_queues.1.is_some();
1095         let ws_op = if let Some(ws_op_queue) = ws_queues.1 {
1096             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1097             let ws_op_queue_evt = ws_op_queue
1098                 .event()
1099                 .try_clone()
1100                 .expect("failed to clone queue event");
1101             handle_ws_op_queue(
1102                 ws_op_queue,
1103                 EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1104                 ws_op_rx,
1105                 state.clone(),
1106                 interrupt.clone(),
1107                 stop_rx,
1108             )
1109             .left_future()
1110         } else {
1111             std::future::pending().right_future()
1112         };
1113         let ws_op = ws_op.fuse();
1114         pin_mut!(ws_op);
1115 
1116         // Future to handle command messages that resize the balloon.
1117         let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1118         let command = handle_command_tube(
1119             &command_tube,
1120             interrupt.clone(),
1121             state.clone(),
1122             stats_tx,
1123             ws_op_tx,
1124             stop_rx,
1125         );
1126         pin_mut!(command);
1127 
1128         // Process any requests to resample the irq value.
1129         let resample = async_utils::handle_irq_resample(&ex, interrupt.clone());
1130         pin_mut!(resample);
1131 
1132         // Send a message if balloon target reached event is triggered.
1133         let target_reached = handle_target_reached(
1134             &ex,
1135             target_reached_evt,
1136             &vm_memory_client,
1137         );
1138         pin_mut!(target_reached);
1139 
1140         // Exit if the kill event is triggered.
1141         let kill = async_utils::await_and_exit(&ex, kill_evt);
1142         pin_mut!(kill);
1143 
1144         // The next queue is used for events if VIRTIO_BALLOON_F_EVENTS_VQ is negotiated.
1145         let has_events_queue = events_queue.is_some();
1146         let events = if let Some(events_queue) = events_queue {
1147             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1148             let events_queue_evt = events_queue
1149                 .event()
1150                 .try_clone()
1151                 .expect("failed to clone queue event");
1152             handle_events_queue(
1153                 events_queue,
1154                 EventAsync::new(events_queue_evt, &ex).expect("failed to create async event"),
1155                 state.clone(),
1156                 interrupt,
1157                 &command_tube,
1158                 stop_rx,
1159             )
1160             .left_future()
1161         } else {
1162             std::future::pending().right_future()
1163         };
1164         let events = events.fuse();
1165         pin_mut!(events);
1166 
1167         let pending_adjusted = handle_pending_adjusted_responses(
1168             EventAsync::new(pending_adjusted_response_event, &ex)
1169                 .expect("failed to create async event"),
1170             &command_tube,
1171             state,
1172         );
1173         pin_mut!(pending_adjusted);
1174 
1175         let res = ex.run_until(async {
1176             select! {
1177                 _ = kill.fuse() => (),
1178                 _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1179                 _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1180                 _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1181                 _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1182                 _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1183                 _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1184                 _ = resample.fuse() => return Err(anyhow!("resample stopped unexpectedly")),
1185                 _ = events => return Err(anyhow!("events stopped unexpectedly")),
1186                 _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1187                 _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1188                 _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1189             }
1190 
1191             // Worker is shutting down. To recover the queues, we have to signal
1192             // all the queue futures to exit.
1193             for stop_tx in stop_queue_oneshots {
1194                 if stop_tx.send(()).is_err() {
1195                     return Err(anyhow!("failed to request stop for queue future"));
1196                 }
1197             }
1198 
1199             // Collect all the queues (awaiting any queue future should now
1200             // return its Queue immediately).
1201             let mut paused_queues = PausedQueues::new(
1202                 inflate.await,
1203                 deflate.await,
1204             );
1205             if has_reporting_queue {
1206                 paused_queues.reporting = Some(reporting.await);
1207             }
1208             if has_events_queue {
1209                 paused_queues.events = Some(events.await.context("failed to stop events queue")?);
1210             }
1211             if has_stats_queue {
1212                 paused_queues.stats = Some(stats.await);
1213             }
1214             if has_ws_op_queue {
1215                 paused_queues.ws.0 = Some(ws_op.await.context("failed to stop ws_op queue")?);
1216             }
1217             if has_ws_data_queue {
1218                 paused_queues.ws.1 = Some(ws_data.await.context("failed to stop ws_data queue")?);
1219             }
1220             Ok(paused_queues)
1221         });
1222 
1223         match res {
1224             Err(e) => {
1225                 error!("error happened in executor: {}", e);
1226                 None
1227             }
1228             Ok(main_future_res) => match main_future_res {
1229                 Ok(paused_queues) => Some(paused_queues),
1230                 Err(e) => {
1231                     error!("error happened in main balloon future: {}", e);
1232                     None
1233                 }
1234             },
1235         }
1236     };
1237 
1238     WorkerReturn {
1239         command_tube: command_tube.into(),
1240         paused_queues,
1241         release_memory_tube,
1242         #[cfg(feature = "registered_events")]
1243         registered_evt_q,
1244         vm_memory_client,
1245     }
1246 }
1247 
handle_target_reached( ex: &Executor, target_reached_evt: Event, vm_memory_client: &VmMemoryClient, ) -> anyhow::Result<()>1248 async fn handle_target_reached(
1249     ex: &Executor,
1250     target_reached_evt: Event,
1251     vm_memory_client: &VmMemoryClient,
1252 ) -> anyhow::Result<()> {
1253     let event_async =
1254         EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1255     loop {
1256         // Wait for target reached trigger.
1257         let _ = event_async.next_val().await;
1258         // Send the message to vm_control on the event. We don't have to read the current
1259         // size yet.
1260         sys::balloon_target_reached(
1261             0,
1262             vm_memory_client,
1263         );
1264     }
1265     // The above loop will never terminate and there is no reason to terminate it either. However,
1266     // the function is used in an executor that expects a Result<> return. Make sure that clippy
1267     // doesn't enforce the unreachable_code condition.
1268     #[allow(unreachable_code)]
1269     Ok(())
1270 }
1271 
1272 /// Virtio device for memory balloon inflation/deflation.
1273 pub struct Balloon {
1274     command_tube: Option<Tube>,
1275     vm_memory_client: Option<VmMemoryClient>,
1276     release_memory_tube: Option<Tube>,
1277     pending_adjusted_response_event: Event,
1278     state: Arc<AsyncRwLock<BalloonState>>,
1279     features: u64,
1280     acked_features: u64,
1281     worker_thread: Option<WorkerThread<WorkerReturn>>,
1282     #[cfg(feature = "registered_events")]
1283     registered_evt_q: Option<SendTube>,
1284     ws_num_bins: u8,
1285     target_reached_evt: Option<Event>,
1286 }
1287 
1288 /// Snapshot of the [Balloon] state.
1289 #[derive(Serialize, Deserialize)]
1290 struct BalloonSnapshot {
1291     state: BalloonState,
1292     features: u64,
1293     acked_features: u64,
1294     ws_num_bins: u8,
1295 }
1296 
1297 /// Operation mode of the balloon.
1298 #[derive(PartialEq, Eq)]
1299 pub enum BalloonMode {
1300     /// The driver can access pages in the balloon (i.e. F_DEFLATE_ON_OOM)
1301     Relaxed,
1302     /// The driver cannot access pages in the balloon. Implies F_RESPONSIVE_DEVICE.
1303     Strict,
1304 }
1305 
1306 impl Balloon {
1307     /// Creates a new virtio balloon device.
1308     /// To let Balloon able to successfully release the memory which are pinned
1309     /// by CoIOMMU to host, the release_memory_tube will be used to send the inflate
1310     /// ranges to CoIOMMU with UnpinRequest/UnpinResponse messages, so that The
1311     /// memory in the inflate range can be unpinned first.
new( base_features: u64, command_tube: Tube, vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, init_balloon_size: u64, mode: BalloonMode, enabled_features: u64, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ws_num_bins: u8, ) -> Result<Balloon>1312     pub fn new(
1313         base_features: u64,
1314         command_tube: Tube,
1315         vm_memory_client: VmMemoryClient,
1316         release_memory_tube: Option<Tube>,
1317         init_balloon_size: u64,
1318         mode: BalloonMode,
1319         enabled_features: u64,
1320         #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1321         ws_num_bins: u8,
1322     ) -> Result<Balloon> {
1323         let features = base_features
1324             | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1325             | 1 << VIRTIO_BALLOON_F_STATS_VQ
1326             | 1 << VIRTIO_BALLOON_F_EVENTS_VQ
1327             | enabled_features
1328             | if mode == BalloonMode::Strict {
1329                 1 << VIRTIO_BALLOON_F_RESPONSIVE_DEVICE
1330             } else {
1331                 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1332             };
1333 
1334         Ok(Balloon {
1335             command_tube: Some(command_tube),
1336             vm_memory_client: Some(vm_memory_client),
1337             release_memory_tube,
1338             pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1339             state: Arc::new(AsyncRwLock::new(BalloonState {
1340                 num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1341                 actual_pages: 0,
1342                 failable_update: false,
1343                 pending_adjusted_responses: VecDeque::new(),
1344                 expecting_ws: false,
1345             })),
1346             worker_thread: None,
1347             features,
1348             acked_features: 0,
1349             #[cfg(feature = "registered_events")]
1350             registered_evt_q,
1351             ws_num_bins,
1352             target_reached_evt: None,
1353         })
1354     }
1355 
get_config(&self) -> virtio_balloon_config1356     fn get_config(&self) -> virtio_balloon_config {
1357         let state = block_on(self.state.lock());
1358         virtio_balloon_config {
1359             num_pages: state.num_pages.into(),
1360             actual: state.actual_pages.into(),
1361             // crosvm does not (currently) use free_page_hint_cmd_id or
1362             // poison_val, but they must be present in the right order and size
1363             // for the virtio-balloon driver in the guest to deserialize the
1364             // config correctly.
1365             free_page_hint_cmd_id: 0.into(),
1366             poison_val: 0.into(),
1367             ws_num_bins: self.ws_num_bins,
1368             _reserved: [0, 0, 0],
1369         }
1370     }
1371 
num_expected_queues(acked_features: u64) -> usize1372     fn num_expected_queues(acked_features: u64) -> usize {
1373         // at minimum we have inflate and deflate vqueues.
1374         let mut num_queues = 2;
1375         // stats vqueue
1376         if acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1377             num_queues += 1;
1378         }
1379         // events vqueue
1380         if acked_features & (1 << VIRTIO_BALLOON_F_EVENTS_VQ) != 0 {
1381             num_queues += 1;
1382         }
1383         // page reporting vqueue
1384         if acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1385             num_queues += 1;
1386         }
1387         // working set vqueues
1388         if acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1389             num_queues += 2;
1390         }
1391 
1392         num_queues
1393     }
1394 
stop_worker(&mut self) -> StoppedWorker<PausedQueues>1395     fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1396         if let Some(worker_thread) = self.worker_thread.take() {
1397             let worker_ret = worker_thread.stop();
1398             self.release_memory_tube = worker_ret.release_memory_tube;
1399             self.command_tube = Some(worker_ret.command_tube);
1400             #[cfg(feature = "registered_events")]
1401             {
1402                 self.registered_evt_q = worker_ret.registered_evt_q;
1403             }
1404             self.vm_memory_client = Some(worker_ret.vm_memory_client);
1405 
1406             if let Some(queues) = worker_ret.paused_queues {
1407                 StoppedWorker::WithQueues(Box::new(queues))
1408             } else {
1409                 StoppedWorker::MissingQueues
1410             }
1411         } else {
1412             StoppedWorker::AlreadyStopped
1413         }
1414     }
1415 
1416     /// Given a filtered queue vector from [VirtioDevice::activate], extract
1417     /// the queues (accounting for queues that are missing because the features
1418     /// are not negotiated) into a structure that is easier to work with.
get_queues_from_map( &self, mut queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<BalloonQueues>1419     fn get_queues_from_map(
1420         &self,
1421         mut queues: BTreeMap<usize, Queue>,
1422     ) -> anyhow::Result<BalloonQueues> {
1423         let expected_queues = Balloon::num_expected_queues(self.acked_features);
1424         if queues.len() != expected_queues {
1425             return Err(anyhow!(
1426                 "expected {} queues, got {}",
1427                 expected_queues,
1428                 queues.len()
1429             ));
1430         }
1431 
1432         // WARNING: We use `pop_first` instead of explicitly using the indices from the virtio spec
1433         // because the Linux virtio drivers only "allocates" queue indices that are used.
1434         let inflate_queue = queues.pop_first().unwrap().1;
1435         let deflate_queue = queues.pop_first().unwrap().1;
1436         let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1437 
1438         if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1439             queue_struct.stats = Some(queues.pop_first().unwrap().1);
1440         }
1441         if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1442             queue_struct.reporting = Some(queues.pop_first().unwrap().1);
1443         }
1444         if self.acked_features & (1 << VIRTIO_BALLOON_F_EVENTS_VQ) != 0 {
1445             queue_struct.events = Some(queues.pop_first().unwrap().1);
1446         }
1447         if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1448             queue_struct.ws = (
1449                 Some(queues.pop_first().unwrap().1),
1450                 Some(queues.pop_first().unwrap().1),
1451             );
1452         }
1453         Ok(queue_struct)
1454     }
1455 
start_worker( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BalloonQueues, ) -> anyhow::Result<()>1456     fn start_worker(
1457         &mut self,
1458         mem: GuestMemory,
1459         interrupt: Interrupt,
1460         queues: BalloonQueues,
1461     ) -> anyhow::Result<()> {
1462         let (self_target_reached_evt, target_reached_evt) = Event::new()
1463             .and_then(|e| Ok((e.try_clone()?, e)))
1464             .context("failed to create target_reached Event pair: {}")?;
1465         self.target_reached_evt = Some(self_target_reached_evt);
1466 
1467         let state = self.state.clone();
1468 
1469         let command_tube = self.command_tube.take().unwrap();
1470 
1471         let vm_memory_client = self.vm_memory_client.take().unwrap();
1472         let release_memory_tube = self.release_memory_tube.take();
1473         #[cfg(feature = "registered_events")]
1474         let registered_evt_q = self.registered_evt_q.take();
1475         let pending_adjusted_response_event = self
1476             .pending_adjusted_response_event
1477             .try_clone()
1478             .context("failed to clone Event")?;
1479 
1480         self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1481             run_worker(
1482                 queues.inflate,
1483                 queues.deflate,
1484                 queues.stats,
1485                 queues.reporting,
1486                 queues.events,
1487                 queues.ws,
1488                 command_tube,
1489                 vm_memory_client,
1490                 release_memory_tube,
1491                 interrupt,
1492                 kill_evt,
1493                 target_reached_evt,
1494                 pending_adjusted_response_event,
1495                 mem,
1496                 state,
1497                 #[cfg(feature = "registered_events")]
1498                 registered_evt_q,
1499             )
1500         }));
1501 
1502         Ok(())
1503     }
1504 }
1505 
1506 impl VirtioDevice for Balloon {
keep_rds(&self) -> Vec<RawDescriptor>1507     fn keep_rds(&self) -> Vec<RawDescriptor> {
1508         let mut rds = Vec::new();
1509         if let Some(command_tube) = &self.command_tube {
1510             rds.push(command_tube.as_raw_descriptor());
1511         }
1512         if let Some(release_memory_tube) = &self.release_memory_tube {
1513             rds.push(release_memory_tube.as_raw_descriptor());
1514         }
1515         #[cfg(feature = "registered_events")]
1516         if let Some(registered_evt_q) = &self.registered_evt_q {
1517             rds.push(registered_evt_q.as_raw_descriptor());
1518         }
1519         rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1520         rds
1521     }
1522 
device_type(&self) -> DeviceType1523     fn device_type(&self) -> DeviceType {
1524         DeviceType::Balloon
1525     }
1526 
queue_max_sizes(&self) -> &[u16]1527     fn queue_max_sizes(&self) -> &[u16] {
1528         QUEUE_SIZES
1529     }
1530 
read_config(&self, offset: u64, data: &mut [u8])1531     fn read_config(&self, offset: u64, data: &mut [u8]) {
1532         copy_config(data, 0, self.get_config().as_bytes(), offset);
1533     }
1534 
write_config(&mut self, offset: u64, data: &[u8])1535     fn write_config(&mut self, offset: u64, data: &[u8]) {
1536         let mut config = self.get_config();
1537         copy_config(config.as_bytes_mut(), offset, data, 0);
1538         let mut state = block_on(self.state.lock());
1539         state.actual_pages = config.actual.to_native();
1540 
1541         // If balloon has updated to the requested memory, let the hypervisor know.
1542         if config.num_pages == config.actual {
1543             debug!(
1544                 "sending target reached event at {}",
1545                 u32::from(config.num_pages)
1546             );
1547             self.target_reached_evt.as_ref().map(|e| e.signal());
1548         }
1549         if state.failable_update && state.actual_pages == state.num_pages {
1550             state.failable_update = false;
1551             let num_pages = state.num_pages;
1552             state.pending_adjusted_responses.push_back(num_pages);
1553             let _ = self.pending_adjusted_response_event.signal();
1554         }
1555     }
1556 
features(&self) -> u641557     fn features(&self) -> u64 {
1558         self.features
1559     }
1560 
ack_features(&mut self, mut value: u64)1561     fn ack_features(&mut self, mut value: u64) {
1562         if value & !self.features != 0 {
1563             warn!("virtio_balloon got unknown feature ack {:x}", value);
1564             value &= self.features;
1565         }
1566         self.acked_features |= value;
1567     }
1568 
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>1569     fn activate(
1570         &mut self,
1571         mem: GuestMemory,
1572         interrupt: Interrupt,
1573         queues: BTreeMap<usize, Queue>,
1574     ) -> anyhow::Result<()> {
1575         let queues = self.get_queues_from_map(queues)?;
1576         self.start_worker(mem, interrupt, queues)
1577     }
1578 
reset(&mut self) -> anyhow::Result<()>1579     fn reset(&mut self) -> anyhow::Result<()> {
1580         let _worker = self.stop_worker();
1581         Ok(())
1582     }
1583 
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>1584     fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1585         match self.stop_worker() {
1586             StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1587             StoppedWorker::MissingQueues => {
1588                 anyhow::bail!("balloon queue workers did not stop cleanly.")
1589             }
1590             StoppedWorker::AlreadyStopped => {
1591                 // Device hasn't been activated.
1592                 Ok(None)
1593             }
1594         }
1595     }
1596 
virtio_wake( &mut self, queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>1597     fn virtio_wake(
1598         &mut self,
1599         queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1600     ) -> anyhow::Result<()> {
1601         if let Some((mem, interrupt, queues)) = queues_state {
1602             if queues.len() < 2 {
1603                 anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1604             }
1605 
1606             let balloon_queues = self.get_queues_from_map(queues)?;
1607             self.start_worker(mem, interrupt, balloon_queues)?;
1608         }
1609         Ok(())
1610     }
1611 
virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value>1612     fn virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
1613         let state = self
1614             .state
1615             .lock()
1616             .now_or_never()
1617             .context("failed to acquire balloon lock")?;
1618         serde_json::to_value(BalloonSnapshot {
1619             features: self.features,
1620             acked_features: self.acked_features,
1621             state: state.clone(),
1622             ws_num_bins: self.ws_num_bins,
1623         })
1624         .context("failed to serialize balloon state")
1625     }
1626 
virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>1627     fn virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
1628         let snap: BalloonSnapshot = serde_json::from_value(data).context("error deserializing")?;
1629         if snap.features != self.features {
1630             anyhow::bail!(
1631                 "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1632                 self.features,
1633                 snap.features,
1634             );
1635         }
1636 
1637         let mut state = self
1638             .state
1639             .lock()
1640             .now_or_never()
1641             .context("failed to acquire balloon lock")?;
1642         *state = snap.state;
1643         self.ws_num_bins = snap.ws_num_bins;
1644         self.acked_features = snap.acked_features;
1645         Ok(())
1646     }
1647 }
1648 
1649 #[cfg(test)]
1650 mod tests {
1651     use super::*;
1652     use crate::suspendable_virtio_tests;
1653     use crate::virtio::descriptor_utils::create_descriptor_chain;
1654     use crate::virtio::descriptor_utils::DescriptorType;
1655 
1656     #[test]
desc_parsing_inflate()1657     fn desc_parsing_inflate() {
1658         // Check that the memory addresses are parsed correctly by 'handle_address_chain' and passed
1659         // to the closure.
1660         let memory_start_addr = GuestAddress(0x0);
1661         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1662         memory
1663             .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1664             .unwrap();
1665         memory
1666             .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1667             .unwrap();
1668 
1669         let mut chain = create_descriptor_chain(
1670             &memory,
1671             GuestAddress(0x0),
1672             GuestAddress(0x100),
1673             vec![(DescriptorType::Readable, 8)],
1674             0,
1675         )
1676         .expect("create_descriptor_chain failed");
1677 
1678         let mut addrs = Vec::new();
1679         let res = handle_address_chain(None, &mut chain, &mut |guest_address, len| {
1680             addrs.push((guest_address, len));
1681         });
1682         assert!(res.is_ok());
1683         assert_eq!(addrs.len(), 2);
1684         assert_eq!(
1685             addrs[0].0,
1686             GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1687         );
1688         assert_eq!(
1689             addrs[1].0,
1690             GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1691         );
1692     }
1693 
1694     #[test]
num_expected_queues()1695     fn num_expected_queues() {
1696         let to_feature_bits =
1697             |features: &[u32]| -> u64 { features.iter().fold(0, |acc, f| acc | (1_u64 << f)) };
1698 
1699         assert_eq!(2, Balloon::num_expected_queues(0));
1700         assert_eq!(
1701             2,
1702             Balloon::num_expected_queues(to_feature_bits(&[VIRTIO_BALLOON_F_MUST_TELL_HOST]))
1703         );
1704         assert_eq!(
1705             3,
1706             Balloon::num_expected_queues(to_feature_bits(&[VIRTIO_BALLOON_F_STATS_VQ]))
1707         );
1708         assert_eq!(
1709             5,
1710             Balloon::num_expected_queues(to_feature_bits(&[
1711                 VIRTIO_BALLOON_F_STATS_VQ,
1712                 VIRTIO_BALLOON_F_EVENTS_VQ,
1713                 VIRTIO_BALLOON_F_PAGE_REPORTING
1714             ]))
1715         );
1716         assert_eq!(
1717             7,
1718             Balloon::num_expected_queues(to_feature_bits(&[
1719                 VIRTIO_BALLOON_F_STATS_VQ,
1720                 VIRTIO_BALLOON_F_EVENTS_VQ,
1721                 VIRTIO_BALLOON_F_PAGE_REPORTING,
1722                 VIRTIO_BALLOON_F_WS_REPORTING
1723             ]))
1724         );
1725     }
1726 
1727     struct BalloonContext {
1728         _ctrl_tube: Tube,
1729         _mem_client_tube: Tube,
1730     }
1731 
modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon)1732     fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1733         balloon.ws_num_bins = !balloon.ws_num_bins;
1734     }
1735 
create_device() -> (BalloonContext, Balloon)1736     fn create_device() -> (BalloonContext, Balloon) {
1737         let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1738         let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1739         (
1740             BalloonContext {
1741                 _ctrl_tube,
1742                 _mem_client_tube,
1743             },
1744             Balloon::new(
1745                 0,
1746                 ctrl_tube_device,
1747                 VmMemoryClient::new(mem_client_tube_device),
1748                 None,
1749                 1024,
1750                 BalloonMode::Relaxed,
1751                 0,
1752                 #[cfg(feature = "registered_events")]
1753                 None,
1754                 0,
1755             )
1756             .unwrap(),
1757         )
1758     }
1759 
1760     suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1761 }
1762