1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 // virtio-sound spec: https://github.com/oasis-tcs/virtio-spec/blob/master/virtio-sound.tex
6 
7 use std::cell::RefCell;
8 use std::fmt;
9 use std::io;
10 use std::num::ParseIntError;
11 use std::rc::Rc;
12 use std::str::{FromStr, ParseBoolError};
13 use std::thread;
14 
15 use anyhow::Context;
16 use audio_streams::{SampleFormat, StreamSource};
17 use base::{
18     error, set_rt_prio_limit, set_rt_round_robin, warn, Error as SysError, Event, RawDescriptor,
19 };
20 use cros_async::sync::Mutex as AsyncMutex;
21 use cros_async::{AsyncError, EventAsync, Executor};
22 use data_model::DataInit;
23 use futures::channel::{
24     mpsc,
25     oneshot::{self, Canceled},
26 };
27 use futures::{pin_mut, select, Future, FutureExt, TryFutureExt};
28 use libcras::{BoxError, CrasClient, CrasClientType, CrasSocketType};
29 use thiserror::Error as ThisError;
30 use vm_memory::GuestMemory;
31 
32 use crate::virtio::snd::common::*;
33 use crate::virtio::snd::constants::*;
34 use crate::virtio::snd::layout::*;
35 use crate::virtio::{
36     async_utils, copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, VirtioDevice,
37     Writer, TYPE_SOUND,
38 };
39 
40 pub mod async_funcs;
41 use crate::virtio::snd::cras_backend::async_funcs::*;
42 
43 // control + event + tx + rx queue
44 pub const MAX_QUEUE_NUM: usize = 4;
45 pub const MAX_VRING_LEN: u16 = 1024;
46 const AUDIO_THREAD_RTPRIO: u16 = 10; // Matches other cros audio clients.
47 
48 #[derive(ThisError, Debug)]
49 pub enum Error {
50     /// next_async failed.
51     #[error("Failed to read descriptor asynchronously: {0}")]
52     Async(AsyncError),
53     /// Creating stream failed.
54     #[error("Failed to create stream: {0}")]
55     CreateStream(BoxError),
56     /// Creating kill event failed.
57     #[error("Failed to create kill event: {0}")]
58     CreateKillEvent(SysError),
59     /// Creating WaitContext failed.
60     #[error("Failed to create wait context: {0}")]
61     CreateWaitContext(SysError),
62     /// Cloning kill event failed.
63     #[error("Failed to clone kill event: {0}")]
64     CloneKillEvent(SysError),
65     /// Descriptor chain was invalid.
66     #[error("Failed to valildate descriptor chain: {0}")]
67     DescriptorChain(DescriptorError),
68     // Future error.
69     #[error("Unexpected error. Done was not triggered before dropped: {0}")]
70     DoneNotTriggered(Canceled),
71     /// Error reading message from queue.
72     #[error("Failed to read message: {0}")]
73     ReadMessage(io::Error),
74     /// Failed writing a response to a control message.
75     #[error("Failed to write message response: {0}")]
76     WriteResponse(io::Error),
77     /// Libcras error.
78     #[error("Error in libcras: {0}")]
79     Libcras(libcras::Error),
80     // Mpsc read error.
81     #[error("Error in mpsc: {0}")]
82     MpscSend(futures::channel::mpsc::SendError),
83     // Oneshot send error.
84     #[error("Error in oneshot send")]
85     OneshotSend(()),
86     /// Stream not found.
87     #[error("stream id ({0}) < num_streams ({1})")]
88     StreamNotFound(usize, usize),
89     /// Fetch buffer error
90     #[error("Failed to get buffer from CRAS: {0}")]
91     FetchBuffer(BoxError),
92     /// Invalid buffer size
93     #[error("Invalid buffer size")]
94     InvalidBufferSize,
95     /// IoError
96     #[error("I/O failed: {0}")]
97     Io(io::Error),
98     /// Operation not supported.
99     #[error("Operation not supported")]
100     OperationNotSupported,
101     /// Writing to a buffer in the guest failed.
102     #[error("failed to write to buffer: {0}")]
103     WriteBuffer(io::Error),
104     /// Failed to parse parameters.
105     #[error("Invalid cras snd parameter: {0}")]
106     UnknownParameter(String),
107     /// Unknown cras snd parameter value.
108     #[error("Invalid cras snd parameter value ({0}): {1}")]
109     InvalidParameterValue(String, String),
110     /// Failed to parse bool value.
111     #[error("Invalid bool value: {0}")]
112     InvalidBoolValue(ParseBoolError),
113     /// Failed to parse int value.
114     #[error("Invalid int value: {0}")]
115     InvalidIntValue(ParseIntError),
116     // Invalid PCM worker state.
117     #[error("Invalid PCM worker state")]
118     InvalidPCMWorkerState,
119 }
120 
121 /// Holds the parameters for a cras sound device
122 #[derive(Debug, Clone)]
123 pub struct Parameters {
124     pub capture: bool,
125     pub client_type: CrasClientType,
126     pub socket_type: CrasSocketType,
127     pub num_output_streams: u32,
128     pub num_input_streams: u32,
129 }
130 
131 impl Default for Parameters {
default() -> Self132     fn default() -> Self {
133         Parameters {
134             capture: false,
135             client_type: CrasClientType::CRAS_CLIENT_TYPE_CROSVM,
136             socket_type: CrasSocketType::Unified,
137             num_output_streams: 1,
138             num_input_streams: 1,
139         }
140     }
141 }
142 
143 impl FromStr for Parameters {
144     type Err = Error;
from_str(s: &str) -> std::result::Result<Self, Self::Err>145     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
146         let mut params: Parameters = Default::default();
147         let opts = s
148             .split(',')
149             .map(|frag| frag.split('='))
150             .map(|mut kv| (kv.next().unwrap_or(""), kv.next().unwrap_or("")));
151 
152         for (k, v) in opts {
153             match k {
154                 "capture" => {
155                     params.capture = v.parse::<bool>().map_err(Error::InvalidBoolValue)?;
156                 }
157                 "client_type" => {
158                     params.client_type = v.parse().map_err(|e: libcras::CrasSysError| {
159                         Error::InvalidParameterValue(v.to_string(), e.to_string())
160                     })?;
161                 }
162                 "socket_type" => {
163                     params.socket_type = v.parse().map_err(|e: libcras::Error| {
164                         Error::InvalidParameterValue(v.to_string(), e.to_string())
165                     })?;
166                 }
167                 "num_output_streams" => {
168                     params.num_output_streams = v.parse::<u32>().map_err(Error::InvalidIntValue)?;
169                 }
170                 "num_input_streams" => {
171                     params.num_input_streams = v.parse::<u32>().map_err(Error::InvalidIntValue)?;
172                 }
173                 _ => {
174                     return Err(Error::UnknownParameter(k.to_string()));
175                 }
176             }
177         }
178 
179         Ok(params)
180     }
181 }
182 
183 pub enum DirectionalStream {
184     Input(Box<dyn audio_streams::capture::AsyncCaptureBufferStream>),
185     Output(Box<dyn audio_streams::AsyncPlaybackBufferStream>),
186 }
187 
188 #[derive(Copy, Clone, std::cmp::PartialEq)]
189 pub enum WorkerStatus {
190     Pause = 0,
191     Running = 1,
192     Quit = 2,
193 }
194 
195 pub struct StreamInfo<'a> {
196     client: Option<CrasClient<'a>>,
197     channels: u8,
198     format: SampleFormat,
199     frame_rate: u32,
200     buffer_bytes: usize,
201     period_bytes: usize,
202     direction: u8, // VIRTIO_SND_D_*
203     state: u32,    // VIRTIO_SND_R_PCM_SET_PARAMS -> VIRTIO_SND_R_PCM_STOP, or 0 (uninitialized)
204 
205     // Worker related
206     status_mutex: Rc<AsyncMutex<WorkerStatus>>,
207     sender: Option<mpsc::UnboundedSender<DescriptorChain>>,
208     worker_future: Option<Box<dyn Future<Output = Result<(), Error>> + Unpin>>,
209 }
210 
211 impl fmt::Debug for StreamInfo<'_> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result212     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213         f.debug_struct("StreamInfo")
214             .field("channels", &self.channels)
215             .field("format", &self.format)
216             .field("frame_rate", &self.frame_rate)
217             .field("buffer_bytes", &self.buffer_bytes)
218             .field("period_bytes", &self.period_bytes)
219             .field("direction", &get_virtio_direction_name(self.direction))
220             .field("state", &get_virtio_snd_r_pcm_cmd_name(self.state))
221             .finish()
222     }
223 }
224 
225 impl Default for StreamInfo<'_> {
default() -> Self226     fn default() -> Self {
227         StreamInfo {
228             client: None,
229             channels: 0,
230             format: SampleFormat::U8,
231             frame_rate: 0,
232             buffer_bytes: 0,
233             period_bytes: 0,
234             direction: 0,
235             state: 0,
236             status_mutex: Rc::new(AsyncMutex::new(WorkerStatus::Pause)),
237             sender: None,
238             worker_future: None,
239         }
240     }
241 }
242 
243 // Stores constant data
244 pub struct SndData {
245     jack_info: Vec<virtio_snd_jack_info>,
246     pcm_info: Vec<virtio_snd_pcm_info>,
247     chmap_info: Vec<virtio_snd_chmap_info>,
248 }
249 
250 impl SndData {
pcm_info_len(&self) -> usize251     pub fn pcm_info_len(&self) -> usize {
252         self.pcm_info.len()
253     }
254 }
255 
256 const SUPPORTED_FORMATS: u64 = 1 << VIRTIO_SND_PCM_FMT_U8
257     | 1 << VIRTIO_SND_PCM_FMT_S16
258     | 1 << VIRTIO_SND_PCM_FMT_S24
259     | 1 << VIRTIO_SND_PCM_FMT_S32;
260 const SUPPORTED_FRAME_RATES: u64 = 1 << VIRTIO_SND_PCM_RATE_8000
261     | 1 << VIRTIO_SND_PCM_RATE_11025
262     | 1 << VIRTIO_SND_PCM_RATE_16000
263     | 1 << VIRTIO_SND_PCM_RATE_22050
264     | 1 << VIRTIO_SND_PCM_RATE_32000
265     | 1 << VIRTIO_SND_PCM_RATE_44100
266     | 1 << VIRTIO_SND_PCM_RATE_48000;
267 
268 // Response from pcm_worker to pcm_queue
269 pub struct PcmResponse {
270     desc_index: u16,
271     status: virtio_snd_pcm_status, // response to the pcm message
272     writer: Writer,
273     done: Option<oneshot::Sender<()>>, // when pcm response is written to the queue
274 }
275 
276 impl<'a> StreamInfo<'a> {
prepare( &mut self, ex: &Executor, mem: GuestMemory, tx_send: &mpsc::UnboundedSender<PcmResponse>, rx_send: &mpsc::UnboundedSender<PcmResponse>, params: &Parameters, ) -> Result<(), Error>277     async fn prepare(
278         &mut self,
279         ex: &Executor,
280         mem: GuestMemory,
281         tx_send: &mpsc::UnboundedSender<PcmResponse>,
282         rx_send: &mpsc::UnboundedSender<PcmResponse>,
283         params: &Parameters,
284     ) -> Result<(), Error> {
285         if self.state != VIRTIO_SND_R_PCM_SET_PARAMS
286             && self.state != VIRTIO_SND_R_PCM_PREPARE
287             && self.state != VIRTIO_SND_R_PCM_RELEASE
288         {
289             error!(
290                 "Invalid PCM state transition from {} to {}",
291                 get_virtio_snd_r_pcm_cmd_name(self.state),
292                 get_virtio_snd_r_pcm_cmd_name(VIRTIO_SND_R_PCM_PREPARE)
293             );
294             return Err(Error::OperationNotSupported);
295         }
296         if self.state == VIRTIO_SND_R_PCM_PREPARE {
297             self.release_worker().await?;
298         }
299         let frame_size = self.channels as usize * self.format.sample_bytes();
300         if self.period_bytes % frame_size != 0 {
301             error!("period_bytes must be divisible by frame size");
302             return Err(Error::OperationNotSupported);
303         }
304         if self.client.is_none() {
305             let mut client = CrasClient::with_type(params.socket_type).map_err(Error::Libcras)?;
306             if params.capture {
307                 client.enable_cras_capture();
308             }
309             client.set_client_type(params.client_type);
310             self.client = Some(client);
311         }
312         // (*)
313         // `buffer_size` in `audio_streams` API indicates the buffer size in bytes that the stream
314         // consumes (or transmits) each time (next_playback/capture_buffer).
315         // `period_bytes` in virtio-snd device (or ALSA) indicates the device transmits (or
316         // consumes) for each PCM message.
317         // Therefore, `buffer_size` in `audio_streams` == `period_bytes` in virtio-snd.
318         let (stream, pcm_sender) = match self.direction {
319             VIRTIO_SND_D_OUTPUT => (
320                 DirectionalStream::Output(
321                     self.client
322                         .as_mut()
323                         .unwrap()
324                         .new_async_playback_stream(
325                             self.channels as usize,
326                             self.format,
327                             self.frame_rate,
328                             // See (*)
329                             self.period_bytes / frame_size,
330                             ex,
331                         )
332                         .map_err(Error::CreateStream)?
333                         .1,
334                 ),
335                 tx_send.clone(),
336             ),
337             VIRTIO_SND_D_INPUT => {
338                 (
339                     DirectionalStream::Input(
340                         self.client
341                             .as_mut()
342                             .unwrap()
343                             .new_async_capture_stream(
344                                 self.channels as usize,
345                                 self.format,
346                                 self.frame_rate,
347                                 // See (*)
348                                 self.period_bytes / frame_size,
349                                 &[],
350                                 ex,
351                             )
352                             .map_err(Error::CreateStream)?
353                             .1,
354                     ),
355                     rx_send.clone(),
356                 )
357             }
358             _ => unreachable!(),
359         };
360 
361         let (sender, receiver) = mpsc::unbounded();
362         self.sender = Some(sender);
363         self.state = VIRTIO_SND_R_PCM_PREPARE;
364 
365         self.status_mutex = Rc::new(AsyncMutex::new(WorkerStatus::Pause));
366         let f = start_pcm_worker(
367             ex.clone(),
368             stream,
369             receiver,
370             self.status_mutex.clone(),
371             mem,
372             pcm_sender,
373             self.period_bytes,
374         );
375         self.worker_future = Some(Box::new(ex.spawn_local(f).into_future()));
376         Ok(())
377     }
378 
start(&mut self) -> Result<(), Error>379     async fn start(&mut self) -> Result<(), Error> {
380         if self.state != VIRTIO_SND_R_PCM_PREPARE && self.state != VIRTIO_SND_R_PCM_STOP {
381             error!(
382                 "Invalid PCM state transition from {} to {}",
383                 get_virtio_snd_r_pcm_cmd_name(self.state),
384                 get_virtio_snd_r_pcm_cmd_name(VIRTIO_SND_R_PCM_START)
385             );
386             return Err(Error::OperationNotSupported);
387         }
388         self.state = VIRTIO_SND_R_PCM_START;
389         *self.status_mutex.lock().await = WorkerStatus::Running;
390         Ok(())
391     }
392 
stop(&mut self) -> Result<(), Error>393     async fn stop(&mut self) -> Result<(), Error> {
394         if self.state != VIRTIO_SND_R_PCM_START {
395             error!(
396                 "Invalid PCM state transition from {} to {}",
397                 get_virtio_snd_r_pcm_cmd_name(self.state),
398                 get_virtio_snd_r_pcm_cmd_name(VIRTIO_SND_R_PCM_STOP)
399             );
400             return Err(Error::OperationNotSupported);
401         }
402         self.state = VIRTIO_SND_R_PCM_STOP;
403         *self.status_mutex.lock().await = WorkerStatus::Pause;
404         Ok(())
405     }
406 
release(&mut self) -> Result<(), Error>407     async fn release(&mut self) -> Result<(), Error> {
408         if self.state != VIRTIO_SND_R_PCM_PREPARE && self.state != VIRTIO_SND_R_PCM_STOP {
409             error!(
410                 "Invalid PCM state transition from {} to {}",
411                 get_virtio_snd_r_pcm_cmd_name(self.state),
412                 get_virtio_snd_r_pcm_cmd_name(VIRTIO_SND_R_PCM_RELEASE)
413             );
414             return Err(Error::OperationNotSupported);
415         }
416         self.state = VIRTIO_SND_R_PCM_RELEASE;
417         self.release_worker().await?;
418         self.client = None;
419         Ok(())
420     }
421 
release_worker(&mut self) -> Result<(), Error>422     async fn release_worker(&mut self) -> Result<(), Error> {
423         *self.status_mutex.lock().await = WorkerStatus::Quit;
424         match self.sender.take() {
425             Some(s) => s.close_channel(),
426             None => (),
427         }
428         match self.worker_future.take() {
429             Some(f) => f.await?,
430             None => (),
431         }
432         Ok(())
433     }
434 }
435 
436 pub struct VirtioSndCras {
437     cfg: virtio_snd_config,
438     avail_features: u64,
439     acked_features: u64,
440     queue_sizes: Box<[u16]>,
441     worker_threads: Vec<thread::JoinHandle<()>>,
442     kill_evt: Option<Event>,
443     params: Parameters,
444 }
445 
446 impl VirtioSndCras {
new(base_features: u64, params: Parameters) -> Result<VirtioSndCras, Error>447     pub fn new(base_features: u64, params: Parameters) -> Result<VirtioSndCras, Error> {
448         let cfg = hardcoded_virtio_snd_config(¶ms);
449 
450         let avail_features = base_features;
451 
452         Ok(VirtioSndCras {
453             cfg,
454             avail_features,
455             acked_features: 0,
456             queue_sizes: vec![MAX_VRING_LEN; MAX_QUEUE_NUM].into_boxed_slice(),
457             worker_threads: Vec::new(),
458             kill_evt: None,
459             params,
460         })
461     }
462 }
463 
464 // To be used with hardcoded_snd_data
hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config465 pub fn hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config {
466     virtio_snd_config {
467         jacks: 0.into(),
468         streams: (params.num_output_streams + params.num_input_streams).into(),
469         chmaps: 4.into(),
470     }
471 }
472 
473 // To be used with hardcoded_virtio_snd_config
hardcoded_snd_data(params: &Parameters) -> SndData474 pub fn hardcoded_snd_data(params: &Parameters) -> SndData {
475     let jack_info: Vec<virtio_snd_jack_info> = Vec::new();
476     let mut pcm_info: Vec<virtio_snd_pcm_info> = Vec::new();
477     let mut chmap_info: Vec<virtio_snd_chmap_info> = Vec::new();
478 
479     for _ in 0..params.num_output_streams {
480         pcm_info.push(virtio_snd_pcm_info {
481             hdr: virtio_snd_info {
482                 hda_fn_nid: 0.into(),
483             },
484             features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
485             formats: SUPPORTED_FORMATS.into(),
486             rates: SUPPORTED_FRAME_RATES.into(),
487             direction: VIRTIO_SND_D_OUTPUT,
488             channels_min: 1,
489             channels_max: 6,
490             padding: [0; 5],
491         });
492     }
493     for _ in 0..params.num_input_streams {
494         pcm_info.push(virtio_snd_pcm_info {
495             hdr: virtio_snd_info {
496                 hda_fn_nid: 0.into(),
497             },
498             features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
499             formats: SUPPORTED_FORMATS.into(),
500             rates: SUPPORTED_FRAME_RATES.into(),
501             direction: VIRTIO_SND_D_INPUT,
502             channels_min: 1,
503             channels_max: 2,
504             padding: [0; 5],
505         });
506     }
507 
508     // Use stereo channel map.
509     let mut positions = [VIRTIO_SND_CHMAP_NONE; VIRTIO_SND_CHMAP_MAX_SIZE];
510     positions[0] = VIRTIO_SND_CHMAP_FL;
511     positions[1] = VIRTIO_SND_CHMAP_FR;
512 
513     chmap_info.push(virtio_snd_chmap_info {
514         hdr: virtio_snd_info {
515             hda_fn_nid: 0.into(),
516         },
517         direction: VIRTIO_SND_D_OUTPUT,
518         channels: 2,
519         positions,
520     });
521     chmap_info.push(virtio_snd_chmap_info {
522         hdr: virtio_snd_info {
523             hda_fn_nid: 0.into(),
524         },
525         direction: VIRTIO_SND_D_INPUT,
526         channels: 2,
527         positions,
528     });
529     positions[2] = VIRTIO_SND_CHMAP_RL;
530     positions[3] = VIRTIO_SND_CHMAP_RR;
531     chmap_info.push(virtio_snd_chmap_info {
532         hdr: virtio_snd_info {
533             hda_fn_nid: 0.into(),
534         },
535         direction: VIRTIO_SND_D_OUTPUT,
536         channels: 4,
537         positions,
538     });
539     positions[2] = VIRTIO_SND_CHMAP_FC;
540     positions[3] = VIRTIO_SND_CHMAP_LFE;
541     positions[4] = VIRTIO_SND_CHMAP_RL;
542     positions[5] = VIRTIO_SND_CHMAP_RR;
543     chmap_info.push(virtio_snd_chmap_info {
544         hdr: virtio_snd_info {
545             hda_fn_nid: 0.into(),
546         },
547         direction: VIRTIO_SND_D_OUTPUT,
548         channels: 6,
549         positions,
550     });
551 
552     SndData {
553         jack_info,
554         pcm_info,
555         chmap_info,
556     }
557 }
558 
559 impl VirtioDevice for VirtioSndCras {
keep_rds(&self) -> Vec<RawDescriptor>560     fn keep_rds(&self) -> Vec<RawDescriptor> {
561         Vec::new()
562     }
563 
device_type(&self) -> u32564     fn device_type(&self) -> u32 {
565         TYPE_SOUND
566     }
567 
queue_max_sizes(&self) -> &[u16]568     fn queue_max_sizes(&self) -> &[u16] {
569         &self.queue_sizes
570     }
571 
features(&self) -> u64572     fn features(&self) -> u64 {
573         self.avail_features
574     }
575 
ack_features(&mut self, mut v: u64)576     fn ack_features(&mut self, mut v: u64) {
577         // Check if the guest is ACK'ing a feature that we didn't claim to have.
578         let unrequested_features = v & !self.avail_features;
579         if unrequested_features != 0 {
580             warn!("virtio_fs got unknown feature ack: {:x}", v);
581 
582             // Don't count these features as acked.
583             v &= !unrequested_features;
584         }
585         self.acked_features |= v;
586     }
587 
read_config(&self, offset: u64, data: &mut [u8])588     fn read_config(&self, offset: u64, data: &mut [u8]) {
589         copy_config(data, 0, self.cfg.as_slice(), offset)
590     }
591 
activate( &mut self, guest_mem: GuestMemory, interrupt: Interrupt, queues: Vec<Queue>, queue_evts: Vec<Event>, )592     fn activate(
593         &mut self,
594         guest_mem: GuestMemory,
595         interrupt: Interrupt,
596         queues: Vec<Queue>,
597         queue_evts: Vec<Event>,
598     ) {
599         if queues.len() != self.queue_sizes.len() || queue_evts.len() != self.queue_sizes.len() {
600             error!(
601                 "snd: expected {} queues, got {}",
602                 self.queue_sizes.len(),
603                 queues.len()
604             );
605         }
606 
607         let (self_kill_evt, kill_evt) =
608             match Event::new().and_then(|evt| Ok((evt.try_clone()?, evt))) {
609                 Ok(v) => v,
610                 Err(e) => {
611                     error!("failed to create kill Event pair: {}", e);
612                     return;
613                 }
614             };
615         self.kill_evt = Some(self_kill_evt);
616 
617         let params = self.params.clone();
618 
619         let worker_result = thread::Builder::new()
620             .name("virtio_snd w".to_string())
621             .spawn(move || {
622                 if let Err(e) = set_rt_prio_limit(u64::from(AUDIO_THREAD_RTPRIO))
623                     .and_then(|_| set_rt_round_robin(i32::from(AUDIO_THREAD_RTPRIO)))
624                 {
625                     warn!("Failed to set audio thread to real time: {}", e);
626                 }
627 
628                 if let Err(err_string) = run_worker(
629                     interrupt,
630                     queues,
631                     guest_mem,
632                     hardcoded_snd_data(¶ms),
633                     queue_evts,
634                     kill_evt,
635                     params,
636                 ) {
637                     error!("{}", err_string);
638                 }
639             });
640 
641         match worker_result {
642             Err(e) => {
643                 error!("failed to spawn virtio_snd worker: {}", e);
644                 return;
645             }
646             Ok(join_handle) => self.worker_threads.push(join_handle),
647         }
648     }
649 
reset(&mut self) -> bool650     fn reset(&mut self) -> bool {
651         if let Some(kill_evt) = self.kill_evt.take() {
652             // Ignore the result because there is nothing we can do about it.
653             let _ = kill_evt.write(1);
654         }
655 
656         true
657     }
658 }
659 
660 impl Drop for VirtioSndCras {
drop(&mut self)661     fn drop(&mut self) {
662         self.reset();
663     }
664 }
665 
run_worker( interrupt: Interrupt, mut queues: Vec<Queue>, mem: GuestMemory, snd_data: SndData, queue_evts: Vec<Event>, kill_evt: Event, params: Parameters, ) -> Result<(), String>666 fn run_worker(
667     interrupt: Interrupt,
668     mut queues: Vec<Queue>,
669     mem: GuestMemory,
670     snd_data: SndData,
671     queue_evts: Vec<Event>,
672     kill_evt: Event,
673     params: Parameters,
674 ) -> Result<(), String> {
675     let ex = Executor::new().expect("Failed to create an executor");
676 
677     let mut streams: Vec<AsyncMutex<StreamInfo>> = Vec::new();
678     streams.resize_with(snd_data.pcm_info.len(), Default::default);
679     let streams = Rc::new(AsyncMutex::new(streams));
680 
681     let interrupt = Rc::new(RefCell::new(interrupt));
682     let interrupt_ref = &*interrupt.borrow();
683 
684     let ctrl_queue = queues.remove(0);
685     let _event_queue = queues.remove(0);
686     let tx_queue = Rc::new(AsyncMutex::new(queues.remove(0)));
687     let rx_queue = Rc::new(AsyncMutex::new(queues.remove(0)));
688 
689     let mut evts_async: Vec<EventAsync> = queue_evts
690         .into_iter()
691         .map(|e| EventAsync::new(e.0, &ex).expect("Failed to create async event for queue"))
692         .collect();
693 
694     let ctrl_queue_evt = evts_async.remove(0);
695     let _event_queue_evt = evts_async.remove(0);
696     let tx_queue_evt = evts_async.remove(0);
697     let rx_queue_evt = evts_async.remove(0);
698 
699     let (tx_send, mut tx_recv) = mpsc::unbounded();
700     let (rx_send, mut rx_recv) = mpsc::unbounded();
701     let tx_send2 = tx_send.clone();
702     let rx_send2 = rx_send.clone();
703 
704     let f_ctrl = handle_ctrl_queue(
705         &ex,
706         &mem,
707         &streams,
708         &snd_data,
709         ctrl_queue,
710         ctrl_queue_evt,
711         interrupt_ref,
712         tx_send,
713         rx_send,
714         ¶ms,
715     );
716 
717     // TODO(woodychow): Enable this when libcras sends jack connect/disconnect evts
718     // let f_event = handle_event_queue(
719     //     &mem,
720     //     snd_state,
721     //     event_queue,
722     //     event_queue_evt,
723     //     interrupt,
724     // );
725 
726     let f_tx = handle_pcm_queue(&mem, &streams, tx_send2, &tx_queue, tx_queue_evt);
727 
728     let f_tx_response = send_pcm_response_worker(&mem, &tx_queue, interrupt_ref, &mut tx_recv);
729 
730     let f_rx = handle_pcm_queue(&mem, &streams, rx_send2, &rx_queue, rx_queue_evt);
731 
732     let f_rx_response = send_pcm_response_worker(&mem, &rx_queue, interrupt_ref, &mut rx_recv);
733 
734     let f_resample = async_utils::handle_irq_resample(&ex, interrupt.clone());
735 
736     // Exit if the kill event is triggered.
737     let f_kill = async_utils::await_and_exit(&ex, kill_evt);
738 
739     pin_mut!(
740         f_ctrl,
741         f_tx,
742         f_tx_response,
743         f_rx,
744         f_rx_response,
745         f_resample,
746         f_kill
747     );
748 
749     let done = async {
750         select! {
751             res = f_ctrl.fuse() => res.context("error in handling ctrl queue"),
752             res = f_tx.fuse() => res.context("error in handling tx queue"),
753             res = f_tx_response.fuse() => res.context("error in handling tx response"),
754             res = f_rx.fuse() => res.context("error in handling rx queue"),
755             res = f_rx_response.fuse() => res.context("error in handling rx response"),
756             res = f_resample.fuse() => res.context("error in handle_irq_resample"),
757             res = f_kill.fuse() => res.context("error in await_and_exit"),
758         }
759     };
760     match ex.run_until(done) {
761         Ok(Ok(())) => {}
762         Ok(Err(e)) => error!("Error in worker: {}", e),
763         Err(e) => error!("Error happened in executor: {}", e),
764     }
765 
766     Ok(())
767 }
768 #[cfg(test)]
769 mod tests {
770     use super::*;
771     #[test]
parameters_fromstr()772     fn parameters_fromstr() {
773         fn check_success(
774             s: &str,
775             capture: bool,
776             client_type: CrasClientType,
777             socket_type: CrasSocketType,
778             num_output_streams: u32,
779             num_input_streams: u32,
780         ) {
781             let params = s.parse::<Parameters>().expect("parse should have succeded");
782             assert_eq!(params.capture, capture);
783             assert_eq!(params.client_type, client_type);
784             assert_eq!(params.socket_type, socket_type);
785             assert_eq!(params.num_output_streams, num_output_streams);
786             assert_eq!(params.num_input_streams, num_input_streams);
787         }
788         fn check_failure(s: &str) {
789             s.parse::<Parameters>()
790                 .expect_err("parse should have failed");
791         }
792 
793         check_success(
794             "capture=false",
795             false,
796             CrasClientType::CRAS_CLIENT_TYPE_CROSVM,
797             CrasSocketType::Unified,
798             1,
799             1,
800         );
801         check_success(
802             "capture=true,client_type=crosvm",
803             true,
804             CrasClientType::CRAS_CLIENT_TYPE_CROSVM,
805             CrasSocketType::Unified,
806             1,
807             1,
808         );
809         check_success(
810             "capture=true,client_type=arcvm",
811             true,
812             CrasClientType::CRAS_CLIENT_TYPE_ARCVM,
813             CrasSocketType::Unified,
814             1,
815             1,
816         );
817         check_failure("capture=true,client_type=none");
818         check_success(
819             "socket_type=legacy",
820             false,
821             CrasClientType::CRAS_CLIENT_TYPE_CROSVM,
822             CrasSocketType::Legacy,
823             1,
824             1,
825         );
826         check_success(
827             "socket_type=unified",
828             false,
829             CrasClientType::CRAS_CLIENT_TYPE_CROSVM,
830             CrasSocketType::Unified,
831             1,
832             1,
833         );
834         check_success(
835             "capture=true,client_type=arcvm,num_output_streams=2,num_input_streams=3",
836             true,
837             CrasClientType::CRAS_CLIENT_TYPE_ARCVM,
838             CrasSocketType::Unified,
839             2,
840             3,
841         );
842     }
843 }
844