• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 // virtio-sound spec: https://github.com/oasis-tcs/virtio-spec/blob/master/virtio-sound.tex
6 
7 use std::collections::BTreeMap;
8 use std::io;
9 use std::rc::Rc;
10 use std::sync::Arc;
11 
12 use anyhow::anyhow;
13 use anyhow::Context;
14 use audio_streams::BoxError;
15 use base::debug;
16 use base::error;
17 use base::warn;
18 use base::AsRawDescriptor;
19 use base::Descriptor;
20 use base::Error as SysError;
21 use base::Event;
22 use base::RawDescriptor;
23 use base::WorkerThread;
24 use cros_async::block_on;
25 use cros_async::sync::Condvar;
26 use cros_async::sync::RwLock as AsyncRwLock;
27 use cros_async::AsyncError;
28 use cros_async::EventAsync;
29 use cros_async::Executor;
30 use futures::channel::mpsc;
31 use futures::channel::oneshot;
32 use futures::channel::oneshot::Canceled;
33 use futures::future::FusedFuture;
34 use futures::join;
35 use futures::pin_mut;
36 use futures::select;
37 use futures::Future;
38 use futures::FutureExt;
39 use serde::Deserialize;
40 use serde::Serialize;
41 use thiserror::Error as ThisError;
42 use vm_memory::GuestMemory;
43 use zerocopy::AsBytes;
44 
45 use crate::virtio::async_utils;
46 use crate::virtio::copy_config;
47 use crate::virtio::device_constants::snd::virtio_snd_config;
48 use crate::virtio::snd::common_backend::async_funcs::*;
49 use crate::virtio::snd::common_backend::stream_info::StreamInfo;
50 use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
51 use crate::virtio::snd::common_backend::stream_info::StreamInfoSnapshot;
52 use crate::virtio::snd::constants::*;
53 use crate::virtio::snd::file_backend::create_file_stream_source_generators;
54 use crate::virtio::snd::file_backend::Error as FileError;
55 use crate::virtio::snd::layout::*;
56 use crate::virtio::snd::null_backend::create_null_stream_source_generators;
57 use crate::virtio::snd::parameters::Parameters;
58 use crate::virtio::snd::parameters::StreamSourceBackend;
59 use crate::virtio::snd::sys::create_stream_source_generators as sys_create_stream_source_generators;
60 use crate::virtio::snd::sys::set_audio_thread_priority;
61 use crate::virtio::snd::sys::SysAsyncStreamObjects;
62 use crate::virtio::snd::sys::SysAudioStreamSourceGenerator;
63 use crate::virtio::snd::sys::SysDirectionOutput;
64 use crate::virtio::DescriptorChain;
65 use crate::virtio::DeviceType;
66 use crate::virtio::Interrupt;
67 use crate::virtio::Queue;
68 use crate::virtio::VirtioDevice;
69 
70 pub mod async_funcs;
71 pub mod stream_info;
72 
73 // control + event + tx + rx queue
74 pub const MAX_QUEUE_NUM: usize = 4;
75 pub const MAX_VRING_LEN: u16 = 1024;
76 
77 #[derive(ThisError, Debug)]
78 pub enum Error {
79     /// next_async failed.
80     #[error("Failed to read descriptor asynchronously: {0}")]
81     Async(AsyncError),
82     /// Creating stream failed.
83     #[error("Failed to create stream: {0}")]
84     CreateStream(BoxError),
85     /// Creating stream failed.
86     #[error("No stream source found.")]
87     EmptyStreamSource,
88     /// Creating kill event failed.
89     #[error("Failed to create kill event: {0}")]
90     CreateKillEvent(SysError),
91     /// Creating WaitContext failed.
92     #[error("Failed to create wait context: {0}")]
93     CreateWaitContext(SysError),
94     #[error("Failed to create file stream source generator")]
95     CreateFileStreamSourceGenerator(FileError),
96     /// Cloning kill event failed.
97     #[error("Failed to clone kill event: {0}")]
98     CloneKillEvent(SysError),
99     // Future error.
100     #[error("Unexpected error. Done was not triggered before dropped: {0}")]
101     DoneNotTriggered(Canceled),
102     /// Error reading message from queue.
103     #[error("Failed to read message: {0}")]
104     ReadMessage(io::Error),
105     /// Failed writing a response to a control message.
106     #[error("Failed to write message response: {0}")]
107     WriteResponse(io::Error),
108     // Mpsc read error.
109     #[error("Error in mpsc: {0}")]
110     MpscSend(futures::channel::mpsc::SendError),
111     // Oneshot send error.
112     #[error("Error in oneshot send")]
113     OneshotSend(()),
114     /// Stream not found.
115     #[error("stream id ({0}) < num_streams ({1})")]
116     StreamNotFound(usize, usize),
117     /// Fetch buffer error
118     #[error("Failed to get buffer from CRAS: {0}")]
119     FetchBuffer(BoxError),
120     /// Invalid buffer size
121     #[error("Invalid buffer size")]
122     InvalidBufferSize,
123     /// IoError
124     #[error("I/O failed: {0}")]
125     Io(io::Error),
126     /// Operation not supported.
127     #[error("Operation not supported")]
128     OperationNotSupported,
129     /// Writing to a buffer in the guest failed.
130     #[error("failed to write to buffer: {0}")]
131     WriteBuffer(io::Error),
132     // Invalid PCM worker state.
133     #[error("Invalid PCM worker state")]
134     InvalidPCMWorkerState,
135     // Invalid backend.
136     #[error("Backend is not implemented")]
137     InvalidBackend,
138     // Failed to generate StreamSource
139     #[error("Failed to generate stream source: {0}")]
140     GenerateStreamSource(BoxError),
141     // PCM worker unexpectedly quitted.
142     #[error("PCM worker quitted unexpectedly")]
143     PCMWorkerQuittedUnexpectedly,
144 }
145 
146 pub enum DirectionalStream {
147     Input(
148         usize, // `period_size` in `usize`
149         Box<dyn CaptureBufferReader>,
150     ),
151     Output(SysDirectionOutput),
152 }
153 
154 #[derive(Copy, Clone, std::cmp::PartialEq, Eq)]
155 pub enum WorkerStatus {
156     Pause = 0,
157     Running = 1,
158     Quit = 2,
159 }
160 
161 // Stores constant data
162 #[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
163 pub struct SndData {
164     pub(crate) jack_info: Vec<virtio_snd_jack_info>,
165     pub(crate) pcm_info: Vec<virtio_snd_pcm_info>,
166     pub(crate) chmap_info: Vec<virtio_snd_chmap_info>,
167 }
168 
169 impl SndData {
pcm_info_len(&self) -> usize170     pub fn pcm_info_len(&self) -> usize {
171         self.pcm_info.len()
172     }
173 
pcm_info_iter(&self) -> std::slice::Iter<'_, virtio_snd_pcm_info>174     pub fn pcm_info_iter(&self) -> std::slice::Iter<'_, virtio_snd_pcm_info> {
175         self.pcm_info.iter()
176     }
177 }
178 
179 const SUPPORTED_FORMATS: u64 = 1 << VIRTIO_SND_PCM_FMT_U8
180     | 1 << VIRTIO_SND_PCM_FMT_S16
181     | 1 << VIRTIO_SND_PCM_FMT_S24
182     | 1 << VIRTIO_SND_PCM_FMT_S32;
183 const SUPPORTED_FRAME_RATES: u64 = 1 << VIRTIO_SND_PCM_RATE_8000
184     | 1 << VIRTIO_SND_PCM_RATE_11025
185     | 1 << VIRTIO_SND_PCM_RATE_16000
186     | 1 << VIRTIO_SND_PCM_RATE_22050
187     | 1 << VIRTIO_SND_PCM_RATE_32000
188     | 1 << VIRTIO_SND_PCM_RATE_44100
189     | 1 << VIRTIO_SND_PCM_RATE_48000;
190 
191 // Response from pcm_worker to pcm_queue
192 pub struct PcmResponse {
193     pub(crate) desc_chain: DescriptorChain,
194     pub(crate) status: virtio_snd_pcm_status, // response to the pcm message
195     pub(crate) done: Option<oneshot::Sender<()>>, // when pcm response is written to the queue
196 }
197 
198 pub struct VirtioSnd {
199     cfg: virtio_snd_config,
200     snd_data: SndData,
201     stream_info_builders: Vec<StreamInfoBuilder>,
202     avail_features: u64,
203     acked_features: u64,
204     queue_sizes: Box<[u16]>,
205     worker_thread: Option<WorkerThread<Result<WorkerReturn, String>>>,
206     keep_rds: Vec<Descriptor>,
207     streams_state: Option<Vec<StreamInfoSnapshot>>,
208 }
209 
210 #[derive(Serialize, Deserialize)]
211 struct VirtioSndSnapshot {
212     avail_features: u64,
213     acked_features: u64,
214     queue_sizes: Vec<u16>,
215     streams_state: Option<Vec<StreamInfoSnapshot>>,
216     snd_data: SndData,
217 }
218 
219 impl VirtioSnd {
new(base_features: u64, params: Parameters) -> Result<VirtioSnd, Error>220     pub fn new(base_features: u64, params: Parameters) -> Result<VirtioSnd, Error> {
221         let params = resize_parameters_pcm_device_config(params);
222         let cfg = hardcoded_virtio_snd_config(&params);
223         let snd_data = hardcoded_snd_data(&params);
224         let avail_features = base_features;
225         let mut keep_rds: Vec<RawDescriptor> = Vec::new();
226 
227         let stream_info_builders = create_stream_info_builders(&params, &snd_data, &mut keep_rds)?;
228 
229         Ok(VirtioSnd {
230             cfg,
231             snd_data,
232             stream_info_builders,
233             avail_features,
234             acked_features: 0,
235             queue_sizes: vec![MAX_VRING_LEN; MAX_QUEUE_NUM].into_boxed_slice(),
236             worker_thread: None,
237             keep_rds: keep_rds.iter().map(|rd| Descriptor(*rd)).collect(),
238             streams_state: None,
239         })
240     }
241 }
242 
create_stream_source_generators( params: &Parameters, snd_data: &SndData, keep_rds: &mut Vec<RawDescriptor>, ) -> Result<Vec<SysAudioStreamSourceGenerator>, Error>243 fn create_stream_source_generators(
244     params: &Parameters,
245     snd_data: &SndData,
246     keep_rds: &mut Vec<RawDescriptor>,
247 ) -> Result<Vec<SysAudioStreamSourceGenerator>, Error> {
248     let generators = match params.backend {
249         StreamSourceBackend::NULL => create_null_stream_source_generators(snd_data),
250         StreamSourceBackend::FILE => {
251             create_file_stream_source_generators(params, snd_data, keep_rds)
252                 .map_err(Error::CreateFileStreamSourceGenerator)?
253         }
254         StreamSourceBackend::Sys(backend) => {
255             sys_create_stream_source_generators(backend, params, snd_data)
256         }
257     };
258     Ok(generators)
259 }
260 
261 /// Creates [`StreamInfoBuilder`]s by calling [`create_stream_source_generators()`] then zip
262 /// them with [`crate::virtio::snd::parameters::PCMDeviceParameters`] from the params to set
263 /// the parameters on each [`StreamInfoBuilder`] (e.g. effects).
create_stream_info_builders( params: &Parameters, snd_data: &SndData, keep_rds: &mut Vec<RawDescriptor>, ) -> Result<Vec<StreamInfoBuilder>, Error>264 pub(crate) fn create_stream_info_builders(
265     params: &Parameters,
266     snd_data: &SndData,
267     keep_rds: &mut Vec<RawDescriptor>,
268 ) -> Result<Vec<StreamInfoBuilder>, Error> {
269     Ok(create_stream_source_generators(params, snd_data, keep_rds)?
270         .into_iter()
271         .map(Arc::new)
272         .zip(snd_data.pcm_info_iter())
273         .map(|(generator, pcm_info)| {
274             let device_params = params.get_device_params(pcm_info).unwrap_or_default();
275             StreamInfo::builder(generator).effects(device_params.effects.unwrap_or_default())
276         })
277         .collect())
278 }
279 
280 // To be used with hardcoded_snd_data
hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config281 pub fn hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config {
282     virtio_snd_config {
283         jacks: 0.into(),
284         streams: params.get_total_streams().into(),
285         chmaps: (params.num_output_devices * 3 + params.num_input_devices).into(),
286     }
287 }
288 
289 // To be used with hardcoded_virtio_snd_config
hardcoded_snd_data(params: &Parameters) -> SndData290 pub fn hardcoded_snd_data(params: &Parameters) -> SndData {
291     let jack_info: Vec<virtio_snd_jack_info> = Vec::new();
292     let mut pcm_info: Vec<virtio_snd_pcm_info> = Vec::new();
293     let mut chmap_info: Vec<virtio_snd_chmap_info> = Vec::new();
294 
295     for dev in 0..params.num_output_devices {
296         for _ in 0..params.num_output_streams {
297             pcm_info.push(virtio_snd_pcm_info {
298                 hdr: virtio_snd_info {
299                     hda_fn_nid: dev.into(),
300                 },
301                 features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
302                 formats: SUPPORTED_FORMATS.into(),
303                 rates: SUPPORTED_FRAME_RATES.into(),
304                 direction: VIRTIO_SND_D_OUTPUT,
305                 channels_min: 1,
306                 channels_max: 6,
307                 padding: [0; 5],
308             });
309         }
310     }
311     for dev in 0..params.num_input_devices {
312         for _ in 0..params.num_input_streams {
313             pcm_info.push(virtio_snd_pcm_info {
314                 hdr: virtio_snd_info {
315                     hda_fn_nid: dev.into(),
316                 },
317                 features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
318                 formats: SUPPORTED_FORMATS.into(),
319                 rates: SUPPORTED_FRAME_RATES.into(),
320                 direction: VIRTIO_SND_D_INPUT,
321                 channels_min: 1,
322                 channels_max: 2,
323                 padding: [0; 5],
324             });
325         }
326     }
327     // Use stereo channel map.
328     let mut positions = [VIRTIO_SND_CHMAP_NONE; VIRTIO_SND_CHMAP_MAX_SIZE];
329     positions[0] = VIRTIO_SND_CHMAP_FL;
330     positions[1] = VIRTIO_SND_CHMAP_FR;
331     for dev in 0..params.num_output_devices {
332         chmap_info.push(virtio_snd_chmap_info {
333             hdr: virtio_snd_info {
334                 hda_fn_nid: dev.into(),
335             },
336             direction: VIRTIO_SND_D_OUTPUT,
337             channels: 2,
338             positions,
339         });
340     }
341     for dev in 0..params.num_input_devices {
342         chmap_info.push(virtio_snd_chmap_info {
343             hdr: virtio_snd_info {
344                 hda_fn_nid: dev.into(),
345             },
346             direction: VIRTIO_SND_D_INPUT,
347             channels: 2,
348             positions,
349         });
350     }
351     positions[2] = VIRTIO_SND_CHMAP_RL;
352     positions[3] = VIRTIO_SND_CHMAP_RR;
353     for dev in 0..params.num_output_devices {
354         chmap_info.push(virtio_snd_chmap_info {
355             hdr: virtio_snd_info {
356                 hda_fn_nid: dev.into(),
357             },
358             direction: VIRTIO_SND_D_OUTPUT,
359             channels: 4,
360             positions,
361         });
362     }
363     positions[2] = VIRTIO_SND_CHMAP_FC;
364     positions[3] = VIRTIO_SND_CHMAP_LFE;
365     positions[4] = VIRTIO_SND_CHMAP_RL;
366     positions[5] = VIRTIO_SND_CHMAP_RR;
367     for dev in 0..params.num_output_devices {
368         chmap_info.push(virtio_snd_chmap_info {
369             hdr: virtio_snd_info {
370                 hda_fn_nid: dev.into(),
371             },
372             direction: VIRTIO_SND_D_OUTPUT,
373             channels: 6,
374             positions,
375         });
376     }
377 
378     SndData {
379         jack_info,
380         pcm_info,
381         chmap_info,
382     }
383 }
384 
resize_parameters_pcm_device_config(mut params: Parameters) -> Parameters385 fn resize_parameters_pcm_device_config(mut params: Parameters) -> Parameters {
386     if params.output_device_config.len() > params.num_output_devices as usize {
387         warn!("Truncating output device config due to length > number of output devices");
388     }
389     params
390         .output_device_config
391         .resize_with(params.num_output_devices as usize, Default::default);
392 
393     if params.input_device_config.len() > params.num_input_devices as usize {
394         warn!("Truncating input device config due to length > number of input devices");
395     }
396     params
397         .input_device_config
398         .resize_with(params.num_input_devices as usize, Default::default);
399 
400     params
401 }
402 
403 impl VirtioDevice for VirtioSnd {
keep_rds(&self) -> Vec<RawDescriptor>404     fn keep_rds(&self) -> Vec<RawDescriptor> {
405         self.keep_rds
406             .iter()
407             .map(|descr| descr.as_raw_descriptor())
408             .collect()
409     }
410 
device_type(&self) -> DeviceType411     fn device_type(&self) -> DeviceType {
412         DeviceType::Sound
413     }
414 
queue_max_sizes(&self) -> &[u16]415     fn queue_max_sizes(&self) -> &[u16] {
416         &self.queue_sizes
417     }
418 
features(&self) -> u64419     fn features(&self) -> u64 {
420         self.avail_features
421     }
422 
ack_features(&mut self, mut v: u64)423     fn ack_features(&mut self, mut v: u64) {
424         // Check if the guest is ACK'ing a feature that we didn't claim to have.
425         let unrequested_features = v & !self.avail_features;
426         if unrequested_features != 0 {
427             warn!("virtio_fs got unknown feature ack: {:x}", v);
428 
429             // Don't count these features as acked.
430             v &= !unrequested_features;
431         }
432         self.acked_features |= v;
433     }
434 
read_config(&self, offset: u64, data: &mut [u8])435     fn read_config(&self, offset: u64, data: &mut [u8]) {
436         copy_config(data, 0, self.cfg.as_bytes(), offset)
437     }
438 
activate( &mut self, _guest_mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>439     fn activate(
440         &mut self,
441         _guest_mem: GuestMemory,
442         interrupt: Interrupt,
443         queues: BTreeMap<usize, Queue>,
444     ) -> anyhow::Result<()> {
445         if queues.len() != self.queue_sizes.len() {
446             return Err(anyhow!(
447                 "snd: expected {} queues, got {}",
448                 self.queue_sizes.len(),
449                 queues.len()
450             ));
451         }
452 
453         let snd_data = self.snd_data.clone();
454         let stream_info_builders = self.stream_info_builders.to_vec();
455         let streams_state = self.streams_state.take();
456         self.worker_thread = Some(WorkerThread::start("v_snd_common", move |kill_evt| {
457             let _thread_priority_handle = set_audio_thread_priority();
458             if let Err(e) = _thread_priority_handle {
459                 warn!("Failed to set audio thread to real time: {}", e);
460             };
461             run_worker(
462                 interrupt,
463                 queues,
464                 snd_data,
465                 kill_evt,
466                 stream_info_builders,
467                 streams_state,
468             )
469         }));
470 
471         Ok(())
472     }
473 
reset(&mut self) -> anyhow::Result<()>474     fn reset(&mut self) -> anyhow::Result<()> {
475         if let Some(worker_thread) = self.worker_thread.take() {
476             let _ = worker_thread.stop();
477         }
478 
479         Ok(())
480     }
481 
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>482     fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
483         if let Some(worker_thread) = self.worker_thread.take() {
484             let worker = worker_thread.stop().unwrap();
485             self.snd_data = worker.snd_data;
486             self.streams_state = Some(worker.streams_state);
487             return Ok(Some(BTreeMap::from_iter(
488                 worker.queues.into_iter().enumerate(),
489             )));
490         }
491         Ok(None)
492     }
493 
virtio_wake( &mut self, device_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>494     fn virtio_wake(
495         &mut self,
496         device_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
497     ) -> anyhow::Result<()> {
498         match device_state {
499             None => Ok(()),
500             Some((mem, interrupt, queues)) => {
501                 // TODO: activate is just what we want at the moment, but we should probably move
502                 // it into a "start workers" function to make it obvious that it isn't strictly
503                 // used for activate events.
504                 self.activate(mem, interrupt, queues)?;
505                 Ok(())
506             }
507         }
508     }
509 
virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value>510     fn virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
511         let streams_state = if let Some(states) = &self.streams_state {
512             let mut state_vec = Vec::new();
513             for state in states {
514                 state_vec.push(state.clone());
515             }
516             Some(state_vec)
517         } else {
518             None
519         };
520         serde_json::to_value(VirtioSndSnapshot {
521             avail_features: self.avail_features,
522             acked_features: self.acked_features,
523             queue_sizes: self.queue_sizes.to_vec(),
524             streams_state,
525             snd_data: self.snd_data.clone(),
526         })
527         .context("failed to Serialize Sound device")
528     }
529 
virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>530     fn virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
531         let mut deser: VirtioSndSnapshot =
532             serde_json::from_value(data).context("failed to Deserialize Sound device")?;
533         anyhow::ensure!(
534             deser.avail_features == self.avail_features,
535             "avail features doesn't match on restore: expected: {}, got: {}",
536             deser.avail_features,
537             self.avail_features
538         );
539         anyhow::ensure!(
540             deser.queue_sizes == self.queue_sizes.to_vec(),
541             "queue sizes doesn't match on restore: expected: {:?}, got: {:?}",
542             deser.queue_sizes,
543             self.queue_sizes.to_vec()
544         );
545         self.acked_features = deser.acked_features;
546         anyhow::ensure!(
547             deser.snd_data == self.snd_data,
548             "snd data doesn't match on restore: expected: {:?}, got: {:?}",
549             deser.snd_data,
550             self.snd_data
551         );
552         self.acked_features = deser.acked_features;
553         self.streams_state = deser.streams_state.take();
554         Ok(())
555     }
556 }
557 
558 #[derive(PartialEq)]
559 enum LoopState {
560     Continue,
561     Break,
562 }
563 
run_worker( interrupt: Interrupt, queues: BTreeMap<usize, Queue>, snd_data: SndData, kill_evt: Event, stream_info_builders: Vec<StreamInfoBuilder>, streams_state: Option<Vec<StreamInfoSnapshot>>, ) -> Result<WorkerReturn, String>564 fn run_worker(
565     interrupt: Interrupt,
566     queues: BTreeMap<usize, Queue>,
567     snd_data: SndData,
568     kill_evt: Event,
569     stream_info_builders: Vec<StreamInfoBuilder>,
570     streams_state: Option<Vec<StreamInfoSnapshot>>,
571 ) -> Result<WorkerReturn, String> {
572     let ex = Executor::new().expect("Failed to create an executor");
573 
574     if snd_data.pcm_info_len() != stream_info_builders.len() {
575         error!(
576             "snd: expected {} streams, got {}",
577             snd_data.pcm_info_len(),
578             stream_info_builders.len(),
579         );
580     }
581     let streams: Vec<AsyncRwLock<StreamInfo>> = stream_info_builders
582         .into_iter()
583         .map(StreamInfoBuilder::build)
584         .map(AsyncRwLock::new)
585         .collect();
586 
587     let (tx_send, mut tx_recv) = mpsc::unbounded();
588     let (rx_send, mut rx_recv) = mpsc::unbounded();
589     let tx_send_clone = tx_send.clone();
590     let rx_send_clone = rx_send.clone();
591     let restore_task = ex.spawn_local(async move {
592         if let Some(states) = &streams_state {
593             let ex = Executor::new().expect("Failed to create an executor");
594             for (stream, state) in streams.iter().zip(states.iter()) {
595                 stream.lock().await.restore(state);
596                 if state.state == VIRTIO_SND_R_PCM_START || state.state == VIRTIO_SND_R_PCM_PREPARE
597                 {
598                     stream
599                         .lock()
600                         .await
601                         .prepare(&ex, &tx_send_clone, &rx_send_clone)
602                         .await
603                         .expect("failed to prepare PCM");
604                 }
605                 if state.state == VIRTIO_SND_R_PCM_START {
606                     stream
607                         .lock()
608                         .await
609                         .start()
610                         .await
611                         .expect("failed to start PCM");
612                 }
613             }
614         }
615         streams
616     });
617     let streams = ex
618         .run_until(restore_task)
619         .expect("failed to restore streams");
620     let streams = Rc::new(AsyncRwLock::new(streams));
621 
622     let mut queues: Vec<(Queue, EventAsync)> = queues
623         .into_values()
624         .map(|q| {
625             let e = q.event().try_clone().expect("Failed to clone queue event");
626             (
627                 q,
628                 EventAsync::new(e, &ex).expect("Failed to create async event for queue"),
629             )
630         })
631         .collect();
632 
633     let (ctrl_queue, mut ctrl_queue_evt) = queues.remove(0);
634     let ctrl_queue = Rc::new(AsyncRwLock::new(ctrl_queue));
635     let (_event_queue, _event_queue_evt) = queues.remove(0);
636     let (tx_queue, tx_queue_evt) = queues.remove(0);
637     let (rx_queue, rx_queue_evt) = queues.remove(0);
638 
639     let tx_queue = Rc::new(AsyncRwLock::new(tx_queue));
640     let rx_queue = Rc::new(AsyncRwLock::new(rx_queue));
641 
642     let f_resample = async_utils::handle_irq_resample(&ex, interrupt.clone()).fuse();
643 
644     // Exit if the kill event is triggered.
645     let f_kill = async_utils::await_and_exit(&ex, kill_evt).fuse();
646 
647     pin_mut!(f_resample, f_kill);
648 
649     loop {
650         if run_worker_once(
651             &ex,
652             &streams,
653             interrupt.clone(),
654             &snd_data,
655             &mut f_kill,
656             &mut f_resample,
657             ctrl_queue.clone(),
658             &mut ctrl_queue_evt,
659             tx_queue.clone(),
660             &tx_queue_evt,
661             tx_send.clone(),
662             &mut tx_recv,
663             rx_queue.clone(),
664             &rx_queue_evt,
665             rx_send.clone(),
666             &mut rx_recv,
667         ) == LoopState::Break
668         {
669             break;
670         }
671 
672         if let Err(e) = reset_streams(
673             &ex,
674             &streams,
675             interrupt.clone(),
676             &tx_queue,
677             &mut tx_recv,
678             &rx_queue,
679             &mut rx_recv,
680         ) {
681             error!("Error reset streams: {}", e);
682             break;
683         }
684     }
685     let streams_state_task = ex.spawn_local(async move {
686         let mut v = Vec::new();
687         for stream in streams.read_lock().await.iter() {
688             v.push(stream.read_lock().await.snapshot());
689         }
690         v
691     });
692     let streams_state = ex
693         .run_until(streams_state_task)
694         .expect("failed to save streams state");
695     let ctrl_queue = match Rc::try_unwrap(ctrl_queue) {
696         Ok(q) => q.into_inner(),
697         Err(_) => panic!("Too many refs to ctrl_queue"),
698     };
699     let tx_queue = match Rc::try_unwrap(tx_queue) {
700         Ok(q) => q.into_inner(),
701         Err(_) => panic!("Too many refs to tx_queue"),
702     };
703     let rx_queue = match Rc::try_unwrap(rx_queue) {
704         Ok(q) => q.into_inner(),
705         Err(_) => panic!("Too many refs to rx_queue"),
706     };
707     let queues = vec![ctrl_queue, _event_queue, tx_queue, rx_queue];
708 
709     Ok(WorkerReturn {
710         queues,
711         snd_data,
712         streams_state,
713     })
714 }
715 
716 struct WorkerReturn {
717     queues: Vec<Queue>,
718     snd_data: SndData,
719     streams_state: Vec<StreamInfoSnapshot>,
720 }
721 
notify_reset_signal(reset_signal: &(AsyncRwLock<bool>, Condvar))722 async fn notify_reset_signal(reset_signal: &(AsyncRwLock<bool>, Condvar)) {
723     let (lock, cvar) = reset_signal;
724     *lock.lock().await = true;
725     cvar.notify_all();
726 }
727 
728 /// Runs all workers once and exit if any worker exit.
729 ///
730 /// Returns [`LoopState::Break`] if the worker `f_kill` or `f_resample` exit, or something went
731 /// wrong on shutdown process. The caller should not run the worker again and should exit the main
732 /// loop.
733 ///
734 /// If this function returns [`LoopState::Continue`], the caller can continue the main loop by
735 /// resetting the streams and run the worker again.
run_worker_once( ex: &Executor, streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>, interrupt: Interrupt, snd_data: &SndData, mut f_kill: &mut (impl Future<Output = anyhow::Result<()>> + FusedFuture + Unpin), mut f_resample: &mut (impl Future<Output = anyhow::Result<()>> + FusedFuture + Unpin), ctrl_queue: Rc<AsyncRwLock<Queue>>, ctrl_queue_evt: &mut EventAsync, tx_queue: Rc<AsyncRwLock<Queue>>, tx_queue_evt: &EventAsync, tx_send: mpsc::UnboundedSender<PcmResponse>, tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, rx_queue: Rc<AsyncRwLock<Queue>>, rx_queue_evt: &EventAsync, rx_send: mpsc::UnboundedSender<PcmResponse>, rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, ) -> LoopState736 fn run_worker_once(
737     ex: &Executor,
738     streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
739     interrupt: Interrupt,
740     snd_data: &SndData,
741     mut f_kill: &mut (impl Future<Output = anyhow::Result<()>> + FusedFuture + Unpin),
742     mut f_resample: &mut (impl Future<Output = anyhow::Result<()>> + FusedFuture + Unpin),
743     ctrl_queue: Rc<AsyncRwLock<Queue>>,
744     ctrl_queue_evt: &mut EventAsync,
745     tx_queue: Rc<AsyncRwLock<Queue>>,
746     tx_queue_evt: &EventAsync,
747     tx_send: mpsc::UnboundedSender<PcmResponse>,
748     tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
749     rx_queue: Rc<AsyncRwLock<Queue>>,
750     rx_queue_evt: &EventAsync,
751     rx_send: mpsc::UnboundedSender<PcmResponse>,
752     rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
753 ) -> LoopState {
754     let tx_send2 = tx_send.clone();
755     let rx_send2 = rx_send.clone();
756 
757     let reset_signal = (AsyncRwLock::new(false), Condvar::new());
758 
759     let f_ctrl = handle_ctrl_queue(
760         ex,
761         streams,
762         snd_data,
763         ctrl_queue,
764         ctrl_queue_evt,
765         interrupt.clone(),
766         tx_send,
767         rx_send,
768         Some(&reset_signal),
769     )
770     .fuse();
771 
772     // TODO(woodychow): Enable this when libcras sends jack connect/disconnect evts
773     // let f_event = handle_event_queue(
774     //     snd_state,
775     //     event_queue,
776     //     event_queue_evt,
777     //     interrupt,
778     // );
779     let f_tx = handle_pcm_queue(
780         streams,
781         tx_send2,
782         tx_queue.clone(),
783         tx_queue_evt,
784         Some(&reset_signal),
785     )
786     .fuse();
787     let f_tx_response =
788         send_pcm_response_worker(tx_queue, interrupt.clone(), tx_recv, Some(&reset_signal)).fuse();
789     let f_rx = handle_pcm_queue(
790         streams,
791         rx_send2,
792         rx_queue.clone(),
793         rx_queue_evt,
794         Some(&reset_signal),
795     )
796     .fuse();
797     let f_rx_response =
798         send_pcm_response_worker(rx_queue, interrupt, rx_recv, Some(&reset_signal)).fuse();
799 
800     pin_mut!(f_ctrl, f_tx, f_tx_response, f_rx, f_rx_response);
801 
802     let done = async {
803         select! {
804             res = f_ctrl => (res.context("error in handling ctrl queue"), LoopState::Continue),
805             res = f_tx => (res.context("error in handling tx queue"), LoopState::Continue),
806             res = f_tx_response => (res.context("error in handling tx response"), LoopState::Continue),
807             res = f_rx => (res.context("error in handling rx queue"), LoopState::Continue),
808             res = f_rx_response => (res.context("error in handling rx response"), LoopState::Continue),
809 
810             // For following workers, do not continue the loop
811             res = f_resample => (res.context("error in handle_irq_resample"), LoopState::Break),
812             res = f_kill => (res.context("error in await_and_exit"), LoopState::Break),
813         }
814     };
815 
816     match ex.run_until(done) {
817         Ok((res, loop_state)) => {
818             if let Err(e) = res {
819                 error!("Error in worker: {:#}", e);
820             }
821             if loop_state == LoopState::Break {
822                 return LoopState::Break;
823             }
824         }
825         Err(e) => {
826             error!("Error happened in executor: {}", e);
827         }
828     }
829 
830     warn!("Shutting down all workers for reset procedure");
831     block_on(notify_reset_signal(&reset_signal));
832 
833     let shutdown = async {
834         loop {
835             let (res, worker_name) = select!(
836                 res = f_ctrl => (res, "f_ctrl"),
837                 res = f_tx => (res, "f_tx"),
838                 res = f_tx_response => (res, "f_tx_response"),
839                 res = f_rx => (res, "f_rx"),
840                 res = f_rx_response => (res, "f_rx_response"),
841                 complete => break,
842             );
843             match res {
844                 Ok(_) => debug!("Worker {} stopped", worker_name),
845                 Err(e) => error!("Worker {} stopped with error {}", worker_name, e),
846             };
847         }
848     };
849 
850     if let Err(e) = ex.run_until(shutdown) {
851         error!("Error happened in executor while shutdown: {}", e);
852         return LoopState::Break;
853     }
854 
855     LoopState::Continue
856 }
857 
reset_streams( ex: &Executor, streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>, interrupt: Interrupt, tx_queue: &Rc<AsyncRwLock<Queue>>, tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, rx_queue: &Rc<AsyncRwLock<Queue>>, rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, ) -> Result<(), AsyncError>858 fn reset_streams(
859     ex: &Executor,
860     streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
861     interrupt: Interrupt,
862     tx_queue: &Rc<AsyncRwLock<Queue>>,
863     tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
864     rx_queue: &Rc<AsyncRwLock<Queue>>,
865     rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
866 ) -> Result<(), AsyncError> {
867     let reset_signal = (AsyncRwLock::new(false), Condvar::new());
868 
869     let do_reset = async {
870         let streams = streams.read_lock().await;
871         for stream_info in &*streams {
872             let mut stream_info = stream_info.lock().await;
873             if stream_info.state == VIRTIO_SND_R_PCM_START {
874                 if let Err(e) = stream_info.stop().await {
875                     error!("Error on stop while resetting stream: {}", e);
876                 }
877             }
878             if stream_info.state == VIRTIO_SND_R_PCM_STOP
879                 || stream_info.state == VIRTIO_SND_R_PCM_PREPARE
880             {
881                 if let Err(e) = stream_info.release().await {
882                     error!("Error on release while resetting stream: {}", e);
883                 }
884             }
885             stream_info.just_reset = true;
886         }
887 
888         notify_reset_signal(&reset_signal).await;
889     };
890 
891     // Run these in a loop to ensure that they will survive until do_reset is finished
892     let f_tx_response = async {
893         while send_pcm_response_worker(
894             tx_queue.clone(),
895             interrupt.clone(),
896             tx_recv,
897             Some(&reset_signal),
898         )
899         .await
900         .is_err()
901         {}
902     };
903 
904     let f_rx_response = async {
905         while send_pcm_response_worker(
906             rx_queue.clone(),
907             interrupt.clone(),
908             rx_recv,
909             Some(&reset_signal),
910         )
911         .await
912         .is_err()
913         {}
914     };
915 
916     let reset = async {
917         join!(f_tx_response, f_rx_response, do_reset);
918     };
919 
920     ex.run_until(reset)
921 }
922 
923 #[cfg(test)]
924 #[allow(clippy::needless_update)]
925 mod tests {
926     use audio_streams::StreamEffect;
927 
928     use super::*;
929     use crate::virtio::snd::parameters::PCMDeviceParameters;
930 
931     #[test]
test_virtio_snd_new()932     fn test_virtio_snd_new() {
933         let params = Parameters {
934             num_output_devices: 3,
935             num_input_devices: 2,
936             num_output_streams: 3,
937             num_input_streams: 2,
938             output_device_config: vec![PCMDeviceParameters {
939                 effects: Some(vec![StreamEffect::EchoCancellation]),
940                 ..PCMDeviceParameters::default()
941             }],
942             input_device_config: vec![PCMDeviceParameters {
943                 effects: Some(vec![StreamEffect::EchoCancellation]),
944                 ..PCMDeviceParameters::default()
945             }],
946             ..Default::default()
947         };
948 
949         let res = VirtioSnd::new(123, params).unwrap();
950 
951         // Default values
952         assert_eq!(res.snd_data.jack_info.len(), 0);
953         assert_eq!(res.acked_features, 0);
954         assert_eq!(res.worker_thread.is_none(), true);
955 
956         assert_eq!(res.avail_features, 123); // avail_features must be equal to the input
957         assert_eq!(res.cfg.jacks.to_native(), 0);
958         assert_eq!(res.cfg.streams.to_native(), 13); // (Output = 3*3) + (Input = 2*2)
959         assert_eq!(res.cfg.chmaps.to_native(), 11); // (Output = 3*3) + (Input = 2*1)
960 
961         // Check snd_data.pcm_info
962         assert_eq!(res.snd_data.pcm_info.len(), 13);
963         // Check hda_fn_nid (PCM Device number)
964         let expected_hda_fn_nid = [0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 1, 1];
965         for (i, pcm_info) in res.snd_data.pcm_info.iter().enumerate() {
966             assert_eq!(
967                 pcm_info.hdr.hda_fn_nid.to_native(),
968                 expected_hda_fn_nid[i],
969                 "pcm_info index {} incorrect hda_fn_nid",
970                 i
971             );
972         }
973         // First 9 devices must be OUTPUT
974         for i in 0..9 {
975             assert_eq!(
976                 res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_OUTPUT,
977                 "pcm_info index {} incorrect direction",
978                 i
979             );
980         }
981         // Next 4 devices must be INPUT
982         for i in 9..13 {
983             assert_eq!(
984                 res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_INPUT,
985                 "pcm_info index {} incorrect direction",
986                 i
987             );
988         }
989 
990         // Check snd_data.chmap_info
991         assert_eq!(res.snd_data.chmap_info.len(), 11);
992         let expected_hda_fn_nid = [0, 1, 2, 0, 1, 0, 1, 2, 0, 1, 2];
993         // Check hda_fn_nid (PCM Device number)
994         for (i, chmap_info) in res.snd_data.chmap_info.iter().enumerate() {
995             assert_eq!(
996                 chmap_info.hdr.hda_fn_nid.to_native(),
997                 expected_hda_fn_nid[i],
998                 "chmap_info index {} incorrect hda_fn_nid",
999                 i
1000             );
1001         }
1002     }
1003 
1004     #[test]
test_resize_parameters_pcm_device_config_truncate()1005     fn test_resize_parameters_pcm_device_config_truncate() {
1006         // If pcm_device_config is larger than number of devices, it will be truncated
1007         let params = Parameters {
1008             num_output_devices: 1,
1009             num_input_devices: 1,
1010             output_device_config: vec![PCMDeviceParameters::default(); 3],
1011             input_device_config: vec![PCMDeviceParameters::default(); 3],
1012             ..Parameters::default()
1013         };
1014         let params = resize_parameters_pcm_device_config(params);
1015         assert_eq!(params.output_device_config.len(), 1);
1016         assert_eq!(params.input_device_config.len(), 1);
1017     }
1018 
1019     #[test]
test_resize_parameters_pcm_device_config_extend()1020     fn test_resize_parameters_pcm_device_config_extend() {
1021         let params = Parameters {
1022             num_output_devices: 3,
1023             num_input_devices: 2,
1024             num_output_streams: 3,
1025             num_input_streams: 2,
1026             output_device_config: vec![PCMDeviceParameters {
1027                 effects: Some(vec![StreamEffect::EchoCancellation]),
1028                 ..PCMDeviceParameters::default()
1029             }],
1030             input_device_config: vec![PCMDeviceParameters {
1031                 effects: Some(vec![StreamEffect::EchoCancellation]),
1032                 ..PCMDeviceParameters::default()
1033             }],
1034             ..Default::default()
1035         };
1036 
1037         let params = resize_parameters_pcm_device_config(params);
1038 
1039         // Check output_device_config correctly extended
1040         assert_eq!(
1041             params.output_device_config,
1042             vec![
1043                 PCMDeviceParameters {
1044                     // Keep from the parameters
1045                     effects: Some(vec![StreamEffect::EchoCancellation]),
1046                     ..PCMDeviceParameters::default()
1047                 },
1048                 PCMDeviceParameters::default(), // Extended with default
1049                 PCMDeviceParameters::default(), // Extended with default
1050             ]
1051         );
1052 
1053         // Check input_device_config correctly extended
1054         assert_eq!(
1055             params.input_device_config,
1056             vec![
1057                 PCMDeviceParameters {
1058                     // Keep from the parameters
1059                     effects: Some(vec![StreamEffect::EchoCancellation]),
1060                     ..PCMDeviceParameters::default()
1061                 },
1062                 PCMDeviceParameters::default(), // Extended with default
1063             ]
1064         );
1065     }
1066 }
1067