1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 mod sys;
6
7 use std::collections::BTreeMap;
8 use std::collections::VecDeque;
9 use std::io::Write;
10 use std::sync::Arc;
11
12 use anyhow::anyhow;
13 use anyhow::Context;
14 use balloon_control::BalloonStats;
15 use balloon_control::BalloonTubeCommand;
16 use balloon_control::BalloonTubeResult;
17 use balloon_control::BalloonWS;
18 use balloon_control::WSBucket;
19 use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
20 use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
21 use base::debug;
22 use base::error;
23 use base::warn;
24 use base::AsRawDescriptor;
25 use base::Event;
26 use base::RawDescriptor;
27 #[cfg(feature = "registered_events")]
28 use base::SendTube;
29 use base::Tube;
30 use base::WorkerThread;
31 use cros_async::block_on;
32 use cros_async::sync::RwLock as AsyncRwLock;
33 use cros_async::AsyncTube;
34 use cros_async::EventAsync;
35 use cros_async::Executor;
36 #[cfg(feature = "registered_events")]
37 use cros_async::SendTubeAsync;
38 use data_model::Le16;
39 use data_model::Le32;
40 use data_model::Le64;
41 use futures::channel::mpsc;
42 use futures::channel::oneshot;
43 use futures::pin_mut;
44 use futures::select;
45 use futures::select_biased;
46 use futures::FutureExt;
47 use futures::StreamExt;
48 use remain::sorted;
49 use serde::Deserialize;
50 use serde::Serialize;
51 use thiserror::Error as ThisError;
52 use vm_control::api::VmMemoryClient;
53 #[cfg(feature = "registered_events")]
54 use vm_control::RegisteredEventWithData;
55 use vm_memory::GuestAddress;
56 use vm_memory::GuestMemory;
57 use zerocopy::AsBytes;
58 use zerocopy::FromBytes;
59 use zerocopy::FromZeroes;
60
61 use super::async_utils;
62 use super::copy_config;
63 use super::create_stop_oneshot;
64 use super::DescriptorChain;
65 use super::DeviceType;
66 use super::Interrupt;
67 use super::Queue;
68 use super::Reader;
69 use super::StoppedWorker;
70 use super::VirtioDevice;
71 use crate::UnpinRequest;
72 use crate::UnpinResponse;
73
74 #[sorted]
75 #[derive(ThisError, Debug)]
76 pub enum BalloonError {
77 /// Failed an async await
78 #[error("failed async await: {0}")]
79 AsyncAwait(cros_async::AsyncError),
80 /// Failed an async await
81 #[error("failed async await: {0}")]
82 AsyncAwaitAnyhow(anyhow::Error),
83 /// Failed to create event.
84 #[error("failed to create event: {0}")]
85 CreatingEvent(base::Error),
86 /// Failed to create async message receiver.
87 #[error("failed to create async message receiver: {0}")]
88 CreatingMessageReceiver(base::TubeError),
89 /// Failed to receive command message.
90 #[error("failed to receive command message: {0}")]
91 ReceivingCommand(base::TubeError),
92 /// Failed to send command response.
93 #[error("failed to send command response: {0}")]
94 SendResponse(base::TubeError),
95 /// Error while writing to virtqueue
96 #[error("failed to write to virtqueue: {0}")]
97 WriteQueue(std::io::Error),
98 /// Failed to write config event.
99 #[error("failed to write config event: {0}")]
100 WritingConfigEvent(base::Error),
101 }
102 pub type Result<T> = std::result::Result<T, BalloonError>;
103
104 // Balloon implements six virt IO queues: Inflate, Deflate, Stats, Event, WsData, WsCmd.
105 const QUEUE_SIZE: u16 = 128;
106 const QUEUE_SIZES: &[u16] = &[
107 QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE,
108 ];
109
110 const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
111 const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
112
113 // The feature bitmap for virtio balloon
114 const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; // Tell before reclaiming pages
115 const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; // Stats reporting enabled
116 const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; // Deflate balloon on OOM
117 const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; // Page reporting virtqueue
118 // TODO(b/273973298): this should maybe be bit 6? to be changed later
119 const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; // Working Set Reporting virtqueues
120
121 #[derive(Copy, Clone)]
122 #[repr(u32)]
123 // Balloon virtqueues
124 pub enum BalloonFeatures {
125 // Page Reporting enabled
126 PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
127 // WS Reporting enabled
128 WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
129 }
130
131 // These feature bits are part of the proposal:
132 // https://lists.oasis-open.org/archives/virtio-comment/202201/msg00139.html
133 const VIRTIO_BALLOON_F_RESPONSIVE_DEVICE: u32 = 6; // Device actively watching guest memory
134 const VIRTIO_BALLOON_F_EVENTS_VQ: u32 = 7; // Event vq is enabled
135
136 // virtio_balloon_config is the balloon device configuration space defined by the virtio spec.
137 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
138 #[repr(C)]
139 struct virtio_balloon_config {
140 num_pages: Le32,
141 actual: Le32,
142 free_page_hint_cmd_id: Le32,
143 poison_val: Le32,
144 // WS field is part of proposed spec extension (b/273973298).
145 ws_num_bins: u8,
146 _reserved: [u8; 3],
147 }
148
149 // BalloonState is shared by the worker and device thread.
150 #[derive(Clone, Default, Serialize, Deserialize)]
151 struct BalloonState {
152 num_pages: u32,
153 actual_pages: u32,
154 expecting_ws: bool,
155 // Flag indicating that the balloon is in the process of a failable update. This
156 // is set by an Adjust command that has allow_failure set, and is cleared when the
157 // Adjusted success/failure response is sent.
158 failable_update: bool,
159 pending_adjusted_responses: VecDeque<u32>,
160 }
161
162 // The constants defining stats types in virtio_baloon_stat
163 const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
164 const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
165 const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
166 const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
167 const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
168 const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
169 const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
170 const VIRTIO_BALLOON_S_CACHES: u16 = 7;
171 const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
172 const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
173 const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
174 const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
175
176 // BalloonStat is used to deserialize stats from the stats_queue.
177 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
178 #[repr(C, packed)]
179 struct BalloonStat {
180 tag: Le16,
181 val: Le64,
182 }
183
184 impl BalloonStat {
update_stats(&self, stats: &mut BalloonStats)185 fn update_stats(&self, stats: &mut BalloonStats) {
186 let val = Some(self.val.to_native());
187 match self.tag.to_native() {
188 VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
189 VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
190 VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
191 VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
192 VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
193 VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
194 VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
195 VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
196 VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
197 VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
198 VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
199 VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
200 _ => (),
201 }
202 }
203 }
204
205 const VIRTIO_BALLOON_EVENT_PRESSURE: u32 = 1;
206 const VIRTIO_BALLOON_EVENT_PUFF_FAILURE: u32 = 2;
207
208 #[repr(C)]
209 #[derive(Copy, Clone, Default, AsBytes, FromZeroes, FromBytes)]
210 struct virtio_balloon_event_header {
211 evt_type: Le32,
212 }
213
214 // virtio_balloon_ws is used to deserialize from the ws data vq.
215 #[repr(C)]
216 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
217 struct virtio_balloon_ws {
218 tag: Le16,
219 node_id: Le16,
220 // virtio prefers field members to align on a word boundary so we must pad. see:
221 // https://crsrc.org/o/src/third_party/kernel/v5.15/include/uapi/linux/virtio_balloon.h;l=105
222 _reserved: [u8; 4],
223 idle_age_ms: Le64,
224 // TODO(b/273973298): these should become separate fields - bytes for ANON and FILE
225 memory_size_bytes: [Le64; 2],
226 }
227
228 impl virtio_balloon_ws {
update_ws(&self, ws: &mut BalloonWS)229 fn update_ws(&self, ws: &mut BalloonWS) {
230 let bucket = WSBucket {
231 age: self.idle_age_ms.to_native(),
232 bytes: [
233 self.memory_size_bytes[0].to_native(),
234 self.memory_size_bytes[1].to_native(),
235 ],
236 };
237 ws.ws.push(bucket);
238 }
239 }
240
241 const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
242 const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
243 const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
244 const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
245
246 // virtio_balloon_op is used to serialize to the ws cmd vq.
247 #[repr(C, packed)]
248 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
249 struct virtio_balloon_op {
250 type_: Le16,
251 }
252
invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F) where F: FnMut(GuestAddress, u64),253 fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
254 where
255 F: FnMut(GuestAddress, u64),
256 {
257 for range in ranges {
258 desc_handler(GuestAddress(range.0), range.1);
259 }
260 }
261
262 // Release a list of guest memory ranges back to the host system.
263 // Unpin requests for each inflate range will be sent via `release_memory_tube`
264 // if provided, and then `desc_handler` will be called for each inflate range.
release_ranges<F>( release_memory_tube: Option<&Tube>, inflate_ranges: Vec<(u64, u64)>, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),265 fn release_ranges<F>(
266 release_memory_tube: Option<&Tube>,
267 inflate_ranges: Vec<(u64, u64)>,
268 desc_handler: &mut F,
269 ) -> anyhow::Result<()>
270 where
271 F: FnMut(GuestAddress, u64),
272 {
273 if let Some(tube) = release_memory_tube {
274 let unpin_ranges = inflate_ranges
275 .iter()
276 .map(|v| {
277 (
278 v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
279 v.1 / VIRTIO_BALLOON_PF_SIZE,
280 )
281 })
282 .collect();
283 let req = UnpinRequest {
284 ranges: unpin_ranges,
285 };
286 if let Err(e) = tube.send(&req) {
287 error!("failed to send unpin request: {}", e);
288 } else {
289 match tube.recv() {
290 Ok(resp) => match resp {
291 UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
292 UnpinResponse::Failed => error!("failed to handle unpin request"),
293 },
294 Err(e) => error!("failed to handle get unpin response: {}", e),
295 }
296 }
297 } else {
298 invoke_desc_handler(inflate_ranges, desc_handler);
299 }
300
301 Ok(())
302 }
303
304 // Processes one message's list of addresses.
handle_address_chain<F>( release_memory_tube: Option<&Tube>, avail_desc: &mut DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),305 fn handle_address_chain<F>(
306 release_memory_tube: Option<&Tube>,
307 avail_desc: &mut DescriptorChain,
308 desc_handler: &mut F,
309 ) -> anyhow::Result<()>
310 where
311 F: FnMut(GuestAddress, u64),
312 {
313 // In a long-running system, there is no reason to expect that
314 // a significant number of freed pages are consecutive. However,
315 // batching is relatively simple and can result in significant
316 // gains in a newly booted system, so it's worth attempting.
317 let mut range_start = 0;
318 let mut range_size = 0;
319 let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
320 for res in avail_desc.reader.iter::<Le32>() {
321 let pfn = match res {
322 Ok(pfn) => pfn,
323 Err(e) => {
324 error!("error while reading unused pages: {}", e);
325 break;
326 }
327 };
328 let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
329 if range_start + range_size == guest_address {
330 range_size += VIRTIO_BALLOON_PF_SIZE;
331 } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
332 range_start = guest_address;
333 range_size += VIRTIO_BALLOON_PF_SIZE;
334 } else {
335 // Discontinuity, so flush the previous range. Note range_size
336 // will be 0 on the first iteration, so skip that.
337 if range_size != 0 {
338 inflate_ranges.push((range_start, range_size));
339 }
340 range_start = guest_address;
341 range_size = VIRTIO_BALLOON_PF_SIZE;
342 }
343 }
344 if range_size != 0 {
345 inflate_ranges.push((range_start, range_size));
346 }
347
348 release_ranges(release_memory_tube, inflate_ranges, desc_handler)
349 }
350
351 // Async task that handles the main balloon inflate and deflate queues.
handle_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, interrupt: Interrupt, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(GuestAddress, u64),352 async fn handle_queue<F>(
353 mut queue: Queue,
354 mut queue_event: EventAsync,
355 release_memory_tube: Option<&Tube>,
356 interrupt: Interrupt,
357 mut desc_handler: F,
358 mut stop_rx: oneshot::Receiver<()>,
359 ) -> Queue
360 where
361 F: FnMut(GuestAddress, u64),
362 {
363 loop {
364 let mut avail_desc = match queue
365 .next_async_interruptable(&mut queue_event, &mut stop_rx)
366 .await
367 {
368 Ok(Some(res)) => res,
369 Ok(None) => return queue,
370 Err(e) => {
371 error!("Failed to read descriptor {}", e);
372 return queue;
373 }
374 };
375 if let Err(e) =
376 handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
377 {
378 error!("balloon: failed to process inflate addresses: {}", e);
379 }
380 queue.add_used(avail_desc, 0);
381 queue.trigger_interrupt(&interrupt);
382 }
383 }
384
385 // Processes one page-reporting descriptor.
handle_reported_buffer<F>( release_memory_tube: Option<&Tube>, avail_desc: &DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),386 fn handle_reported_buffer<F>(
387 release_memory_tube: Option<&Tube>,
388 avail_desc: &DescriptorChain,
389 desc_handler: &mut F,
390 ) -> anyhow::Result<()>
391 where
392 F: FnMut(GuestAddress, u64),
393 {
394 let reported_ranges: Vec<(u64, u64)> = avail_desc
395 .reader
396 .get_remaining_regions()
397 .chain(avail_desc.writer.get_remaining_regions())
398 .map(|r| (r.offset, r.len as u64))
399 .collect();
400
401 release_ranges(release_memory_tube, reported_ranges, desc_handler)
402 }
403
404 // Async task that handles the page reporting queue.
handle_reporting_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, interrupt: Interrupt, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(GuestAddress, u64),405 async fn handle_reporting_queue<F>(
406 mut queue: Queue,
407 mut queue_event: EventAsync,
408 release_memory_tube: Option<&Tube>,
409 interrupt: Interrupt,
410 mut desc_handler: F,
411 mut stop_rx: oneshot::Receiver<()>,
412 ) -> Queue
413 where
414 F: FnMut(GuestAddress, u64),
415 {
416 loop {
417 let avail_desc = match queue
418 .next_async_interruptable(&mut queue_event, &mut stop_rx)
419 .await
420 {
421 Ok(Some(res)) => res,
422 Ok(None) => return queue,
423 Err(e) => {
424 error!("Failed to read descriptor {}", e);
425 return queue;
426 }
427 };
428 if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
429 {
430 error!("balloon: failed to process reported buffer: {}", e);
431 }
432 queue.add_used(avail_desc, 0);
433 queue.trigger_interrupt(&interrupt);
434 }
435 }
436
parse_balloon_stats(reader: &mut Reader) -> BalloonStats437 fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
438 let mut stats: BalloonStats = Default::default();
439 for res in reader.iter::<BalloonStat>() {
440 match res {
441 Ok(stat) => stat.update_stats(&mut stats),
442 Err(e) => {
443 error!("error while reading stats: {}", e);
444 break;
445 }
446 };
447 }
448 stats
449 }
450
451 // Async task that handles the stats queue. Note that the cadence of this is driven by requests for
452 // balloon stats from the control pipe.
453 // The guests queues an initial buffer on boot, which is read and then this future will block until
454 // signaled from the command socket that stats should be collected again.
handle_stats_queue( mut queue: Queue, mut queue_event: EventAsync, mut stats_rx: mpsc::Receiver<()>, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, mut stop_rx: oneshot::Receiver<()>, ) -> Queue455 async fn handle_stats_queue(
456 mut queue: Queue,
457 mut queue_event: EventAsync,
458 mut stats_rx: mpsc::Receiver<()>,
459 command_tube: &AsyncTube,
460 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
461 state: Arc<AsyncRwLock<BalloonState>>,
462 interrupt: Interrupt,
463 mut stop_rx: oneshot::Receiver<()>,
464 ) -> Queue {
465 let mut avail_desc = match queue
466 .next_async_interruptable(&mut queue_event, &mut stop_rx)
467 .await
468 {
469 // Consume the first stats buffer sent from the guest at startup. It was not
470 // requested by anyone, and the stats are stale.
471 Ok(Some(res)) => res,
472 Ok(None) => return queue,
473 Err(e) => {
474 error!("Failed to read descriptor {}", e);
475 return queue;
476 }
477 };
478
479 loop {
480 select_biased! {
481 msg = stats_rx.next() => {
482 // Wait for a request to read the stats.
483 match msg {
484 Some(()) => (),
485 None => {
486 error!("stats signal channel was closed");
487 return queue;
488 }
489 }
490 }
491 _ = stop_rx => return queue,
492 };
493
494 // Request a new stats_desc to the guest.
495 queue.add_used(avail_desc, 0);
496 queue.trigger_interrupt(&interrupt);
497
498 avail_desc = match queue.next_async(&mut queue_event).await {
499 Err(e) => {
500 error!("Failed to read descriptor {}", e);
501 return queue;
502 }
503 Ok(d) => d,
504 };
505 let stats = parse_balloon_stats(&mut avail_desc.reader);
506
507 let actual_pages = state.lock().await.actual_pages as u64;
508 let result = BalloonTubeResult::Stats {
509 balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
510 stats,
511 };
512 let send_result = command_tube.send(result).await;
513 if let Err(e) = send_result {
514 error!("failed to send stats result: {}", e);
515 }
516
517 #[cfg(feature = "registered_events")]
518 if let Some(registered_evt_q) = registered_evt_q {
519 if let Err(e) = registered_evt_q
520 .send(&RegisteredEventWithData::VirtioBalloonResize)
521 .await
522 {
523 error!("failed to send VirtioBalloonResize event: {}", e);
524 }
525 }
526 }
527 }
528
send_adjusted_response( tube: &AsyncTube, num_pages: u32, ) -> std::result::Result<(), base::TubeError>529 async fn send_adjusted_response(
530 tube: &AsyncTube,
531 num_pages: u32,
532 ) -> std::result::Result<(), base::TubeError> {
533 let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
534 let result = BalloonTubeResult::Adjusted { num_bytes };
535 tube.send(result).await
536 }
537
handle_event( state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, r: &mut Reader, command_tube: &AsyncTube, ) -> Result<()>538 async fn handle_event(
539 state: Arc<AsyncRwLock<BalloonState>>,
540 interrupt: Interrupt,
541 r: &mut Reader,
542 command_tube: &AsyncTube,
543 ) -> Result<()> {
544 match r.read_obj::<virtio_balloon_event_header>() {
545 Ok(hdr) => match hdr.evt_type.to_native() {
546 VIRTIO_BALLOON_EVENT_PRESSURE => {
547 // TODO(b/213962590): See how this can be integrated this into memory rebalancing
548 }
549 VIRTIO_BALLOON_EVENT_PUFF_FAILURE => {
550 let mut state = state.lock().await;
551 if state.failable_update {
552 state.num_pages = state.actual_pages;
553 interrupt.signal_config_changed();
554
555 state.failable_update = false;
556 send_adjusted_response(command_tube, state.actual_pages)
557 .await
558 .map_err(BalloonError::SendResponse)?;
559 }
560 }
561 _ => {
562 warn!("Unknown event {}", hdr.evt_type.to_native());
563 }
564 },
565 Err(e) => error!("Failed to parse event header {:?}", e),
566 }
567 Ok(())
568 }
569
570 // Async task that handles the events queue.
handle_events_queue( mut queue: Queue, mut queue_event: EventAsync, state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, command_tube: &AsyncTube, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>571 async fn handle_events_queue(
572 mut queue: Queue,
573 mut queue_event: EventAsync,
574 state: Arc<AsyncRwLock<BalloonState>>,
575 interrupt: Interrupt,
576 command_tube: &AsyncTube,
577 mut stop_rx: oneshot::Receiver<()>,
578 ) -> Result<Queue> {
579 while let Some(mut avail_desc) = queue
580 .next_async_interruptable(&mut queue_event, &mut stop_rx)
581 .await
582 .map_err(BalloonError::AsyncAwait)?
583 {
584 handle_event(
585 state.clone(),
586 interrupt.clone(),
587 &mut avail_desc.reader,
588 command_tube,
589 )
590 .await?;
591
592 queue.add_used(avail_desc, 0);
593 queue.trigger_interrupt(&interrupt);
594 }
595 Ok(queue)
596 }
597
598 enum WSOp {
599 WSReport,
600 WSConfig {
601 bins: Vec<u32>,
602 refresh_threshold: u32,
603 report_threshold: u32,
604 },
605 }
606
handle_ws_op_queue( mut queue: Queue, mut queue_event: EventAsync, mut ws_op_rx: mpsc::Receiver<WSOp>, state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>607 async fn handle_ws_op_queue(
608 mut queue: Queue,
609 mut queue_event: EventAsync,
610 mut ws_op_rx: mpsc::Receiver<WSOp>,
611 state: Arc<AsyncRwLock<BalloonState>>,
612 interrupt: Interrupt,
613 mut stop_rx: oneshot::Receiver<()>,
614 ) -> Result<Queue> {
615 loop {
616 let op = select_biased! {
617 next_op = ws_op_rx.next().fuse() => {
618 match next_op {
619 Some(op) => op,
620 None => {
621 error!("ws op tube was closed");
622 break;
623 }
624 }
625 }
626 _ = stop_rx => {
627 break;
628 }
629 };
630 let mut avail_desc = queue
631 .next_async(&mut queue_event)
632 .await
633 .map_err(BalloonError::AsyncAwait)?;
634 let writer = &mut avail_desc.writer;
635
636 match op {
637 WSOp::WSReport => {
638 {
639 let mut state = state.lock().await;
640 state.expecting_ws = true;
641 }
642
643 let ws_r = virtio_balloon_op {
644 type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
645 };
646
647 writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
648 }
649 WSOp::WSConfig {
650 bins,
651 refresh_threshold,
652 report_threshold,
653 } => {
654 let cmd = virtio_balloon_op {
655 type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
656 };
657
658 writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
659 writer
660 .write_all(bins.as_bytes())
661 .map_err(BalloonError::WriteQueue)?;
662 writer
663 .write_obj(refresh_threshold)
664 .map_err(BalloonError::WriteQueue)?;
665 writer
666 .write_obj(report_threshold)
667 .map_err(BalloonError::WriteQueue)?;
668 }
669 }
670
671 let len = writer.bytes_written() as u32;
672 queue.add_used(avail_desc, len);
673 queue.trigger_interrupt(&interrupt);
674 }
675
676 Ok(queue)
677 }
678
parse_balloon_ws(reader: &mut Reader) -> BalloonWS679 fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
680 let mut ws = BalloonWS::new();
681 for res in reader.iter::<virtio_balloon_ws>() {
682 match res {
683 Ok(ws_msg) => {
684 ws_msg.update_ws(&mut ws);
685 }
686 Err(e) => {
687 error!("error while reading ws: {}", e);
688 break;
689 }
690 }
691 }
692 if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
693 {
694 error!("unexpected number of WS buckets: {}", ws.ws.len());
695 }
696 ws
697 }
698
699 // Async task that handles the stats queue. Note that the arrival of events on
700 // the WS vq may be the result of either a WS request (WS-R) command having
701 // been sent to the guest, or an unprompted send due to memory pressue in the
702 // guest. If the data was requested, we should also send that back on the
703 // command tube.
handle_ws_data_queue( mut queue: Queue, mut queue_event: EventAsync, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, interrupt: Interrupt, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>704 async fn handle_ws_data_queue(
705 mut queue: Queue,
706 mut queue_event: EventAsync,
707 command_tube: &AsyncTube,
708 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
709 state: Arc<AsyncRwLock<BalloonState>>,
710 interrupt: Interrupt,
711 mut stop_rx: oneshot::Receiver<()>,
712 ) -> Result<Queue> {
713 loop {
714 let mut avail_desc = match queue
715 .next_async_interruptable(&mut queue_event, &mut stop_rx)
716 .await
717 .map_err(BalloonError::AsyncAwait)?
718 {
719 Some(res) => res,
720 None => return Ok(queue),
721 };
722
723 let ws = parse_balloon_ws(&mut avail_desc.reader);
724
725 let mut state = state.lock().await;
726
727 // update ws report with balloon pages now that we have a lock on state
728 let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
729
730 if state.expecting_ws {
731 let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
732 let send_result = command_tube.send(result).await;
733 if let Err(e) = send_result {
734 error!("failed to send ws result: {}", e);
735 }
736
737 state.expecting_ws = false;
738 } else {
739 #[cfg(feature = "registered_events")]
740 if let Some(registered_evt_q) = registered_evt_q {
741 if let Err(e) = registered_evt_q
742 .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
743 .await
744 {
745 error!("failed to send VirtioBalloonWSReport event: {}", e);
746 }
747 }
748 }
749
750 queue.add_used(avail_desc, 0);
751 queue.trigger_interrupt(&interrupt);
752 }
753 }
754
755 // Async task that handles the command socket. The command socket handles messages from the host
756 // requesting that the guest balloon be adjusted or to report guest memory statistics.
handle_command_tube( command_tube: &AsyncTube, interrupt: Interrupt, state: Arc<AsyncRwLock<BalloonState>>, mut stats_tx: mpsc::Sender<()>, mut ws_op_tx: mpsc::Sender<WSOp>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<()>757 async fn handle_command_tube(
758 command_tube: &AsyncTube,
759 interrupt: Interrupt,
760 state: Arc<AsyncRwLock<BalloonState>>,
761 mut stats_tx: mpsc::Sender<()>,
762 mut ws_op_tx: mpsc::Sender<WSOp>,
763 mut stop_rx: oneshot::Receiver<()>,
764 ) -> Result<()> {
765 loop {
766 let cmd_res = select_biased! {
767 res = command_tube.next().fuse() => res,
768 _ = stop_rx => return Ok(())
769 };
770 match cmd_res {
771 Ok(command) => match command {
772 BalloonTubeCommand::Adjust {
773 num_bytes,
774 allow_failure,
775 } => {
776 let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
777 let mut state = state.lock().await;
778
779 state.num_pages = num_pages;
780 interrupt.signal_config_changed();
781
782 if allow_failure {
783 if num_pages == state.actual_pages {
784 send_adjusted_response(command_tube, num_pages)
785 .await
786 .map_err(BalloonError::SendResponse)?;
787 } else {
788 state.failable_update = true;
789 }
790 }
791 }
792 BalloonTubeCommand::WorkingSetConfig {
793 bins,
794 refresh_threshold,
795 report_threshold,
796 } => {
797 if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
798 bins,
799 refresh_threshold,
800 report_threshold,
801 }) {
802 error!("failed to send config to ws handler: {}", e);
803 }
804 }
805 BalloonTubeCommand::Stats => {
806 if let Err(e) = stats_tx.try_send(()) {
807 error!("failed to signal the stat handler: {}", e);
808 }
809 }
810 BalloonTubeCommand::WorkingSet => {
811 if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
812 error!("failed to send report request to ws handler: {}", e);
813 }
814 }
815 },
816 #[cfg(windows)]
817 Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
818 // On Windows, async IO tasks like the next/recv above are cancelled as the VM is
819 // shutting down. For the sake of consistency with unix, we can't *just* return
820 // here; instead, we wait for the stop request to arrive, *and then* return.
821 //
822 // The real fix is to get rid of the global unblock pool, since then we won't
823 // cancel the tasks early (b/196911556).
824 let _ = stop_rx.await;
825 return Ok(());
826 }
827 Err(e) => {
828 return Err(BalloonError::ReceivingCommand(e));
829 }
830 }
831 }
832 }
833
handle_pending_adjusted_responses( pending_adjusted_response_event: EventAsync, command_tube: &AsyncTube, state: Arc<AsyncRwLock<BalloonState>>, ) -> Result<()>834 async fn handle_pending_adjusted_responses(
835 pending_adjusted_response_event: EventAsync,
836 command_tube: &AsyncTube,
837 state: Arc<AsyncRwLock<BalloonState>>,
838 ) -> Result<()> {
839 loop {
840 pending_adjusted_response_event
841 .next_val()
842 .await
843 .map_err(BalloonError::AsyncAwait)?;
844 while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
845 send_adjusted_response(command_tube, num_pages)
846 .await
847 .map_err(BalloonError::SendResponse)?;
848 }
849 }
850 }
851
852 /// Represents queues & events for the balloon device.
853 struct BalloonQueues {
854 inflate: Queue,
855 deflate: Queue,
856 stats: Option<Queue>,
857 reporting: Option<Queue>,
858 events: Option<Queue>,
859 ws: (Option<Queue>, Option<Queue>),
860 }
861
862 impl BalloonQueues {
new(inflate: Queue, deflate: Queue) -> Self863 fn new(inflate: Queue, deflate: Queue) -> Self {
864 BalloonQueues {
865 inflate,
866 deflate,
867 stats: None,
868 reporting: None,
869 events: None,
870 ws: (None, None),
871 }
872 }
873 }
874
875 /// When the worker is stopped, the queues are preserved here.
876 struct PausedQueues {
877 inflate: Queue,
878 deflate: Queue,
879 stats: Option<Queue>,
880 reporting: Option<Queue>,
881 events: Option<Queue>,
882 ws: (Option<Queue>, Option<Queue>),
883 }
884
885 impl PausedQueues {
new(inflate: Queue, deflate: Queue) -> Self886 fn new(inflate: Queue, deflate: Queue) -> Self {
887 PausedQueues {
888 inflate,
889 deflate,
890 stats: None,
891 reporting: None,
892 events: None,
893 ws: (None, None),
894 }
895 }
896 }
897
apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F) where F: FnMut(Queue) -> R,898 fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
899 where
900 F: FnMut(Queue) -> R,
901 {
902 if let Some(queue) = queue_opt {
903 func(queue);
904 }
905 }
906
907 impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue>908 fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
909 let mut ret = Vec::new();
910 ret.push(queues.inflate);
911 ret.push(queues.deflate);
912 apply_if_some(queues.stats, |stats| ret.push(stats));
913 apply_if_some(queues.reporting, |reporting| ret.push(reporting));
914 apply_if_some(queues.events, |events| ret.push(events));
915 apply_if_some(queues.ws.0, |ws_data| ret.push(ws_data));
916 apply_if_some(queues.ws.1, |ws_op| ret.push(ws_op));
917 // WARNING: We don't use the indices from the virito spec on purpose, see comment in
918 // get_queues_from_map for the rationale.
919 ret.into_iter().enumerate().collect()
920 }
921 }
922
923 /// Stores data from the worker when it stops so that data can be re-used when
924 /// the worker is restarted.
925 struct WorkerReturn {
926 release_memory_tube: Option<Tube>,
927 command_tube: Tube,
928 #[cfg(feature = "registered_events")]
929 registered_evt_q: Option<SendTube>,
930 paused_queues: Option<PausedQueues>,
931 vm_memory_client: VmMemoryClient,
932 }
933
934 // The main worker thread. Initialized the asynchronous worker tasks and passes them to the executor
935 // to be processed.
run_worker( inflate_queue: Queue, deflate_queue: Queue, stats_queue: Option<Queue>, reporting_queue: Option<Queue>, events_queue: Option<Queue>, ws_queues: (Option<Queue>, Option<Queue>), command_tube: Tube, vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, interrupt: Interrupt, kill_evt: Event, target_reached_evt: Event, pending_adjusted_response_event: Event, mem: GuestMemory, state: Arc<AsyncRwLock<BalloonState>>, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ) -> WorkerReturn936 fn run_worker(
937 inflate_queue: Queue,
938 deflate_queue: Queue,
939 stats_queue: Option<Queue>,
940 reporting_queue: Option<Queue>,
941 events_queue: Option<Queue>,
942 ws_queues: (Option<Queue>, Option<Queue>),
943 command_tube: Tube,
944 vm_memory_client: VmMemoryClient,
945 release_memory_tube: Option<Tube>,
946 interrupt: Interrupt,
947 kill_evt: Event,
948 target_reached_evt: Event,
949 pending_adjusted_response_event: Event,
950 mem: GuestMemory,
951 state: Arc<AsyncRwLock<BalloonState>>,
952 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
953 ) -> WorkerReturn {
954 let ex = Executor::new().unwrap();
955 let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
956 #[cfg(feature = "registered_events")]
957 let registered_evt_q_async = registered_evt_q
958 .as_ref()
959 .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
960
961 let mut stop_queue_oneshots = Vec::new();
962
963 // We need a block to release all references to command_tube at the end before returning it.
964 let paused_queues = {
965 // The first queue is used for inflate messages
966 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
967 let inflate_queue_evt = inflate_queue
968 .event()
969 .try_clone()
970 .expect("failed to clone queue event");
971 let inflate = handle_queue(
972 inflate_queue,
973 EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
974 release_memory_tube.as_ref(),
975 interrupt.clone(),
976 |guest_address, len| {
977 sys::free_memory(
978 &guest_address,
979 len,
980 &vm_memory_client,
981 )
982 },
983 stop_rx,
984 );
985 let inflate = inflate.fuse();
986 pin_mut!(inflate);
987
988 // The second queue is used for deflate messages
989 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
990 let deflate_queue_evt = deflate_queue
991 .event()
992 .try_clone()
993 .expect("failed to clone queue event");
994 let deflate = handle_queue(
995 deflate_queue,
996 EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
997 None,
998 interrupt.clone(),
999 |guest_address, len| {
1000 sys::reclaim_memory(
1001 &guest_address,
1002 len,
1003 &vm_memory_client,
1004 )
1005 },
1006 stop_rx,
1007 );
1008 let deflate = deflate.fuse();
1009 pin_mut!(deflate);
1010
1011 // The next queue is used for stats messages if VIRTIO_BALLOON_F_STATS_VQ is negotiated.
1012 let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
1013 let has_stats_queue = stats_queue.is_some();
1014 let stats = if let Some(stats_queue) = stats_queue {
1015 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1016 let stats_queue_evt = stats_queue
1017 .event()
1018 .try_clone()
1019 .expect("failed to clone queue event");
1020 handle_stats_queue(
1021 stats_queue,
1022 EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
1023 stats_rx,
1024 &command_tube,
1025 #[cfg(feature = "registered_events")]
1026 registered_evt_q_async.as_ref(),
1027 state.clone(),
1028 interrupt.clone(),
1029 stop_rx,
1030 )
1031 .left_future()
1032 } else {
1033 std::future::pending().right_future()
1034 };
1035 let stats = stats.fuse();
1036 pin_mut!(stats);
1037
1038 // The next queue is used for reporting messages
1039 let has_reporting_queue = reporting_queue.is_some();
1040 let reporting = if let Some(reporting_queue) = reporting_queue {
1041 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1042 let reporting_queue_evt = reporting_queue
1043 .event()
1044 .try_clone()
1045 .expect("failed to clone queue event");
1046 handle_reporting_queue(
1047 reporting_queue,
1048 EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
1049 release_memory_tube.as_ref(),
1050 interrupt.clone(),
1051 |guest_address, len| {
1052 sys::free_memory(
1053 &guest_address,
1054 len,
1055 &vm_memory_client,
1056 )
1057 },
1058 stop_rx,
1059 )
1060 .left_future()
1061 } else {
1062 std::future::pending().right_future()
1063 };
1064 let reporting = reporting.fuse();
1065 pin_mut!(reporting);
1066
1067 // If VIRTIO_BALLOON_F_WS_REPORTING is set 2 queues must handled - one for WS data and one
1068 // for WS notifications.
1069 let has_ws_data_queue = ws_queues.0.is_some();
1070 let ws_data = if let Some(ws_data_queue) = ws_queues.0 {
1071 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1072 let ws_data_queue_evt = ws_data_queue
1073 .event()
1074 .try_clone()
1075 .expect("failed to clone queue event");
1076 handle_ws_data_queue(
1077 ws_data_queue,
1078 EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1079 &command_tube,
1080 #[cfg(feature = "registered_events")]
1081 registered_evt_q_async.as_ref(),
1082 state.clone(),
1083 interrupt.clone(),
1084 stop_rx,
1085 )
1086 .left_future()
1087 } else {
1088 std::future::pending().right_future()
1089 };
1090 let ws_data = ws_data.fuse();
1091 pin_mut!(ws_data);
1092
1093 let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1094 let has_ws_op_queue = ws_queues.1.is_some();
1095 let ws_op = if let Some(ws_op_queue) = ws_queues.1 {
1096 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1097 let ws_op_queue_evt = ws_op_queue
1098 .event()
1099 .try_clone()
1100 .expect("failed to clone queue event");
1101 handle_ws_op_queue(
1102 ws_op_queue,
1103 EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1104 ws_op_rx,
1105 state.clone(),
1106 interrupt.clone(),
1107 stop_rx,
1108 )
1109 .left_future()
1110 } else {
1111 std::future::pending().right_future()
1112 };
1113 let ws_op = ws_op.fuse();
1114 pin_mut!(ws_op);
1115
1116 // Future to handle command messages that resize the balloon.
1117 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1118 let command = handle_command_tube(
1119 &command_tube,
1120 interrupt.clone(),
1121 state.clone(),
1122 stats_tx,
1123 ws_op_tx,
1124 stop_rx,
1125 );
1126 pin_mut!(command);
1127
1128 // Process any requests to resample the irq value.
1129 let resample = async_utils::handle_irq_resample(&ex, interrupt.clone());
1130 pin_mut!(resample);
1131
1132 // Send a message if balloon target reached event is triggered.
1133 let target_reached = handle_target_reached(
1134 &ex,
1135 target_reached_evt,
1136 &vm_memory_client,
1137 );
1138 pin_mut!(target_reached);
1139
1140 // Exit if the kill event is triggered.
1141 let kill = async_utils::await_and_exit(&ex, kill_evt);
1142 pin_mut!(kill);
1143
1144 // The next queue is used for events if VIRTIO_BALLOON_F_EVENTS_VQ is negotiated.
1145 let has_events_queue = events_queue.is_some();
1146 let events = if let Some(events_queue) = events_queue {
1147 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1148 let events_queue_evt = events_queue
1149 .event()
1150 .try_clone()
1151 .expect("failed to clone queue event");
1152 handle_events_queue(
1153 events_queue,
1154 EventAsync::new(events_queue_evt, &ex).expect("failed to create async event"),
1155 state.clone(),
1156 interrupt,
1157 &command_tube,
1158 stop_rx,
1159 )
1160 .left_future()
1161 } else {
1162 std::future::pending().right_future()
1163 };
1164 let events = events.fuse();
1165 pin_mut!(events);
1166
1167 let pending_adjusted = handle_pending_adjusted_responses(
1168 EventAsync::new(pending_adjusted_response_event, &ex)
1169 .expect("failed to create async event"),
1170 &command_tube,
1171 state,
1172 );
1173 pin_mut!(pending_adjusted);
1174
1175 let res = ex.run_until(async {
1176 select! {
1177 _ = kill.fuse() => (),
1178 _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1179 _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1180 _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1181 _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1182 _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1183 _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1184 _ = resample.fuse() => return Err(anyhow!("resample stopped unexpectedly")),
1185 _ = events => return Err(anyhow!("events stopped unexpectedly")),
1186 _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1187 _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1188 _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1189 }
1190
1191 // Worker is shutting down. To recover the queues, we have to signal
1192 // all the queue futures to exit.
1193 for stop_tx in stop_queue_oneshots {
1194 if stop_tx.send(()).is_err() {
1195 return Err(anyhow!("failed to request stop for queue future"));
1196 }
1197 }
1198
1199 // Collect all the queues (awaiting any queue future should now
1200 // return its Queue immediately).
1201 let mut paused_queues = PausedQueues::new(
1202 inflate.await,
1203 deflate.await,
1204 );
1205 if has_reporting_queue {
1206 paused_queues.reporting = Some(reporting.await);
1207 }
1208 if has_events_queue {
1209 paused_queues.events = Some(events.await.context("failed to stop events queue")?);
1210 }
1211 if has_stats_queue {
1212 paused_queues.stats = Some(stats.await);
1213 }
1214 if has_ws_op_queue {
1215 paused_queues.ws.0 = Some(ws_op.await.context("failed to stop ws_op queue")?);
1216 }
1217 if has_ws_data_queue {
1218 paused_queues.ws.1 = Some(ws_data.await.context("failed to stop ws_data queue")?);
1219 }
1220 Ok(paused_queues)
1221 });
1222
1223 match res {
1224 Err(e) => {
1225 error!("error happened in executor: {}", e);
1226 None
1227 }
1228 Ok(main_future_res) => match main_future_res {
1229 Ok(paused_queues) => Some(paused_queues),
1230 Err(e) => {
1231 error!("error happened in main balloon future: {}", e);
1232 None
1233 }
1234 },
1235 }
1236 };
1237
1238 WorkerReturn {
1239 command_tube: command_tube.into(),
1240 paused_queues,
1241 release_memory_tube,
1242 #[cfg(feature = "registered_events")]
1243 registered_evt_q,
1244 vm_memory_client,
1245 }
1246 }
1247
handle_target_reached( ex: &Executor, target_reached_evt: Event, vm_memory_client: &VmMemoryClient, ) -> anyhow::Result<()>1248 async fn handle_target_reached(
1249 ex: &Executor,
1250 target_reached_evt: Event,
1251 vm_memory_client: &VmMemoryClient,
1252 ) -> anyhow::Result<()> {
1253 let event_async =
1254 EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1255 loop {
1256 // Wait for target reached trigger.
1257 let _ = event_async.next_val().await;
1258 // Send the message to vm_control on the event. We don't have to read the current
1259 // size yet.
1260 sys::balloon_target_reached(
1261 0,
1262 vm_memory_client,
1263 );
1264 }
1265 // The above loop will never terminate and there is no reason to terminate it either. However,
1266 // the function is used in an executor that expects a Result<> return. Make sure that clippy
1267 // doesn't enforce the unreachable_code condition.
1268 #[allow(unreachable_code)]
1269 Ok(())
1270 }
1271
1272 /// Virtio device for memory balloon inflation/deflation.
1273 pub struct Balloon {
1274 command_tube: Option<Tube>,
1275 vm_memory_client: Option<VmMemoryClient>,
1276 release_memory_tube: Option<Tube>,
1277 pending_adjusted_response_event: Event,
1278 state: Arc<AsyncRwLock<BalloonState>>,
1279 features: u64,
1280 acked_features: u64,
1281 worker_thread: Option<WorkerThread<WorkerReturn>>,
1282 #[cfg(feature = "registered_events")]
1283 registered_evt_q: Option<SendTube>,
1284 ws_num_bins: u8,
1285 target_reached_evt: Option<Event>,
1286 }
1287
1288 /// Snapshot of the [Balloon] state.
1289 #[derive(Serialize, Deserialize)]
1290 struct BalloonSnapshot {
1291 state: BalloonState,
1292 features: u64,
1293 acked_features: u64,
1294 ws_num_bins: u8,
1295 }
1296
1297 /// Operation mode of the balloon.
1298 #[derive(PartialEq, Eq)]
1299 pub enum BalloonMode {
1300 /// The driver can access pages in the balloon (i.e. F_DEFLATE_ON_OOM)
1301 Relaxed,
1302 /// The driver cannot access pages in the balloon. Implies F_RESPONSIVE_DEVICE.
1303 Strict,
1304 }
1305
1306 impl Balloon {
1307 /// Creates a new virtio balloon device.
1308 /// To let Balloon able to successfully release the memory which are pinned
1309 /// by CoIOMMU to host, the release_memory_tube will be used to send the inflate
1310 /// ranges to CoIOMMU with UnpinRequest/UnpinResponse messages, so that The
1311 /// memory in the inflate range can be unpinned first.
new( base_features: u64, command_tube: Tube, vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, init_balloon_size: u64, mode: BalloonMode, enabled_features: u64, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ws_num_bins: u8, ) -> Result<Balloon>1312 pub fn new(
1313 base_features: u64,
1314 command_tube: Tube,
1315 vm_memory_client: VmMemoryClient,
1316 release_memory_tube: Option<Tube>,
1317 init_balloon_size: u64,
1318 mode: BalloonMode,
1319 enabled_features: u64,
1320 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1321 ws_num_bins: u8,
1322 ) -> Result<Balloon> {
1323 let features = base_features
1324 | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1325 | 1 << VIRTIO_BALLOON_F_STATS_VQ
1326 | 1 << VIRTIO_BALLOON_F_EVENTS_VQ
1327 | enabled_features
1328 | if mode == BalloonMode::Strict {
1329 1 << VIRTIO_BALLOON_F_RESPONSIVE_DEVICE
1330 } else {
1331 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1332 };
1333
1334 Ok(Balloon {
1335 command_tube: Some(command_tube),
1336 vm_memory_client: Some(vm_memory_client),
1337 release_memory_tube,
1338 pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1339 state: Arc::new(AsyncRwLock::new(BalloonState {
1340 num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1341 actual_pages: 0,
1342 failable_update: false,
1343 pending_adjusted_responses: VecDeque::new(),
1344 expecting_ws: false,
1345 })),
1346 worker_thread: None,
1347 features,
1348 acked_features: 0,
1349 #[cfg(feature = "registered_events")]
1350 registered_evt_q,
1351 ws_num_bins,
1352 target_reached_evt: None,
1353 })
1354 }
1355
get_config(&self) -> virtio_balloon_config1356 fn get_config(&self) -> virtio_balloon_config {
1357 let state = block_on(self.state.lock());
1358 virtio_balloon_config {
1359 num_pages: state.num_pages.into(),
1360 actual: state.actual_pages.into(),
1361 // crosvm does not (currently) use free_page_hint_cmd_id or
1362 // poison_val, but they must be present in the right order and size
1363 // for the virtio-balloon driver in the guest to deserialize the
1364 // config correctly.
1365 free_page_hint_cmd_id: 0.into(),
1366 poison_val: 0.into(),
1367 ws_num_bins: self.ws_num_bins,
1368 _reserved: [0, 0, 0],
1369 }
1370 }
1371
num_expected_queues(acked_features: u64) -> usize1372 fn num_expected_queues(acked_features: u64) -> usize {
1373 // at minimum we have inflate and deflate vqueues.
1374 let mut num_queues = 2;
1375 // stats vqueue
1376 if acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1377 num_queues += 1;
1378 }
1379 // events vqueue
1380 if acked_features & (1 << VIRTIO_BALLOON_F_EVENTS_VQ) != 0 {
1381 num_queues += 1;
1382 }
1383 // page reporting vqueue
1384 if acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1385 num_queues += 1;
1386 }
1387 // working set vqueues
1388 if acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1389 num_queues += 2;
1390 }
1391
1392 num_queues
1393 }
1394
stop_worker(&mut self) -> StoppedWorker<PausedQueues>1395 fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1396 if let Some(worker_thread) = self.worker_thread.take() {
1397 let worker_ret = worker_thread.stop();
1398 self.release_memory_tube = worker_ret.release_memory_tube;
1399 self.command_tube = Some(worker_ret.command_tube);
1400 #[cfg(feature = "registered_events")]
1401 {
1402 self.registered_evt_q = worker_ret.registered_evt_q;
1403 }
1404 self.vm_memory_client = Some(worker_ret.vm_memory_client);
1405
1406 if let Some(queues) = worker_ret.paused_queues {
1407 StoppedWorker::WithQueues(Box::new(queues))
1408 } else {
1409 StoppedWorker::MissingQueues
1410 }
1411 } else {
1412 StoppedWorker::AlreadyStopped
1413 }
1414 }
1415
1416 /// Given a filtered queue vector from [VirtioDevice::activate], extract
1417 /// the queues (accounting for queues that are missing because the features
1418 /// are not negotiated) into a structure that is easier to work with.
get_queues_from_map( &self, mut queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<BalloonQueues>1419 fn get_queues_from_map(
1420 &self,
1421 mut queues: BTreeMap<usize, Queue>,
1422 ) -> anyhow::Result<BalloonQueues> {
1423 let expected_queues = Balloon::num_expected_queues(self.acked_features);
1424 if queues.len() != expected_queues {
1425 return Err(anyhow!(
1426 "expected {} queues, got {}",
1427 expected_queues,
1428 queues.len()
1429 ));
1430 }
1431
1432 // WARNING: We use `pop_first` instead of explicitly using the indices from the virtio spec
1433 // because the Linux virtio drivers only "allocates" queue indices that are used.
1434 let inflate_queue = queues.pop_first().unwrap().1;
1435 let deflate_queue = queues.pop_first().unwrap().1;
1436 let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1437
1438 if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1439 queue_struct.stats = Some(queues.pop_first().unwrap().1);
1440 }
1441 if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1442 queue_struct.reporting = Some(queues.pop_first().unwrap().1);
1443 }
1444 if self.acked_features & (1 << VIRTIO_BALLOON_F_EVENTS_VQ) != 0 {
1445 queue_struct.events = Some(queues.pop_first().unwrap().1);
1446 }
1447 if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1448 queue_struct.ws = (
1449 Some(queues.pop_first().unwrap().1),
1450 Some(queues.pop_first().unwrap().1),
1451 );
1452 }
1453 Ok(queue_struct)
1454 }
1455
start_worker( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BalloonQueues, ) -> anyhow::Result<()>1456 fn start_worker(
1457 &mut self,
1458 mem: GuestMemory,
1459 interrupt: Interrupt,
1460 queues: BalloonQueues,
1461 ) -> anyhow::Result<()> {
1462 let (self_target_reached_evt, target_reached_evt) = Event::new()
1463 .and_then(|e| Ok((e.try_clone()?, e)))
1464 .context("failed to create target_reached Event pair: {}")?;
1465 self.target_reached_evt = Some(self_target_reached_evt);
1466
1467 let state = self.state.clone();
1468
1469 let command_tube = self.command_tube.take().unwrap();
1470
1471 let vm_memory_client = self.vm_memory_client.take().unwrap();
1472 let release_memory_tube = self.release_memory_tube.take();
1473 #[cfg(feature = "registered_events")]
1474 let registered_evt_q = self.registered_evt_q.take();
1475 let pending_adjusted_response_event = self
1476 .pending_adjusted_response_event
1477 .try_clone()
1478 .context("failed to clone Event")?;
1479
1480 self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1481 run_worker(
1482 queues.inflate,
1483 queues.deflate,
1484 queues.stats,
1485 queues.reporting,
1486 queues.events,
1487 queues.ws,
1488 command_tube,
1489 vm_memory_client,
1490 release_memory_tube,
1491 interrupt,
1492 kill_evt,
1493 target_reached_evt,
1494 pending_adjusted_response_event,
1495 mem,
1496 state,
1497 #[cfg(feature = "registered_events")]
1498 registered_evt_q,
1499 )
1500 }));
1501
1502 Ok(())
1503 }
1504 }
1505
1506 impl VirtioDevice for Balloon {
keep_rds(&self) -> Vec<RawDescriptor>1507 fn keep_rds(&self) -> Vec<RawDescriptor> {
1508 let mut rds = Vec::new();
1509 if let Some(command_tube) = &self.command_tube {
1510 rds.push(command_tube.as_raw_descriptor());
1511 }
1512 if let Some(release_memory_tube) = &self.release_memory_tube {
1513 rds.push(release_memory_tube.as_raw_descriptor());
1514 }
1515 #[cfg(feature = "registered_events")]
1516 if let Some(registered_evt_q) = &self.registered_evt_q {
1517 rds.push(registered_evt_q.as_raw_descriptor());
1518 }
1519 rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1520 rds
1521 }
1522
device_type(&self) -> DeviceType1523 fn device_type(&self) -> DeviceType {
1524 DeviceType::Balloon
1525 }
1526
queue_max_sizes(&self) -> &[u16]1527 fn queue_max_sizes(&self) -> &[u16] {
1528 QUEUE_SIZES
1529 }
1530
read_config(&self, offset: u64, data: &mut [u8])1531 fn read_config(&self, offset: u64, data: &mut [u8]) {
1532 copy_config(data, 0, self.get_config().as_bytes(), offset);
1533 }
1534
write_config(&mut self, offset: u64, data: &[u8])1535 fn write_config(&mut self, offset: u64, data: &[u8]) {
1536 let mut config = self.get_config();
1537 copy_config(config.as_bytes_mut(), offset, data, 0);
1538 let mut state = block_on(self.state.lock());
1539 state.actual_pages = config.actual.to_native();
1540
1541 // If balloon has updated to the requested memory, let the hypervisor know.
1542 if config.num_pages == config.actual {
1543 debug!(
1544 "sending target reached event at {}",
1545 u32::from(config.num_pages)
1546 );
1547 self.target_reached_evt.as_ref().map(|e| e.signal());
1548 }
1549 if state.failable_update && state.actual_pages == state.num_pages {
1550 state.failable_update = false;
1551 let num_pages = state.num_pages;
1552 state.pending_adjusted_responses.push_back(num_pages);
1553 let _ = self.pending_adjusted_response_event.signal();
1554 }
1555 }
1556
features(&self) -> u641557 fn features(&self) -> u64 {
1558 self.features
1559 }
1560
ack_features(&mut self, mut value: u64)1561 fn ack_features(&mut self, mut value: u64) {
1562 if value & !self.features != 0 {
1563 warn!("virtio_balloon got unknown feature ack {:x}", value);
1564 value &= self.features;
1565 }
1566 self.acked_features |= value;
1567 }
1568
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>1569 fn activate(
1570 &mut self,
1571 mem: GuestMemory,
1572 interrupt: Interrupt,
1573 queues: BTreeMap<usize, Queue>,
1574 ) -> anyhow::Result<()> {
1575 let queues = self.get_queues_from_map(queues)?;
1576 self.start_worker(mem, interrupt, queues)
1577 }
1578
reset(&mut self) -> anyhow::Result<()>1579 fn reset(&mut self) -> anyhow::Result<()> {
1580 let _worker = self.stop_worker();
1581 Ok(())
1582 }
1583
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>1584 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1585 match self.stop_worker() {
1586 StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1587 StoppedWorker::MissingQueues => {
1588 anyhow::bail!("balloon queue workers did not stop cleanly.")
1589 }
1590 StoppedWorker::AlreadyStopped => {
1591 // Device hasn't been activated.
1592 Ok(None)
1593 }
1594 }
1595 }
1596
virtio_wake( &mut self, queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>1597 fn virtio_wake(
1598 &mut self,
1599 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1600 ) -> anyhow::Result<()> {
1601 if let Some((mem, interrupt, queues)) = queues_state {
1602 if queues.len() < 2 {
1603 anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1604 }
1605
1606 let balloon_queues = self.get_queues_from_map(queues)?;
1607 self.start_worker(mem, interrupt, balloon_queues)?;
1608 }
1609 Ok(())
1610 }
1611
virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value>1612 fn virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
1613 let state = self
1614 .state
1615 .lock()
1616 .now_or_never()
1617 .context("failed to acquire balloon lock")?;
1618 serde_json::to_value(BalloonSnapshot {
1619 features: self.features,
1620 acked_features: self.acked_features,
1621 state: state.clone(),
1622 ws_num_bins: self.ws_num_bins,
1623 })
1624 .context("failed to serialize balloon state")
1625 }
1626
virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>1627 fn virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
1628 let snap: BalloonSnapshot = serde_json::from_value(data).context("error deserializing")?;
1629 if snap.features != self.features {
1630 anyhow::bail!(
1631 "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1632 self.features,
1633 snap.features,
1634 );
1635 }
1636
1637 let mut state = self
1638 .state
1639 .lock()
1640 .now_or_never()
1641 .context("failed to acquire balloon lock")?;
1642 *state = snap.state;
1643 self.ws_num_bins = snap.ws_num_bins;
1644 self.acked_features = snap.acked_features;
1645 Ok(())
1646 }
1647 }
1648
1649 #[cfg(test)]
1650 mod tests {
1651 use super::*;
1652 use crate::suspendable_virtio_tests;
1653 use crate::virtio::descriptor_utils::create_descriptor_chain;
1654 use crate::virtio::descriptor_utils::DescriptorType;
1655
1656 #[test]
desc_parsing_inflate()1657 fn desc_parsing_inflate() {
1658 // Check that the memory addresses are parsed correctly by 'handle_address_chain' and passed
1659 // to the closure.
1660 let memory_start_addr = GuestAddress(0x0);
1661 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1662 memory
1663 .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1664 .unwrap();
1665 memory
1666 .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1667 .unwrap();
1668
1669 let mut chain = create_descriptor_chain(
1670 &memory,
1671 GuestAddress(0x0),
1672 GuestAddress(0x100),
1673 vec![(DescriptorType::Readable, 8)],
1674 0,
1675 )
1676 .expect("create_descriptor_chain failed");
1677
1678 let mut addrs = Vec::new();
1679 let res = handle_address_chain(None, &mut chain, &mut |guest_address, len| {
1680 addrs.push((guest_address, len));
1681 });
1682 assert!(res.is_ok());
1683 assert_eq!(addrs.len(), 2);
1684 assert_eq!(
1685 addrs[0].0,
1686 GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1687 );
1688 assert_eq!(
1689 addrs[1].0,
1690 GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1691 );
1692 }
1693
1694 #[test]
num_expected_queues()1695 fn num_expected_queues() {
1696 let to_feature_bits =
1697 |features: &[u32]| -> u64 { features.iter().fold(0, |acc, f| acc | (1_u64 << f)) };
1698
1699 assert_eq!(2, Balloon::num_expected_queues(0));
1700 assert_eq!(
1701 2,
1702 Balloon::num_expected_queues(to_feature_bits(&[VIRTIO_BALLOON_F_MUST_TELL_HOST]))
1703 );
1704 assert_eq!(
1705 3,
1706 Balloon::num_expected_queues(to_feature_bits(&[VIRTIO_BALLOON_F_STATS_VQ]))
1707 );
1708 assert_eq!(
1709 5,
1710 Balloon::num_expected_queues(to_feature_bits(&[
1711 VIRTIO_BALLOON_F_STATS_VQ,
1712 VIRTIO_BALLOON_F_EVENTS_VQ,
1713 VIRTIO_BALLOON_F_PAGE_REPORTING
1714 ]))
1715 );
1716 assert_eq!(
1717 7,
1718 Balloon::num_expected_queues(to_feature_bits(&[
1719 VIRTIO_BALLOON_F_STATS_VQ,
1720 VIRTIO_BALLOON_F_EVENTS_VQ,
1721 VIRTIO_BALLOON_F_PAGE_REPORTING,
1722 VIRTIO_BALLOON_F_WS_REPORTING
1723 ]))
1724 );
1725 }
1726
1727 struct BalloonContext {
1728 _ctrl_tube: Tube,
1729 _mem_client_tube: Tube,
1730 }
1731
modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon)1732 fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1733 balloon.ws_num_bins = !balloon.ws_num_bins;
1734 }
1735
create_device() -> (BalloonContext, Balloon)1736 fn create_device() -> (BalloonContext, Balloon) {
1737 let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1738 let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1739 (
1740 BalloonContext {
1741 _ctrl_tube,
1742 _mem_client_tube,
1743 },
1744 Balloon::new(
1745 0,
1746 ctrl_tube_device,
1747 VmMemoryClient::new(mem_client_tube_device),
1748 None,
1749 1024,
1750 BalloonMode::Relaxed,
1751 0,
1752 #[cfg(feature = "registered_events")]
1753 None,
1754 0,
1755 )
1756 .unwrap(),
1757 )
1758 }
1759
1760 suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1761 }
1762