• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 // virtio-sound spec: https://github.com/oasis-tcs/virtio-spec/blob/master/virtio-sound.tex
6 
7 use std::collections::BTreeMap;
8 use std::io;
9 use std::rc::Rc;
10 use std::sync::Arc;
11 
12 use anyhow::anyhow;
13 use anyhow::Context;
14 use audio_streams::BoxError;
15 use base::debug;
16 use base::error;
17 use base::warn;
18 use base::AsRawDescriptor;
19 use base::Descriptor;
20 use base::Error as SysError;
21 use base::Event;
22 use base::RawDescriptor;
23 use base::Tube;
24 use base::WorkerThread;
25 use cros_async::block_on;
26 use cros_async::sync::Condvar;
27 use cros_async::sync::RwLock as AsyncRwLock;
28 use cros_async::AsyncError;
29 use cros_async::AsyncTube;
30 use cros_async::EventAsync;
31 use cros_async::Executor;
32 use futures::channel::mpsc;
33 use futures::channel::oneshot;
34 use futures::channel::oneshot::Canceled;
35 use futures::future::FusedFuture;
36 use futures::join;
37 use futures::pin_mut;
38 use futures::select;
39 use futures::FutureExt;
40 use serde::Deserialize;
41 use serde::Serialize;
42 use snapshot::AnySnapshot;
43 use thiserror::Error as ThisError;
44 use vm_memory::GuestMemory;
45 use zerocopy::IntoBytes;
46 
47 use crate::virtio::async_utils;
48 use crate::virtio::copy_config;
49 use crate::virtio::device_constants::snd::virtio_snd_config;
50 use crate::virtio::snd::common_backend::async_funcs::*;
51 use crate::virtio::snd::common_backend::stream_info::StreamInfo;
52 use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
53 use crate::virtio::snd::common_backend::stream_info::StreamInfoSnapshot;
54 use crate::virtio::snd::constants::*;
55 use crate::virtio::snd::file_backend::create_file_stream_source_generators;
56 use crate::virtio::snd::file_backend::Error as FileError;
57 use crate::virtio::snd::layout::*;
58 use crate::virtio::snd::null_backend::create_null_stream_source_generators;
59 use crate::virtio::snd::parameters::Parameters;
60 use crate::virtio::snd::parameters::StreamSourceBackend;
61 use crate::virtio::snd::sys::create_stream_source_generators as sys_create_stream_source_generators;
62 use crate::virtio::snd::sys::set_audio_thread_priority;
63 use crate::virtio::snd::sys::SysAsyncStreamObjects;
64 use crate::virtio::snd::sys::SysAudioStreamSourceGenerator;
65 use crate::virtio::snd::sys::SysDirectionOutput;
66 use crate::virtio::DescriptorChain;
67 use crate::virtio::DeviceType;
68 use crate::virtio::Interrupt;
69 use crate::virtio::Queue;
70 use crate::virtio::VirtioDevice;
71 
72 pub mod async_funcs;
73 pub mod stream_info;
74 
75 // control + event + tx + rx queue
76 pub const MAX_QUEUE_NUM: usize = 4;
77 pub const MAX_VRING_LEN: u16 = 1024;
78 
79 #[derive(ThisError, Debug)]
80 pub enum Error {
81     /// next_async failed.
82     #[error("Failed to read descriptor asynchronously: {0}")]
83     Async(AsyncError),
84     /// Creating stream failed.
85     #[error("Failed to create stream: {0}")]
86     CreateStream(BoxError),
87     /// Creating stream failed.
88     #[error("No stream source found.")]
89     EmptyStreamSource,
90     /// Creating kill event failed.
91     #[error("Failed to create kill event: {0}")]
92     CreateKillEvent(SysError),
93     /// Creating WaitContext failed.
94     #[error("Failed to create wait context: {0}")]
95     CreateWaitContext(SysError),
96     #[error("Failed to create file stream source generator")]
97     CreateFileStreamSourceGenerator(FileError),
98     /// Cloning kill event failed.
99     #[error("Failed to clone kill event: {0}")]
100     CloneKillEvent(SysError),
101     // Future error.
102     #[error("Unexpected error. Done was not triggered before dropped: {0}")]
103     DoneNotTriggered(Canceled),
104     /// Error reading message from queue.
105     #[error("Failed to read message: {0}")]
106     ReadMessage(io::Error),
107     /// Failed writing a response to a control message.
108     #[error("Failed to write message response: {0}")]
109     WriteResponse(io::Error),
110     // Mpsc read error.
111     #[error("Error in mpsc: {0}")]
112     MpscSend(futures::channel::mpsc::SendError),
113     // Oneshot send error.
114     #[error("Error in oneshot send")]
115     OneshotSend(()),
116     /// Failure in communicating with the host
117     #[error("Failed to send/receive to/from control tube")]
118     ControlTubeError(base::TubeError),
119     /// Stream not found.
120     #[error("stream id ({0}) < num_streams ({1})")]
121     StreamNotFound(usize, usize),
122     /// Fetch buffer error
123     #[error("Failed to get buffer from CRAS: {0}")]
124     FetchBuffer(BoxError),
125     /// Invalid buffer size
126     #[error("Invalid buffer size")]
127     InvalidBufferSize,
128     /// IoError
129     #[error("I/O failed: {0}")]
130     Io(io::Error),
131     /// Operation not supported.
132     #[error("Operation not supported")]
133     OperationNotSupported,
134     /// Writing to a buffer in the guest failed.
135     #[error("failed to write to buffer: {0}")]
136     WriteBuffer(io::Error),
137     // Invalid PCM worker state.
138     #[error("Invalid PCM worker state")]
139     InvalidPCMWorkerState,
140     // Invalid backend.
141     #[error("Backend is not implemented")]
142     InvalidBackend,
143     // Failed to generate StreamSource
144     #[error("Failed to generate stream source: {0}")]
145     GenerateStreamSource(BoxError),
146     // PCM worker unexpectedly quitted.
147     #[error("PCM worker quitted unexpectedly")]
148     PCMWorkerQuittedUnexpectedly,
149 }
150 
151 pub enum DirectionalStream {
152     Input(
153         usize, // `period_size` in `usize`
154         Box<dyn CaptureBufferReader>,
155     ),
156     Output(SysDirectionOutput),
157 }
158 
159 #[derive(Copy, Clone, std::cmp::PartialEq, Eq)]
160 pub enum WorkerStatus {
161     Pause = 0,
162     Running = 1,
163     Quit = 2,
164 }
165 
166 // Stores constant data
167 #[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
168 pub struct SndData {
169     pub(crate) jack_info: Vec<virtio_snd_jack_info>,
170     pub(crate) pcm_info: Vec<virtio_snd_pcm_info>,
171     pub(crate) chmap_info: Vec<virtio_snd_chmap_info>,
172 }
173 
174 impl SndData {
pcm_info_len(&self) -> usize175     pub fn pcm_info_len(&self) -> usize {
176         self.pcm_info.len()
177     }
178 
pcm_info_iter(&self) -> std::slice::Iter<'_, virtio_snd_pcm_info>179     pub fn pcm_info_iter(&self) -> std::slice::Iter<'_, virtio_snd_pcm_info> {
180         self.pcm_info.iter()
181     }
182 }
183 
184 const SUPPORTED_FORMATS: u64 = 1 << VIRTIO_SND_PCM_FMT_U8
185     | 1 << VIRTIO_SND_PCM_FMT_S16
186     | 1 << VIRTIO_SND_PCM_FMT_S24
187     | 1 << VIRTIO_SND_PCM_FMT_S32;
188 const SUPPORTED_FRAME_RATES: u64 = 1 << VIRTIO_SND_PCM_RATE_8000
189     | 1 << VIRTIO_SND_PCM_RATE_11025
190     | 1 << VIRTIO_SND_PCM_RATE_16000
191     | 1 << VIRTIO_SND_PCM_RATE_22050
192     | 1 << VIRTIO_SND_PCM_RATE_32000
193     | 1 << VIRTIO_SND_PCM_RATE_44100
194     | 1 << VIRTIO_SND_PCM_RATE_48000;
195 
196 // Response from pcm_worker to pcm_queue
197 pub struct PcmResponse {
198     pub(crate) desc_chain: DescriptorChain,
199     pub(crate) status: virtio_snd_pcm_status, // response to the pcm message
200     pub(crate) done: Option<oneshot::Sender<()>>, // when pcm response is written to the queue
201 }
202 
203 pub struct VirtioSnd {
204     control_tube: Option<Tube>,
205     cfg: virtio_snd_config,
206     snd_data: SndData,
207     stream_info_builders: Vec<StreamInfoBuilder>,
208     avail_features: u64,
209     acked_features: u64,
210     queue_sizes: Box<[u16]>,
211     worker_thread: Option<WorkerThread<Result<WorkerReturn, String>>>,
212     keep_rds: Vec<Descriptor>,
213     streams_state: Option<Vec<StreamInfoSnapshot>>,
214     card_index: usize,
215 }
216 
217 #[derive(Serialize, Deserialize)]
218 struct VirtioSndSnapshot {
219     avail_features: u64,
220     acked_features: u64,
221     queue_sizes: Vec<u16>,
222     streams_state: Option<Vec<StreamInfoSnapshot>>,
223     snd_data: SndData,
224 }
225 
226 impl VirtioSnd {
new( base_features: u64, params: Parameters, control_tube: Tube, ) -> Result<VirtioSnd, Error>227     pub fn new(
228         base_features: u64,
229         params: Parameters,
230         control_tube: Tube,
231     ) -> Result<VirtioSnd, Error> {
232         let params = resize_parameters_pcm_device_config(params);
233         let cfg = hardcoded_virtio_snd_config(&params);
234         let snd_data = hardcoded_snd_data(&params);
235         let avail_features = base_features;
236         let mut keep_rds: Vec<RawDescriptor> = Vec::new();
237         keep_rds.push(control_tube.as_raw_descriptor());
238 
239         let stream_info_builders =
240             create_stream_info_builders(&params, &snd_data, &mut keep_rds, params.card_index)?;
241 
242         Ok(VirtioSnd {
243             control_tube: Some(control_tube),
244             cfg,
245             snd_data,
246             stream_info_builders,
247             avail_features,
248             acked_features: 0,
249             queue_sizes: vec![MAX_VRING_LEN; MAX_QUEUE_NUM].into_boxed_slice(),
250             worker_thread: None,
251             keep_rds: keep_rds.iter().map(|rd| Descriptor(*rd)).collect(),
252             streams_state: None,
253             card_index: params.card_index,
254         })
255     }
256 }
257 
create_stream_source_generators( params: &Parameters, snd_data: &SndData, keep_rds: &mut Vec<RawDescriptor>, ) -> Result<Vec<SysAudioStreamSourceGenerator>, Error>258 fn create_stream_source_generators(
259     params: &Parameters,
260     snd_data: &SndData,
261     keep_rds: &mut Vec<RawDescriptor>,
262 ) -> Result<Vec<SysAudioStreamSourceGenerator>, Error> {
263     let generators = match params.backend {
264         StreamSourceBackend::NULL => create_null_stream_source_generators(snd_data),
265         StreamSourceBackend::FILE => {
266             create_file_stream_source_generators(params, snd_data, keep_rds)
267                 .map_err(Error::CreateFileStreamSourceGenerator)?
268         }
269         StreamSourceBackend::Sys(backend) => {
270             sys_create_stream_source_generators(backend, params, snd_data)
271         }
272     };
273     Ok(generators)
274 }
275 
276 /// Creates [`StreamInfoBuilder`]s by calling [`create_stream_source_generators()`] then zip
277 /// them with [`crate::virtio::snd::parameters::PCMDeviceParameters`] from the params to set
278 /// the parameters on each [`StreamInfoBuilder`] (e.g. effects).
create_stream_info_builders( params: &Parameters, snd_data: &SndData, keep_rds: &mut Vec<RawDescriptor>, card_index: usize, ) -> Result<Vec<StreamInfoBuilder>, Error>279 pub(crate) fn create_stream_info_builders(
280     params: &Parameters,
281     snd_data: &SndData,
282     keep_rds: &mut Vec<RawDescriptor>,
283     card_index: usize,
284 ) -> Result<Vec<StreamInfoBuilder>, Error> {
285     Ok(create_stream_source_generators(params, snd_data, keep_rds)?
286         .into_iter()
287         .map(Arc::new)
288         .zip(snd_data.pcm_info_iter())
289         .map(|(generator, pcm_info)| {
290             let device_params = params.get_device_params(pcm_info).unwrap_or_default();
291             StreamInfo::builder(generator, card_index)
292                 .effects(device_params.effects.unwrap_or_default())
293         })
294         .collect())
295 }
296 
297 // To be used with hardcoded_snd_data
hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config298 pub fn hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config {
299     virtio_snd_config {
300         jacks: 0.into(),
301         streams: params.get_total_streams().into(),
302         chmaps: (params.num_output_devices * 3 + params.num_input_devices).into(),
303     }
304 }
305 
306 // To be used with hardcoded_virtio_snd_config
hardcoded_snd_data(params: &Parameters) -> SndData307 pub fn hardcoded_snd_data(params: &Parameters) -> SndData {
308     let jack_info: Vec<virtio_snd_jack_info> = Vec::new();
309     let mut pcm_info: Vec<virtio_snd_pcm_info> = Vec::new();
310     let mut chmap_info: Vec<virtio_snd_chmap_info> = Vec::new();
311 
312     for dev in 0..params.num_output_devices {
313         for _ in 0..params.num_output_streams {
314             pcm_info.push(virtio_snd_pcm_info {
315                 hdr: virtio_snd_info {
316                     hda_fn_nid: dev.into(),
317                 },
318                 features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
319                 formats: SUPPORTED_FORMATS.into(),
320                 rates: SUPPORTED_FRAME_RATES.into(),
321                 direction: VIRTIO_SND_D_OUTPUT,
322                 channels_min: 1,
323                 channels_max: 6,
324                 padding: [0; 5],
325             });
326         }
327     }
328     for dev in 0..params.num_input_devices {
329         for _ in 0..params.num_input_streams {
330             pcm_info.push(virtio_snd_pcm_info {
331                 hdr: virtio_snd_info {
332                     hda_fn_nid: dev.into(),
333                 },
334                 features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
335                 formats: SUPPORTED_FORMATS.into(),
336                 rates: SUPPORTED_FRAME_RATES.into(),
337                 direction: VIRTIO_SND_D_INPUT,
338                 channels_min: 1,
339                 channels_max: 2,
340                 padding: [0; 5],
341             });
342         }
343     }
344     // Use stereo channel map.
345     let mut positions = [VIRTIO_SND_CHMAP_NONE; VIRTIO_SND_CHMAP_MAX_SIZE];
346     positions[0] = VIRTIO_SND_CHMAP_FL;
347     positions[1] = VIRTIO_SND_CHMAP_FR;
348     for dev in 0..params.num_output_devices {
349         chmap_info.push(virtio_snd_chmap_info {
350             hdr: virtio_snd_info {
351                 hda_fn_nid: dev.into(),
352             },
353             direction: VIRTIO_SND_D_OUTPUT,
354             channels: 2,
355             positions,
356         });
357     }
358     for dev in 0..params.num_input_devices {
359         chmap_info.push(virtio_snd_chmap_info {
360             hdr: virtio_snd_info {
361                 hda_fn_nid: dev.into(),
362             },
363             direction: VIRTIO_SND_D_INPUT,
364             channels: 2,
365             positions,
366         });
367     }
368     positions[2] = VIRTIO_SND_CHMAP_RL;
369     positions[3] = VIRTIO_SND_CHMAP_RR;
370     for dev in 0..params.num_output_devices {
371         chmap_info.push(virtio_snd_chmap_info {
372             hdr: virtio_snd_info {
373                 hda_fn_nid: dev.into(),
374             },
375             direction: VIRTIO_SND_D_OUTPUT,
376             channels: 4,
377             positions,
378         });
379     }
380     positions[2] = VIRTIO_SND_CHMAP_FC;
381     positions[3] = VIRTIO_SND_CHMAP_LFE;
382     positions[4] = VIRTIO_SND_CHMAP_RL;
383     positions[5] = VIRTIO_SND_CHMAP_RR;
384     for dev in 0..params.num_output_devices {
385         chmap_info.push(virtio_snd_chmap_info {
386             hdr: virtio_snd_info {
387                 hda_fn_nid: dev.into(),
388             },
389             direction: VIRTIO_SND_D_OUTPUT,
390             channels: 6,
391             positions,
392         });
393     }
394 
395     SndData {
396         jack_info,
397         pcm_info,
398         chmap_info,
399     }
400 }
401 
resize_parameters_pcm_device_config(mut params: Parameters) -> Parameters402 fn resize_parameters_pcm_device_config(mut params: Parameters) -> Parameters {
403     if params.output_device_config.len() > params.num_output_devices as usize {
404         warn!("Truncating output device config due to length > number of output devices");
405     }
406     params
407         .output_device_config
408         .resize_with(params.num_output_devices as usize, Default::default);
409 
410     if params.input_device_config.len() > params.num_input_devices as usize {
411         warn!("Truncating input device config due to length > number of input devices");
412     }
413     params
414         .input_device_config
415         .resize_with(params.num_input_devices as usize, Default::default);
416 
417     params
418 }
419 
420 impl VirtioDevice for VirtioSnd {
keep_rds(&self) -> Vec<RawDescriptor>421     fn keep_rds(&self) -> Vec<RawDescriptor> {
422         self.keep_rds
423             .iter()
424             .map(|descr| descr.as_raw_descriptor())
425             .collect()
426     }
427 
device_type(&self) -> DeviceType428     fn device_type(&self) -> DeviceType {
429         DeviceType::Sound
430     }
431 
queue_max_sizes(&self) -> &[u16]432     fn queue_max_sizes(&self) -> &[u16] {
433         &self.queue_sizes
434     }
435 
features(&self) -> u64436     fn features(&self) -> u64 {
437         self.avail_features
438     }
439 
ack_features(&mut self, mut v: u64)440     fn ack_features(&mut self, mut v: u64) {
441         // Check if the guest is ACK'ing a feature that we didn't claim to have.
442         let unrequested_features = v & !self.avail_features;
443         if unrequested_features != 0 {
444             warn!("virtio_fs got unknown feature ack: {:x}", v);
445 
446             // Don't count these features as acked.
447             v &= !unrequested_features;
448         }
449         self.acked_features |= v;
450     }
451 
read_config(&self, offset: u64, data: &mut [u8])452     fn read_config(&self, offset: u64, data: &mut [u8]) {
453         copy_config(data, 0, self.cfg.as_bytes(), offset)
454     }
455 
activate( &mut self, _guest_mem: GuestMemory, _interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>456     fn activate(
457         &mut self,
458         _guest_mem: GuestMemory,
459         _interrupt: Interrupt,
460         queues: BTreeMap<usize, Queue>,
461     ) -> anyhow::Result<()> {
462         if queues.len() != self.queue_sizes.len() {
463             return Err(anyhow!(
464                 "snd: expected {} queues, got {}",
465                 self.queue_sizes.len(),
466                 queues.len()
467             ));
468         }
469 
470         let snd_data = self.snd_data.clone();
471         let stream_info_builders = self.stream_info_builders.to_vec();
472         let streams_state = self.streams_state.take();
473         let card_index = self.card_index;
474         let control_tube = self.control_tube.take().unwrap();
475         self.worker_thread = Some(WorkerThread::start("v_snd_common", move |kill_evt| {
476             let _thread_priority_handle = set_audio_thread_priority();
477             if let Err(e) = _thread_priority_handle {
478                 warn!("Failed to set audio thread to real time: {}", e);
479             };
480             run_worker(
481                 queues,
482                 snd_data,
483                 kill_evt,
484                 stream_info_builders,
485                 streams_state,
486                 card_index,
487                 control_tube,
488             )
489         }));
490 
491         Ok(())
492     }
493 
reset(&mut self) -> anyhow::Result<()>494     fn reset(&mut self) -> anyhow::Result<()> {
495         if let Some(worker_thread) = self.worker_thread.take() {
496             let worker = worker_thread.stop().unwrap();
497             self.control_tube = Some(worker.control_tube);
498         }
499 
500         Ok(())
501     }
502 
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>503     fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
504         if let Some(worker_thread) = self.worker_thread.take() {
505             let worker = worker_thread.stop().unwrap();
506             self.control_tube = Some(worker.control_tube);
507             self.snd_data = worker.snd_data;
508             self.streams_state = Some(worker.streams_state);
509             return Ok(Some(BTreeMap::from_iter(
510                 worker.queues.into_iter().enumerate(),
511             )));
512         }
513         Ok(None)
514     }
515 
virtio_wake( &mut self, device_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>516     fn virtio_wake(
517         &mut self,
518         device_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
519     ) -> anyhow::Result<()> {
520         match device_state {
521             None => Ok(()),
522             Some((mem, interrupt, queues)) => {
523                 // TODO: activate is just what we want at the moment, but we should probably move
524                 // it into a "start workers" function to make it obvious that it isn't strictly
525                 // used for activate events.
526                 self.activate(mem, interrupt, queues)?;
527                 Ok(())
528             }
529         }
530     }
531 
virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot>532     fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
533         let streams_state = if let Some(states) = &self.streams_state {
534             let mut state_vec = Vec::new();
535             for state in states {
536                 state_vec.push(state.clone());
537             }
538             Some(state_vec)
539         } else {
540             None
541         };
542         AnySnapshot::to_any(VirtioSndSnapshot {
543             avail_features: self.avail_features,
544             acked_features: self.acked_features,
545             queue_sizes: self.queue_sizes.to_vec(),
546             streams_state,
547             snd_data: self.snd_data.clone(),
548         })
549         .context("failed to Serialize Sound device")
550     }
551 
virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>552     fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
553         let mut deser: VirtioSndSnapshot =
554             AnySnapshot::from_any(data).context("failed to Deserialize Sound device")?;
555         anyhow::ensure!(
556             deser.avail_features == self.avail_features,
557             "avail features doesn't match on restore: expected: {}, got: {}",
558             deser.avail_features,
559             self.avail_features
560         );
561         anyhow::ensure!(
562             deser.queue_sizes == self.queue_sizes.to_vec(),
563             "queue sizes doesn't match on restore: expected: {:?}, got: {:?}",
564             deser.queue_sizes,
565             self.queue_sizes.to_vec()
566         );
567         self.acked_features = deser.acked_features;
568         anyhow::ensure!(
569             deser.snd_data == self.snd_data,
570             "snd data doesn't match on restore: expected: {:?}, got: {:?}",
571             deser.snd_data,
572             self.snd_data
573         );
574         self.acked_features = deser.acked_features;
575         self.streams_state = deser.streams_state.take();
576         Ok(())
577     }
578 }
579 
580 #[derive(PartialEq)]
581 enum LoopState {
582     Continue,
583     Break,
584 }
585 
run_worker( queues: BTreeMap<usize, Queue>, snd_data: SndData, kill_evt: Event, stream_info_builders: Vec<StreamInfoBuilder>, streams_state: Option<Vec<StreamInfoSnapshot>>, card_index: usize, control_tube: Tube, ) -> Result<WorkerReturn, String>586 fn run_worker(
587     queues: BTreeMap<usize, Queue>,
588     snd_data: SndData,
589     kill_evt: Event,
590     stream_info_builders: Vec<StreamInfoBuilder>,
591     streams_state: Option<Vec<StreamInfoSnapshot>>,
592     card_index: usize,
593     control_tube: Tube,
594 ) -> Result<WorkerReturn, String> {
595     let ex = Executor::new().expect("Failed to create an executor");
596     let control_tube = AsyncTube::new(&ex, control_tube).expect("failed to create async snd tube");
597 
598     if snd_data.pcm_info_len() != stream_info_builders.len() {
599         error!(
600             "snd: expected {} streams, got {}",
601             snd_data.pcm_info_len(),
602             stream_info_builders.len(),
603         );
604     }
605     let streams: Vec<AsyncRwLock<StreamInfo>> = stream_info_builders
606         .into_iter()
607         .map(StreamInfoBuilder::build)
608         .map(AsyncRwLock::new)
609         .collect();
610 
611     let (tx_send, mut tx_recv) = mpsc::unbounded();
612     let (rx_send, mut rx_recv) = mpsc::unbounded();
613     let tx_send_clone = tx_send.clone();
614     let rx_send_clone = rx_send.clone();
615     let restore_task = ex.spawn_local(async move {
616         if let Some(states) = &streams_state {
617             let ex = Executor::new().expect("Failed to create an executor");
618             for (stream, state) in streams.iter().zip(states.iter()) {
619                 stream.lock().await.restore(state);
620                 if state.state == VIRTIO_SND_R_PCM_START || state.state == VIRTIO_SND_R_PCM_PREPARE
621                 {
622                     stream
623                         .lock()
624                         .await
625                         .prepare(&ex, &tx_send_clone, &rx_send_clone)
626                         .await
627                         .expect("failed to prepare PCM");
628                 }
629                 if state.state == VIRTIO_SND_R_PCM_START {
630                     stream
631                         .lock()
632                         .await
633                         .start()
634                         .await
635                         .expect("failed to start PCM");
636                 }
637             }
638         }
639         streams
640     });
641     let streams = ex
642         .run_until(restore_task)
643         .expect("failed to restore streams");
644     let streams = Rc::new(AsyncRwLock::new(streams));
645 
646     let mut queues: Vec<(Queue, EventAsync)> = queues
647         .into_values()
648         .map(|q| {
649             let e = q.event().try_clone().expect("Failed to clone queue event");
650             (
651                 q,
652                 EventAsync::new(e, &ex).expect("Failed to create async event for queue"),
653             )
654         })
655         .collect();
656 
657     let (ctrl_queue, mut ctrl_queue_evt) = queues.remove(0);
658     let ctrl_queue = Rc::new(AsyncRwLock::new(ctrl_queue));
659     let (_event_queue, _event_queue_evt) = queues.remove(0);
660     let (tx_queue, tx_queue_evt) = queues.remove(0);
661     let (rx_queue, rx_queue_evt) = queues.remove(0);
662 
663     let tx_queue = Rc::new(AsyncRwLock::new(tx_queue));
664     let rx_queue = Rc::new(AsyncRwLock::new(rx_queue));
665 
666     // Exit if the kill event is triggered.
667     let f_kill = async_utils::await_and_exit(&ex, kill_evt).fuse();
668 
669     pin_mut!(f_kill);
670 
671     loop {
672         if run_worker_once(
673             &ex,
674             &streams,
675             &snd_data,
676             &mut f_kill,
677             ctrl_queue.clone(),
678             &mut ctrl_queue_evt,
679             tx_queue.clone(),
680             &tx_queue_evt,
681             tx_send.clone(),
682             &mut tx_recv,
683             rx_queue.clone(),
684             &rx_queue_evt,
685             rx_send.clone(),
686             &mut rx_recv,
687             card_index,
688             &control_tube,
689         ) == LoopState::Break
690         {
691             break;
692         }
693 
694         if let Err(e) = reset_streams(
695             &ex,
696             &streams,
697             &tx_queue,
698             &mut tx_recv,
699             &rx_queue,
700             &mut rx_recv,
701         ) {
702             error!("Error reset streams: {}", e);
703             break;
704         }
705     }
706     let streams_state_task = ex.spawn_local(async move {
707         let mut v = Vec::new();
708         for stream in streams.read_lock().await.iter() {
709             v.push(stream.read_lock().await.snapshot());
710         }
711         v
712     });
713     let streams_state = ex
714         .run_until(streams_state_task)
715         .expect("failed to save streams state");
716     let ctrl_queue = match Rc::try_unwrap(ctrl_queue) {
717         Ok(q) => q.into_inner(),
718         Err(_) => panic!("Too many refs to ctrl_queue"),
719     };
720     let tx_queue = match Rc::try_unwrap(tx_queue) {
721         Ok(q) => q.into_inner(),
722         Err(_) => panic!("Too many refs to tx_queue"),
723     };
724     let rx_queue = match Rc::try_unwrap(rx_queue) {
725         Ok(q) => q.into_inner(),
726         Err(_) => panic!("Too many refs to rx_queue"),
727     };
728     let queues = vec![ctrl_queue, _event_queue, tx_queue, rx_queue];
729 
730     Ok(WorkerReturn {
731         control_tube: control_tube.into(),
732         queues,
733         snd_data,
734         streams_state,
735     })
736 }
737 
738 struct WorkerReturn {
739     control_tube: Tube,
740     queues: Vec<Queue>,
741     snd_data: SndData,
742     streams_state: Vec<StreamInfoSnapshot>,
743 }
744 
notify_reset_signal(reset_signal: &(AsyncRwLock<bool>, Condvar))745 async fn notify_reset_signal(reset_signal: &(AsyncRwLock<bool>, Condvar)) {
746     let (lock, cvar) = reset_signal;
747     *lock.lock().await = true;
748     cvar.notify_all();
749 }
750 
751 /// Runs all workers once and exit if any worker exit.
752 ///
753 /// Returns [`LoopState::Break`] if the worker `f_kill` exits, or something went
754 /// wrong on shutdown process. The caller should not run the worker again and should exit the main
755 /// loop.
756 ///
757 /// If this function returns [`LoopState::Continue`], the caller can continue the main loop by
758 /// resetting the streams and run the worker again.
run_worker_once( ex: &Executor, streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>, snd_data: &SndData, mut f_kill: &mut (impl FusedFuture<Output = anyhow::Result<()>> + Unpin), ctrl_queue: Rc<AsyncRwLock<Queue>>, ctrl_queue_evt: &mut EventAsync, tx_queue: Rc<AsyncRwLock<Queue>>, tx_queue_evt: &EventAsync, tx_send: mpsc::UnboundedSender<PcmResponse>, tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, rx_queue: Rc<AsyncRwLock<Queue>>, rx_queue_evt: &EventAsync, rx_send: mpsc::UnboundedSender<PcmResponse>, rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, card_index: usize, control_tube: &AsyncTube, ) -> LoopState759 fn run_worker_once(
760     ex: &Executor,
761     streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
762     snd_data: &SndData,
763     mut f_kill: &mut (impl FusedFuture<Output = anyhow::Result<()>> + Unpin),
764     ctrl_queue: Rc<AsyncRwLock<Queue>>,
765     ctrl_queue_evt: &mut EventAsync,
766     tx_queue: Rc<AsyncRwLock<Queue>>,
767     tx_queue_evt: &EventAsync,
768     tx_send: mpsc::UnboundedSender<PcmResponse>,
769     tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
770     rx_queue: Rc<AsyncRwLock<Queue>>,
771     rx_queue_evt: &EventAsync,
772     rx_send: mpsc::UnboundedSender<PcmResponse>,
773     rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
774     card_index: usize,
775     control_tube: &AsyncTube,
776 ) -> LoopState {
777     let tx_send2 = tx_send.clone();
778     let rx_send2 = rx_send.clone();
779 
780     let reset_signal = (AsyncRwLock::new(false), Condvar::new());
781 
782     let f_host_ctrl = handle_ctrl_tube(streams, control_tube, Some(&reset_signal)).fuse();
783 
784     let f_ctrl = handle_ctrl_queue(
785         ex,
786         streams,
787         snd_data,
788         ctrl_queue,
789         ctrl_queue_evt,
790         tx_send,
791         rx_send,
792         card_index,
793         Some(&reset_signal),
794     )
795     .fuse();
796 
797     // TODO(woodychow): Enable this when libcras sends jack connect/disconnect evts
798     // let f_event = handle_event_queue(
799     //     snd_state,
800     //     event_queue,
801     //     event_queue_evt,
802     // );
803     let f_tx = handle_pcm_queue(
804         streams,
805         tx_send2,
806         tx_queue.clone(),
807         tx_queue_evt,
808         card_index,
809         Some(&reset_signal),
810     )
811     .fuse();
812     let f_tx_response = send_pcm_response_worker(tx_queue, tx_recv, Some(&reset_signal)).fuse();
813     let f_rx = handle_pcm_queue(
814         streams,
815         rx_send2,
816         rx_queue.clone(),
817         rx_queue_evt,
818         card_index,
819         Some(&reset_signal),
820     )
821     .fuse();
822     let f_rx_response = send_pcm_response_worker(rx_queue, rx_recv, Some(&reset_signal)).fuse();
823 
824     pin_mut!(
825         f_host_ctrl,
826         f_ctrl,
827         f_tx,
828         f_tx_response,
829         f_rx,
830         f_rx_response
831     );
832 
833     let done = async {
834         select! {
835             res = f_host_ctrl => (res.context("error in handling host control command"), LoopState::Continue),
836             res = f_ctrl => (res.context("error in handling ctrl queue"), LoopState::Continue),
837             res = f_tx => (res.context("error in handling tx queue"), LoopState::Continue),
838             res = f_tx_response => (res.context("error in handling tx response"), LoopState::Continue),
839             res = f_rx => (res.context("error in handling rx queue"), LoopState::Continue),
840             res = f_rx_response => (res.context("error in handling rx response"), LoopState::Continue),
841 
842             // For following workers, do not continue the loop
843             res = f_kill => (res.context("error in await_and_exit"), LoopState::Break),
844         }
845     };
846 
847     match ex.run_until(done) {
848         Ok((res, loop_state)) => {
849             if let Err(e) = res {
850                 error!("Error in worker: {:#}", e);
851             }
852             if loop_state == LoopState::Break {
853                 return LoopState::Break;
854             }
855         }
856         Err(e) => {
857             error!("Error happened in executor: {}", e);
858         }
859     }
860 
861     warn!("Shutting down all workers for reset procedure");
862     block_on(notify_reset_signal(&reset_signal));
863 
864     let shutdown = async {
865         loop {
866             let (res, worker_name) = select!(
867                 res = f_ctrl => (res, "f_ctrl"),
868                 res = f_tx => (res, "f_tx"),
869                 res = f_tx_response => (res, "f_tx_response"),
870                 res = f_rx => (res, "f_rx"),
871                 res = f_rx_response => (res, "f_rx_response"),
872                 complete => break,
873             );
874             match res {
875                 Ok(_) => debug!("Worker {} stopped", worker_name),
876                 Err(e) => error!("Worker {} stopped with error {}", worker_name, e),
877             };
878         }
879     };
880 
881     if let Err(e) = ex.run_until(shutdown) {
882         error!("Error happened in executor while shutdown: {}", e);
883         return LoopState::Break;
884     }
885 
886     LoopState::Continue
887 }
888 
reset_streams( ex: &Executor, streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>, tx_queue: &Rc<AsyncRwLock<Queue>>, tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, rx_queue: &Rc<AsyncRwLock<Queue>>, rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, ) -> Result<(), AsyncError>889 fn reset_streams(
890     ex: &Executor,
891     streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
892     tx_queue: &Rc<AsyncRwLock<Queue>>,
893     tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
894     rx_queue: &Rc<AsyncRwLock<Queue>>,
895     rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
896 ) -> Result<(), AsyncError> {
897     let reset_signal = (AsyncRwLock::new(false), Condvar::new());
898 
899     let do_reset = async {
900         let streams = streams.read_lock().await;
901         for stream_info in &*streams {
902             let mut stream_info = stream_info.lock().await;
903             if stream_info.state == VIRTIO_SND_R_PCM_START {
904                 if let Err(e) = stream_info.stop().await {
905                     error!("Error on stop while resetting stream: {}", e);
906                 }
907             }
908             if stream_info.state == VIRTIO_SND_R_PCM_STOP
909                 || stream_info.state == VIRTIO_SND_R_PCM_PREPARE
910             {
911                 if let Err(e) = stream_info.release().await {
912                     error!("Error on release while resetting stream: {}", e);
913                 }
914             }
915             stream_info.just_reset = true;
916         }
917 
918         notify_reset_signal(&reset_signal).await;
919     };
920 
921     // Run these in a loop to ensure that they will survive until do_reset is finished
922     let f_tx_response = async {
923         while send_pcm_response_worker(tx_queue.clone(), tx_recv, Some(&reset_signal))
924             .await
925             .is_err()
926         {}
927     };
928 
929     let f_rx_response = async {
930         while send_pcm_response_worker(rx_queue.clone(), rx_recv, Some(&reset_signal))
931             .await
932             .is_err()
933         {}
934     };
935 
936     let reset = async {
937         join!(f_tx_response, f_rx_response, do_reset);
938     };
939 
940     ex.run_until(reset)
941 }
942 
943 #[cfg(test)]
944 #[allow(clippy::needless_update)]
945 mod tests {
946     use audio_streams::StreamEffect;
947 
948     use super::*;
949     use crate::virtio::snd::parameters::PCMDeviceParameters;
950 
951     #[test]
test_virtio_snd_new()952     fn test_virtio_snd_new() {
953         let params = Parameters {
954             num_output_devices: 3,
955             num_input_devices: 2,
956             num_output_streams: 3,
957             num_input_streams: 2,
958             output_device_config: vec![PCMDeviceParameters {
959                 effects: Some(vec![StreamEffect::EchoCancellation]),
960                 ..PCMDeviceParameters::default()
961             }],
962             input_device_config: vec![PCMDeviceParameters {
963                 effects: Some(vec![StreamEffect::EchoCancellation]),
964                 ..PCMDeviceParameters::default()
965             }],
966             ..Default::default()
967         };
968 
969         let (t0, _t1) = Tube::pair().expect("failed to create tube");
970         let res = VirtioSnd::new(123, params, t0).unwrap();
971 
972         // Default values
973         assert_eq!(res.snd_data.jack_info.len(), 0);
974         assert_eq!(res.acked_features, 0);
975         assert_eq!(res.worker_thread.is_none(), true);
976 
977         assert_eq!(res.avail_features, 123); // avail_features must be equal to the input
978         assert_eq!(res.cfg.jacks.to_native(), 0);
979         assert_eq!(res.cfg.streams.to_native(), 13); // (Output = 3*3) + (Input = 2*2)
980         assert_eq!(res.cfg.chmaps.to_native(), 11); // (Output = 3*3) + (Input = 2*1)
981 
982         // Check snd_data.pcm_info
983         assert_eq!(res.snd_data.pcm_info.len(), 13);
984         // Check hda_fn_nid (PCM Device number)
985         let expected_hda_fn_nid = [0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 1, 1];
986         for (i, pcm_info) in res.snd_data.pcm_info.iter().enumerate() {
987             assert_eq!(
988                 pcm_info.hdr.hda_fn_nid.to_native(),
989                 expected_hda_fn_nid[i],
990                 "pcm_info index {} incorrect hda_fn_nid",
991                 i
992             );
993         }
994         // First 9 devices must be OUTPUT
995         for i in 0..9 {
996             assert_eq!(
997                 res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_OUTPUT,
998                 "pcm_info index {} incorrect direction",
999                 i
1000             );
1001         }
1002         // Next 4 devices must be INPUT
1003         for i in 9..13 {
1004             assert_eq!(
1005                 res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_INPUT,
1006                 "pcm_info index {} incorrect direction",
1007                 i
1008             );
1009         }
1010 
1011         // Check snd_data.chmap_info
1012         assert_eq!(res.snd_data.chmap_info.len(), 11);
1013         let expected_hda_fn_nid = [0, 1, 2, 0, 1, 0, 1, 2, 0, 1, 2];
1014         // Check hda_fn_nid (PCM Device number)
1015         for (i, chmap_info) in res.snd_data.chmap_info.iter().enumerate() {
1016             assert_eq!(
1017                 chmap_info.hdr.hda_fn_nid.to_native(),
1018                 expected_hda_fn_nid[i],
1019                 "chmap_info index {} incorrect hda_fn_nid",
1020                 i
1021             );
1022         }
1023     }
1024 
1025     #[test]
test_resize_parameters_pcm_device_config_truncate()1026     fn test_resize_parameters_pcm_device_config_truncate() {
1027         // If pcm_device_config is larger than number of devices, it will be truncated
1028         let params = Parameters {
1029             num_output_devices: 1,
1030             num_input_devices: 1,
1031             output_device_config: vec![PCMDeviceParameters::default(); 3],
1032             input_device_config: vec![PCMDeviceParameters::default(); 3],
1033             ..Parameters::default()
1034         };
1035         let params = resize_parameters_pcm_device_config(params);
1036         assert_eq!(params.output_device_config.len(), 1);
1037         assert_eq!(params.input_device_config.len(), 1);
1038     }
1039 
1040     #[test]
test_resize_parameters_pcm_device_config_extend()1041     fn test_resize_parameters_pcm_device_config_extend() {
1042         let params = Parameters {
1043             num_output_devices: 3,
1044             num_input_devices: 2,
1045             num_output_streams: 3,
1046             num_input_streams: 2,
1047             output_device_config: vec![PCMDeviceParameters {
1048                 effects: Some(vec![StreamEffect::EchoCancellation]),
1049                 ..PCMDeviceParameters::default()
1050             }],
1051             input_device_config: vec![PCMDeviceParameters {
1052                 effects: Some(vec![StreamEffect::EchoCancellation]),
1053                 ..PCMDeviceParameters::default()
1054             }],
1055             ..Default::default()
1056         };
1057 
1058         let params = resize_parameters_pcm_device_config(params);
1059 
1060         // Check output_device_config correctly extended
1061         assert_eq!(
1062             params.output_device_config,
1063             vec![
1064                 PCMDeviceParameters {
1065                     // Keep from the parameters
1066                     effects: Some(vec![StreamEffect::EchoCancellation]),
1067                     ..PCMDeviceParameters::default()
1068                 },
1069                 PCMDeviceParameters::default(), // Extended with default
1070                 PCMDeviceParameters::default(), // Extended with default
1071             ]
1072         );
1073 
1074         // Check input_device_config correctly extended
1075         assert_eq!(
1076             params.input_device_config,
1077             vec![
1078                 PCMDeviceParameters {
1079                     // Keep from the parameters
1080                     effects: Some(vec![StreamEffect::EchoCancellation]),
1081                     ..PCMDeviceParameters::default()
1082                 },
1083                 PCMDeviceParameters::default(), // Extended with default
1084             ]
1085         );
1086     }
1087 }
1088