• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::collections::BTreeMap;
6 use std::collections::VecDeque;
7 use std::io::Write;
8 use std::sync::Arc;
9 
10 use anyhow::anyhow;
11 use anyhow::Context;
12 use balloon_control::BalloonStats;
13 use balloon_control::BalloonTubeCommand;
14 use balloon_control::BalloonTubeResult;
15 use balloon_control::BalloonWS;
16 use balloon_control::WSBucket;
17 use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
18 use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
19 use base::debug;
20 use base::error;
21 use base::warn;
22 use base::AsRawDescriptor;
23 use base::Event;
24 use base::RawDescriptor;
25 #[cfg(feature = "registered_events")]
26 use base::SendTube;
27 use base::Tube;
28 use base::WorkerThread;
29 use cros_async::block_on;
30 use cros_async::sync::RwLock as AsyncRwLock;
31 use cros_async::AsyncTube;
32 use cros_async::EventAsync;
33 use cros_async::Executor;
34 #[cfg(feature = "registered_events")]
35 use cros_async::SendTubeAsync;
36 use data_model::Le16;
37 use data_model::Le32;
38 use data_model::Le64;
39 use futures::channel::mpsc;
40 use futures::channel::oneshot;
41 use futures::pin_mut;
42 use futures::select;
43 use futures::select_biased;
44 use futures::FutureExt;
45 use futures::StreamExt;
46 use remain::sorted;
47 use serde::Deserialize;
48 use serde::Serialize;
49 use snapshot::AnySnapshot;
50 use thiserror::Error as ThisError;
51 use vm_control::api::VmMemoryClient;
52 #[cfg(feature = "registered_events")]
53 use vm_control::RegisteredEventWithData;
54 use vm_memory::GuestAddress;
55 use vm_memory::GuestMemory;
56 use zerocopy::FromBytes;
57 use zerocopy::Immutable;
58 use zerocopy::IntoBytes;
59 use zerocopy::KnownLayout;
60 
61 use super::async_utils;
62 use super::copy_config;
63 use super::create_stop_oneshot;
64 use super::DescriptorChain;
65 use super::DeviceType;
66 use super::Interrupt;
67 use super::Queue;
68 use super::Reader;
69 use super::StoppedWorker;
70 use super::VirtioDevice;
71 use crate::UnpinRequest;
72 use crate::UnpinResponse;
73 
74 #[sorted]
75 #[derive(ThisError, Debug)]
76 pub enum BalloonError {
77     /// Failed an async await
78     #[error("failed async await: {0}")]
79     AsyncAwait(cros_async::AsyncError),
80     /// Failed an async await
81     #[error("failed async await: {0}")]
82     AsyncAwaitAnyhow(anyhow::Error),
83     /// Failed to create event.
84     #[error("failed to create event: {0}")]
85     CreatingEvent(base::Error),
86     /// Failed to create async message receiver.
87     #[error("failed to create async message receiver: {0}")]
88     CreatingMessageReceiver(base::TubeError),
89     /// Failed to receive command message.
90     #[error("failed to receive command message: {0}")]
91     ReceivingCommand(base::TubeError),
92     /// Failed to send command response.
93     #[error("failed to send command response: {0}")]
94     SendResponse(base::TubeError),
95     /// Error while writing to virtqueue
96     #[error("failed to write to virtqueue: {0}")]
97     WriteQueue(std::io::Error),
98     /// Failed to write config event.
99     #[error("failed to write config event: {0}")]
100     WritingConfigEvent(base::Error),
101 }
102 pub type Result<T> = std::result::Result<T, BalloonError>;
103 
104 // Balloon implements five virt IO queues: Inflate, Deflate, Stats, WsData, WsCmd.
105 const QUEUE_SIZE: u16 = 128;
106 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE];
107 
108 // Virtqueue indexes
109 const INFLATEQ: usize = 0;
110 const DEFLATEQ: usize = 1;
111 const STATSQ: usize = 2;
112 const _FREE_PAGE_VQ: usize = 3;
113 const REPORTING_VQ: usize = 4;
114 const WS_DATA_VQ: usize = 5;
115 const WS_OP_VQ: usize = 6;
116 
117 const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
118 const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
119 
120 // The feature bitmap for virtio balloon
121 const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; // Tell before reclaiming pages
122 const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; // Stats reporting enabled
123 const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; // Deflate balloon on OOM
124 const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; // Page reporting virtqueue
125                                                 // TODO(b/273973298): this should maybe be bit 6? to be changed later
126 const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; // Working Set Reporting virtqueues
127 
128 #[derive(Copy, Clone)]
129 #[repr(u32)]
130 // Balloon virtqueues
131 pub enum BalloonFeatures {
132     // Page Reporting enabled
133     PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
134     // WS Reporting enabled
135     WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
136 }
137 
138 // virtio_balloon_config is the balloon device configuration space defined by the virtio spec.
139 #[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
140 #[repr(C)]
141 struct virtio_balloon_config {
142     num_pages: Le32,
143     actual: Le32,
144     free_page_hint_cmd_id: Le32,
145     poison_val: Le32,
146     // WS field is part of proposed spec extension (b/273973298).
147     ws_num_bins: u8,
148     _reserved: [u8; 3],
149 }
150 
151 // BalloonState is shared by the worker and device thread.
152 #[derive(Clone, Default, Serialize, Deserialize)]
153 struct BalloonState {
154     num_pages: u32,
155     actual_pages: u32,
156     expecting_ws: bool,
157     // Flag indicating that the balloon is in the process of a failable update. This
158     // is set by an Adjust command that has allow_failure set, and is cleared when the
159     // Adjusted success/failure response is sent.
160     failable_update: bool,
161     pending_adjusted_responses: VecDeque<u32>,
162 }
163 
164 // The constants defining stats types in virtio_baloon_stat
165 const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
166 const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
167 const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
168 const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
169 const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
170 const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
171 const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
172 const VIRTIO_BALLOON_S_CACHES: u16 = 7;
173 const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
174 const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
175 const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
176 const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
177 
178 // BalloonStat is used to deserialize stats from the stats_queue.
179 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
180 #[repr(C, packed)]
181 struct BalloonStat {
182     tag: Le16,
183     val: Le64,
184 }
185 
186 impl BalloonStat {
update_stats(&self, stats: &mut BalloonStats)187     fn update_stats(&self, stats: &mut BalloonStats) {
188         let val = Some(self.val.to_native());
189         match self.tag.to_native() {
190             VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
191             VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
192             VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
193             VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
194             VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
195             VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
196             VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
197             VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
198             VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
199             VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
200             VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
201             VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
202             _ => (),
203         }
204     }
205 }
206 
207 // virtio_balloon_ws is used to deserialize from the ws data vq.
208 #[repr(C)]
209 #[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
210 struct virtio_balloon_ws {
211     tag: Le16,
212     node_id: Le16,
213     // virtio prefers field members to align on a word boundary so we must pad. see:
214     // https://crsrc.org/o/src/third_party/kernel/v5.15/include/uapi/linux/virtio_balloon.h;l=105
215     _reserved: [u8; 4],
216     idle_age_ms: Le64,
217     // TODO(b/273973298): these should become separate fields - bytes for ANON and FILE
218     memory_size_bytes: [Le64; 2],
219 }
220 
221 impl virtio_balloon_ws {
update_ws(&self, ws: &mut BalloonWS)222     fn update_ws(&self, ws: &mut BalloonWS) {
223         let bucket = WSBucket {
224             age: self.idle_age_ms.to_native(),
225             bytes: [
226                 self.memory_size_bytes[0].to_native(),
227                 self.memory_size_bytes[1].to_native(),
228             ],
229         };
230         ws.ws.push(bucket);
231     }
232 }
233 
234 const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
235 const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
236 const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
237 const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
238 
239 // virtio_balloon_op is used to serialize to the ws cmd vq.
240 #[repr(C, packed)]
241 #[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
242 struct virtio_balloon_op {
243     type_: Le16,
244 }
245 
invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F) where F: FnMut(Vec<(GuestAddress, u64)>),246 fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
247 where
248     F: FnMut(Vec<(GuestAddress, u64)>),
249 {
250     desc_handler(
251         ranges
252             .into_iter()
253             .map(|range| (GuestAddress(range.0), range.1))
254             .collect(),
255     );
256 }
257 
258 // Release a list of guest memory ranges back to the host system.
259 // Unpin requests for each inflate range will be sent via `release_memory_tube`
260 // if provided, and then `desc_handler` will be called for each inflate range.
release_ranges<F>( release_memory_tube: Option<&Tube>, inflate_ranges: Vec<(u64, u64)>, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(Vec<(GuestAddress, u64)>),261 fn release_ranges<F>(
262     release_memory_tube: Option<&Tube>,
263     inflate_ranges: Vec<(u64, u64)>,
264     desc_handler: &mut F,
265 ) -> anyhow::Result<()>
266 where
267     F: FnMut(Vec<(GuestAddress, u64)>),
268 {
269     if let Some(tube) = release_memory_tube {
270         let unpin_ranges = inflate_ranges
271             .iter()
272             .map(|v| {
273                 (
274                     v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
275                     v.1 / VIRTIO_BALLOON_PF_SIZE,
276                 )
277             })
278             .collect();
279         let req = UnpinRequest {
280             ranges: unpin_ranges,
281         };
282         if let Err(e) = tube.send(&req) {
283             error!("failed to send unpin request: {}", e);
284         } else {
285             match tube.recv() {
286                 Ok(resp) => match resp {
287                     UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
288                     UnpinResponse::Failed => error!("failed to handle unpin request"),
289                 },
290                 Err(e) => error!("failed to handle get unpin response: {}", e),
291             }
292         }
293     } else {
294         invoke_desc_handler(inflate_ranges, desc_handler);
295     }
296 
297     Ok(())
298 }
299 
300 // Processes one message's list of addresses.
handle_address_chain<F>( release_memory_tube: Option<&Tube>, avail_desc: &mut DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(Vec<(GuestAddress, u64)>),301 fn handle_address_chain<F>(
302     release_memory_tube: Option<&Tube>,
303     avail_desc: &mut DescriptorChain,
304     desc_handler: &mut F,
305 ) -> anyhow::Result<()>
306 where
307     F: FnMut(Vec<(GuestAddress, u64)>),
308 {
309     // In a long-running system, there is no reason to expect that
310     // a significant number of freed pages are consecutive. However,
311     // batching is relatively simple and can result in significant
312     // gains in a newly booted system, so it's worth attempting.
313     let mut range_start = 0;
314     let mut range_size = 0;
315     let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
316     for res in avail_desc.reader.iter::<Le32>() {
317         let pfn = match res {
318             Ok(pfn) => pfn,
319             Err(e) => {
320                 error!("error while reading unused pages: {}", e);
321                 break;
322             }
323         };
324         let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
325         if range_start + range_size == guest_address {
326             range_size += VIRTIO_BALLOON_PF_SIZE;
327         } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
328             range_start = guest_address;
329             range_size += VIRTIO_BALLOON_PF_SIZE;
330         } else {
331             // Discontinuity, so flush the previous range. Note range_size
332             // will be 0 on the first iteration, so skip that.
333             if range_size != 0 {
334                 inflate_ranges.push((range_start, range_size));
335             }
336             range_start = guest_address;
337             range_size = VIRTIO_BALLOON_PF_SIZE;
338         }
339     }
340     if range_size != 0 {
341         inflate_ranges.push((range_start, range_size));
342     }
343 
344     release_ranges(release_memory_tube, inflate_ranges, desc_handler)
345 }
346 
347 // Async task that handles the main balloon inflate and deflate queues.
handle_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(Vec<(GuestAddress, u64)>),348 async fn handle_queue<F>(
349     mut queue: Queue,
350     mut queue_event: EventAsync,
351     release_memory_tube: Option<&Tube>,
352     mut desc_handler: F,
353     mut stop_rx: oneshot::Receiver<()>,
354 ) -> Queue
355 where
356     F: FnMut(Vec<(GuestAddress, u64)>),
357 {
358     loop {
359         let mut avail_desc = match queue
360             .next_async_interruptable(&mut queue_event, &mut stop_rx)
361             .await
362         {
363             Ok(Some(res)) => res,
364             Ok(None) => return queue,
365             Err(e) => {
366                 error!("Failed to read descriptor {}", e);
367                 return queue;
368             }
369         };
370         if let Err(e) =
371             handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
372         {
373             error!("balloon: failed to process inflate addresses: {}", e);
374         }
375         queue.add_used(avail_desc, 0);
376         queue.trigger_interrupt();
377     }
378 }
379 
380 // Processes one page-reporting descriptor.
handle_reported_buffer<F>( release_memory_tube: Option<&Tube>, avail_desc: &DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(Vec<(GuestAddress, u64)>),381 fn handle_reported_buffer<F>(
382     release_memory_tube: Option<&Tube>,
383     avail_desc: &DescriptorChain,
384     desc_handler: &mut F,
385 ) -> anyhow::Result<()>
386 where
387     F: FnMut(Vec<(GuestAddress, u64)>),
388 {
389     let reported_ranges: Vec<(u64, u64)> = avail_desc
390         .reader
391         .get_remaining_regions()
392         .chain(avail_desc.writer.get_remaining_regions())
393         .map(|r| (r.offset, r.len as u64))
394         .collect();
395 
396     release_ranges(release_memory_tube, reported_ranges, desc_handler)
397 }
398 
399 // Async task that handles the page reporting queue.
handle_reporting_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(Vec<(GuestAddress, u64)>),400 async fn handle_reporting_queue<F>(
401     mut queue: Queue,
402     mut queue_event: EventAsync,
403     release_memory_tube: Option<&Tube>,
404     mut desc_handler: F,
405     mut stop_rx: oneshot::Receiver<()>,
406 ) -> Queue
407 where
408     F: FnMut(Vec<(GuestAddress, u64)>),
409 {
410     loop {
411         let avail_desc = match queue
412             .next_async_interruptable(&mut queue_event, &mut stop_rx)
413             .await
414         {
415             Ok(Some(res)) => res,
416             Ok(None) => return queue,
417             Err(e) => {
418                 error!("Failed to read descriptor {}", e);
419                 return queue;
420             }
421         };
422         if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
423         {
424             error!("balloon: failed to process reported buffer: {}", e);
425         }
426         queue.add_used(avail_desc, 0);
427         queue.trigger_interrupt();
428     }
429 }
430 
parse_balloon_stats(reader: &mut Reader) -> BalloonStats431 fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
432     let mut stats: BalloonStats = Default::default();
433     for res in reader.iter::<BalloonStat>() {
434         match res {
435             Ok(stat) => stat.update_stats(&mut stats),
436             Err(e) => {
437                 error!("error while reading stats: {}", e);
438                 break;
439             }
440         };
441     }
442     stats
443 }
444 
445 // Async task that handles the stats queue. Note that the cadence of this is driven by requests for
446 // balloon stats from the control pipe.
447 // The guests queues an initial buffer on boot, which is read and then this future will block until
448 // signaled from the command socket that stats should be collected again.
handle_stats_queue( mut queue: Queue, mut queue_event: EventAsync, mut stats_rx: mpsc::Receiver<()>, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Queue449 async fn handle_stats_queue(
450     mut queue: Queue,
451     mut queue_event: EventAsync,
452     mut stats_rx: mpsc::Receiver<()>,
453     command_tube: &AsyncTube,
454     #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
455     state: Arc<AsyncRwLock<BalloonState>>,
456     mut stop_rx: oneshot::Receiver<()>,
457 ) -> Queue {
458     let mut avail_desc = match queue
459         .next_async_interruptable(&mut queue_event, &mut stop_rx)
460         .await
461     {
462         // Consume the first stats buffer sent from the guest at startup. It was not
463         // requested by anyone, and the stats are stale.
464         Ok(Some(res)) => res,
465         Ok(None) => return queue,
466         Err(e) => {
467             error!("Failed to read descriptor {}", e);
468             return queue;
469         }
470     };
471 
472     loop {
473         select_biased! {
474             msg = stats_rx.next() => {
475                 // Wait for a request to read the stats.
476                 match msg {
477                     Some(()) => (),
478                     None => {
479                         error!("stats signal channel was closed");
480                         return queue;
481                     }
482                 }
483             }
484             _ = stop_rx => return queue,
485         };
486 
487         // Request a new stats_desc to the guest.
488         queue.add_used(avail_desc, 0);
489         queue.trigger_interrupt();
490 
491         avail_desc = match queue.next_async(&mut queue_event).await {
492             Err(e) => {
493                 error!("Failed to read descriptor {}", e);
494                 return queue;
495             }
496             Ok(d) => d,
497         };
498         let stats = parse_balloon_stats(&mut avail_desc.reader);
499 
500         let actual_pages = state.lock().await.actual_pages as u64;
501         let result = BalloonTubeResult::Stats {
502             balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
503             stats,
504         };
505         let send_result = command_tube.send(result).await;
506         if let Err(e) = send_result {
507             error!("failed to send stats result: {}", e);
508         }
509 
510         #[cfg(feature = "registered_events")]
511         if let Some(registered_evt_q) = registered_evt_q {
512             if let Err(e) = registered_evt_q
513                 .send(&RegisteredEventWithData::VirtioBalloonResize)
514                 .await
515             {
516                 error!("failed to send VirtioBalloonResize event: {}", e);
517             }
518         }
519     }
520 }
521 
send_adjusted_response( tube: &AsyncTube, num_pages: u32, ) -> std::result::Result<(), base::TubeError>522 async fn send_adjusted_response(
523     tube: &AsyncTube,
524     num_pages: u32,
525 ) -> std::result::Result<(), base::TubeError> {
526     let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
527     let result = BalloonTubeResult::Adjusted { num_bytes };
528     tube.send(result).await
529 }
530 
531 enum WSOp {
532     WSReport,
533     WSConfig {
534         bins: Vec<u32>,
535         refresh_threshold: u32,
536         report_threshold: u32,
537     },
538 }
539 
handle_ws_op_queue( mut queue: Queue, mut queue_event: EventAsync, mut ws_op_rx: mpsc::Receiver<WSOp>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>540 async fn handle_ws_op_queue(
541     mut queue: Queue,
542     mut queue_event: EventAsync,
543     mut ws_op_rx: mpsc::Receiver<WSOp>,
544     state: Arc<AsyncRwLock<BalloonState>>,
545     mut stop_rx: oneshot::Receiver<()>,
546 ) -> Result<Queue> {
547     loop {
548         let op = select_biased! {
549             next_op = ws_op_rx.next().fuse() => {
550                 match next_op {
551                     Some(op) => op,
552                     None => {
553                         error!("ws op tube was closed");
554                         break;
555                     }
556                 }
557             }
558             _ = stop_rx => {
559                 break;
560             }
561         };
562         let mut avail_desc = queue
563             .next_async(&mut queue_event)
564             .await
565             .map_err(BalloonError::AsyncAwait)?;
566         let writer = &mut avail_desc.writer;
567 
568         match op {
569             WSOp::WSReport => {
570                 {
571                     let mut state = state.lock().await;
572                     state.expecting_ws = true;
573                 }
574 
575                 let ws_r = virtio_balloon_op {
576                     type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
577                 };
578 
579                 writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
580             }
581             WSOp::WSConfig {
582                 bins,
583                 refresh_threshold,
584                 report_threshold,
585             } => {
586                 let cmd = virtio_balloon_op {
587                     type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
588                 };
589 
590                 writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
591                 writer
592                     .write_all(bins.as_bytes())
593                     .map_err(BalloonError::WriteQueue)?;
594                 writer
595                     .write_obj(refresh_threshold)
596                     .map_err(BalloonError::WriteQueue)?;
597                 writer
598                     .write_obj(report_threshold)
599                     .map_err(BalloonError::WriteQueue)?;
600             }
601         }
602 
603         let len = writer.bytes_written() as u32;
604         queue.add_used(avail_desc, len);
605         queue.trigger_interrupt();
606     }
607 
608     Ok(queue)
609 }
610 
parse_balloon_ws(reader: &mut Reader) -> BalloonWS611 fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
612     let mut ws = BalloonWS::new();
613     for res in reader.iter::<virtio_balloon_ws>() {
614         match res {
615             Ok(ws_msg) => {
616                 ws_msg.update_ws(&mut ws);
617             }
618             Err(e) => {
619                 error!("error while reading ws: {}", e);
620                 break;
621             }
622         }
623     }
624     if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
625     {
626         error!("unexpected number of WS buckets: {}", ws.ws.len());
627     }
628     ws
629 }
630 
631 // Async task that handles the stats queue. Note that the arrival of events on
632 // the WS vq may be the result of either a WS request (WS-R) command having
633 // been sent to the guest, or an unprompted send due to memory pressue in the
634 // guest. If the data was requested, we should also send that back on the
635 // command tube.
handle_ws_data_queue( mut queue: Queue, mut queue_event: EventAsync, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>636 async fn handle_ws_data_queue(
637     mut queue: Queue,
638     mut queue_event: EventAsync,
639     command_tube: &AsyncTube,
640     #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
641     state: Arc<AsyncRwLock<BalloonState>>,
642     mut stop_rx: oneshot::Receiver<()>,
643 ) -> Result<Queue> {
644     loop {
645         let mut avail_desc = match queue
646             .next_async_interruptable(&mut queue_event, &mut stop_rx)
647             .await
648             .map_err(BalloonError::AsyncAwait)?
649         {
650             Some(res) => res,
651             None => return Ok(queue),
652         };
653 
654         let ws = parse_balloon_ws(&mut avail_desc.reader);
655 
656         let mut state = state.lock().await;
657 
658         // update ws report with balloon pages now that we have a lock on state
659         let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
660 
661         if state.expecting_ws {
662             let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
663             let send_result = command_tube.send(result).await;
664             if let Err(e) = send_result {
665                 error!("failed to send ws result: {}", e);
666             }
667 
668             state.expecting_ws = false;
669         } else {
670             #[cfg(feature = "registered_events")]
671             if let Some(registered_evt_q) = registered_evt_q {
672                 if let Err(e) = registered_evt_q
673                     .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
674                     .await
675                 {
676                     error!("failed to send VirtioBalloonWSReport event: {}", e);
677                 }
678             }
679         }
680 
681         queue.add_used(avail_desc, 0);
682         queue.trigger_interrupt();
683     }
684 }
685 
686 // Async task that handles the command socket. The command socket handles messages from the host
687 // requesting that the guest balloon be adjusted or to report guest memory statistics.
handle_command_tube( command_tube: &AsyncTube, interrupt: Interrupt, state: Arc<AsyncRwLock<BalloonState>>, mut stats_tx: mpsc::Sender<()>, mut ws_op_tx: mpsc::Sender<WSOp>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<()>688 async fn handle_command_tube(
689     command_tube: &AsyncTube,
690     interrupt: Interrupt,
691     state: Arc<AsyncRwLock<BalloonState>>,
692     mut stats_tx: mpsc::Sender<()>,
693     mut ws_op_tx: mpsc::Sender<WSOp>,
694     mut stop_rx: oneshot::Receiver<()>,
695 ) -> Result<()> {
696     loop {
697         let cmd_res = select_biased! {
698             res = command_tube.next().fuse() => res,
699             _ = stop_rx => return Ok(())
700         };
701         match cmd_res {
702             Ok(command) => match command {
703                 BalloonTubeCommand::Adjust {
704                     num_bytes,
705                     allow_failure,
706                 } => {
707                     let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
708                     let mut state = state.lock().await;
709 
710                     state.num_pages = num_pages;
711                     interrupt.signal_config_changed();
712 
713                     if allow_failure {
714                         if num_pages == state.actual_pages {
715                             send_adjusted_response(command_tube, num_pages)
716                                 .await
717                                 .map_err(BalloonError::SendResponse)?;
718                         } else {
719                             state.failable_update = true;
720                         }
721                     }
722                 }
723                 BalloonTubeCommand::WorkingSetConfig {
724                     bins,
725                     refresh_threshold,
726                     report_threshold,
727                 } => {
728                     if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
729                         bins,
730                         refresh_threshold,
731                         report_threshold,
732                     }) {
733                         error!("failed to send config to ws handler: {}", e);
734                     }
735                 }
736                 BalloonTubeCommand::Stats => {
737                     if let Err(e) = stats_tx.try_send(()) {
738                         error!("failed to signal the stat handler: {}", e);
739                     }
740                 }
741                 BalloonTubeCommand::WorkingSet => {
742                     if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
743                         error!("failed to send report request to ws handler: {}", e);
744                     }
745                 }
746             },
747             #[cfg(windows)]
748             Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
749                 // On Windows, async IO tasks like the next/recv above are cancelled as the VM is
750                 // shutting down. For the sake of consistency with unix, we can't *just* return
751                 // here; instead, we wait for the stop request to arrive, *and then* return.
752                 //
753                 // The real fix is to get rid of the global unblock pool, since then we won't
754                 // cancel the tasks early (b/196911556).
755                 let _ = stop_rx.await;
756                 return Ok(());
757             }
758             Err(e) => {
759                 return Err(BalloonError::ReceivingCommand(e));
760             }
761         }
762     }
763 }
764 
handle_pending_adjusted_responses( pending_adjusted_response_event: EventAsync, command_tube: &AsyncTube, state: Arc<AsyncRwLock<BalloonState>>, ) -> Result<()>765 async fn handle_pending_adjusted_responses(
766     pending_adjusted_response_event: EventAsync,
767     command_tube: &AsyncTube,
768     state: Arc<AsyncRwLock<BalloonState>>,
769 ) -> Result<()> {
770     loop {
771         pending_adjusted_response_event
772             .next_val()
773             .await
774             .map_err(BalloonError::AsyncAwait)?;
775         while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
776             send_adjusted_response(command_tube, num_pages)
777                 .await
778                 .map_err(BalloonError::SendResponse)?;
779         }
780     }
781 }
782 
783 /// Represents queues & events for the balloon device.
784 struct BalloonQueues {
785     inflate: Queue,
786     deflate: Queue,
787     stats: Option<Queue>,
788     reporting: Option<Queue>,
789     ws_data: Option<Queue>,
790     ws_op: Option<Queue>,
791 }
792 
793 impl BalloonQueues {
new(inflate: Queue, deflate: Queue) -> Self794     fn new(inflate: Queue, deflate: Queue) -> Self {
795         BalloonQueues {
796             inflate,
797             deflate,
798             stats: None,
799             reporting: None,
800             ws_data: None,
801             ws_op: None,
802         }
803     }
804 }
805 
806 /// When the worker is stopped, the queues are preserved here.
807 struct PausedQueues {
808     inflate: Queue,
809     deflate: Queue,
810     stats: Option<Queue>,
811     reporting: Option<Queue>,
812     ws_data: Option<Queue>,
813     ws_op: Option<Queue>,
814 }
815 
816 impl PausedQueues {
new(inflate: Queue, deflate: Queue) -> Self817     fn new(inflate: Queue, deflate: Queue) -> Self {
818         PausedQueues {
819             inflate,
820             deflate,
821             stats: None,
822             reporting: None,
823             ws_data: None,
824             ws_op: None,
825         }
826     }
827 }
828 
apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F) where F: FnMut(Queue) -> R,829 fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
830 where
831     F: FnMut(Queue) -> R,
832 {
833     if let Some(queue) = queue_opt {
834         func(queue);
835     }
836 }
837 
838 impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue>839     fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
840         let mut ret = Vec::new();
841         ret.push(queues.inflate);
842         ret.push(queues.deflate);
843         apply_if_some(queues.stats, |stats| ret.push(stats));
844         apply_if_some(queues.reporting, |reporting| ret.push(reporting));
845         apply_if_some(queues.ws_data, |ws_data| ret.push(ws_data));
846         apply_if_some(queues.ws_op, |ws_op| ret.push(ws_op));
847         // WARNING: We don't use the indices from the virito spec on purpose, see comment in
848         // get_queues_from_map for the rationale.
849         ret.into_iter().enumerate().collect()
850     }
851 }
852 
free_memory( vm_memory_client: &VmMemoryClient, mem: &GuestMemory, ranges: Vec<(GuestAddress, u64)>, )853 fn free_memory(
854     vm_memory_client: &VmMemoryClient,
855     mem: &GuestMemory,
856     ranges: Vec<(GuestAddress, u64)>,
857 ) {
858     // When `--lock-guest-memory` is used, it is not possible to free the memory from the main
859     // process, so we free it from the sandboxed balloon process directly.
860     #[cfg(any(target_os = "android", target_os = "linux"))]
861     if mem.locked() {
862         for (guest_address, len) in ranges {
863             if let Err(e) = mem.remove_range(guest_address, len) {
864                 warn!("Marking pages unused failed: {}, addr={}", e, guest_address);
865             }
866         }
867         return;
868     }
869     if let Err(e) = vm_memory_client.dynamically_free_memory_ranges(ranges) {
870         warn!("Failed to dynamically free memory ranges: {e:#}");
871     }
872 }
873 
reclaim_memory(vm_memory_client: &VmMemoryClient, ranges: Vec<(GuestAddress, u64)>)874 fn reclaim_memory(vm_memory_client: &VmMemoryClient, ranges: Vec<(GuestAddress, u64)>) {
875     if let Err(e) = vm_memory_client.dynamically_reclaim_memory_ranges(ranges) {
876         warn!("Failed to dynamically reclaim memory range: {e:#}");
877     }
878 }
879 
880 /// Stores data from the worker when it stops so that data can be re-used when
881 /// the worker is restarted.
882 struct WorkerReturn {
883     release_memory_tube: Option<Tube>,
884     command_tube: Tube,
885     #[cfg(feature = "registered_events")]
886     registered_evt_q: Option<SendTube>,
887     paused_queues: Option<PausedQueues>,
888     vm_memory_client: VmMemoryClient,
889 }
890 
891 // The main worker thread. Initialized the asynchronous worker tasks and passes them to the executor
892 // to be processed.
run_worker( inflate_queue: Queue, deflate_queue: Queue, stats_queue: Option<Queue>, reporting_queue: Option<Queue>, ws_data_queue: Option<Queue>, ws_op_queue: Option<Queue>, command_tube: Tube, vm_memory_client: VmMemoryClient, mem: GuestMemory, release_memory_tube: Option<Tube>, interrupt: Interrupt, kill_evt: Event, target_reached_evt: Event, pending_adjusted_response_event: Event, state: Arc<AsyncRwLock<BalloonState>>, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ) -> WorkerReturn893 fn run_worker(
894     inflate_queue: Queue,
895     deflate_queue: Queue,
896     stats_queue: Option<Queue>,
897     reporting_queue: Option<Queue>,
898     ws_data_queue: Option<Queue>,
899     ws_op_queue: Option<Queue>,
900     command_tube: Tube,
901     vm_memory_client: VmMemoryClient,
902     mem: GuestMemory,
903     release_memory_tube: Option<Tube>,
904     interrupt: Interrupt,
905     kill_evt: Event,
906     target_reached_evt: Event,
907     pending_adjusted_response_event: Event,
908     state: Arc<AsyncRwLock<BalloonState>>,
909     #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
910 ) -> WorkerReturn {
911     let ex = Executor::new().unwrap();
912     let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
913     #[cfg(feature = "registered_events")]
914     let registered_evt_q_async = registered_evt_q
915         .as_ref()
916         .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
917 
918     let mut stop_queue_oneshots = Vec::new();
919 
920     // We need a block to release all references to command_tube at the end before returning it.
921     let paused_queues = {
922         // The first queue is used for inflate messages
923         let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
924         let inflate_queue_evt = inflate_queue
925             .event()
926             .try_clone()
927             .expect("failed to clone queue event");
928         let inflate = handle_queue(
929             inflate_queue,
930             EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
931             release_memory_tube.as_ref(),
932             |ranges| free_memory(&vm_memory_client, &mem, ranges),
933             stop_rx,
934         );
935         let inflate = inflate.fuse();
936         pin_mut!(inflate);
937 
938         // The second queue is used for deflate messages
939         let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
940         let deflate_queue_evt = deflate_queue
941             .event()
942             .try_clone()
943             .expect("failed to clone queue event");
944         let deflate = handle_queue(
945             deflate_queue,
946             EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
947             None,
948             |ranges| reclaim_memory(&vm_memory_client, ranges),
949             stop_rx,
950         );
951         let deflate = deflate.fuse();
952         pin_mut!(deflate);
953 
954         // The next queue is used for stats messages if VIRTIO_BALLOON_F_STATS_VQ is negotiated.
955         let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
956         let has_stats_queue = stats_queue.is_some();
957         let stats = if let Some(stats_queue) = stats_queue {
958             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
959             let stats_queue_evt = stats_queue
960                 .event()
961                 .try_clone()
962                 .expect("failed to clone queue event");
963             handle_stats_queue(
964                 stats_queue,
965                 EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
966                 stats_rx,
967                 &command_tube,
968                 #[cfg(feature = "registered_events")]
969                 registered_evt_q_async.as_ref(),
970                 state.clone(),
971                 stop_rx,
972             )
973             .left_future()
974         } else {
975             std::future::pending().right_future()
976         };
977         let stats = stats.fuse();
978         pin_mut!(stats);
979 
980         // The next queue is used for reporting messages
981         let has_reporting_queue = reporting_queue.is_some();
982         let reporting = if let Some(reporting_queue) = reporting_queue {
983             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
984             let reporting_queue_evt = reporting_queue
985                 .event()
986                 .try_clone()
987                 .expect("failed to clone queue event");
988             handle_reporting_queue(
989                 reporting_queue,
990                 EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
991                 release_memory_tube.as_ref(),
992                 |ranges| free_memory(&vm_memory_client, &mem, ranges),
993                 stop_rx,
994             )
995             .left_future()
996         } else {
997             std::future::pending().right_future()
998         };
999         let reporting = reporting.fuse();
1000         pin_mut!(reporting);
1001 
1002         // If VIRTIO_BALLOON_F_WS_REPORTING is set 2 queues must handled - one for WS data and one
1003         // for WS notifications.
1004         let has_ws_data_queue = ws_data_queue.is_some();
1005         let ws_data = if let Some(ws_data_queue) = ws_data_queue {
1006             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1007             let ws_data_queue_evt = ws_data_queue
1008                 .event()
1009                 .try_clone()
1010                 .expect("failed to clone queue event");
1011             handle_ws_data_queue(
1012                 ws_data_queue,
1013                 EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1014                 &command_tube,
1015                 #[cfg(feature = "registered_events")]
1016                 registered_evt_q_async.as_ref(),
1017                 state.clone(),
1018                 stop_rx,
1019             )
1020             .left_future()
1021         } else {
1022             std::future::pending().right_future()
1023         };
1024         let ws_data = ws_data.fuse();
1025         pin_mut!(ws_data);
1026 
1027         let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1028         let has_ws_op_queue = ws_op_queue.is_some();
1029         let ws_op = if let Some(ws_op_queue) = ws_op_queue {
1030             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1031             let ws_op_queue_evt = ws_op_queue
1032                 .event()
1033                 .try_clone()
1034                 .expect("failed to clone queue event");
1035             handle_ws_op_queue(
1036                 ws_op_queue,
1037                 EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1038                 ws_op_rx,
1039                 state.clone(),
1040                 stop_rx,
1041             )
1042             .left_future()
1043         } else {
1044             std::future::pending().right_future()
1045         };
1046         let ws_op = ws_op.fuse();
1047         pin_mut!(ws_op);
1048 
1049         // Future to handle command messages that resize the balloon.
1050         let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1051         let command = handle_command_tube(
1052             &command_tube,
1053             interrupt.clone(),
1054             state.clone(),
1055             stats_tx,
1056             ws_op_tx,
1057             stop_rx,
1058         );
1059         pin_mut!(command);
1060 
1061         // Send a message if balloon target reached event is triggered.
1062         let target_reached = handle_target_reached(&ex, target_reached_evt, &vm_memory_client);
1063         pin_mut!(target_reached);
1064 
1065         // Exit if the kill event is triggered.
1066         let kill = async_utils::await_and_exit(&ex, kill_evt);
1067         pin_mut!(kill);
1068 
1069         let pending_adjusted = handle_pending_adjusted_responses(
1070             EventAsync::new(pending_adjusted_response_event, &ex)
1071                 .expect("failed to create async event"),
1072             &command_tube,
1073             state,
1074         );
1075         pin_mut!(pending_adjusted);
1076 
1077         let res = ex.run_until(async {
1078             select! {
1079                 _ = kill.fuse() => (),
1080                 _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1081                 _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1082                 _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1083                 _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1084                 _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1085                 _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1086                 _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1087                 _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1088                 _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1089             }
1090 
1091             // Worker is shutting down. To recover the queues, we have to signal
1092             // all the queue futures to exit.
1093             for stop_tx in stop_queue_oneshots {
1094                 if stop_tx.send(()).is_err() {
1095                     return Err(anyhow!("failed to request stop for queue future"));
1096                 }
1097             }
1098 
1099             // Collect all the queues (awaiting any queue future should now
1100             // return its Queue immediately).
1101             let mut paused_queues = PausedQueues::new(
1102                 inflate.await,
1103                 deflate.await,
1104             );
1105             if has_reporting_queue {
1106                 paused_queues.reporting = Some(reporting.await);
1107             }
1108             if has_stats_queue {
1109                 paused_queues.stats = Some(stats.await);
1110             }
1111             if has_ws_data_queue {
1112                 paused_queues.ws_data = Some(ws_data.await.context("failed to stop ws_data queue")?);
1113             }
1114             if has_ws_op_queue {
1115                 paused_queues.ws_op = Some(ws_op.await.context("failed to stop ws_op queue")?);
1116             }
1117             Ok(paused_queues)
1118         });
1119 
1120         match res {
1121             Err(e) => {
1122                 error!("error happened in executor: {}", e);
1123                 None
1124             }
1125             Ok(main_future_res) => match main_future_res {
1126                 Ok(paused_queues) => Some(paused_queues),
1127                 Err(e) => {
1128                     error!("error happened in main balloon future: {}", e);
1129                     None
1130                 }
1131             },
1132         }
1133     };
1134 
1135     WorkerReturn {
1136         command_tube: command_tube.into(),
1137         paused_queues,
1138         release_memory_tube,
1139         #[cfg(feature = "registered_events")]
1140         registered_evt_q,
1141         vm_memory_client,
1142     }
1143 }
1144 
handle_target_reached( ex: &Executor, target_reached_evt: Event, vm_memory_client: &VmMemoryClient, ) -> anyhow::Result<()>1145 async fn handle_target_reached(
1146     ex: &Executor,
1147     target_reached_evt: Event,
1148     vm_memory_client: &VmMemoryClient,
1149 ) -> anyhow::Result<()> {
1150     let event_async =
1151         EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1152     loop {
1153         // Wait for target reached trigger.
1154         let _ = event_async.next_val().await;
1155         // Send the message to vm_control on the event. We don't have to read the current
1156         // size yet.
1157         if let Err(e) = vm_memory_client.balloon_target_reached(0) {
1158             warn!("Failed to send or receive allocation complete request: {e:#}");
1159         }
1160     }
1161     // The above loop will never terminate and there is no reason to terminate it either. However,
1162     // the function is used in an executor that expects a Result<> return. Make sure that clippy
1163     // doesn't enforce the unreachable_code condition.
1164     #[allow(unreachable_code)]
1165     Ok(())
1166 }
1167 
1168 /// Virtio device for memory balloon inflation/deflation.
1169 pub struct Balloon {
1170     command_tube: Option<Tube>,
1171     vm_memory_client: Option<VmMemoryClient>,
1172     release_memory_tube: Option<Tube>,
1173     pending_adjusted_response_event: Event,
1174     state: Arc<AsyncRwLock<BalloonState>>,
1175     features: u64,
1176     acked_features: u64,
1177     worker_thread: Option<WorkerThread<WorkerReturn>>,
1178     #[cfg(feature = "registered_events")]
1179     registered_evt_q: Option<SendTube>,
1180     ws_num_bins: u8,
1181     target_reached_evt: Option<Event>,
1182 }
1183 
1184 /// Snapshot of the [Balloon] state.
1185 #[derive(Serialize, Deserialize)]
1186 struct BalloonSnapshot {
1187     state: BalloonState,
1188     features: u64,
1189     acked_features: u64,
1190     ws_num_bins: u8,
1191 }
1192 
1193 impl Balloon {
1194     /// Creates a new virtio balloon device.
1195     /// To let Balloon able to successfully release the memory which are pinned
1196     /// by CoIOMMU to host, the release_memory_tube will be used to send the inflate
1197     /// ranges to CoIOMMU with UnpinRequest/UnpinResponse messages, so that The
1198     /// memory in the inflate range can be unpinned first.
new( base_features: u64, command_tube: Tube, vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, init_balloon_size: u64, enabled_features: u64, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ws_num_bins: u8, ) -> Result<Balloon>1199     pub fn new(
1200         base_features: u64,
1201         command_tube: Tube,
1202         vm_memory_client: VmMemoryClient,
1203         release_memory_tube: Option<Tube>,
1204         init_balloon_size: u64,
1205         enabled_features: u64,
1206         #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1207         ws_num_bins: u8,
1208     ) -> Result<Balloon> {
1209         let features = base_features
1210             | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1211             | 1 << VIRTIO_BALLOON_F_STATS_VQ
1212             | 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1213             | enabled_features;
1214 
1215         Ok(Balloon {
1216             command_tube: Some(command_tube),
1217             vm_memory_client: Some(vm_memory_client),
1218             release_memory_tube,
1219             pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1220             state: Arc::new(AsyncRwLock::new(BalloonState {
1221                 num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1222                 actual_pages: 0,
1223                 failable_update: false,
1224                 pending_adjusted_responses: VecDeque::new(),
1225                 expecting_ws: false,
1226             })),
1227             worker_thread: None,
1228             features,
1229             acked_features: 0,
1230             #[cfg(feature = "registered_events")]
1231             registered_evt_q,
1232             ws_num_bins,
1233             target_reached_evt: None,
1234         })
1235     }
1236 
get_config(&self) -> virtio_balloon_config1237     fn get_config(&self) -> virtio_balloon_config {
1238         let state = block_on(self.state.lock());
1239         virtio_balloon_config {
1240             num_pages: state.num_pages.into(),
1241             actual: state.actual_pages.into(),
1242             // crosvm does not (currently) use free_page_hint_cmd_id or
1243             // poison_val, but they must be present in the right order and size
1244             // for the virtio-balloon driver in the guest to deserialize the
1245             // config correctly.
1246             free_page_hint_cmd_id: 0.into(),
1247             poison_val: 0.into(),
1248             ws_num_bins: self.ws_num_bins,
1249             _reserved: [0, 0, 0],
1250         }
1251     }
1252 
stop_worker(&mut self) -> StoppedWorker<PausedQueues>1253     fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1254         if let Some(worker_thread) = self.worker_thread.take() {
1255             let worker_ret = worker_thread.stop();
1256             self.release_memory_tube = worker_ret.release_memory_tube;
1257             self.command_tube = Some(worker_ret.command_tube);
1258             #[cfg(feature = "registered_events")]
1259             {
1260                 self.registered_evt_q = worker_ret.registered_evt_q;
1261             }
1262             self.vm_memory_client = Some(worker_ret.vm_memory_client);
1263 
1264             if let Some(queues) = worker_ret.paused_queues {
1265                 StoppedWorker::WithQueues(Box::new(queues))
1266             } else {
1267                 StoppedWorker::MissingQueues
1268             }
1269         } else {
1270             StoppedWorker::AlreadyStopped
1271         }
1272     }
1273 
1274     /// Given a filtered queue vector from [VirtioDevice::activate], extract
1275     /// the queues (accounting for queues that are missing because the features
1276     /// are not negotiated) into a structure that is easier to work with.
get_queues_from_map( &self, mut queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<BalloonQueues>1277     fn get_queues_from_map(
1278         &self,
1279         mut queues: BTreeMap<usize, Queue>,
1280     ) -> anyhow::Result<BalloonQueues> {
1281         fn pop_queue(
1282             queues: &mut BTreeMap<usize, Queue>,
1283             expected_index: usize,
1284             name: &str,
1285         ) -> anyhow::Result<Queue> {
1286             let (queue_index, queue) = queues
1287                 .pop_first()
1288                 .with_context(|| format!("missing {}", name))?;
1289 
1290             if queue_index == expected_index {
1291                 debug!("{name} index {queue_index}");
1292             } else {
1293                 warn!("expected {name} index {expected_index}, got {queue_index}");
1294             }
1295 
1296             Ok(queue)
1297         }
1298 
1299         // WARNING: We use `pop_first` instead of explicitly using the indices from the virtio spec
1300         // because the Linux virtio drivers only "allocates" queue indices that are used, so queues
1301         // need to be removed in order of ascending virtqueue index.
1302         let inflate_queue = pop_queue(&mut queues, INFLATEQ, "inflateq")?;
1303         let deflate_queue = pop_queue(&mut queues, DEFLATEQ, "deflateq")?;
1304         let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1305 
1306         if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1307             queue_struct.stats = Some(pop_queue(&mut queues, STATSQ, "statsq")?);
1308         }
1309         if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1310             queue_struct.reporting = Some(pop_queue(&mut queues, REPORTING_VQ, "reporting_vq")?);
1311         }
1312         if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1313             queue_struct.ws_data = Some(pop_queue(&mut queues, WS_DATA_VQ, "ws_data_vq")?);
1314             queue_struct.ws_op = Some(pop_queue(&mut queues, WS_OP_VQ, "ws_op_vq")?);
1315         }
1316 
1317         if !queues.is_empty() {
1318             return Err(anyhow!("unexpected queues {:?}", queues.into_keys()));
1319         }
1320 
1321         Ok(queue_struct)
1322     }
1323 
start_worker( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BalloonQueues, ) -> anyhow::Result<()>1324     fn start_worker(
1325         &mut self,
1326         mem: GuestMemory,
1327         interrupt: Interrupt,
1328         queues: BalloonQueues,
1329     ) -> anyhow::Result<()> {
1330         let (self_target_reached_evt, target_reached_evt) = Event::new()
1331             .and_then(|e| Ok((e.try_clone()?, e)))
1332             .context("failed to create target_reached Event pair: {}")?;
1333         self.target_reached_evt = Some(self_target_reached_evt);
1334 
1335         let state = self.state.clone();
1336 
1337         let command_tube = self.command_tube.take().unwrap();
1338 
1339         let vm_memory_client = self.vm_memory_client.take().unwrap();
1340         let release_memory_tube = self.release_memory_tube.take();
1341         #[cfg(feature = "registered_events")]
1342         let registered_evt_q = self.registered_evt_q.take();
1343         let pending_adjusted_response_event = self
1344             .pending_adjusted_response_event
1345             .try_clone()
1346             .context("failed to clone Event")?;
1347 
1348         self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1349             run_worker(
1350                 queues.inflate,
1351                 queues.deflate,
1352                 queues.stats,
1353                 queues.reporting,
1354                 queues.ws_data,
1355                 queues.ws_op,
1356                 command_tube,
1357                 vm_memory_client,
1358                 mem,
1359                 release_memory_tube,
1360                 interrupt,
1361                 kill_evt,
1362                 target_reached_evt,
1363                 pending_adjusted_response_event,
1364                 state,
1365                 #[cfg(feature = "registered_events")]
1366                 registered_evt_q,
1367             )
1368         }));
1369 
1370         Ok(())
1371     }
1372 }
1373 
1374 impl VirtioDevice for Balloon {
keep_rds(&self) -> Vec<RawDescriptor>1375     fn keep_rds(&self) -> Vec<RawDescriptor> {
1376         let mut rds = Vec::new();
1377         if let Some(command_tube) = &self.command_tube {
1378             rds.push(command_tube.as_raw_descriptor());
1379         }
1380         if let Some(vm_memory_client) = &self.vm_memory_client {
1381             rds.push(vm_memory_client.as_raw_descriptor());
1382         }
1383         if let Some(release_memory_tube) = &self.release_memory_tube {
1384             rds.push(release_memory_tube.as_raw_descriptor());
1385         }
1386         #[cfg(feature = "registered_events")]
1387         if let Some(registered_evt_q) = &self.registered_evt_q {
1388             rds.push(registered_evt_q.as_raw_descriptor());
1389         }
1390         rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1391         rds
1392     }
1393 
device_type(&self) -> DeviceType1394     fn device_type(&self) -> DeviceType {
1395         DeviceType::Balloon
1396     }
1397 
queue_max_sizes(&self) -> &[u16]1398     fn queue_max_sizes(&self) -> &[u16] {
1399         QUEUE_SIZES
1400     }
1401 
read_config(&self, offset: u64, data: &mut [u8])1402     fn read_config(&self, offset: u64, data: &mut [u8]) {
1403         copy_config(data, 0, self.get_config().as_bytes(), offset);
1404     }
1405 
write_config(&mut self, offset: u64, data: &[u8])1406     fn write_config(&mut self, offset: u64, data: &[u8]) {
1407         let mut config = self.get_config();
1408         copy_config(config.as_mut_bytes(), offset, data, 0);
1409         let mut state = block_on(self.state.lock());
1410         state.actual_pages = config.actual.to_native();
1411 
1412         // If balloon has updated to the requested memory, let the hypervisor know.
1413         if config.num_pages == config.actual {
1414             debug!(
1415                 "sending target reached event at {}",
1416                 u32::from(config.num_pages)
1417             );
1418             self.target_reached_evt.as_ref().map(|e| e.signal());
1419         }
1420         if state.failable_update && state.actual_pages == state.num_pages {
1421             state.failable_update = false;
1422             let num_pages = state.num_pages;
1423             state.pending_adjusted_responses.push_back(num_pages);
1424             let _ = self.pending_adjusted_response_event.signal();
1425         }
1426     }
1427 
features(&self) -> u641428     fn features(&self) -> u64 {
1429         self.features
1430     }
1431 
ack_features(&mut self, mut value: u64)1432     fn ack_features(&mut self, mut value: u64) {
1433         if value & !self.features != 0 {
1434             warn!("virtio_balloon got unknown feature ack {:x}", value);
1435             value &= self.features;
1436         }
1437         self.acked_features |= value;
1438     }
1439 
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>1440     fn activate(
1441         &mut self,
1442         mem: GuestMemory,
1443         interrupt: Interrupt,
1444         queues: BTreeMap<usize, Queue>,
1445     ) -> anyhow::Result<()> {
1446         let queues = self.get_queues_from_map(queues)?;
1447         self.start_worker(mem, interrupt, queues)
1448     }
1449 
reset(&mut self) -> anyhow::Result<()>1450     fn reset(&mut self) -> anyhow::Result<()> {
1451         let _worker = self.stop_worker();
1452         Ok(())
1453     }
1454 
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>1455     fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1456         match self.stop_worker() {
1457             StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1458             StoppedWorker::MissingQueues => {
1459                 anyhow::bail!("balloon queue workers did not stop cleanly.")
1460             }
1461             StoppedWorker::AlreadyStopped => {
1462                 // Device hasn't been activated.
1463                 Ok(None)
1464             }
1465         }
1466     }
1467 
virtio_wake( &mut self, queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>1468     fn virtio_wake(
1469         &mut self,
1470         queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1471     ) -> anyhow::Result<()> {
1472         if let Some((mem, interrupt, queues)) = queues_state {
1473             if queues.len() < 2 {
1474                 anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1475             }
1476 
1477             let balloon_queues = self.get_queues_from_map(queues)?;
1478             self.start_worker(mem, interrupt, balloon_queues)?;
1479         }
1480         Ok(())
1481     }
1482 
virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot>1483     fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1484         let state = self
1485             .state
1486             .lock()
1487             .now_or_never()
1488             .context("failed to acquire balloon lock")?;
1489         AnySnapshot::to_any(BalloonSnapshot {
1490             features: self.features,
1491             acked_features: self.acked_features,
1492             state: state.clone(),
1493             ws_num_bins: self.ws_num_bins,
1494         })
1495         .context("failed to serialize balloon state")
1496     }
1497 
virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>1498     fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1499         let snap: BalloonSnapshot = AnySnapshot::from_any(data).context("error deserializing")?;
1500         if snap.features != self.features {
1501             anyhow::bail!(
1502                 "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1503                 self.features,
1504                 snap.features,
1505             );
1506         }
1507 
1508         let mut state = self
1509             .state
1510             .lock()
1511             .now_or_never()
1512             .context("failed to acquire balloon lock")?;
1513         *state = snap.state;
1514         self.ws_num_bins = snap.ws_num_bins;
1515         self.acked_features = snap.acked_features;
1516         Ok(())
1517     }
1518 }
1519 
1520 #[cfg(test)]
1521 mod tests {
1522     use super::*;
1523     use crate::suspendable_virtio_tests;
1524     use crate::virtio::descriptor_utils::create_descriptor_chain;
1525     use crate::virtio::descriptor_utils::DescriptorType;
1526 
1527     #[test]
desc_parsing_inflate()1528     fn desc_parsing_inflate() {
1529         // Check that the memory addresses are parsed correctly by 'handle_address_chain' and passed
1530         // to the closure.
1531         let memory_start_addr = GuestAddress(0x0);
1532         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1533         memory
1534             .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1535             .unwrap();
1536         memory
1537             .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1538             .unwrap();
1539 
1540         let mut chain = create_descriptor_chain(
1541             &memory,
1542             GuestAddress(0x0),
1543             GuestAddress(0x100),
1544             vec![(DescriptorType::Readable, 8)],
1545             0,
1546         )
1547         .expect("create_descriptor_chain failed");
1548 
1549         let mut addrs = Vec::new();
1550         let res = handle_address_chain(None, &mut chain, &mut |mut ranges| {
1551             addrs.append(&mut ranges)
1552         });
1553         assert!(res.is_ok());
1554         assert_eq!(addrs.len(), 2);
1555         assert_eq!(
1556             addrs[0].0,
1557             GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1558         );
1559         assert_eq!(
1560             addrs[1].0,
1561             GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1562         );
1563     }
1564 
1565     struct BalloonContext {
1566         _ctrl_tube: Tube,
1567         _mem_client_tube: Tube,
1568     }
1569 
modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon)1570     fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1571         balloon.ws_num_bins = !balloon.ws_num_bins;
1572     }
1573 
create_device() -> (BalloonContext, Balloon)1574     fn create_device() -> (BalloonContext, Balloon) {
1575         let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1576         let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1577         (
1578             BalloonContext {
1579                 _ctrl_tube,
1580                 _mem_client_tube,
1581             },
1582             Balloon::new(
1583                 0,
1584                 ctrl_tube_device,
1585                 VmMemoryClient::new(mem_client_tube_device),
1586                 None,
1587                 1024,
1588                 0,
1589                 #[cfg(feature = "registered_events")]
1590                 None,
1591                 0,
1592             )
1593             .unwrap(),
1594         )
1595     }
1596 
1597     suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1598 }
1599