1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::collections::BTreeMap;
6 use std::collections::VecDeque;
7 use std::io::Write;
8 use std::sync::Arc;
9
10 use anyhow::anyhow;
11 use anyhow::Context;
12 use balloon_control::BalloonStats;
13 use balloon_control::BalloonTubeCommand;
14 use balloon_control::BalloonTubeResult;
15 use balloon_control::BalloonWS;
16 use balloon_control::WSBucket;
17 use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
18 use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
19 use base::debug;
20 use base::error;
21 use base::warn;
22 use base::AsRawDescriptor;
23 use base::Event;
24 use base::RawDescriptor;
25 #[cfg(feature = "registered_events")]
26 use base::SendTube;
27 use base::Tube;
28 use base::WorkerThread;
29 use cros_async::block_on;
30 use cros_async::sync::RwLock as AsyncRwLock;
31 use cros_async::AsyncTube;
32 use cros_async::EventAsync;
33 use cros_async::Executor;
34 #[cfg(feature = "registered_events")]
35 use cros_async::SendTubeAsync;
36 use data_model::Le16;
37 use data_model::Le32;
38 use data_model::Le64;
39 use futures::channel::mpsc;
40 use futures::channel::oneshot;
41 use futures::pin_mut;
42 use futures::select;
43 use futures::select_biased;
44 use futures::FutureExt;
45 use futures::StreamExt;
46 use remain::sorted;
47 use serde::Deserialize;
48 use serde::Serialize;
49 use snapshot::AnySnapshot;
50 use thiserror::Error as ThisError;
51 use vm_control::api::VmMemoryClient;
52 #[cfg(feature = "registered_events")]
53 use vm_control::RegisteredEventWithData;
54 use vm_memory::GuestAddress;
55 use vm_memory::GuestMemory;
56 use zerocopy::FromBytes;
57 use zerocopy::Immutable;
58 use zerocopy::IntoBytes;
59 use zerocopy::KnownLayout;
60
61 use super::async_utils;
62 use super::copy_config;
63 use super::create_stop_oneshot;
64 use super::DescriptorChain;
65 use super::DeviceType;
66 use super::Interrupt;
67 use super::Queue;
68 use super::Reader;
69 use super::StoppedWorker;
70 use super::VirtioDevice;
71 use crate::UnpinRequest;
72 use crate::UnpinResponse;
73
74 #[sorted]
75 #[derive(ThisError, Debug)]
76 pub enum BalloonError {
77 /// Failed an async await
78 #[error("failed async await: {0}")]
79 AsyncAwait(cros_async::AsyncError),
80 /// Failed an async await
81 #[error("failed async await: {0}")]
82 AsyncAwaitAnyhow(anyhow::Error),
83 /// Failed to create event.
84 #[error("failed to create event: {0}")]
85 CreatingEvent(base::Error),
86 /// Failed to create async message receiver.
87 #[error("failed to create async message receiver: {0}")]
88 CreatingMessageReceiver(base::TubeError),
89 /// Failed to receive command message.
90 #[error("failed to receive command message: {0}")]
91 ReceivingCommand(base::TubeError),
92 /// Failed to send command response.
93 #[error("failed to send command response: {0}")]
94 SendResponse(base::TubeError),
95 /// Error while writing to virtqueue
96 #[error("failed to write to virtqueue: {0}")]
97 WriteQueue(std::io::Error),
98 /// Failed to write config event.
99 #[error("failed to write config event: {0}")]
100 WritingConfigEvent(base::Error),
101 }
102 pub type Result<T> = std::result::Result<T, BalloonError>;
103
104 // Balloon implements five virt IO queues: Inflate, Deflate, Stats, WsData, WsCmd.
105 const QUEUE_SIZE: u16 = 128;
106 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE];
107
108 // Virtqueue indexes
109 const INFLATEQ: usize = 0;
110 const DEFLATEQ: usize = 1;
111 const STATSQ: usize = 2;
112 const _FREE_PAGE_VQ: usize = 3;
113 const REPORTING_VQ: usize = 4;
114 const WS_DATA_VQ: usize = 5;
115 const WS_OP_VQ: usize = 6;
116
117 const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
118 const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
119
120 // The feature bitmap for virtio balloon
121 const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; // Tell before reclaiming pages
122 const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; // Stats reporting enabled
123 const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; // Deflate balloon on OOM
124 const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; // Page reporting virtqueue
125 // TODO(b/273973298): this should maybe be bit 6? to be changed later
126 const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; // Working Set Reporting virtqueues
127
128 #[derive(Copy, Clone)]
129 #[repr(u32)]
130 // Balloon virtqueues
131 pub enum BalloonFeatures {
132 // Page Reporting enabled
133 PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
134 // WS Reporting enabled
135 WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
136 }
137
138 // virtio_balloon_config is the balloon device configuration space defined by the virtio spec.
139 #[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
140 #[repr(C)]
141 struct virtio_balloon_config {
142 num_pages: Le32,
143 actual: Le32,
144 free_page_hint_cmd_id: Le32,
145 poison_val: Le32,
146 // WS field is part of proposed spec extension (b/273973298).
147 ws_num_bins: u8,
148 _reserved: [u8; 3],
149 }
150
151 // BalloonState is shared by the worker and device thread.
152 #[derive(Clone, Default, Serialize, Deserialize)]
153 struct BalloonState {
154 num_pages: u32,
155 actual_pages: u32,
156 expecting_ws: bool,
157 // Flag indicating that the balloon is in the process of a failable update. This
158 // is set by an Adjust command that has allow_failure set, and is cleared when the
159 // Adjusted success/failure response is sent.
160 failable_update: bool,
161 pending_adjusted_responses: VecDeque<u32>,
162 }
163
164 // The constants defining stats types in virtio_baloon_stat
165 const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
166 const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
167 const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
168 const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
169 const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
170 const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
171 const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
172 const VIRTIO_BALLOON_S_CACHES: u16 = 7;
173 const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
174 const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
175 const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
176 const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
177
178 // BalloonStat is used to deserialize stats from the stats_queue.
179 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
180 #[repr(C, packed)]
181 struct BalloonStat {
182 tag: Le16,
183 val: Le64,
184 }
185
186 impl BalloonStat {
update_stats(&self, stats: &mut BalloonStats)187 fn update_stats(&self, stats: &mut BalloonStats) {
188 let val = Some(self.val.to_native());
189 match self.tag.to_native() {
190 VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
191 VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
192 VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
193 VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
194 VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
195 VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
196 VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
197 VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
198 VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
199 VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
200 VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
201 VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
202 _ => (),
203 }
204 }
205 }
206
207 // virtio_balloon_ws is used to deserialize from the ws data vq.
208 #[repr(C)]
209 #[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
210 struct virtio_balloon_ws {
211 tag: Le16,
212 node_id: Le16,
213 // virtio prefers field members to align on a word boundary so we must pad. see:
214 // https://crsrc.org/o/src/third_party/kernel/v5.15/include/uapi/linux/virtio_balloon.h;l=105
215 _reserved: [u8; 4],
216 idle_age_ms: Le64,
217 // TODO(b/273973298): these should become separate fields - bytes for ANON and FILE
218 memory_size_bytes: [Le64; 2],
219 }
220
221 impl virtio_balloon_ws {
update_ws(&self, ws: &mut BalloonWS)222 fn update_ws(&self, ws: &mut BalloonWS) {
223 let bucket = WSBucket {
224 age: self.idle_age_ms.to_native(),
225 bytes: [
226 self.memory_size_bytes[0].to_native(),
227 self.memory_size_bytes[1].to_native(),
228 ],
229 };
230 ws.ws.push(bucket);
231 }
232 }
233
234 const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
235 const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
236 const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
237 const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
238
239 // virtio_balloon_op is used to serialize to the ws cmd vq.
240 #[repr(C, packed)]
241 #[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
242 struct virtio_balloon_op {
243 type_: Le16,
244 }
245
invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F) where F: FnMut(Vec<(GuestAddress, u64)>),246 fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
247 where
248 F: FnMut(Vec<(GuestAddress, u64)>),
249 {
250 desc_handler(
251 ranges
252 .into_iter()
253 .map(|range| (GuestAddress(range.0), range.1))
254 .collect(),
255 );
256 }
257
258 // Release a list of guest memory ranges back to the host system.
259 // Unpin requests for each inflate range will be sent via `release_memory_tube`
260 // if provided, and then `desc_handler` will be called for each inflate range.
release_ranges<F>( release_memory_tube: Option<&Tube>, inflate_ranges: Vec<(u64, u64)>, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(Vec<(GuestAddress, u64)>),261 fn release_ranges<F>(
262 release_memory_tube: Option<&Tube>,
263 inflate_ranges: Vec<(u64, u64)>,
264 desc_handler: &mut F,
265 ) -> anyhow::Result<()>
266 where
267 F: FnMut(Vec<(GuestAddress, u64)>),
268 {
269 if let Some(tube) = release_memory_tube {
270 let unpin_ranges = inflate_ranges
271 .iter()
272 .map(|v| {
273 (
274 v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
275 v.1 / VIRTIO_BALLOON_PF_SIZE,
276 )
277 })
278 .collect();
279 let req = UnpinRequest {
280 ranges: unpin_ranges,
281 };
282 if let Err(e) = tube.send(&req) {
283 error!("failed to send unpin request: {}", e);
284 } else {
285 match tube.recv() {
286 Ok(resp) => match resp {
287 UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
288 UnpinResponse::Failed => error!("failed to handle unpin request"),
289 },
290 Err(e) => error!("failed to handle get unpin response: {}", e),
291 }
292 }
293 } else {
294 invoke_desc_handler(inflate_ranges, desc_handler);
295 }
296
297 Ok(())
298 }
299
300 // Processes one message's list of addresses.
handle_address_chain<F>( release_memory_tube: Option<&Tube>, avail_desc: &mut DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(Vec<(GuestAddress, u64)>),301 fn handle_address_chain<F>(
302 release_memory_tube: Option<&Tube>,
303 avail_desc: &mut DescriptorChain,
304 desc_handler: &mut F,
305 ) -> anyhow::Result<()>
306 where
307 F: FnMut(Vec<(GuestAddress, u64)>),
308 {
309 // In a long-running system, there is no reason to expect that
310 // a significant number of freed pages are consecutive. However,
311 // batching is relatively simple and can result in significant
312 // gains in a newly booted system, so it's worth attempting.
313 let mut range_start = 0;
314 let mut range_size = 0;
315 let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
316 for res in avail_desc.reader.iter::<Le32>() {
317 let pfn = match res {
318 Ok(pfn) => pfn,
319 Err(e) => {
320 error!("error while reading unused pages: {}", e);
321 break;
322 }
323 };
324 let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
325 if range_start + range_size == guest_address {
326 range_size += VIRTIO_BALLOON_PF_SIZE;
327 } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
328 range_start = guest_address;
329 range_size += VIRTIO_BALLOON_PF_SIZE;
330 } else {
331 // Discontinuity, so flush the previous range. Note range_size
332 // will be 0 on the first iteration, so skip that.
333 if range_size != 0 {
334 inflate_ranges.push((range_start, range_size));
335 }
336 range_start = guest_address;
337 range_size = VIRTIO_BALLOON_PF_SIZE;
338 }
339 }
340 if range_size != 0 {
341 inflate_ranges.push((range_start, range_size));
342 }
343
344 release_ranges(release_memory_tube, inflate_ranges, desc_handler)
345 }
346
347 // Async task that handles the main balloon inflate and deflate queues.
handle_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(Vec<(GuestAddress, u64)>),348 async fn handle_queue<F>(
349 mut queue: Queue,
350 mut queue_event: EventAsync,
351 release_memory_tube: Option<&Tube>,
352 mut desc_handler: F,
353 mut stop_rx: oneshot::Receiver<()>,
354 ) -> Queue
355 where
356 F: FnMut(Vec<(GuestAddress, u64)>),
357 {
358 loop {
359 let mut avail_desc = match queue
360 .next_async_interruptable(&mut queue_event, &mut stop_rx)
361 .await
362 {
363 Ok(Some(res)) => res,
364 Ok(None) => return queue,
365 Err(e) => {
366 error!("Failed to read descriptor {}", e);
367 return queue;
368 }
369 };
370 if let Err(e) =
371 handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
372 {
373 error!("balloon: failed to process inflate addresses: {}", e);
374 }
375 queue.add_used(avail_desc, 0);
376 queue.trigger_interrupt();
377 }
378 }
379
380 // Processes one page-reporting descriptor.
handle_reported_buffer<F>( release_memory_tube: Option<&Tube>, avail_desc: &DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(Vec<(GuestAddress, u64)>),381 fn handle_reported_buffer<F>(
382 release_memory_tube: Option<&Tube>,
383 avail_desc: &DescriptorChain,
384 desc_handler: &mut F,
385 ) -> anyhow::Result<()>
386 where
387 F: FnMut(Vec<(GuestAddress, u64)>),
388 {
389 let reported_ranges: Vec<(u64, u64)> = avail_desc
390 .reader
391 .get_remaining_regions()
392 .chain(avail_desc.writer.get_remaining_regions())
393 .map(|r| (r.offset, r.len as u64))
394 .collect();
395
396 release_ranges(release_memory_tube, reported_ranges, desc_handler)
397 }
398
399 // Async task that handles the page reporting queue.
handle_reporting_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(Vec<(GuestAddress, u64)>),400 async fn handle_reporting_queue<F>(
401 mut queue: Queue,
402 mut queue_event: EventAsync,
403 release_memory_tube: Option<&Tube>,
404 mut desc_handler: F,
405 mut stop_rx: oneshot::Receiver<()>,
406 ) -> Queue
407 where
408 F: FnMut(Vec<(GuestAddress, u64)>),
409 {
410 loop {
411 let avail_desc = match queue
412 .next_async_interruptable(&mut queue_event, &mut stop_rx)
413 .await
414 {
415 Ok(Some(res)) => res,
416 Ok(None) => return queue,
417 Err(e) => {
418 error!("Failed to read descriptor {}", e);
419 return queue;
420 }
421 };
422 if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
423 {
424 error!("balloon: failed to process reported buffer: {}", e);
425 }
426 queue.add_used(avail_desc, 0);
427 queue.trigger_interrupt();
428 }
429 }
430
parse_balloon_stats(reader: &mut Reader) -> BalloonStats431 fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
432 let mut stats: BalloonStats = Default::default();
433 for res in reader.iter::<BalloonStat>() {
434 match res {
435 Ok(stat) => stat.update_stats(&mut stats),
436 Err(e) => {
437 error!("error while reading stats: {}", e);
438 break;
439 }
440 };
441 }
442 stats
443 }
444
445 // Async task that handles the stats queue. Note that the cadence of this is driven by requests for
446 // balloon stats from the control pipe.
447 // The guests queues an initial buffer on boot, which is read and then this future will block until
448 // signaled from the command socket that stats should be collected again.
handle_stats_queue( mut queue: Queue, mut queue_event: EventAsync, mut stats_rx: mpsc::Receiver<()>, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Queue449 async fn handle_stats_queue(
450 mut queue: Queue,
451 mut queue_event: EventAsync,
452 mut stats_rx: mpsc::Receiver<()>,
453 command_tube: &AsyncTube,
454 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
455 state: Arc<AsyncRwLock<BalloonState>>,
456 mut stop_rx: oneshot::Receiver<()>,
457 ) -> Queue {
458 let mut avail_desc = match queue
459 .next_async_interruptable(&mut queue_event, &mut stop_rx)
460 .await
461 {
462 // Consume the first stats buffer sent from the guest at startup. It was not
463 // requested by anyone, and the stats are stale.
464 Ok(Some(res)) => res,
465 Ok(None) => return queue,
466 Err(e) => {
467 error!("Failed to read descriptor {}", e);
468 return queue;
469 }
470 };
471
472 loop {
473 select_biased! {
474 msg = stats_rx.next() => {
475 // Wait for a request to read the stats.
476 match msg {
477 Some(()) => (),
478 None => {
479 error!("stats signal channel was closed");
480 return queue;
481 }
482 }
483 }
484 _ = stop_rx => return queue,
485 };
486
487 // Request a new stats_desc to the guest.
488 queue.add_used(avail_desc, 0);
489 queue.trigger_interrupt();
490
491 avail_desc = match queue.next_async(&mut queue_event).await {
492 Err(e) => {
493 error!("Failed to read descriptor {}", e);
494 return queue;
495 }
496 Ok(d) => d,
497 };
498 let stats = parse_balloon_stats(&mut avail_desc.reader);
499
500 let actual_pages = state.lock().await.actual_pages as u64;
501 let result = BalloonTubeResult::Stats {
502 balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
503 stats,
504 };
505 let send_result = command_tube.send(result).await;
506 if let Err(e) = send_result {
507 error!("failed to send stats result: {}", e);
508 }
509
510 #[cfg(feature = "registered_events")]
511 if let Some(registered_evt_q) = registered_evt_q {
512 if let Err(e) = registered_evt_q
513 .send(&RegisteredEventWithData::VirtioBalloonResize)
514 .await
515 {
516 error!("failed to send VirtioBalloonResize event: {}", e);
517 }
518 }
519 }
520 }
521
send_adjusted_response( tube: &AsyncTube, num_pages: u32, ) -> std::result::Result<(), base::TubeError>522 async fn send_adjusted_response(
523 tube: &AsyncTube,
524 num_pages: u32,
525 ) -> std::result::Result<(), base::TubeError> {
526 let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
527 let result = BalloonTubeResult::Adjusted { num_bytes };
528 tube.send(result).await
529 }
530
531 enum WSOp {
532 WSReport,
533 WSConfig {
534 bins: Vec<u32>,
535 refresh_threshold: u32,
536 report_threshold: u32,
537 },
538 }
539
handle_ws_op_queue( mut queue: Queue, mut queue_event: EventAsync, mut ws_op_rx: mpsc::Receiver<WSOp>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>540 async fn handle_ws_op_queue(
541 mut queue: Queue,
542 mut queue_event: EventAsync,
543 mut ws_op_rx: mpsc::Receiver<WSOp>,
544 state: Arc<AsyncRwLock<BalloonState>>,
545 mut stop_rx: oneshot::Receiver<()>,
546 ) -> Result<Queue> {
547 loop {
548 let op = select_biased! {
549 next_op = ws_op_rx.next().fuse() => {
550 match next_op {
551 Some(op) => op,
552 None => {
553 error!("ws op tube was closed");
554 break;
555 }
556 }
557 }
558 _ = stop_rx => {
559 break;
560 }
561 };
562 let mut avail_desc = queue
563 .next_async(&mut queue_event)
564 .await
565 .map_err(BalloonError::AsyncAwait)?;
566 let writer = &mut avail_desc.writer;
567
568 match op {
569 WSOp::WSReport => {
570 {
571 let mut state = state.lock().await;
572 state.expecting_ws = true;
573 }
574
575 let ws_r = virtio_balloon_op {
576 type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
577 };
578
579 writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
580 }
581 WSOp::WSConfig {
582 bins,
583 refresh_threshold,
584 report_threshold,
585 } => {
586 let cmd = virtio_balloon_op {
587 type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
588 };
589
590 writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
591 writer
592 .write_all(bins.as_bytes())
593 .map_err(BalloonError::WriteQueue)?;
594 writer
595 .write_obj(refresh_threshold)
596 .map_err(BalloonError::WriteQueue)?;
597 writer
598 .write_obj(report_threshold)
599 .map_err(BalloonError::WriteQueue)?;
600 }
601 }
602
603 let len = writer.bytes_written() as u32;
604 queue.add_used(avail_desc, len);
605 queue.trigger_interrupt();
606 }
607
608 Ok(queue)
609 }
610
parse_balloon_ws(reader: &mut Reader) -> BalloonWS611 fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
612 let mut ws = BalloonWS::new();
613 for res in reader.iter::<virtio_balloon_ws>() {
614 match res {
615 Ok(ws_msg) => {
616 ws_msg.update_ws(&mut ws);
617 }
618 Err(e) => {
619 error!("error while reading ws: {}", e);
620 break;
621 }
622 }
623 }
624 if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
625 {
626 error!("unexpected number of WS buckets: {}", ws.ws.len());
627 }
628 ws
629 }
630
631 // Async task that handles the stats queue. Note that the arrival of events on
632 // the WS vq may be the result of either a WS request (WS-R) command having
633 // been sent to the guest, or an unprompted send due to memory pressue in the
634 // guest. If the data was requested, we should also send that back on the
635 // command tube.
handle_ws_data_queue( mut queue: Queue, mut queue_event: EventAsync, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>636 async fn handle_ws_data_queue(
637 mut queue: Queue,
638 mut queue_event: EventAsync,
639 command_tube: &AsyncTube,
640 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
641 state: Arc<AsyncRwLock<BalloonState>>,
642 mut stop_rx: oneshot::Receiver<()>,
643 ) -> Result<Queue> {
644 loop {
645 let mut avail_desc = match queue
646 .next_async_interruptable(&mut queue_event, &mut stop_rx)
647 .await
648 .map_err(BalloonError::AsyncAwait)?
649 {
650 Some(res) => res,
651 None => return Ok(queue),
652 };
653
654 let ws = parse_balloon_ws(&mut avail_desc.reader);
655
656 let mut state = state.lock().await;
657
658 // update ws report with balloon pages now that we have a lock on state
659 let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
660
661 if state.expecting_ws {
662 let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
663 let send_result = command_tube.send(result).await;
664 if let Err(e) = send_result {
665 error!("failed to send ws result: {}", e);
666 }
667
668 state.expecting_ws = false;
669 } else {
670 #[cfg(feature = "registered_events")]
671 if let Some(registered_evt_q) = registered_evt_q {
672 if let Err(e) = registered_evt_q
673 .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
674 .await
675 {
676 error!("failed to send VirtioBalloonWSReport event: {}", e);
677 }
678 }
679 }
680
681 queue.add_used(avail_desc, 0);
682 queue.trigger_interrupt();
683 }
684 }
685
686 // Async task that handles the command socket. The command socket handles messages from the host
687 // requesting that the guest balloon be adjusted or to report guest memory statistics.
handle_command_tube( command_tube: &AsyncTube, interrupt: Interrupt, state: Arc<AsyncRwLock<BalloonState>>, mut stats_tx: mpsc::Sender<()>, mut ws_op_tx: mpsc::Sender<WSOp>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<()>688 async fn handle_command_tube(
689 command_tube: &AsyncTube,
690 interrupt: Interrupt,
691 state: Arc<AsyncRwLock<BalloonState>>,
692 mut stats_tx: mpsc::Sender<()>,
693 mut ws_op_tx: mpsc::Sender<WSOp>,
694 mut stop_rx: oneshot::Receiver<()>,
695 ) -> Result<()> {
696 loop {
697 let cmd_res = select_biased! {
698 res = command_tube.next().fuse() => res,
699 _ = stop_rx => return Ok(())
700 };
701 match cmd_res {
702 Ok(command) => match command {
703 BalloonTubeCommand::Adjust {
704 num_bytes,
705 allow_failure,
706 } => {
707 let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
708 let mut state = state.lock().await;
709
710 state.num_pages = num_pages;
711 interrupt.signal_config_changed();
712
713 if allow_failure {
714 if num_pages == state.actual_pages {
715 send_adjusted_response(command_tube, num_pages)
716 .await
717 .map_err(BalloonError::SendResponse)?;
718 } else {
719 state.failable_update = true;
720 }
721 }
722 }
723 BalloonTubeCommand::WorkingSetConfig {
724 bins,
725 refresh_threshold,
726 report_threshold,
727 } => {
728 if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
729 bins,
730 refresh_threshold,
731 report_threshold,
732 }) {
733 error!("failed to send config to ws handler: {}", e);
734 }
735 }
736 BalloonTubeCommand::Stats => {
737 if let Err(e) = stats_tx.try_send(()) {
738 error!("failed to signal the stat handler: {}", e);
739 }
740 }
741 BalloonTubeCommand::WorkingSet => {
742 if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
743 error!("failed to send report request to ws handler: {}", e);
744 }
745 }
746 },
747 #[cfg(windows)]
748 Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
749 // On Windows, async IO tasks like the next/recv above are cancelled as the VM is
750 // shutting down. For the sake of consistency with unix, we can't *just* return
751 // here; instead, we wait for the stop request to arrive, *and then* return.
752 //
753 // The real fix is to get rid of the global unblock pool, since then we won't
754 // cancel the tasks early (b/196911556).
755 let _ = stop_rx.await;
756 return Ok(());
757 }
758 Err(e) => {
759 return Err(BalloonError::ReceivingCommand(e));
760 }
761 }
762 }
763 }
764
handle_pending_adjusted_responses( pending_adjusted_response_event: EventAsync, command_tube: &AsyncTube, state: Arc<AsyncRwLock<BalloonState>>, ) -> Result<()>765 async fn handle_pending_adjusted_responses(
766 pending_adjusted_response_event: EventAsync,
767 command_tube: &AsyncTube,
768 state: Arc<AsyncRwLock<BalloonState>>,
769 ) -> Result<()> {
770 loop {
771 pending_adjusted_response_event
772 .next_val()
773 .await
774 .map_err(BalloonError::AsyncAwait)?;
775 while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
776 send_adjusted_response(command_tube, num_pages)
777 .await
778 .map_err(BalloonError::SendResponse)?;
779 }
780 }
781 }
782
783 /// Represents queues & events for the balloon device.
784 struct BalloonQueues {
785 inflate: Queue,
786 deflate: Queue,
787 stats: Option<Queue>,
788 reporting: Option<Queue>,
789 ws_data: Option<Queue>,
790 ws_op: Option<Queue>,
791 }
792
793 impl BalloonQueues {
new(inflate: Queue, deflate: Queue) -> Self794 fn new(inflate: Queue, deflate: Queue) -> Self {
795 BalloonQueues {
796 inflate,
797 deflate,
798 stats: None,
799 reporting: None,
800 ws_data: None,
801 ws_op: None,
802 }
803 }
804 }
805
806 /// When the worker is stopped, the queues are preserved here.
807 struct PausedQueues {
808 inflate: Queue,
809 deflate: Queue,
810 stats: Option<Queue>,
811 reporting: Option<Queue>,
812 ws_data: Option<Queue>,
813 ws_op: Option<Queue>,
814 }
815
816 impl PausedQueues {
new(inflate: Queue, deflate: Queue) -> Self817 fn new(inflate: Queue, deflate: Queue) -> Self {
818 PausedQueues {
819 inflate,
820 deflate,
821 stats: None,
822 reporting: None,
823 ws_data: None,
824 ws_op: None,
825 }
826 }
827 }
828
apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F) where F: FnMut(Queue) -> R,829 fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
830 where
831 F: FnMut(Queue) -> R,
832 {
833 if let Some(queue) = queue_opt {
834 func(queue);
835 }
836 }
837
838 impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue>839 fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
840 let mut ret = Vec::new();
841 ret.push(queues.inflate);
842 ret.push(queues.deflate);
843 apply_if_some(queues.stats, |stats| ret.push(stats));
844 apply_if_some(queues.reporting, |reporting| ret.push(reporting));
845 apply_if_some(queues.ws_data, |ws_data| ret.push(ws_data));
846 apply_if_some(queues.ws_op, |ws_op| ret.push(ws_op));
847 // WARNING: We don't use the indices from the virito spec on purpose, see comment in
848 // get_queues_from_map for the rationale.
849 ret.into_iter().enumerate().collect()
850 }
851 }
852
free_memory( vm_memory_client: &VmMemoryClient, mem: &GuestMemory, ranges: Vec<(GuestAddress, u64)>, )853 fn free_memory(
854 vm_memory_client: &VmMemoryClient,
855 mem: &GuestMemory,
856 ranges: Vec<(GuestAddress, u64)>,
857 ) {
858 // When `--lock-guest-memory` is used, it is not possible to free the memory from the main
859 // process, so we free it from the sandboxed balloon process directly.
860 #[cfg(any(target_os = "android", target_os = "linux"))]
861 if mem.locked() {
862 for (guest_address, len) in ranges {
863 if let Err(e) = mem.remove_range(guest_address, len) {
864 warn!("Marking pages unused failed: {}, addr={}", e, guest_address);
865 }
866 }
867 return;
868 }
869 if let Err(e) = vm_memory_client.dynamically_free_memory_ranges(ranges) {
870 warn!("Failed to dynamically free memory ranges: {e:#}");
871 }
872 }
873
reclaim_memory(vm_memory_client: &VmMemoryClient, ranges: Vec<(GuestAddress, u64)>)874 fn reclaim_memory(vm_memory_client: &VmMemoryClient, ranges: Vec<(GuestAddress, u64)>) {
875 if let Err(e) = vm_memory_client.dynamically_reclaim_memory_ranges(ranges) {
876 warn!("Failed to dynamically reclaim memory range: {e:#}");
877 }
878 }
879
880 /// Stores data from the worker when it stops so that data can be re-used when
881 /// the worker is restarted.
882 struct WorkerReturn {
883 release_memory_tube: Option<Tube>,
884 command_tube: Tube,
885 #[cfg(feature = "registered_events")]
886 registered_evt_q: Option<SendTube>,
887 paused_queues: Option<PausedQueues>,
888 vm_memory_client: VmMemoryClient,
889 }
890
891 // The main worker thread. Initialized the asynchronous worker tasks and passes them to the executor
892 // to be processed.
run_worker( inflate_queue: Queue, deflate_queue: Queue, stats_queue: Option<Queue>, reporting_queue: Option<Queue>, ws_data_queue: Option<Queue>, ws_op_queue: Option<Queue>, command_tube: Tube, vm_memory_client: VmMemoryClient, mem: GuestMemory, release_memory_tube: Option<Tube>, interrupt: Interrupt, kill_evt: Event, target_reached_evt: Event, pending_adjusted_response_event: Event, state: Arc<AsyncRwLock<BalloonState>>, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ) -> WorkerReturn893 fn run_worker(
894 inflate_queue: Queue,
895 deflate_queue: Queue,
896 stats_queue: Option<Queue>,
897 reporting_queue: Option<Queue>,
898 ws_data_queue: Option<Queue>,
899 ws_op_queue: Option<Queue>,
900 command_tube: Tube,
901 vm_memory_client: VmMemoryClient,
902 mem: GuestMemory,
903 release_memory_tube: Option<Tube>,
904 interrupt: Interrupt,
905 kill_evt: Event,
906 target_reached_evt: Event,
907 pending_adjusted_response_event: Event,
908 state: Arc<AsyncRwLock<BalloonState>>,
909 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
910 ) -> WorkerReturn {
911 let ex = Executor::new().unwrap();
912 let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
913 #[cfg(feature = "registered_events")]
914 let registered_evt_q_async = registered_evt_q
915 .as_ref()
916 .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
917
918 let mut stop_queue_oneshots = Vec::new();
919
920 // We need a block to release all references to command_tube at the end before returning it.
921 let paused_queues = {
922 // The first queue is used for inflate messages
923 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
924 let inflate_queue_evt = inflate_queue
925 .event()
926 .try_clone()
927 .expect("failed to clone queue event");
928 let inflate = handle_queue(
929 inflate_queue,
930 EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
931 release_memory_tube.as_ref(),
932 |ranges| free_memory(&vm_memory_client, &mem, ranges),
933 stop_rx,
934 );
935 let inflate = inflate.fuse();
936 pin_mut!(inflate);
937
938 // The second queue is used for deflate messages
939 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
940 let deflate_queue_evt = deflate_queue
941 .event()
942 .try_clone()
943 .expect("failed to clone queue event");
944 let deflate = handle_queue(
945 deflate_queue,
946 EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
947 None,
948 |ranges| reclaim_memory(&vm_memory_client, ranges),
949 stop_rx,
950 );
951 let deflate = deflate.fuse();
952 pin_mut!(deflate);
953
954 // The next queue is used for stats messages if VIRTIO_BALLOON_F_STATS_VQ is negotiated.
955 let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
956 let has_stats_queue = stats_queue.is_some();
957 let stats = if let Some(stats_queue) = stats_queue {
958 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
959 let stats_queue_evt = stats_queue
960 .event()
961 .try_clone()
962 .expect("failed to clone queue event");
963 handle_stats_queue(
964 stats_queue,
965 EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
966 stats_rx,
967 &command_tube,
968 #[cfg(feature = "registered_events")]
969 registered_evt_q_async.as_ref(),
970 state.clone(),
971 stop_rx,
972 )
973 .left_future()
974 } else {
975 std::future::pending().right_future()
976 };
977 let stats = stats.fuse();
978 pin_mut!(stats);
979
980 // The next queue is used for reporting messages
981 let has_reporting_queue = reporting_queue.is_some();
982 let reporting = if let Some(reporting_queue) = reporting_queue {
983 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
984 let reporting_queue_evt = reporting_queue
985 .event()
986 .try_clone()
987 .expect("failed to clone queue event");
988 handle_reporting_queue(
989 reporting_queue,
990 EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
991 release_memory_tube.as_ref(),
992 |ranges| free_memory(&vm_memory_client, &mem, ranges),
993 stop_rx,
994 )
995 .left_future()
996 } else {
997 std::future::pending().right_future()
998 };
999 let reporting = reporting.fuse();
1000 pin_mut!(reporting);
1001
1002 // If VIRTIO_BALLOON_F_WS_REPORTING is set 2 queues must handled - one for WS data and one
1003 // for WS notifications.
1004 let has_ws_data_queue = ws_data_queue.is_some();
1005 let ws_data = if let Some(ws_data_queue) = ws_data_queue {
1006 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1007 let ws_data_queue_evt = ws_data_queue
1008 .event()
1009 .try_clone()
1010 .expect("failed to clone queue event");
1011 handle_ws_data_queue(
1012 ws_data_queue,
1013 EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1014 &command_tube,
1015 #[cfg(feature = "registered_events")]
1016 registered_evt_q_async.as_ref(),
1017 state.clone(),
1018 stop_rx,
1019 )
1020 .left_future()
1021 } else {
1022 std::future::pending().right_future()
1023 };
1024 let ws_data = ws_data.fuse();
1025 pin_mut!(ws_data);
1026
1027 let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1028 let has_ws_op_queue = ws_op_queue.is_some();
1029 let ws_op = if let Some(ws_op_queue) = ws_op_queue {
1030 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1031 let ws_op_queue_evt = ws_op_queue
1032 .event()
1033 .try_clone()
1034 .expect("failed to clone queue event");
1035 handle_ws_op_queue(
1036 ws_op_queue,
1037 EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1038 ws_op_rx,
1039 state.clone(),
1040 stop_rx,
1041 )
1042 .left_future()
1043 } else {
1044 std::future::pending().right_future()
1045 };
1046 let ws_op = ws_op.fuse();
1047 pin_mut!(ws_op);
1048
1049 // Future to handle command messages that resize the balloon.
1050 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1051 let command = handle_command_tube(
1052 &command_tube,
1053 interrupt.clone(),
1054 state.clone(),
1055 stats_tx,
1056 ws_op_tx,
1057 stop_rx,
1058 );
1059 pin_mut!(command);
1060
1061 // Send a message if balloon target reached event is triggered.
1062 let target_reached = handle_target_reached(&ex, target_reached_evt, &vm_memory_client);
1063 pin_mut!(target_reached);
1064
1065 // Exit if the kill event is triggered.
1066 let kill = async_utils::await_and_exit(&ex, kill_evt);
1067 pin_mut!(kill);
1068
1069 let pending_adjusted = handle_pending_adjusted_responses(
1070 EventAsync::new(pending_adjusted_response_event, &ex)
1071 .expect("failed to create async event"),
1072 &command_tube,
1073 state,
1074 );
1075 pin_mut!(pending_adjusted);
1076
1077 let res = ex.run_until(async {
1078 select! {
1079 _ = kill.fuse() => (),
1080 _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1081 _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1082 _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1083 _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1084 _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1085 _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1086 _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1087 _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1088 _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1089 }
1090
1091 // Worker is shutting down. To recover the queues, we have to signal
1092 // all the queue futures to exit.
1093 for stop_tx in stop_queue_oneshots {
1094 if stop_tx.send(()).is_err() {
1095 return Err(anyhow!("failed to request stop for queue future"));
1096 }
1097 }
1098
1099 // Collect all the queues (awaiting any queue future should now
1100 // return its Queue immediately).
1101 let mut paused_queues = PausedQueues::new(
1102 inflate.await,
1103 deflate.await,
1104 );
1105 if has_reporting_queue {
1106 paused_queues.reporting = Some(reporting.await);
1107 }
1108 if has_stats_queue {
1109 paused_queues.stats = Some(stats.await);
1110 }
1111 if has_ws_data_queue {
1112 paused_queues.ws_data = Some(ws_data.await.context("failed to stop ws_data queue")?);
1113 }
1114 if has_ws_op_queue {
1115 paused_queues.ws_op = Some(ws_op.await.context("failed to stop ws_op queue")?);
1116 }
1117 Ok(paused_queues)
1118 });
1119
1120 match res {
1121 Err(e) => {
1122 error!("error happened in executor: {}", e);
1123 None
1124 }
1125 Ok(main_future_res) => match main_future_res {
1126 Ok(paused_queues) => Some(paused_queues),
1127 Err(e) => {
1128 error!("error happened in main balloon future: {}", e);
1129 None
1130 }
1131 },
1132 }
1133 };
1134
1135 WorkerReturn {
1136 command_tube: command_tube.into(),
1137 paused_queues,
1138 release_memory_tube,
1139 #[cfg(feature = "registered_events")]
1140 registered_evt_q,
1141 vm_memory_client,
1142 }
1143 }
1144
handle_target_reached( ex: &Executor, target_reached_evt: Event, vm_memory_client: &VmMemoryClient, ) -> anyhow::Result<()>1145 async fn handle_target_reached(
1146 ex: &Executor,
1147 target_reached_evt: Event,
1148 vm_memory_client: &VmMemoryClient,
1149 ) -> anyhow::Result<()> {
1150 let event_async =
1151 EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1152 loop {
1153 // Wait for target reached trigger.
1154 let _ = event_async.next_val().await;
1155 // Send the message to vm_control on the event. We don't have to read the current
1156 // size yet.
1157 if let Err(e) = vm_memory_client.balloon_target_reached(0) {
1158 warn!("Failed to send or receive allocation complete request: {e:#}");
1159 }
1160 }
1161 // The above loop will never terminate and there is no reason to terminate it either. However,
1162 // the function is used in an executor that expects a Result<> return. Make sure that clippy
1163 // doesn't enforce the unreachable_code condition.
1164 #[allow(unreachable_code)]
1165 Ok(())
1166 }
1167
1168 /// Virtio device for memory balloon inflation/deflation.
1169 pub struct Balloon {
1170 command_tube: Option<Tube>,
1171 vm_memory_client: Option<VmMemoryClient>,
1172 release_memory_tube: Option<Tube>,
1173 pending_adjusted_response_event: Event,
1174 state: Arc<AsyncRwLock<BalloonState>>,
1175 features: u64,
1176 acked_features: u64,
1177 worker_thread: Option<WorkerThread<WorkerReturn>>,
1178 #[cfg(feature = "registered_events")]
1179 registered_evt_q: Option<SendTube>,
1180 ws_num_bins: u8,
1181 target_reached_evt: Option<Event>,
1182 }
1183
1184 /// Snapshot of the [Balloon] state.
1185 #[derive(Serialize, Deserialize)]
1186 struct BalloonSnapshot {
1187 state: BalloonState,
1188 features: u64,
1189 acked_features: u64,
1190 ws_num_bins: u8,
1191 }
1192
1193 impl Balloon {
1194 /// Creates a new virtio balloon device.
1195 /// To let Balloon able to successfully release the memory which are pinned
1196 /// by CoIOMMU to host, the release_memory_tube will be used to send the inflate
1197 /// ranges to CoIOMMU with UnpinRequest/UnpinResponse messages, so that The
1198 /// memory in the inflate range can be unpinned first.
new( base_features: u64, command_tube: Tube, vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, init_balloon_size: u64, enabled_features: u64, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ws_num_bins: u8, ) -> Result<Balloon>1199 pub fn new(
1200 base_features: u64,
1201 command_tube: Tube,
1202 vm_memory_client: VmMemoryClient,
1203 release_memory_tube: Option<Tube>,
1204 init_balloon_size: u64,
1205 enabled_features: u64,
1206 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1207 ws_num_bins: u8,
1208 ) -> Result<Balloon> {
1209 let features = base_features
1210 | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1211 | 1 << VIRTIO_BALLOON_F_STATS_VQ
1212 | 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1213 | enabled_features;
1214
1215 Ok(Balloon {
1216 command_tube: Some(command_tube),
1217 vm_memory_client: Some(vm_memory_client),
1218 release_memory_tube,
1219 pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1220 state: Arc::new(AsyncRwLock::new(BalloonState {
1221 num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1222 actual_pages: 0,
1223 failable_update: false,
1224 pending_adjusted_responses: VecDeque::new(),
1225 expecting_ws: false,
1226 })),
1227 worker_thread: None,
1228 features,
1229 acked_features: 0,
1230 #[cfg(feature = "registered_events")]
1231 registered_evt_q,
1232 ws_num_bins,
1233 target_reached_evt: None,
1234 })
1235 }
1236
get_config(&self) -> virtio_balloon_config1237 fn get_config(&self) -> virtio_balloon_config {
1238 let state = block_on(self.state.lock());
1239 virtio_balloon_config {
1240 num_pages: state.num_pages.into(),
1241 actual: state.actual_pages.into(),
1242 // crosvm does not (currently) use free_page_hint_cmd_id or
1243 // poison_val, but they must be present in the right order and size
1244 // for the virtio-balloon driver in the guest to deserialize the
1245 // config correctly.
1246 free_page_hint_cmd_id: 0.into(),
1247 poison_val: 0.into(),
1248 ws_num_bins: self.ws_num_bins,
1249 _reserved: [0, 0, 0],
1250 }
1251 }
1252
stop_worker(&mut self) -> StoppedWorker<PausedQueues>1253 fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1254 if let Some(worker_thread) = self.worker_thread.take() {
1255 let worker_ret = worker_thread.stop();
1256 self.release_memory_tube = worker_ret.release_memory_tube;
1257 self.command_tube = Some(worker_ret.command_tube);
1258 #[cfg(feature = "registered_events")]
1259 {
1260 self.registered_evt_q = worker_ret.registered_evt_q;
1261 }
1262 self.vm_memory_client = Some(worker_ret.vm_memory_client);
1263
1264 if let Some(queues) = worker_ret.paused_queues {
1265 StoppedWorker::WithQueues(Box::new(queues))
1266 } else {
1267 StoppedWorker::MissingQueues
1268 }
1269 } else {
1270 StoppedWorker::AlreadyStopped
1271 }
1272 }
1273
1274 /// Given a filtered queue vector from [VirtioDevice::activate], extract
1275 /// the queues (accounting for queues that are missing because the features
1276 /// are not negotiated) into a structure that is easier to work with.
get_queues_from_map( &self, mut queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<BalloonQueues>1277 fn get_queues_from_map(
1278 &self,
1279 mut queues: BTreeMap<usize, Queue>,
1280 ) -> anyhow::Result<BalloonQueues> {
1281 fn pop_queue(
1282 queues: &mut BTreeMap<usize, Queue>,
1283 expected_index: usize,
1284 name: &str,
1285 ) -> anyhow::Result<Queue> {
1286 let (queue_index, queue) = queues
1287 .pop_first()
1288 .with_context(|| format!("missing {}", name))?;
1289
1290 if queue_index == expected_index {
1291 debug!("{name} index {queue_index}");
1292 } else {
1293 warn!("expected {name} index {expected_index}, got {queue_index}");
1294 }
1295
1296 Ok(queue)
1297 }
1298
1299 // WARNING: We use `pop_first` instead of explicitly using the indices from the virtio spec
1300 // because the Linux virtio drivers only "allocates" queue indices that are used, so queues
1301 // need to be removed in order of ascending virtqueue index.
1302 let inflate_queue = pop_queue(&mut queues, INFLATEQ, "inflateq")?;
1303 let deflate_queue = pop_queue(&mut queues, DEFLATEQ, "deflateq")?;
1304 let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1305
1306 if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1307 queue_struct.stats = Some(pop_queue(&mut queues, STATSQ, "statsq")?);
1308 }
1309 if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1310 queue_struct.reporting = Some(pop_queue(&mut queues, REPORTING_VQ, "reporting_vq")?);
1311 }
1312 if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1313 queue_struct.ws_data = Some(pop_queue(&mut queues, WS_DATA_VQ, "ws_data_vq")?);
1314 queue_struct.ws_op = Some(pop_queue(&mut queues, WS_OP_VQ, "ws_op_vq")?);
1315 }
1316
1317 if !queues.is_empty() {
1318 return Err(anyhow!("unexpected queues {:?}", queues.into_keys()));
1319 }
1320
1321 Ok(queue_struct)
1322 }
1323
start_worker( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BalloonQueues, ) -> anyhow::Result<()>1324 fn start_worker(
1325 &mut self,
1326 mem: GuestMemory,
1327 interrupt: Interrupt,
1328 queues: BalloonQueues,
1329 ) -> anyhow::Result<()> {
1330 let (self_target_reached_evt, target_reached_evt) = Event::new()
1331 .and_then(|e| Ok((e.try_clone()?, e)))
1332 .context("failed to create target_reached Event pair: {}")?;
1333 self.target_reached_evt = Some(self_target_reached_evt);
1334
1335 let state = self.state.clone();
1336
1337 let command_tube = self.command_tube.take().unwrap();
1338
1339 let vm_memory_client = self.vm_memory_client.take().unwrap();
1340 let release_memory_tube = self.release_memory_tube.take();
1341 #[cfg(feature = "registered_events")]
1342 let registered_evt_q = self.registered_evt_q.take();
1343 let pending_adjusted_response_event = self
1344 .pending_adjusted_response_event
1345 .try_clone()
1346 .context("failed to clone Event")?;
1347
1348 self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1349 run_worker(
1350 queues.inflate,
1351 queues.deflate,
1352 queues.stats,
1353 queues.reporting,
1354 queues.ws_data,
1355 queues.ws_op,
1356 command_tube,
1357 vm_memory_client,
1358 mem,
1359 release_memory_tube,
1360 interrupt,
1361 kill_evt,
1362 target_reached_evt,
1363 pending_adjusted_response_event,
1364 state,
1365 #[cfg(feature = "registered_events")]
1366 registered_evt_q,
1367 )
1368 }));
1369
1370 Ok(())
1371 }
1372 }
1373
1374 impl VirtioDevice for Balloon {
keep_rds(&self) -> Vec<RawDescriptor>1375 fn keep_rds(&self) -> Vec<RawDescriptor> {
1376 let mut rds = Vec::new();
1377 if let Some(command_tube) = &self.command_tube {
1378 rds.push(command_tube.as_raw_descriptor());
1379 }
1380 if let Some(vm_memory_client) = &self.vm_memory_client {
1381 rds.push(vm_memory_client.as_raw_descriptor());
1382 }
1383 if let Some(release_memory_tube) = &self.release_memory_tube {
1384 rds.push(release_memory_tube.as_raw_descriptor());
1385 }
1386 #[cfg(feature = "registered_events")]
1387 if let Some(registered_evt_q) = &self.registered_evt_q {
1388 rds.push(registered_evt_q.as_raw_descriptor());
1389 }
1390 rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1391 rds
1392 }
1393
device_type(&self) -> DeviceType1394 fn device_type(&self) -> DeviceType {
1395 DeviceType::Balloon
1396 }
1397
queue_max_sizes(&self) -> &[u16]1398 fn queue_max_sizes(&self) -> &[u16] {
1399 QUEUE_SIZES
1400 }
1401
read_config(&self, offset: u64, data: &mut [u8])1402 fn read_config(&self, offset: u64, data: &mut [u8]) {
1403 copy_config(data, 0, self.get_config().as_bytes(), offset);
1404 }
1405
write_config(&mut self, offset: u64, data: &[u8])1406 fn write_config(&mut self, offset: u64, data: &[u8]) {
1407 let mut config = self.get_config();
1408 copy_config(config.as_mut_bytes(), offset, data, 0);
1409 let mut state = block_on(self.state.lock());
1410 state.actual_pages = config.actual.to_native();
1411
1412 // If balloon has updated to the requested memory, let the hypervisor know.
1413 if config.num_pages == config.actual {
1414 debug!(
1415 "sending target reached event at {}",
1416 u32::from(config.num_pages)
1417 );
1418 self.target_reached_evt.as_ref().map(|e| e.signal());
1419 }
1420 if state.failable_update && state.actual_pages == state.num_pages {
1421 state.failable_update = false;
1422 let num_pages = state.num_pages;
1423 state.pending_adjusted_responses.push_back(num_pages);
1424 let _ = self.pending_adjusted_response_event.signal();
1425 }
1426 }
1427
features(&self) -> u641428 fn features(&self) -> u64 {
1429 self.features
1430 }
1431
ack_features(&mut self, mut value: u64)1432 fn ack_features(&mut self, mut value: u64) {
1433 if value & !self.features != 0 {
1434 warn!("virtio_balloon got unknown feature ack {:x}", value);
1435 value &= self.features;
1436 }
1437 self.acked_features |= value;
1438 }
1439
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>1440 fn activate(
1441 &mut self,
1442 mem: GuestMemory,
1443 interrupt: Interrupt,
1444 queues: BTreeMap<usize, Queue>,
1445 ) -> anyhow::Result<()> {
1446 let queues = self.get_queues_from_map(queues)?;
1447 self.start_worker(mem, interrupt, queues)
1448 }
1449
reset(&mut self) -> anyhow::Result<()>1450 fn reset(&mut self) -> anyhow::Result<()> {
1451 let _worker = self.stop_worker();
1452 Ok(())
1453 }
1454
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>1455 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1456 match self.stop_worker() {
1457 StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1458 StoppedWorker::MissingQueues => {
1459 anyhow::bail!("balloon queue workers did not stop cleanly.")
1460 }
1461 StoppedWorker::AlreadyStopped => {
1462 // Device hasn't been activated.
1463 Ok(None)
1464 }
1465 }
1466 }
1467
virtio_wake( &mut self, queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>1468 fn virtio_wake(
1469 &mut self,
1470 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1471 ) -> anyhow::Result<()> {
1472 if let Some((mem, interrupt, queues)) = queues_state {
1473 if queues.len() < 2 {
1474 anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1475 }
1476
1477 let balloon_queues = self.get_queues_from_map(queues)?;
1478 self.start_worker(mem, interrupt, balloon_queues)?;
1479 }
1480 Ok(())
1481 }
1482
virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot>1483 fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1484 let state = self
1485 .state
1486 .lock()
1487 .now_or_never()
1488 .context("failed to acquire balloon lock")?;
1489 AnySnapshot::to_any(BalloonSnapshot {
1490 features: self.features,
1491 acked_features: self.acked_features,
1492 state: state.clone(),
1493 ws_num_bins: self.ws_num_bins,
1494 })
1495 .context("failed to serialize balloon state")
1496 }
1497
virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>1498 fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1499 let snap: BalloonSnapshot = AnySnapshot::from_any(data).context("error deserializing")?;
1500 if snap.features != self.features {
1501 anyhow::bail!(
1502 "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1503 self.features,
1504 snap.features,
1505 );
1506 }
1507
1508 let mut state = self
1509 .state
1510 .lock()
1511 .now_or_never()
1512 .context("failed to acquire balloon lock")?;
1513 *state = snap.state;
1514 self.ws_num_bins = snap.ws_num_bins;
1515 self.acked_features = snap.acked_features;
1516 Ok(())
1517 }
1518 }
1519
1520 #[cfg(test)]
1521 mod tests {
1522 use super::*;
1523 use crate::suspendable_virtio_tests;
1524 use crate::virtio::descriptor_utils::create_descriptor_chain;
1525 use crate::virtio::descriptor_utils::DescriptorType;
1526
1527 #[test]
desc_parsing_inflate()1528 fn desc_parsing_inflate() {
1529 // Check that the memory addresses are parsed correctly by 'handle_address_chain' and passed
1530 // to the closure.
1531 let memory_start_addr = GuestAddress(0x0);
1532 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1533 memory
1534 .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1535 .unwrap();
1536 memory
1537 .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1538 .unwrap();
1539
1540 let mut chain = create_descriptor_chain(
1541 &memory,
1542 GuestAddress(0x0),
1543 GuestAddress(0x100),
1544 vec![(DescriptorType::Readable, 8)],
1545 0,
1546 )
1547 .expect("create_descriptor_chain failed");
1548
1549 let mut addrs = Vec::new();
1550 let res = handle_address_chain(None, &mut chain, &mut |mut ranges| {
1551 addrs.append(&mut ranges)
1552 });
1553 assert!(res.is_ok());
1554 assert_eq!(addrs.len(), 2);
1555 assert_eq!(
1556 addrs[0].0,
1557 GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1558 );
1559 assert_eq!(
1560 addrs[1].0,
1561 GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1562 );
1563 }
1564
1565 struct BalloonContext {
1566 _ctrl_tube: Tube,
1567 _mem_client_tube: Tube,
1568 }
1569
modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon)1570 fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1571 balloon.ws_num_bins = !balloon.ws_num_bins;
1572 }
1573
create_device() -> (BalloonContext, Balloon)1574 fn create_device() -> (BalloonContext, Balloon) {
1575 let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1576 let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1577 (
1578 BalloonContext {
1579 _ctrl_tube,
1580 _mem_client_tube,
1581 },
1582 Balloon::new(
1583 0,
1584 ctrl_tube_device,
1585 VmMemoryClient::new(mem_client_tube_device),
1586 None,
1587 1024,
1588 0,
1589 #[cfg(feature = "registered_events")]
1590 None,
1591 0,
1592 )
1593 .unwrap(),
1594 )
1595 }
1596
1597 suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1598 }
1599