1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 mod sys;
6 
7 use std::collections::VecDeque;
8 use std::sync::Arc;
9 
10 use anyhow::anyhow;
11 use anyhow::Context;
12 use balloon_control::BalloonStats;
13 use balloon_control::BalloonTubeCommand;
14 use balloon_control::BalloonTubeResult;
15 use balloon_control::BalloonWSS;
16 use balloon_control::VIRTIO_BALLOON_WSS_CONFIG_SIZE;
17 use balloon_control::VIRTIO_BALLOON_WSS_NUM_BINS;
18 use base::error;
19 use base::warn;
20 use base::AsRawDescriptor;
21 use base::Event;
22 use base::RawDescriptor;
23 use base::SendTube;
24 use base::Tube;
25 use base::WorkerThread;
26 use cros_async::block_on;
27 use cros_async::select12;
28 use cros_async::sync::Mutex as AsyncMutex;
29 use cros_async::AsyncTube;
30 use cros_async::EventAsync;
31 use cros_async::Executor;
32 use cros_async::SendTubeAsync;
33 use data_model::Le16;
34 use data_model::Le32;
35 use data_model::Le64;
36 use futures::channel::mpsc;
37 use futures::pin_mut;
38 use futures::FutureExt;
39 use futures::StreamExt;
40 use remain::sorted;
41 use thiserror::Error as ThisError;
42 use vm_control::RegisteredEvent;
43 use vm_memory::GuestAddress;
44 use vm_memory::GuestMemory;
45 use zerocopy::AsBytes;
46 use zerocopy::FromBytes;
47 
48 use super::async_utils;
49 use super::copy_config;
50 use super::descriptor_utils;
51 use super::DescriptorChain;
52 use super::DescriptorError;
53 use super::DeviceType;
54 use super::Interrupt;
55 use super::Queue;
56 use super::Reader;
57 use super::SignalableInterrupt;
58 use super::VirtioDevice;
59 use super::Writer;
60 use crate::Suspendable;
61 use crate::UnpinRequest;
62 use crate::UnpinResponse;
63 
64 #[sorted]
65 #[derive(ThisError, Debug)]
66 pub enum BalloonError {
67     /// Failed an async await
68     #[error("failed async await: {0}")]
69     AsyncAwait(cros_async::AsyncError),
70     /// Failed to create event.
71     #[error("failed to create event: {0}")]
72     CreatingEvent(base::Error),
73     /// Failed to create async message receiver.
74     #[error("failed to create async message receiver: {0}")]
75     CreatingMessageReceiver(base::TubeError),
76     /// Virtio descriptor error
77     #[error("virtio descriptor error: {0}")]
78     Descriptor(DescriptorError),
79     /// Failed to receive command message.
80     #[error("failed to receive command message: {0}")]
81     ReceivingCommand(base::TubeError),
82     /// Failed to send command response.
83     #[error("failed to send command response: {0}")]
84     SendResponse(base::TubeError),
85     /// Error while writing to virtqueue
86     #[error("failed to write to virtqueue: {0}")]
87     WriteQueue(std::io::Error),
88     /// Failed to write config event.
89     #[error("failed to write config event: {0}")]
90     WritingConfigEvent(base::Error),
91 }
92 pub type Result<T> = std::result::Result<T, BalloonError>;
93 
94 // Balloon implements six virt IO queues: Inflate, Deflate, Stats, Event, WssData, WssCmd.
95 const QUEUE_SIZE: u16 = 128;
96 const QUEUE_SIZES: &[u16] = &[
97     QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE,
98 ];
99 
100 const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
101 const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
102 
103 // The feature bitmap for virtio balloon
104 const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; // Tell before reclaiming pages
105 const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; // Stats reporting enabled
106 const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; // Deflate balloon on OOM
107 const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; // Page reporting virtqueue
108                                                 // TODO(b/273973298): this should maybe be bit 6? to be changed later
109 const VIRTIO_BALLOON_F_WSS_REPORTING: u32 = 8; // Working Set Size reporting virtqueues
110 
111 #[derive(Copy, Clone)]
112 #[repr(u32)]
113 // Balloon virtqueues
114 pub enum BalloonFeatures {
115     // Page Reporting enabled
116     PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
117     // WSS Reporting enabled
118     WSSReporting = VIRTIO_BALLOON_F_WSS_REPORTING,
119 }
120 
121 // These feature bits are part of the proposal:
122 //  https://lists.oasis-open.org/archives/virtio-comment/202201/msg00139.html
123 const VIRTIO_BALLOON_F_RESPONSIVE_DEVICE: u32 = 6; // Device actively watching guest memory
124 const VIRTIO_BALLOON_F_EVENTS_VQ: u32 = 7; // Event vq is enabled
125 
126 // virtio_balloon_config is the balloon device configuration space defined by the virtio spec.
127 #[derive(Copy, Clone, Debug, Default, AsBytes, FromBytes)]
128 #[repr(C)]
129 struct virtio_balloon_config {
130     num_pages: Le32,
131     actual: Le32,
132     free_page_hint_cmd_id: Le32,
133     poison_val: Le32,
134     // WSS field is part of proposed spec extension (b/273973298).
135     wss_num_bins: Le32,
136 }
137 
138 // BalloonState is shared by the worker and device thread.
139 #[derive(Default)]
140 struct BalloonState {
141     num_pages: u32,
142     actual_pages: u32,
143     expecting_wss: bool,
144     expected_wss_id: u64,
145     // Flag indicating that the balloon is in the process of a failable update. This
146     // is set by an Adjust command that has allow_failure set, and is cleared when the
147     // Adjusted success/failure response is sent.
148     failable_update: bool,
149     pending_adjusted_responses: VecDeque<u32>,
150 }
151 
152 // The constants defining stats types in virtio_baloon_stat
153 const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
154 const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
155 const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
156 const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
157 const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
158 const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
159 const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
160 const VIRTIO_BALLOON_S_CACHES: u16 = 7;
161 const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
162 const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
163 const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
164 const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
165 
166 // BalloonStat is used to deserialize stats from the stats_queue.
167 #[derive(Copy, Clone, FromBytes, AsBytes)]
168 #[repr(C, packed)]
169 struct BalloonStat {
170     tag: Le16,
171     val: Le64,
172 }
173 
174 impl BalloonStat {
update_stats(&self, stats: &mut BalloonStats)175     fn update_stats(&self, stats: &mut BalloonStats) {
176         let val = Some(self.val.to_native());
177         match self.tag.to_native() {
178             VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
179             VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
180             VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
181             VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
182             VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
183             VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
184             VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
185             VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
186             VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
187             VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
188             VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
189             VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
190             _ => (),
191         }
192     }
193 }
194 
195 const VIRTIO_BALLOON_EVENT_PRESSURE: u32 = 1;
196 const VIRTIO_BALLOON_EVENT_PUFF_FAILURE: u32 = 2;
197 
198 #[repr(C)]
199 #[derive(Copy, Clone, Default, AsBytes, FromBytes)]
200 struct virtio_balloon_event_header {
201     evt_type: Le32,
202 }
203 
204 // virtio_balloon_wss is used to deserialize from the wss data vq.
205 #[repr(C)]
206 #[derive(Copy, Clone, Debug, Default, AsBytes, FromBytes)]
207 struct virtio_balloon_wss {
208     tag: Le16,
209     node_id: Le16,
210     // virtio prefers field members to align on a word boundary so we must pad. see:
211     // https://crsrc.org/o/src/third_party/kernel/v5.15/include/uapi/linux/virtio_balloon.h;l=105
212     _reserved: [u8; 4],
213     idle_age_ms: Le64,
214     // TODO(b/273973298): these should become separate fields - bytes for ANON and FILE
215     memory_size_bytes: [Le64; 2],
216 }
217 
218 impl virtio_balloon_wss {
update_wss(&self, wss: &mut BalloonWSS, index: usize)219     fn update_wss(&self, wss: &mut BalloonWSS, index: usize) {
220         if index >= VIRTIO_BALLOON_WSS_NUM_BINS {
221             error!(
222                 "index {} outside of known WSS bins: {}",
223                 index, VIRTIO_BALLOON_WSS_NUM_BINS
224             );
225             return;
226         }
227         wss.wss[index].age = self.idle_age_ms.to_native();
228         wss.wss[index].bytes[0] = self.memory_size_bytes[0].to_native();
229         wss.wss[index].bytes[1] = self.memory_size_bytes[1].to_native();
230     }
231 }
232 
233 const _VIRTIO_BALLOON_WSS_OP_INVALID: u16 = 0;
234 const VIRTIO_BALLOON_WSS_OP_REQUEST: u16 = 1;
235 const VIRTIO_BALLOON_WSS_OP_CONFIG: u16 = 2;
236 const _VIRTIO_BALLOON_WSS_OP_DISCARD: u16 = 3;
237 
238 // virtio_balloon_op is used to serialize to the wss cmd vq.
239 #[repr(C, packed)]
240 #[derive(Copy, Clone, Debug, Default, AsBytes, FromBytes)]
241 struct virtio_balloon_op {
242     type_: Le16,
243 }
244 
invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F) where F: FnMut(GuestAddress, u64),245 fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
246 where
247     F: FnMut(GuestAddress, u64),
248 {
249     for range in ranges {
250         desc_handler(GuestAddress(range.0), range.1);
251     }
252 }
253 
254 // Release a list of guest memory ranges back to the host system.
255 // Unpin requests for each inflate range will be sent via `release_memory_tube`
256 // if provided, and then `desc_handler` will be called for each inflate range.
release_ranges<F>( release_memory_tube: Option<&Tube>, inflate_ranges: Vec<(u64, u64)>, desc_handler: &mut F, ) -> descriptor_utils::Result<()> where F: FnMut(GuestAddress, u64),257 fn release_ranges<F>(
258     release_memory_tube: Option<&Tube>,
259     inflate_ranges: Vec<(u64, u64)>,
260     desc_handler: &mut F,
261 ) -> descriptor_utils::Result<()>
262 where
263     F: FnMut(GuestAddress, u64),
264 {
265     if let Some(tube) = release_memory_tube {
266         let unpin_ranges = inflate_ranges
267             .iter()
268             .map(|v| {
269                 (
270                     v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
271                     v.1 / VIRTIO_BALLOON_PF_SIZE,
272                 )
273             })
274             .collect();
275         let req = UnpinRequest {
276             ranges: unpin_ranges,
277         };
278         if let Err(e) = tube.send(&req) {
279             error!("failed to send unpin request: {}", e);
280         } else {
281             match tube.recv() {
282                 Ok(resp) => match resp {
283                     UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
284                     UnpinResponse::Failed => error!("failed to handle unpin request"),
285                 },
286                 Err(e) => error!("failed to handle get unpin response: {}", e),
287             }
288         }
289     } else {
290         invoke_desc_handler(inflate_ranges, desc_handler);
291     }
292 
293     Ok(())
294 }
295 
296 // Processes one message's list of addresses.
handle_address_chain<F>( release_memory_tube: Option<&Tube>, avail_desc: DescriptorChain, mem: &GuestMemory, desc_handler: &mut F, ) -> descriptor_utils::Result<()> where F: FnMut(GuestAddress, u64),297 fn handle_address_chain<F>(
298     release_memory_tube: Option<&Tube>,
299     avail_desc: DescriptorChain,
300     mem: &GuestMemory,
301     desc_handler: &mut F,
302 ) -> descriptor_utils::Result<()>
303 where
304     F: FnMut(GuestAddress, u64),
305 {
306     // In a long-running system, there is no reason to expect that
307     // a significant number of freed pages are consecutive. However,
308     // batching is relatively simple and can result in significant
309     // gains in a newly booted system, so it's worth attempting.
310     let mut range_start = 0;
311     let mut range_size = 0;
312     let mut reader = Reader::new(mem.clone(), avail_desc)?;
313     let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
314     for res in reader.iter::<Le32>() {
315         let pfn = match res {
316             Ok(pfn) => pfn,
317             Err(e) => {
318                 error!("error while reading unused pages: {}", e);
319                 break;
320             }
321         };
322         let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
323         if range_start + range_size == guest_address {
324             range_size += VIRTIO_BALLOON_PF_SIZE;
325         } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
326             range_start = guest_address;
327             range_size += VIRTIO_BALLOON_PF_SIZE;
328         } else {
329             // Discontinuity, so flush the previous range. Note range_size
330             // will be 0 on the first iteration, so skip that.
331             if range_size != 0 {
332                 inflate_ranges.push((range_start, range_size));
333             }
334             range_start = guest_address;
335             range_size = VIRTIO_BALLOON_PF_SIZE;
336         }
337     }
338     if range_size != 0 {
339         inflate_ranges.push((range_start, range_size));
340     }
341 
342     release_ranges(release_memory_tube, inflate_ranges, desc_handler)
343 }
344 
345 // Async task that handles the main balloon inflate and deflate queues.
handle_queue<F>( mem: &GuestMemory, mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, interrupt: Interrupt, mut desc_handler: F, ) where F: FnMut(GuestAddress, u64),346 async fn handle_queue<F>(
347     mem: &GuestMemory,
348     mut queue: Queue,
349     mut queue_event: EventAsync,
350     release_memory_tube: Option<&Tube>,
351     interrupt: Interrupt,
352     mut desc_handler: F,
353 ) where
354     F: FnMut(GuestAddress, u64),
355 {
356     loop {
357         let avail_desc = match queue.next_async(mem, &mut queue_event).await {
358             Err(e) => {
359                 error!("Failed to read descriptor {}", e);
360                 return;
361             }
362             Ok(d) => d,
363         };
364         let index = avail_desc.index;
365         if let Err(e) =
366             handle_address_chain(release_memory_tube, avail_desc, mem, &mut desc_handler)
367         {
368             error!("balloon: failed to process inflate addresses: {}", e);
369         }
370         queue.add_used(mem, index, 0);
371         queue.trigger_interrupt(mem, &interrupt);
372     }
373 }
374 
375 // Processes one page-reporting descriptor.
handle_reported_buffer<F>( release_memory_tube: Option<&Tube>, avail_desc: DescriptorChain, desc_handler: &mut F, ) -> descriptor_utils::Result<()> where F: FnMut(GuestAddress, u64),376 fn handle_reported_buffer<F>(
377     release_memory_tube: Option<&Tube>,
378     avail_desc: DescriptorChain,
379     desc_handler: &mut F,
380 ) -> descriptor_utils::Result<()>
381 where
382     F: FnMut(GuestAddress, u64),
383 {
384     let mut reported_ranges: Vec<(u64, u64)> = Vec::new();
385     let regions = avail_desc.into_iter();
386     for desc in regions {
387         let (desc_regions, _exported) = desc.into_mem_regions();
388         for r in desc_regions {
389             reported_ranges.push((r.gpa.offset(), r.len));
390         }
391     }
392 
393     release_ranges(release_memory_tube, reported_ranges, desc_handler)
394 }
395 
396 // Async task that handles the page reporting queue.
handle_reporting_queue<F>( mem: &GuestMemory, mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, interrupt: Interrupt, mut desc_handler: F, ) where F: FnMut(GuestAddress, u64),397 async fn handle_reporting_queue<F>(
398     mem: &GuestMemory,
399     mut queue: Queue,
400     mut queue_event: EventAsync,
401     release_memory_tube: Option<&Tube>,
402     interrupt: Interrupt,
403     mut desc_handler: F,
404 ) where
405     F: FnMut(GuestAddress, u64),
406 {
407     loop {
408         let avail_desc = match queue.next_async(mem, &mut queue_event).await {
409             Err(e) => {
410                 error!("Failed to read descriptor {}", e);
411                 return;
412             }
413             Ok(d) => d,
414         };
415         let index = avail_desc.index;
416         if let Err(e) = handle_reported_buffer(release_memory_tube, avail_desc, &mut desc_handler) {
417             error!("balloon: failed to process reported buffer: {}", e);
418         }
419         queue.add_used(mem, index, 0);
420         queue.trigger_interrupt(mem, &interrupt);
421     }
422 }
423 
parse_balloon_stats(reader: &mut Reader) -> BalloonStats424 fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
425     let mut stats: BalloonStats = Default::default();
426     for res in reader.iter::<BalloonStat>() {
427         match res {
428             Ok(stat) => stat.update_stats(&mut stats),
429             Err(e) => {
430                 error!("error while reading stats: {}", e);
431                 break;
432             }
433         };
434     }
435     stats
436 }
437 
438 // Async task that handles the stats queue. Note that the cadence of this is driven by requests for
439 // balloon stats from the control pipe.
440 // The guests queues an initial buffer on boot, which is read and then this future will block until
441 // signaled from the command socket that stats should be collected again.
handle_stats_queue( mem: &GuestMemory, mut queue: Queue, mut queue_event: EventAsync, mut stats_rx: mpsc::Receiver<u64>, command_tube: &AsyncTube, registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncMutex<BalloonState>>, interrupt: Interrupt, )442 async fn handle_stats_queue(
443     mem: &GuestMemory,
444     mut queue: Queue,
445     mut queue_event: EventAsync,
446     mut stats_rx: mpsc::Receiver<u64>,
447     command_tube: &AsyncTube,
448     registered_evt_q: Option<&SendTubeAsync>,
449     state: Arc<AsyncMutex<BalloonState>>,
450     interrupt: Interrupt,
451 ) {
452     // Consume the first stats buffer sent from the guest at startup. It was not
453     // requested by anyone, and the stats are stale.
454     let mut index = match queue.next_async(mem, &mut queue_event).await {
455         Err(e) => {
456             error!("Failed to read descriptor {}", e);
457             return;
458         }
459         Ok(d) => d.index,
460     };
461     loop {
462         // Wait for a request to read the stats.
463         let id = match stats_rx.next().await {
464             Some(id) => id,
465             None => {
466                 error!("stats signal tube was closed");
467                 break;
468             }
469         };
470 
471         // Request a new stats_desc to the guest.
472         queue.add_used(mem, index, 0);
473         queue.trigger_interrupt(mem, &interrupt);
474 
475         let stats_desc = match queue.next_async(mem, &mut queue_event).await {
476             Err(e) => {
477                 error!("Failed to read descriptor {}", e);
478                 return;
479             }
480             Ok(d) => d,
481         };
482         index = stats_desc.index;
483         let mut reader = match Reader::new(mem.clone(), stats_desc) {
484             Ok(r) => r,
485             Err(e) => {
486                 error!("balloon: failed to CREATE Reader: {}", e);
487                 continue;
488             }
489         };
490         let stats = parse_balloon_stats(&mut reader);
491 
492         let actual_pages = state.lock().await.actual_pages as u64;
493         let result = BalloonTubeResult::Stats {
494             balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
495             stats,
496             id,
497         };
498         let send_result = command_tube.send(result).await;
499         if let Err(e) = send_result {
500             error!("failed to send stats result: {}", e);
501         }
502 
503         if let Some(registered_evt_q) = registered_evt_q {
504             if let Err(e) = registered_evt_q
505                 .send(&RegisteredEvent::VirtioBalloonResize)
506                 .await
507             {
508                 error!("failed to send VirtioBalloonResize event: {}", e);
509             }
510         }
511     }
512 }
513 
send_adjusted_response( tube: &AsyncTube, num_pages: u32, ) -> std::result::Result<(), base::TubeError>514 async fn send_adjusted_response(
515     tube: &AsyncTube,
516     num_pages: u32,
517 ) -> std::result::Result<(), base::TubeError> {
518     let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
519     let result = BalloonTubeResult::Adjusted { num_bytes };
520     tube.send(result).await
521 }
522 
handle_event( state: Arc<AsyncMutex<BalloonState>>, interrupt: Interrupt, r: &mut Reader, command_tube: &AsyncTube, ) -> Result<()>523 async fn handle_event(
524     state: Arc<AsyncMutex<BalloonState>>,
525     interrupt: Interrupt,
526     r: &mut Reader,
527     command_tube: &AsyncTube,
528 ) -> Result<()> {
529     match r.read_obj::<virtio_balloon_event_header>() {
530         Ok(hdr) => match hdr.evt_type.to_native() {
531             VIRTIO_BALLOON_EVENT_PRESSURE => {
532                 // TODO(b/213962590): See how this can be integrated this into memory rebalancing
533             }
534             VIRTIO_BALLOON_EVENT_PUFF_FAILURE => {
535                 let mut state = state.lock().await;
536                 if state.failable_update {
537                     state.num_pages = state.actual_pages;
538                     interrupt.signal_config_changed();
539 
540                     state.failable_update = false;
541                     send_adjusted_response(command_tube, state.actual_pages)
542                         .await
543                         .map_err(BalloonError::SendResponse)?;
544                 }
545             }
546             _ => {
547                 warn!("Unknown event {}", hdr.evt_type.to_native());
548             }
549         },
550         Err(e) => error!("Failed to parse event header {:?}", e),
551     }
552     Ok(())
553 }
554 
555 // Async task that handles the events queue.
handle_events_queue( mem: &GuestMemory, mut queue: Queue, mut queue_event: EventAsync, state: Arc<AsyncMutex<BalloonState>>, interrupt: Interrupt, command_tube: &AsyncTube, ) -> Result<()>556 async fn handle_events_queue(
557     mem: &GuestMemory,
558     mut queue: Queue,
559     mut queue_event: EventAsync,
560     state: Arc<AsyncMutex<BalloonState>>,
561     interrupt: Interrupt,
562     command_tube: &AsyncTube,
563 ) -> Result<()> {
564     loop {
565         let avail_desc = queue
566             .next_async(mem, &mut queue_event)
567             .await
568             .map_err(BalloonError::AsyncAwait)?;
569         let index = avail_desc.index;
570         match Reader::new(mem.clone(), avail_desc) {
571             Ok(mut r) => {
572                 handle_event(state.clone(), interrupt.clone(), &mut r, command_tube).await?
573             }
574             Err(e) => error!("balloon: failed to CREATE Reader: {}", e),
575         };
576 
577         queue.add_used(mem, index, 0);
578         queue.trigger_interrupt(mem, &interrupt);
579     }
580 }
581 
handle_wss_queue( mem: &GuestMemory, mut queue: Queue, mut queue_event: EventAsync, mut wss_rx: mpsc::Receiver<u64>, state: Arc<AsyncMutex<BalloonState>>, interrupt: Interrupt, ) -> Result<()>582 async fn handle_wss_queue(
583     mem: &GuestMemory,
584     mut queue: Queue,
585     mut queue_event: EventAsync,
586     mut wss_rx: mpsc::Receiver<u64>,
587     state: Arc<AsyncMutex<BalloonState>>,
588     interrupt: Interrupt,
589 ) -> Result<()> {
590     if let Err(e) =
591         send_initial_wss_config(mem, &mut queue, &mut queue_event, interrupt.clone()).await
592     {
593         error!("unable to send initial WSS config to guest: {}", e);
594     }
595 
596     loop {
597         let id = match wss_rx.next().await {
598             Some(id) => id,
599             None => {
600                 error!("wss signal tube was closed");
601                 break;
602             }
603         };
604 
605         {
606             let mut state = state.lock().await;
607             state.expecting_wss = true;
608             state.expected_wss_id = id;
609         }
610 
611         let avail_desc = queue
612             .next_async(mem, &mut queue_event)
613             .await
614             .map_err(BalloonError::AsyncAwait)?;
615 
616         let index = avail_desc.index;
617 
618         let mut writer = match Writer::new(mem.clone(), avail_desc) {
619             Ok(w) => w,
620             Err(e) => {
621                 error!("balloon: failed to CREATE Writer: {}", e);
622                 continue;
623             }
624         };
625 
626         let wss_r = virtio_balloon_op {
627             type_: VIRTIO_BALLOON_WSS_OP_REQUEST.into(),
628         };
629 
630         if let Err(e) = writer.write_obj(wss_r) {
631             error!("failed to write wss-r command: {}", e);
632         }
633 
634         queue.add_used(mem, index, writer.bytes_written() as u32);
635         queue.trigger_interrupt(mem, &interrupt);
636     }
637 
638     Ok(())
639 }
640 
parse_balloon_wss(reader: &mut Reader) -> BalloonWSS641 fn parse_balloon_wss(reader: &mut Reader) -> BalloonWSS {
642     let mut count = 0;
643     let mut wss = BalloonWSS::new();
644     for res in reader.iter::<virtio_balloon_wss>() {
645         match res {
646             Ok(wss_msg) => {
647                 wss_msg.update_wss(&mut wss, count);
648                 count += 1;
649                 if count > VIRTIO_BALLOON_WSS_NUM_BINS {
650                     error!(
651                         "we should never receive more than {} wss buckets",
652                         VIRTIO_BALLOON_WSS_NUM_BINS
653                     );
654                     break;
655                 }
656             }
657             Err(e) => {
658                 error!("error while reading wss: {}", e);
659                 break;
660             }
661         }
662     }
663     wss
664 }
665 
666 // Async task that handles the stats queue. Note that the arrival of events on
667 // the WSS vq may be the result of either a WSS request (WSS-R) command having
668 // been sent to the guest, or an unprompted send due to memory pressue in the
669 // guest. If the data was requested, we should also send that back on the
670 // command tube.
handle_wss_data_queue( mem: &GuestMemory, mut queue: Queue, mut queue_event: EventAsync, wss_op_tube: Option<&AsyncTube>, registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncMutex<BalloonState>>, interrupt: Interrupt, ) -> Result<()>671 async fn handle_wss_data_queue(
672     mem: &GuestMemory,
673     mut queue: Queue,
674     mut queue_event: EventAsync,
675     wss_op_tube: Option<&AsyncTube>,
676     registered_evt_q: Option<&SendTubeAsync>,
677     state: Arc<AsyncMutex<BalloonState>>,
678     interrupt: Interrupt,
679 ) -> Result<()> {
680     if let Some(wss_op_tube) = wss_op_tube {
681         loop {
682             let avail_desc = queue
683                 .next_async(mem, &mut queue_event)
684                 .await
685                 .map_err(BalloonError::AsyncAwait)?;
686             let index = avail_desc.index;
687             let mut reader = match Reader::new(mem.clone(), avail_desc) {
688                 Ok(r) => r,
689                 Err(e) => {
690                     error!("balloon: failed to CREATE Reader: {}", e);
691                     continue;
692                 }
693             };
694 
695             let wss = parse_balloon_wss(&mut reader);
696 
697             // Closure to hold the mutex for handling a WSS-R command response
698             {
699                 let mut state = state.lock().await;
700                 if state.expecting_wss {
701                     let result = BalloonTubeResult::WorkingSetSize {
702                         wss,
703                         id: state.expected_wss_id,
704                     };
705                     let send_result = wss_op_tube.send(result).await;
706                     if let Err(e) = send_result {
707                         error!("failed to send wss result: {}", e);
708                     }
709 
710                     state.expecting_wss = false;
711                 }
712             }
713 
714             // TODO: pipe back the wss to the registered event socket, needs
715             // event-with-payload, currently events are simple enums
716             if let Some(registered_evt_q) = registered_evt_q {
717                 if let Err(e) = registered_evt_q
718                     .send(&RegisteredEvent::VirtioBalloonWssReport)
719                     .await
720                 {
721                     error!("failed to send VirtioBalloonWSSReport event: {}", e);
722                 }
723             }
724 
725             queue.add_used(mem, index, 0);
726             queue.trigger_interrupt(mem, &interrupt);
727         }
728     } else {
729         error!("no wss device tube even though we have wss vqueues");
730         Ok(())
731     }
732 }
733 
send_wss_config( writer: &mut Writer, config: [u64; VIRTIO_BALLOON_WSS_CONFIG_SIZE], queue: &mut Queue, mem: &GuestMemory, index: u16, interrupt: Interrupt, ) -> Result<()>734 async fn send_wss_config(
735     writer: &mut Writer,
736     config: [u64; VIRTIO_BALLOON_WSS_CONFIG_SIZE],
737     queue: &mut Queue,
738     mem: &GuestMemory,
739     index: u16,
740     interrupt: Interrupt,
741 ) -> Result<()> {
742     let cmd = virtio_balloon_op {
743         type_: VIRTIO_BALLOON_WSS_OP_CONFIG.into(),
744     };
745 
746     writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
747 
748     writer.write_obj(config).map_err(BalloonError::WriteQueue)?;
749 
750     queue.add_used(mem, index, writer.bytes_written() as u32);
751     queue.trigger_interrupt(mem, &interrupt);
752 
753     Ok(())
754 }
755 
send_initial_wss_config( mem: &GuestMemory, queue: &mut Queue, queue_event: &mut EventAsync, interrupt: Interrupt, ) -> Result<()>756 async fn send_initial_wss_config(
757     mem: &GuestMemory,
758     queue: &mut Queue,
759     queue_event: &mut EventAsync,
760     interrupt: Interrupt,
761 ) -> Result<()> {
762     let avail_desc = queue
763         .next_async(mem, queue_event)
764         .await
765         .map_err(BalloonError::AsyncAwait)?;
766     let index = avail_desc.index;
767 
768     let mut writer = Writer::new(mem.clone(), avail_desc).map_err(BalloonError::Descriptor)?;
769 
770     // NOTE: first VIRTIO_BALLOON_WSS_NUM_BINS - 1 values are the
771     // interval boundaries, then refresh and reporting thresholds.
772     let config: [u64; VIRTIO_BALLOON_WSS_CONFIG_SIZE] = [1, 5, 10, 750, 1000];
773 
774     send_wss_config(&mut writer, config, queue, mem, index, interrupt).await
775 }
776 
777 // Async task that handles the command socket. The command socket handles messages from the host
778 // requesting that the guest balloon be adjusted or to report guest memory statistics.
handle_command_tube( command_tube: &AsyncTube, interrupt: Interrupt, state: Arc<AsyncMutex<BalloonState>>, mut stats_tx: mpsc::Sender<u64>, ) -> Result<()>779 async fn handle_command_tube(
780     command_tube: &AsyncTube,
781     interrupt: Interrupt,
782     state: Arc<AsyncMutex<BalloonState>>,
783     mut stats_tx: mpsc::Sender<u64>,
784 ) -> Result<()> {
785     loop {
786         match command_tube.next().await {
787             Ok(command) => match command {
788                 BalloonTubeCommand::Adjust {
789                     num_bytes,
790                     allow_failure,
791                 } => {
792                     let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
793                     let mut state = state.lock().await;
794 
795                     state.num_pages = num_pages;
796                     interrupt.signal_config_changed();
797 
798                     if allow_failure {
799                         if num_pages == state.actual_pages {
800                             send_adjusted_response(command_tube, num_pages)
801                                 .await
802                                 .map_err(BalloonError::SendResponse)?;
803                         } else {
804                             state.failable_update = true;
805                         }
806                     }
807                 }
808                 BalloonTubeCommand::Stats { id } => {
809                     if let Err(e) = stats_tx.try_send(id) {
810                         error!("failed to signal the stat handler: {}", e);
811                     }
812                 }
813                 BalloonTubeCommand::WorkingSetSize { .. } => {
814                     error!("should not get a working set size command on this tube!");
815                 }
816             },
817             Err(e) => {
818                 return Err(BalloonError::ReceivingCommand(e));
819             }
820         }
821     }
822 }
823 
824 // Async task that handles the command socket. The command socket handles messages from the host
825 // requesting that the guest balloon be adjusted or to report guest memory statistics.
handle_wss_op_tube( wss_op_tube: Option<&AsyncTube>, mut wss_tx: mpsc::Sender<u64>, ) -> Result<()>826 async fn handle_wss_op_tube(
827     wss_op_tube: Option<&AsyncTube>,
828     mut wss_tx: mpsc::Sender<u64>,
829 ) -> Result<()> {
830     if let Some(wss_op_tube) = wss_op_tube {
831         loop {
832             match wss_op_tube.next().await {
833                 Ok(command) => match command {
834                     BalloonTubeCommand::WorkingSetSize { id } => {
835                         if let Err(e) = wss_tx.try_send(id) {
836                             error!("failed to signal the wss handler: {}", e);
837                         }
838                     }
839                     _ => {
840                         error!("should only ever get a working set size command on this tube!");
841                     }
842                 },
843                 Err(e) => {
844                     return Err(BalloonError::ReceivingCommand(e));
845                 }
846             }
847         }
848     } else {
849         // No wss_op_tube; just park this future.
850         futures::future::pending::<()>().await;
851         Ok(())
852     }
853 }
854 
handle_pending_adjusted_responses( pending_adjusted_response_event: EventAsync, command_tube: &AsyncTube, state: Arc<AsyncMutex<BalloonState>>, ) -> Result<()>855 async fn handle_pending_adjusted_responses(
856     pending_adjusted_response_event: EventAsync,
857     command_tube: &AsyncTube,
858     state: Arc<AsyncMutex<BalloonState>>,
859 ) -> Result<()> {
860     loop {
861         pending_adjusted_response_event
862             .next_val()
863             .await
864             .map_err(BalloonError::AsyncAwait)?;
865         while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
866             send_adjusted_response(command_tube, num_pages)
867                 .await
868                 .map_err(BalloonError::SendResponse)?;
869         }
870     }
871 }
872 
873 // The main worker thread. Initialized the asynchronous worker tasks and passes them to the executor
874 // to be processed.
run_worker( inflate_queue: (Queue, Event), deflate_queue: (Queue, Event), stats_queue: Option<(Queue, Event)>, reporting_queue: Option<(Queue, Event)>, events_queue: Option<(Queue, Event)>, wss_queues: (Option<(Queue, Event)>, Option<(Queue, Event)>), command_tube: Tube, wss_op_tube: Option<Tube>, #[cfg(windows)] dynamic_mapping_tube: Tube, release_memory_tube: Option<Tube>, interrupt: Interrupt, kill_evt: Event, pending_adjusted_response_event: Event, mem: GuestMemory, state: Arc<AsyncMutex<BalloonState>>, registered_evt_q: Option<SendTube>, ) -> (Option<Tube>, Tube, Option<Tube>, Option<SendTube>)875 fn run_worker(
876     inflate_queue: (Queue, Event),
877     deflate_queue: (Queue, Event),
878     stats_queue: Option<(Queue, Event)>,
879     reporting_queue: Option<(Queue, Event)>,
880     events_queue: Option<(Queue, Event)>,
881     wss_queues: (Option<(Queue, Event)>, Option<(Queue, Event)>),
882     command_tube: Tube,
883     wss_op_tube: Option<Tube>,
884     #[cfg(windows)] dynamic_mapping_tube: Tube,
885     release_memory_tube: Option<Tube>,
886     interrupt: Interrupt,
887     kill_evt: Event,
888     pending_adjusted_response_event: Event,
889     mem: GuestMemory,
890     state: Arc<AsyncMutex<BalloonState>>,
891     registered_evt_q: Option<SendTube>,
892 ) -> (Option<Tube>, Tube, Option<Tube>, Option<SendTube>) {
893     let ex = Executor::new().unwrap();
894     let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
895     let wss_op_tube = wss_op_tube.map(|t| AsyncTube::new(&ex, t).unwrap());
896     let registered_evt_q_async = registered_evt_q
897         .as_ref()
898         .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
899 
900     // We need a block to release all references to command_tube at the end before returning it.
901     {
902         // The first queue is used for inflate messages
903         let inflate = handle_queue(
904             &mem,
905             inflate_queue.0,
906             EventAsync::new(inflate_queue.1, &ex).expect("failed to create async event"),
907             release_memory_tube.as_ref(),
908             interrupt.clone(),
909             |guest_address, len| {
910                 sys::free_memory(
911                     &guest_address,
912                     len,
913                     #[cfg(windows)]
914                     &dynamic_mapping_tube,
915                     #[cfg(unix)]
916                     &mem,
917                 )
918             },
919         );
920         pin_mut!(inflate);
921 
922         // The second queue is used for deflate messages
923         let deflate = handle_queue(
924             &mem,
925             deflate_queue.0,
926             EventAsync::new(deflate_queue.1, &ex).expect("failed to create async event"),
927             None,
928             interrupt.clone(),
929             |guest_address, len| {
930                 sys::reclaim_memory(
931                     &guest_address,
932                     len,
933                     #[cfg(windows)]
934                     &dynamic_mapping_tube,
935                 )
936             },
937         );
938         pin_mut!(deflate);
939 
940         // The next queue is used for stats messages if VIRTIO_BALLOON_F_STATS_VQ is negotiated.
941         // The message type is the id of the stats request, so we can detect if there are any stale
942         // stats results that were queued during an error condition.
943         let (stats_tx, stats_rx) = mpsc::channel::<u64>(1);
944         let stats = if let Some((stats_queue, stats_queue_evt)) = stats_queue {
945             handle_stats_queue(
946                 &mem,
947                 stats_queue,
948                 EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
949                 stats_rx,
950                 &command_tube,
951                 registered_evt_q_async.as_ref(),
952                 state.clone(),
953                 interrupt.clone(),
954             )
955             .left_future()
956         } else {
957             std::future::pending().right_future()
958         };
959         pin_mut!(stats);
960 
961         // The next queue is used for reporting messages
962         let reporting = if let Some((reporting_queue, reporting_queue_evt)) = reporting_queue {
963             handle_reporting_queue(
964                 &mem,
965                 reporting_queue,
966                 EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
967                 release_memory_tube.as_ref(),
968                 interrupt.clone(),
969                 |guest_address, len| {
970                     sys::free_memory(
971                         &guest_address,
972                         len,
973                         #[cfg(windows)]
974                         &dynamic_mapping_tube,
975                         #[cfg(unix)]
976                         &mem,
977                     )
978                 },
979             )
980             .left_future()
981         } else {
982             std::future::pending().right_future()
983         };
984         pin_mut!(reporting);
985 
986         // If VIRTIO_BALLOON_F_WSS_REPORTING is set 2 queues must handled - one
987         // for WSS data and one for WSS notifications.
988         let wss_data = if let Some((wss_data_queue, wss_data_queue_evt)) = wss_queues.0 {
989             handle_wss_data_queue(
990                 &mem,
991                 wss_data_queue,
992                 EventAsync::new(wss_data_queue_evt, &ex).expect("failed to create async event"),
993                 wss_op_tube.as_ref(),
994                 registered_evt_q_async.as_ref(),
995                 state.clone(),
996                 interrupt.clone(),
997             )
998             .left_future()
999         } else {
1000             std::future::pending().right_future()
1001         };
1002         pin_mut!(wss_data);
1003 
1004         let (wss_tx, wss_rx) = mpsc::channel::<u64>(1);
1005         let wss_queue = if let Some((wss_cmd_queue, wss_cmd_queue_evt)) = wss_queues.1 {
1006             handle_wss_queue(
1007                 &mem,
1008                 wss_cmd_queue,
1009                 EventAsync::new(wss_cmd_queue_evt, &ex).expect("failed to create async event"),
1010                 wss_rx,
1011                 state.clone(),
1012                 interrupt.clone(),
1013             )
1014             .left_future()
1015         } else {
1016             std::future::pending().right_future()
1017         };
1018         pin_mut!(wss_queue);
1019 
1020         // Future to handle command messages that resize the balloon.
1021         let command =
1022             handle_command_tube(&command_tube, interrupt.clone(), state.clone(), stats_tx);
1023         pin_mut!(command);
1024 
1025         // Future to handle wss command messages for the balloon.
1026         let wss_op = handle_wss_op_tube(wss_op_tube.as_ref(), wss_tx);
1027         pin_mut!(wss_op);
1028 
1029         // Process any requests to resample the irq value.
1030         let resample = async_utils::handle_irq_resample(&ex, interrupt.clone());
1031         pin_mut!(resample);
1032 
1033         // Exit if the kill event is triggered.
1034         let kill = async_utils::await_and_exit(&ex, kill_evt);
1035         pin_mut!(kill);
1036 
1037         // The next queue is used for events if VIRTIO_BALLOON_F_EVENTS_VQ is negotiated.
1038         let events = if let Some((events_queue, events_queue_evt)) = events_queue {
1039             handle_events_queue(
1040                 &mem,
1041                 events_queue,
1042                 EventAsync::new(events_queue_evt, &ex).expect("failed to create async event"),
1043                 state.clone(),
1044                 interrupt,
1045                 &command_tube,
1046             )
1047             .left_future()
1048         } else {
1049             std::future::pending().right_future()
1050         };
1051         pin_mut!(events);
1052 
1053         let pending_adjusted = handle_pending_adjusted_responses(
1054             EventAsync::new(pending_adjusted_response_event, &ex)
1055                 .expect("failed to create async event"),
1056             &command_tube,
1057             state,
1058         );
1059         pin_mut!(pending_adjusted);
1060 
1061         if let Err(e) = ex
1062             .run_until(select12(
1063                 inflate,
1064                 deflate,
1065                 stats,
1066                 reporting,
1067                 command,
1068                 wss_op,
1069                 resample,
1070                 kill,
1071                 events,
1072                 pending_adjusted,
1073                 wss_data,
1074                 wss_queue,
1075             ))
1076             .map(|_| ())
1077         {
1078             error!("error happened in executor: {}", e);
1079         }
1080     }
1081 
1082     (
1083         release_memory_tube,
1084         command_tube.into(),
1085         wss_op_tube.map(Into::into),
1086         registered_evt_q,
1087     )
1088 }
1089 
1090 /// Virtio device for memory balloon inflation/deflation.
1091 pub struct Balloon {
1092     command_tube: Option<Tube>,
1093     wss_op_tube: Option<Tube>,
1094     #[cfg(windows)]
1095     dynamic_mapping_tube: Option<Tube>,
1096     release_memory_tube: Option<Tube>,
1097     pending_adjusted_response_event: Event,
1098     state: Arc<AsyncMutex<BalloonState>>,
1099     features: u64,
1100     acked_features: u64,
1101     worker_thread: Option<WorkerThread<(Option<Tube>, Tube, Option<Tube>, Option<SendTube>)>>,
1102     registered_evt_q: Option<SendTube>,
1103 }
1104 
1105 /// Operation mode of the balloon.
1106 #[derive(PartialEq, Eq)]
1107 pub enum BalloonMode {
1108     /// The driver can access pages in the balloon (i.e. F_DEFLATE_ON_OOM)
1109     Relaxed,
1110     /// The driver cannot access pages in the balloon. Implies F_RESPONSIVE_DEVICE.
1111     Strict,
1112 }
1113 
1114 impl Balloon {
1115     /// Creates a new virtio balloon device.
1116     /// To let Balloon able to successfully release the memory which are pinned
1117     /// by CoIOMMU to host, the release_memory_tube will be used to send the inflate
1118     /// ranges to CoIOMMU with UnpinRequest/UnpinResponse messages, so that The
1119     /// memory in the inflate range can be unpinned first.
new( base_features: u64, command_tube: Tube, wss_op_tube: Option<Tube>, #[cfg(windows)] dynamic_mapping_tube: Tube, release_memory_tube: Option<Tube>, init_balloon_size: u64, mode: BalloonMode, enabled_features: u64, registered_evt_q: Option<SendTube>, ) -> Result<Balloon>1120     pub fn new(
1121         base_features: u64,
1122         command_tube: Tube,
1123         wss_op_tube: Option<Tube>,
1124         #[cfg(windows)] dynamic_mapping_tube: Tube,
1125         release_memory_tube: Option<Tube>,
1126         init_balloon_size: u64,
1127         mode: BalloonMode,
1128         enabled_features: u64,
1129         registered_evt_q: Option<SendTube>,
1130     ) -> Result<Balloon> {
1131         let features = base_features
1132             | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1133             | 1 << VIRTIO_BALLOON_F_STATS_VQ
1134             | 1 << VIRTIO_BALLOON_F_EVENTS_VQ
1135             | enabled_features
1136             | if mode == BalloonMode::Strict {
1137                 1 << VIRTIO_BALLOON_F_RESPONSIVE_DEVICE
1138             } else {
1139                 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1140             };
1141 
1142         Ok(Balloon {
1143             command_tube: Some(command_tube),
1144             wss_op_tube,
1145             #[cfg(windows)]
1146             dynamic_mapping_tube: Some(dynamic_mapping_tube),
1147             release_memory_tube,
1148             pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1149             state: Arc::new(AsyncMutex::new(BalloonState {
1150                 num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1151                 actual_pages: 0,
1152                 failable_update: false,
1153                 pending_adjusted_responses: VecDeque::new(),
1154                 expecting_wss: false,
1155                 expected_wss_id: 0,
1156             })),
1157             worker_thread: None,
1158             features,
1159             acked_features: 0,
1160             registered_evt_q,
1161         })
1162     }
1163 
get_config(&self) -> virtio_balloon_config1164     fn get_config(&self) -> virtio_balloon_config {
1165         let state = block_on(self.state.lock());
1166         virtio_balloon_config {
1167             num_pages: state.num_pages.into(),
1168             actual: state.actual_pages.into(),
1169             // crosvm does not (currently) use free_page_hint_cmd_id or
1170             // poison_val, but they must be present in the right order and size
1171             // for the virtio-balloon driver in the guest to deserialize the
1172             // config correctly.
1173             free_page_hint_cmd_id: 0.into(),
1174             poison_val: 0.into(),
1175             wss_num_bins: (VIRTIO_BALLOON_WSS_NUM_BINS as u32).into(),
1176         }
1177     }
1178 
num_expected_queues(acked_features: u64) -> usize1179     fn num_expected_queues(acked_features: u64) -> usize {
1180         // at minimum we have inflate and deflate vqueues.
1181         let mut num_queues = 2;
1182         // stats vqueue
1183         if acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1184             num_queues += 1;
1185         }
1186         // events vqueue
1187         if acked_features & (1 << VIRTIO_BALLOON_F_EVENTS_VQ) != 0 {
1188             num_queues += 1;
1189         }
1190         // page reporting vqueue
1191         if acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1192             num_queues += 1;
1193         }
1194         // working set size vqueues
1195         if acked_features & (1 << VIRTIO_BALLOON_F_WSS_REPORTING) != 0 {
1196             num_queues += 2;
1197         }
1198 
1199         num_queues
1200     }
1201 }
1202 
1203 impl VirtioDevice for Balloon {
keep_rds(&self) -> Vec<RawDescriptor>1204     fn keep_rds(&self) -> Vec<RawDescriptor> {
1205         let mut rds = Vec::new();
1206         if let Some(command_tube) = &self.command_tube {
1207             rds.push(command_tube.as_raw_descriptor());
1208         }
1209         if let Some(wss_op_tube) = &self.wss_op_tube {
1210             rds.push(wss_op_tube.as_raw_descriptor());
1211         }
1212         if let Some(release_memory_tube) = &self.release_memory_tube {
1213             rds.push(release_memory_tube.as_raw_descriptor());
1214         }
1215         if let Some(registered_evt_q) = &self.registered_evt_q {
1216             rds.push(registered_evt_q.as_raw_descriptor());
1217         }
1218         rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1219         rds
1220     }
1221 
device_type(&self) -> DeviceType1222     fn device_type(&self) -> DeviceType {
1223         DeviceType::Balloon
1224     }
1225 
queue_max_sizes(&self) -> &[u16]1226     fn queue_max_sizes(&self) -> &[u16] {
1227         QUEUE_SIZES
1228     }
1229 
read_config(&self, offset: u64, data: &mut [u8])1230     fn read_config(&self, offset: u64, data: &mut [u8]) {
1231         copy_config(data, 0, self.get_config().as_bytes(), offset);
1232     }
1233 
write_config(&mut self, offset: u64, data: &[u8])1234     fn write_config(&mut self, offset: u64, data: &[u8]) {
1235         let mut config = self.get_config();
1236         copy_config(config.as_bytes_mut(), offset, data, 0);
1237         let mut state = block_on(self.state.lock());
1238         state.actual_pages = config.actual.to_native();
1239         if state.failable_update && state.actual_pages == state.num_pages {
1240             state.failable_update = false;
1241             let num_pages = state.num_pages;
1242             state.pending_adjusted_responses.push_back(num_pages);
1243             let _ = self.pending_adjusted_response_event.signal();
1244         }
1245     }
1246 
features(&self) -> u641247     fn features(&self) -> u64 {
1248         self.features
1249     }
1250 
ack_features(&mut self, mut value: u64)1251     fn ack_features(&mut self, mut value: u64) {
1252         if value & !self.features != 0 {
1253             warn!("virtio_balloon got unknown feature ack {:x}", value);
1254             value &= self.features;
1255         }
1256         self.acked_features |= value;
1257     }
1258 
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, mut queues: Vec<(Queue, Event)>, ) -> anyhow::Result<()>1259     fn activate(
1260         &mut self,
1261         mem: GuestMemory,
1262         interrupt: Interrupt,
1263         mut queues: Vec<(Queue, Event)>,
1264     ) -> anyhow::Result<()> {
1265         let expected_queues = Balloon::num_expected_queues(self.acked_features);
1266         if queues.len() != expected_queues {
1267             return Err(anyhow!(
1268                 "expected {} queues, got {}",
1269                 expected_queues,
1270                 queues.len()
1271             ));
1272         }
1273 
1274         let inflate_queue = queues.remove(0);
1275         let deflate_queue = queues.remove(0);
1276         let stats_queue = if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1277             Some(queues.remove(0))
1278         } else {
1279             None
1280         };
1281         let reporting_queue = if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1282             Some(queues.remove(0))
1283         } else {
1284             None
1285         };
1286         let events_queue = if self.acked_features & (1 << VIRTIO_BALLOON_F_EVENTS_VQ) != 0 {
1287             Some(queues.remove(0))
1288         } else {
1289             None
1290         };
1291         let wss_queues = if self.acked_features & (1 << VIRTIO_BALLOON_F_WSS_REPORTING) != 0 {
1292             (Some(queues.remove(0)), Some(queues.remove(0)))
1293         } else {
1294             (None, None)
1295         };
1296 
1297         let state = self.state.clone();
1298 
1299         let command_tube = self.command_tube.take().unwrap();
1300 
1301         let wss_op_tube = self.wss_op_tube.take();
1302 
1303         #[cfg(windows)]
1304         let mapping_tube = self.dynamic_mapping_tube.take().unwrap();
1305         let release_memory_tube = self.release_memory_tube.take();
1306         let registered_evt_q = self.registered_evt_q.take();
1307         let pending_adjusted_response_event = self
1308             .pending_adjusted_response_event
1309             .try_clone()
1310             .context("failed to clone Event")?;
1311 
1312         self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1313             run_worker(
1314                 inflate_queue,
1315                 deflate_queue,
1316                 stats_queue,
1317                 reporting_queue,
1318                 events_queue,
1319                 wss_queues,
1320                 command_tube,
1321                 wss_op_tube,
1322                 #[cfg(windows)]
1323                 mapping_tube,
1324                 release_memory_tube,
1325                 interrupt,
1326                 kill_evt,
1327                 pending_adjusted_response_event,
1328                 mem,
1329                 state,
1330                 registered_evt_q,
1331             )
1332         }));
1333 
1334         Ok(())
1335     }
1336 
reset(&mut self) -> bool1337     fn reset(&mut self) -> bool {
1338         if let Some(worker_thread) = self.worker_thread.take() {
1339             let (release_memory_tube, command_tube, wss_op_tube, registered_evt_q) =
1340                 worker_thread.stop();
1341             self.release_memory_tube = release_memory_tube;
1342             self.command_tube = Some(command_tube);
1343             self.registered_evt_q = registered_evt_q;
1344             self.wss_op_tube = wss_op_tube;
1345             return true;
1346         }
1347         false
1348     }
1349 }
1350 
1351 impl Suspendable for Balloon {}
1352 
1353 #[cfg(test)]
1354 mod tests {
1355     use super::*;
1356     use crate::virtio::descriptor_utils::create_descriptor_chain;
1357     use crate::virtio::descriptor_utils::DescriptorType;
1358 
1359     #[test]
desc_parsing_inflate()1360     fn desc_parsing_inflate() {
1361         // Check that the memory addresses are parsed correctly by 'handle_address_chain' and passed
1362         // to the closure.
1363         let memory_start_addr = GuestAddress(0x0);
1364         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1365         memory
1366             .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1367             .unwrap();
1368         memory
1369             .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1370             .unwrap();
1371 
1372         let chain = create_descriptor_chain(
1373             &memory,
1374             GuestAddress(0x0),
1375             GuestAddress(0x100),
1376             vec![(DescriptorType::Readable, 8)],
1377             0,
1378         )
1379         .expect("create_descriptor_chain failed");
1380 
1381         let mut addrs = Vec::new();
1382         let res = handle_address_chain(None, chain, &memory, &mut |guest_address, len| {
1383             addrs.push((guest_address, len));
1384         });
1385         assert!(res.is_ok());
1386         assert_eq!(addrs.len(), 2);
1387         assert_eq!(
1388             addrs[0].0,
1389             GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1390         );
1391         assert_eq!(
1392             addrs[1].0,
1393             GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1394         );
1395     }
1396 
1397     #[test]
num_expected_queues()1398     fn num_expected_queues() {
1399         let to_feature_bits =
1400             |features: &[u32]| -> u64 { features.iter().fold(0, |acc, f| acc | (1_u64 << f)) };
1401 
1402         assert_eq!(2, Balloon::num_expected_queues(0));
1403         assert_eq!(
1404             2,
1405             Balloon::num_expected_queues(to_feature_bits(&[VIRTIO_BALLOON_F_MUST_TELL_HOST]))
1406         );
1407         assert_eq!(
1408             3,
1409             Balloon::num_expected_queues(to_feature_bits(&[VIRTIO_BALLOON_F_STATS_VQ]))
1410         );
1411         assert_eq!(
1412             5,
1413             Balloon::num_expected_queues(to_feature_bits(&[
1414                 VIRTIO_BALLOON_F_STATS_VQ,
1415                 VIRTIO_BALLOON_F_EVENTS_VQ,
1416                 VIRTIO_BALLOON_F_PAGE_REPORTING
1417             ]))
1418         );
1419         assert_eq!(
1420             7,
1421             Balloon::num_expected_queues(to_feature_bits(&[
1422                 VIRTIO_BALLOON_F_STATS_VQ,
1423                 VIRTIO_BALLOON_F_EVENTS_VQ,
1424                 VIRTIO_BALLOON_F_PAGE_REPORTING,
1425                 VIRTIO_BALLOON_F_WSS_REPORTING
1426             ]))
1427         );
1428     }
1429 }
1430