• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::collections::HashMap;
6 use std::collections::VecDeque;
7 use std::fs::File;
8 use std::io::Error as IOError;
9 use std::io::ErrorKind as IOErrorKind;
10 use std::io::Seek;
11 use std::io::SeekFrom;
12 use std::path::Path;
13 use std::path::PathBuf;
14 use std::sync::mpsc::channel;
15 use std::sync::mpsc::Receiver;
16 use std::sync::mpsc::RecvError;
17 use std::sync::mpsc::Sender;
18 use std::sync::Arc;
19 
20 use base::error;
21 use base::AsRawDescriptor;
22 use base::Error as BaseError;
23 use base::Event;
24 use base::EventToken;
25 use base::FromRawDescriptor;
26 use base::IntoRawDescriptor;
27 use base::MemoryMapping;
28 use base::MemoryMappingBuilder;
29 use base::MmapError;
30 use base::RawDescriptor;
31 use base::SafeDescriptor;
32 use base::ScmSocket;
33 use base::UnixSeqpacket;
34 use base::VolatileMemory;
35 use base::VolatileMemoryError;
36 use base::VolatileSlice;
37 use base::WaitContext;
38 use base::WorkerThread;
39 use remain::sorted;
40 use serde::Deserialize;
41 use serde::Serialize;
42 use sync::Mutex;
43 use thiserror::Error as ThisError;
44 use zerocopy::FromBytes;
45 use zerocopy::Immutable;
46 use zerocopy::IntoBytes;
47 use zerocopy::KnownLayout;
48 
49 use crate::virtio::snd::constants::*;
50 use crate::virtio::snd::layout::*;
51 use crate::virtio::snd::vios_backend::streams::StreamState;
52 
53 pub type Result<T> = std::result::Result<T, Error>;
54 
55 #[sorted]
56 #[derive(ThisError, Debug)]
57 pub enum Error {
58     #[error("Error memory mapping client_shm: {0}")]
59     BaseMmapError(BaseError),
60     #[error("Sender was dropped without sending buffer status, the recv thread may have exited")]
61     BufferStatusSenderLost(RecvError),
62     #[error("Command failed with status {0}")]
63     CommandFailed(u32),
64     #[error("Error duplicating file descriptor: {0}")]
65     DupError(BaseError),
66     #[error("Failed to create Recv event: {0}")]
67     EventCreateError(BaseError),
68     #[error("Failed to dup Recv event: {0}")]
69     EventDupError(BaseError),
70     #[error("Failed to signal event: {0}")]
71     EventWriteError(BaseError),
72     #[error("Failed to get size of tx shared memory: {0}")]
73     FileSizeError(IOError),
74     #[error("Error accessing guest's shared memory: {0}")]
75     GuestMmapError(MmapError),
76     #[error("No jack with id {0}")]
77     InvalidJackId(u32),
78     #[error("No stream with id {0}")]
79     InvalidStreamId(u32),
80     #[error("IO buffer operation failed: status = {0}")]
81     IOBufferError(u32),
82     #[error("No PCM streams available")]
83     NoStreamsAvailable,
84     #[error("Insuficient space for the new buffer in the queue's buffer area")]
85     OutOfSpace,
86     #[error("Platform not supported")]
87     PlatformNotSupported,
88     #[error("{0}")]
89     ProtocolError(ProtocolErrorKind),
90     #[error("Failed to connect to VioS server {1}: {0:?}")]
91     ServerConnectionError(IOError, PathBuf),
92     #[error("Failed to communicate with VioS server: {0:?}")]
93     ServerError(IOError),
94     #[error("Failed to communicate with VioS server: {0:?}")]
95     ServerIOError(IOError),
96     #[error("Error accessing VioS server's shared memory: {0}")]
97     ServerMmapError(MmapError),
98     #[error("Failed to duplicate UnixSeqpacket: {0}")]
99     UnixSeqpacketDupError(IOError),
100     #[error("Unsupported frame rate: {0}")]
101     UnsupportedFrameRate(u32),
102     #[error("Error accessing volatile memory: {0}")]
103     VolatileMemoryError(VolatileMemoryError),
104     #[error("Failed to create Recv thread's WaitContext: {0}")]
105     WaitContextCreateError(BaseError),
106     #[error("Error waiting for events")]
107     WaitError(BaseError),
108     #[error("Invalid operation for stream direction: {0}")]
109     WrongDirection(u8),
110     #[error("Set saved params should only be used while restoring the device")]
111     WrongSetParams,
112 }
113 
114 #[derive(ThisError, Debug)]
115 pub enum ProtocolErrorKind {
116     #[error("The server sent a config of the wrong size: {0}")]
117     UnexpectedConfigSize(usize),
118     #[error("Received {1} file descriptors from the server, expected {0}")]
119     UnexpectedNumberOfFileDescriptors(usize, usize), // expected, received
120     #[error("Server's version ({0}) doesn't match client's")]
121     VersionMismatch(u32),
122     #[error("Received a msg with an unexpected size: expected {0}, received {1}")]
123     UnexpectedMessageSize(usize, usize), // expected, received
124 }
125 
126 /// The client for the VioS backend
127 ///
128 /// Uses a protocol equivalent to virtio-snd over a shared memory file and a unix socket for
129 /// notifications. It's thread safe, it can be encapsulated in an Arc smart pointer and shared
130 /// between threads.
131 pub struct VioSClient {
132     // These mutexes should almost never be held simultaneously. If at some point they have to the
133     // locking order should match the order in which they are declared here.
134     config: VioSConfig,
135     jacks: Vec<virtio_snd_jack_info>,
136     streams: Vec<virtio_snd_pcm_info>,
137     chmaps: Vec<virtio_snd_chmap_info>,
138     // The control socket is used from multiple threads to send and wait for a reply, which needs
139     // to happen atomically, hence the need for a mutex instead of just sharing clones of the
140     // socket.
141     control_socket: Mutex<UnixSeqpacket>,
142     event_socket: UnixSeqpacket,
143     // These are thread safe and don't require locking
144     tx: IoBufferQueue,
145     rx: IoBufferQueue,
146     // This is accessed by the recv_thread and whatever thread processes the events
147     events: Arc<Mutex<VecDeque<virtio_snd_event>>>,
148     event_notifier: Event,
149     // These are accessed by the recv_thread and the stream threads
150     tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
151     rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
152     recv_thread_state: Arc<Mutex<ThreadFlags>>,
153     recv_thread: Mutex<Option<WorkerThread<Result<()>>>>,
154     // Params are required to be stored for snapshot/restore. On restore, we don't have the params
155     // locally available as the VM is started anew, so they need to be restored.
156     params: HashMap<u32, virtio_snd_pcm_set_params>,
157 }
158 
159 #[derive(Serialize, Deserialize)]
160 pub struct VioSClientSnapshot {
161     config: VioSConfig,
162     jacks: Vec<virtio_snd_jack_info>,
163     streams: Vec<virtio_snd_pcm_info>,
164     chmaps: Vec<virtio_snd_chmap_info>,
165     params: HashMap<u32, virtio_snd_pcm_set_params>,
166 }
167 
168 impl VioSClient {
169     /// Create a new client given the path to the audio server's socket.
try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient>170     pub fn try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient> {
171         let client_socket = ScmSocket::try_from(
172             UnixSeqpacket::connect(server.as_ref())
173                 .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?,
174         )
175         .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?;
176         let mut config: VioSConfig = Default::default();
177         const NUM_FDS: usize = 5;
178         let (recv_size, mut safe_fds) = client_socket
179             .recv_with_fds(config.as_mut_bytes(), NUM_FDS)
180             .map_err(Error::ServerError)?;
181 
182         if recv_size != std::mem::size_of::<VioSConfig>() {
183             return Err(Error::ProtocolError(
184                 ProtocolErrorKind::UnexpectedConfigSize(recv_size),
185             ));
186         }
187 
188         if config.version != VIOS_VERSION {
189             return Err(Error::ProtocolError(ProtocolErrorKind::VersionMismatch(
190                 config.version,
191             )));
192         }
193 
194         fn pop<T: FromRawDescriptor>(
195             safe_fds: &mut Vec<SafeDescriptor>,
196             expected: usize,
197             received: usize,
198         ) -> Result<T> {
199             // SAFETY:
200             // Safe because we transfer ownership from the SafeDescriptor to T
201             unsafe {
202                 Ok(T::from_raw_descriptor(
203                     safe_fds
204                         .pop()
205                         .ok_or(Error::ProtocolError(
206                             ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(
207                                 expected, received,
208                             ),
209                         ))?
210                         .into_raw_descriptor(),
211                 ))
212             }
213         }
214 
215         let fd_count = safe_fds.len();
216         let rx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
217         let tx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
218         let rx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
219         let tx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
220         let event_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
221 
222         if !safe_fds.is_empty() {
223             return Err(Error::ProtocolError(
224                 ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(NUM_FDS, fd_count),
225             ));
226         }
227 
228         let tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
229             Arc::new(Mutex::new(HashMap::new()));
230         let rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
231             Arc::new(Mutex::new(HashMap::new()));
232         let recv_thread_state = Arc::new(Mutex::new(ThreadFlags {
233             reporting_events: false,
234         }));
235 
236         let mut client = VioSClient {
237             config,
238             jacks: Vec::new(),
239             streams: Vec::new(),
240             chmaps: Vec::new(),
241             control_socket: Mutex::new(client_socket.into_inner()),
242             event_socket,
243             tx: IoBufferQueue::new(tx_socket, tx_shm_file)?,
244             rx: IoBufferQueue::new(rx_socket, rx_shm_file)?,
245             events: Arc::new(Mutex::new(VecDeque::new())),
246             event_notifier: Event::new().map_err(Error::EventCreateError)?,
247             tx_subscribers,
248             rx_subscribers,
249             recv_thread_state,
250             recv_thread: Mutex::new(None),
251             params: HashMap::new(),
252         };
253         client.request_and_cache_info()?;
254         Ok(client)
255     }
256 
257     /// Get the number of jacks
num_jacks(&self) -> u32258     pub fn num_jacks(&self) -> u32 {
259         self.config.jacks
260     }
261 
262     /// Get the number of pcm streams
num_streams(&self) -> u32263     pub fn num_streams(&self) -> u32 {
264         self.config.streams
265     }
266 
267     /// Get the number of channel maps
num_chmaps(&self) -> u32268     pub fn num_chmaps(&self) -> u32 {
269         self.config.chmaps
270     }
271 
272     /// Get the configuration information on a jack
jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info>273     pub fn jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info> {
274         self.jacks.get(idx as usize).copied()
275     }
276 
277     /// Get the configuration information on a pcm stream
stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info>278     pub fn stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info> {
279         self.streams.get(idx as usize).cloned()
280     }
281 
282     /// Get the configuration information on a channel map
chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info>283     pub fn chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info> {
284         self.chmaps.get(idx as usize).copied()
285     }
286 
287     /// Starts the background thread that receives release messages from the server. If the thread
288     /// was already started this function does nothing.
289     /// This thread must be started prior to attempting any stream IO operation or the calling
290     /// thread would block.
start_bg_thread(&self) -> Result<()>291     pub fn start_bg_thread(&self) -> Result<()> {
292         if self.recv_thread.lock().is_some() {
293             return Ok(());
294         }
295         let tx_socket = self.tx.try_clone_socket()?;
296         let rx_socket = self.rx.try_clone_socket()?;
297         let event_socket = self
298             .event_socket
299             .try_clone()
300             .map_err(Error::UnixSeqpacketDupError)?;
301         let mut opt = self.recv_thread.lock();
302         // The lock on recv_thread was released above to avoid holding more than one lock at a time
303         // while duplicating the fds. So we have to check the condition again.
304         if opt.is_none() {
305             *opt = Some(spawn_recv_thread(
306                 self.tx_subscribers.clone(),
307                 self.rx_subscribers.clone(),
308                 self.event_notifier
309                     .try_clone()
310                     .map_err(Error::EventDupError)?,
311                 self.events.clone(),
312                 self.recv_thread_state.clone(),
313                 tx_socket,
314                 rx_socket,
315                 event_socket,
316             ));
317         }
318         Ok(())
319     }
320 
321     /// Stops the background thread.
stop_bg_thread(&self) -> Result<()>322     pub fn stop_bg_thread(&self) -> Result<()> {
323         if let Some(recv_thread) = self.recv_thread.lock().take() {
324             recv_thread.stop()?;
325         }
326         Ok(())
327     }
328 
329     /// Gets an Event object that will trigger every time an event is received from the server
get_event_notifier(&self) -> Result<Event>330     pub fn get_event_notifier(&self) -> Result<Event> {
331         // Let the background thread know that there is at least one consumer of events
332         self.recv_thread_state.lock().reporting_events = true;
333         self.event_notifier
334             .try_clone()
335             .map_err(Error::EventDupError)
336     }
337 
338     /// Retrieves one event. Callers should have received a notification through the event notifier
339     /// before calling this function.
pop_event(&self) -> Option<virtio_snd_event>340     pub fn pop_event(&self) -> Option<virtio_snd_event> {
341         self.events.lock().pop_front()
342     }
343 
344     /// Remap a jack. This should only be called if the jack announces support for the operation
345     /// through the features field in the corresponding virtio_snd_jack_info struct.
remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()>346     pub fn remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()> {
347         if jack_id >= self.config.jacks {
348             return Err(Error::InvalidJackId(jack_id));
349         }
350         let msg = virtio_snd_jack_remap {
351             hdr: virtio_snd_jack_hdr {
352                 hdr: virtio_snd_hdr {
353                     code: VIRTIO_SND_R_JACK_REMAP.into(),
354                 },
355                 jack_id: jack_id.into(),
356             },
357             association: association.into(),
358             sequence: sequence.into(),
359         };
360         let control_socket_lock = self.control_socket.lock();
361         send_cmd(&control_socket_lock, msg)
362     }
363 
364     /// Configures a stream with the given parameters.
set_stream_parameters( &mut self, stream_id: u32, params: VioSStreamParams, ) -> Result<()>365     pub fn set_stream_parameters(
366         &mut self,
367         stream_id: u32,
368         params: VioSStreamParams,
369     ) -> Result<()> {
370         self.streams
371             .get(stream_id as usize)
372             .ok_or(Error::InvalidStreamId(stream_id))?;
373         let raw_params: virtio_snd_pcm_set_params = (stream_id, params).into();
374         // Old value is not needed and is dropped
375         let _ = self.params.insert(stream_id, raw_params);
376         let control_socket_lock = self.control_socket.lock();
377         send_cmd(&control_socket_lock, raw_params)
378     }
379 
380     /// Configures a stream with the given parameters.
set_stream_parameters_raw( &mut self, raw_params: virtio_snd_pcm_set_params, ) -> Result<()>381     pub fn set_stream_parameters_raw(
382         &mut self,
383         raw_params: virtio_snd_pcm_set_params,
384     ) -> Result<()> {
385         let stream_id = raw_params.hdr.stream_id.to_native();
386         // Old value is not needed and is dropped
387         let _ = self.params.insert(stream_id, raw_params);
388         self.streams
389             .get(stream_id as usize)
390             .ok_or(Error::InvalidStreamId(stream_id))?;
391         let control_socket_lock = self.control_socket.lock();
392         send_cmd(&control_socket_lock, raw_params)
393     }
394 
395     /// Send the PREPARE_STREAM command to the server.
prepare_stream(&self, stream_id: u32) -> Result<()>396     pub fn prepare_stream(&self, stream_id: u32) -> Result<()> {
397         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_PREPARE)
398     }
399 
400     /// Send the RELEASE_STREAM command to the server.
release_stream(&self, stream_id: u32) -> Result<()>401     pub fn release_stream(&self, stream_id: u32) -> Result<()> {
402         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_RELEASE)
403     }
404 
405     /// Send the START_STREAM command to the server.
start_stream(&self, stream_id: u32) -> Result<()>406     pub fn start_stream(&self, stream_id: u32) -> Result<()> {
407         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_START)
408     }
409 
410     /// Send the STOP_STREAM command to the server.
stop_stream(&self, stream_id: u32) -> Result<()>411     pub fn stop_stream(&self, stream_id: u32) -> Result<()> {
412         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_STOP)
413     }
414 
415     /// Send audio frames to the server. Blocks the calling thread until the server acknowledges
416     /// the data.
inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>( &self, stream_id: u32, size: usize, callback: Cb, ) -> Result<(u32, R)>417     pub fn inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>(
418         &self,
419         stream_id: u32,
420         size: usize,
421         callback: Cb,
422     ) -> Result<(u32, R)> {
423         if self
424             .streams
425             .get(stream_id as usize)
426             .ok_or(Error::InvalidStreamId(stream_id))?
427             .direction
428             != VIRTIO_SND_D_OUTPUT
429         {
430             return Err(Error::WrongDirection(VIRTIO_SND_D_OUTPUT));
431         }
432         self.streams
433             .get(stream_id as usize)
434             .ok_or(Error::InvalidStreamId(stream_id))?;
435         let dst_offset = self.tx.allocate_buffer(size)?;
436         let buffer_slice = self.tx.buffer_at(dst_offset, size)?;
437         let ret = callback(buffer_slice);
438         // Register to receive the status before sending the buffer to the server
439         let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
440         self.tx_subscribers.lock().insert(dst_offset, sender);
441         self.tx.send_buffer(stream_id, dst_offset, size)?;
442         let (_, latency) = await_status(receiver)?;
443         Ok((latency, ret))
444     }
445 
446     /// Request audio frames from the server. It blocks until the data is available.
request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>( &self, stream_id: u32, size: usize, callback: Cb, ) -> Result<(u32, R)>447     pub fn request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>(
448         &self,
449         stream_id: u32,
450         size: usize,
451         callback: Cb,
452     ) -> Result<(u32, R)> {
453         if self
454             .streams
455             .get(stream_id as usize)
456             .ok_or(Error::InvalidStreamId(stream_id))?
457             .direction
458             != VIRTIO_SND_D_INPUT
459         {
460             return Err(Error::WrongDirection(VIRTIO_SND_D_INPUT));
461         }
462         let src_offset = self.rx.allocate_buffer(size)?;
463         // Register to receive the status before sending the buffer to the server
464         let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
465         self.rx_subscribers.lock().insert(src_offset, sender);
466         self.rx.send_buffer(stream_id, src_offset, size)?;
467         // Make sure no mutexes are held while awaiting for the buffer to be written to
468         let (recv_size, latency) = await_status(receiver)?;
469         let buffer_slice = self.rx.buffer_at(src_offset, recv_size)?;
470         Ok((latency, callback(&buffer_slice)))
471     }
472 
473     /// Get a list of file descriptors used by the implementation.
keep_rds(&self) -> Vec<RawDescriptor>474     pub fn keep_rds(&self) -> Vec<RawDescriptor> {
475         let control_desc = self.control_socket.lock().as_raw_descriptor();
476         let event_desc = self.event_socket.as_raw_descriptor();
477         let event_notifier = self.event_notifier.as_raw_descriptor();
478         let mut ret = vec![control_desc, event_desc, event_notifier];
479         ret.append(&mut self.tx.keep_rds());
480         ret.append(&mut self.rx.keep_rds());
481         ret
482     }
483 
common_stream_op(&self, stream_id: u32, op: u32) -> Result<()>484     fn common_stream_op(&self, stream_id: u32, op: u32) -> Result<()> {
485         self.streams
486             .get(stream_id as usize)
487             .ok_or(Error::InvalidStreamId(stream_id))?;
488         let msg = virtio_snd_pcm_hdr {
489             hdr: virtio_snd_hdr { code: op.into() },
490             stream_id: stream_id.into(),
491         };
492         let control_socket_lock = self.control_socket.lock();
493         send_cmd(&control_socket_lock, msg)
494     }
495 
request_and_cache_info(&mut self) -> Result<()>496     fn request_and_cache_info(&mut self) -> Result<()> {
497         self.request_and_cache_jacks_info()?;
498         self.request_and_cache_streams_info()?;
499         self.request_and_cache_chmaps_info()?;
500         Ok(())
501     }
502 
request_info<T: IntoBytes + FromBytes + Default + Copy + Clone>( &self, req_code: u32, count: usize, ) -> Result<Vec<T>>503     fn request_info<T: IntoBytes + FromBytes + Default + Copy + Clone>(
504         &self,
505         req_code: u32,
506         count: usize,
507     ) -> Result<Vec<T>> {
508         let info_size = std::mem::size_of::<T>();
509         let status_size = std::mem::size_of::<virtio_snd_hdr>();
510         let req = virtio_snd_query_info {
511             hdr: virtio_snd_hdr {
512                 code: req_code.into(),
513             },
514             start_id: 0u32.into(),
515             count: (count as u32).into(),
516             size: (std::mem::size_of::<virtio_snd_query_info>() as u32).into(),
517         };
518         let control_socket_lock = self.control_socket.lock();
519         seq_socket_send(&control_socket_lock, req.as_bytes())?;
520         let reply = control_socket_lock
521             .recv_as_vec()
522             .map_err(Error::ServerIOError)?;
523         let mut status: virtio_snd_hdr = Default::default();
524         status
525             .as_mut_bytes()
526             .copy_from_slice(&reply[0..status_size]);
527         if status.code.to_native() != VIRTIO_SND_S_OK {
528             return Err(Error::CommandFailed(status.code.to_native()));
529         }
530         if reply.len() != status_size + count * info_size {
531             return Err(Error::ProtocolError(
532                 ProtocolErrorKind::UnexpectedMessageSize(count * info_size, reply.len()),
533             ));
534         }
535         Ok(reply[status_size..]
536             .chunks(info_size)
537             .map(|info_buffer| T::read_from_bytes(info_buffer).unwrap())
538             .collect())
539     }
540 
request_and_cache_jacks_info(&mut self) -> Result<()>541     fn request_and_cache_jacks_info(&mut self) -> Result<()> {
542         let num_jacks = self.config.jacks as usize;
543         if num_jacks == 0 {
544             return Ok(());
545         }
546         self.jacks = self.request_info(VIRTIO_SND_R_JACK_INFO, num_jacks)?;
547         Ok(())
548     }
549 
request_and_cache_streams_info(&mut self) -> Result<()>550     fn request_and_cache_streams_info(&mut self) -> Result<()> {
551         let num_streams = self.config.streams as usize;
552         if num_streams == 0 {
553             return Ok(());
554         }
555         self.streams = self.request_info(VIRTIO_SND_R_PCM_INFO, num_streams)?;
556         Ok(())
557     }
558 
request_and_cache_chmaps_info(&mut self) -> Result<()>559     fn request_and_cache_chmaps_info(&mut self) -> Result<()> {
560         let num_chmaps = self.config.chmaps as usize;
561         if num_chmaps == 0 {
562             return Ok(());
563         }
564         self.chmaps = self.request_info(VIRTIO_SND_R_CHMAP_INFO, num_chmaps)?;
565         Ok(())
566     }
567 
snapshot(&self) -> VioSClientSnapshot568     pub fn snapshot(&self) -> VioSClientSnapshot {
569         VioSClientSnapshot {
570             config: self.config,
571             jacks: self.jacks.clone(),
572             streams: self.streams.clone(),
573             chmaps: self.chmaps.clone(),
574             params: self.params.clone(),
575         }
576     }
577 
578     // Function called `restore` to signify it will happen as part of the snapshot/restore flow. No
579     // data is actually restored in the case of VioSClient.
restore(&mut self, data: VioSClientSnapshot) -> anyhow::Result<()>580     pub fn restore(&mut self, data: VioSClientSnapshot) -> anyhow::Result<()> {
581         anyhow::ensure!(
582             data.config == self.config,
583             "config doesn't match on restore: expected: {:?}, got: {:?}",
584             data.config,
585             self.config
586         );
587         self.jacks = data.jacks;
588         self.streams = data.streams;
589         self.chmaps = data.chmaps;
590         self.params = data.params;
591         Ok(())
592     }
593 
restore_stream(&mut self, stream_id: u32, state: StreamState) -> Result<()>594     pub fn restore_stream(&mut self, stream_id: u32, state: StreamState) -> Result<()> {
595         if let Some(params) = self.params.get(&stream_id).cloned() {
596             self.set_stream_parameters_raw(params)?;
597         }
598         match state {
599             StreamState::Started => {
600                 // If state != prepared, start will always fail.
601                 // As such, it is fine to only print the first error without returning, as the
602                 // second action will then fail.
603                 if let Err(e) = self.prepare_stream(stream_id) {
604                     error!("failed to prepare stream: {}", e);
605                 };
606                 self.start_stream(stream_id)
607             }
608             StreamState::Prepared => self.prepare_stream(stream_id),
609             // Nothing to do here
610             _ => Ok(()),
611         }
612     }
613 }
614 
615 #[derive(Clone, Copy)]
616 struct ThreadFlags {
617     reporting_events: bool,
618 }
619 
620 #[derive(EventToken)]
621 enum Token {
622     Notification,
623     TxBufferMsg,
624     RxBufferMsg,
625     EventMsg,
626 }
627 
recv_buffer_status_msg( socket: &UnixSeqpacket, subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, ) -> Result<()>628 fn recv_buffer_status_msg(
629     socket: &UnixSeqpacket,
630     subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
631 ) -> Result<()> {
632     let mut msg: IoStatusMsg = Default::default();
633     let size = socket
634         .recv(msg.as_mut_bytes())
635         .map_err(Error::ServerIOError)?;
636     if size != std::mem::size_of::<IoStatusMsg>() {
637         return Err(Error::ProtocolError(
638             ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<IoStatusMsg>(), size),
639         ));
640     }
641     let mut status = msg.status.status.into();
642     if status == u32::MAX {
643         // Anyone waiting for this would continue to wait for as long as status is
644         // u32::MAX
645         status -= 1;
646     }
647     let latency = msg.status.latency_bytes.into();
648     let offset = msg.buffer_offset as usize;
649     let consumed_len = msg.consumed_len as usize;
650     let promise_opt = subscribers.lock().remove(&offset);
651     match promise_opt {
652         None => error!(
653             "Received an unexpected buffer status message: {}. This is a BUG!!",
654             offset
655         ),
656         Some(sender) => {
657             if let Err(e) = sender.send(BufferReleaseMsg {
658                 status,
659                 latency,
660                 consumed_len,
661             }) {
662                 error!("Failed to notify waiting thread: {:?}", e);
663             }
664         }
665     }
666     Ok(())
667 }
668 
recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event>669 fn recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event> {
670     let mut msg: virtio_snd_event = Default::default();
671     let size = socket
672         .recv(msg.as_mut_bytes())
673         .map_err(Error::ServerIOError)?;
674     if size != std::mem::size_of::<virtio_snd_event>() {
675         return Err(Error::ProtocolError(
676             ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<virtio_snd_event>(), size),
677         ));
678     }
679     Ok(msg)
680 }
681 
spawn_recv_thread( tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, event_notifier: Event, event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>, state: Arc<Mutex<ThreadFlags>>, tx_socket: UnixSeqpacket, rx_socket: UnixSeqpacket, event_socket: UnixSeqpacket, ) -> WorkerThread<Result<()>>682 fn spawn_recv_thread(
683     tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
684     rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
685     event_notifier: Event,
686     event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>,
687     state: Arc<Mutex<ThreadFlags>>,
688     tx_socket: UnixSeqpacket,
689     rx_socket: UnixSeqpacket,
690     event_socket: UnixSeqpacket,
691 ) -> WorkerThread<Result<()>> {
692     WorkerThread::start("shm_vios", move |event| {
693         let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
694             (&tx_socket, Token::TxBufferMsg),
695             (&rx_socket, Token::RxBufferMsg),
696             (&event_socket, Token::EventMsg),
697             (&event, Token::Notification),
698         ])
699         .map_err(Error::WaitContextCreateError)?;
700         let mut running = true;
701         while running {
702             let events = wait_ctx.wait().map_err(Error::WaitError)?;
703             for evt in events {
704                 match evt.token {
705                     Token::TxBufferMsg => recv_buffer_status_msg(&tx_socket, &tx_subscribers)?,
706                     Token::RxBufferMsg => recv_buffer_status_msg(&rx_socket, &rx_subscribers)?,
707                     Token::EventMsg => {
708                         let evt = recv_event(&event_socket)?;
709                         let state_cpy = *state.lock();
710                         if state_cpy.reporting_events {
711                             event_queue.lock().push_back(evt);
712                             event_notifier.signal().map_err(Error::EventWriteError)?;
713                         } // else just drop the events
714                     }
715                     Token::Notification => {
716                         // Just consume the notification and check for termination on the next
717                         // iteration
718                         if let Err(e) = event.wait() {
719                             error!("Failed to consume notification from recv thread: {:?}", e);
720                         }
721                         running = false;
722                     }
723                 }
724             }
725         }
726         Ok(())
727     })
728 }
729 
await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)>730 fn await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)> {
731     let BufferReleaseMsg {
732         status,
733         latency,
734         consumed_len,
735     } = promise.recv().map_err(Error::BufferStatusSenderLost)?;
736     if status == VIRTIO_SND_S_OK {
737         Ok((consumed_len, latency))
738     } else {
739         Err(Error::IOBufferError(status))
740     }
741 }
742 
743 struct IoBufferQueue {
744     socket: UnixSeqpacket,
745     file: File,
746     mmap: MemoryMapping,
747     size: usize,
748     next: Mutex<usize>,
749 }
750 
751 impl IoBufferQueue {
new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue>752     fn new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue> {
753         let size = file.seek(SeekFrom::End(0)).map_err(Error::FileSizeError)? as usize;
754 
755         let mmap = MemoryMappingBuilder::new(size)
756             .from_file(&file)
757             .build()
758             .map_err(Error::ServerMmapError)?;
759 
760         Ok(IoBufferQueue {
761             socket,
762             file,
763             mmap,
764             size,
765             next: Mutex::new(0),
766         })
767     }
768 
allocate_buffer(&self, size: usize) -> Result<usize>769     fn allocate_buffer(&self, size: usize) -> Result<usize> {
770         if size > self.size {
771             return Err(Error::OutOfSpace);
772         }
773         let mut next_lock = self.next.lock();
774         let offset = if size > self.size - *next_lock {
775             // Can't fit the new buffer at the end of the area, so put it at the beginning
776             0
777         } else {
778             *next_lock
779         };
780         *next_lock = offset + size;
781         Ok(offset)
782     }
783 
buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice>784     fn buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice> {
785         self.mmap
786             .get_slice(offset, len)
787             .map_err(Error::VolatileMemoryError)
788     }
789 
try_clone_socket(&self) -> Result<UnixSeqpacket>790     fn try_clone_socket(&self) -> Result<UnixSeqpacket> {
791         self.socket
792             .try_clone()
793             .map_err(Error::UnixSeqpacketDupError)
794     }
795 
send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()>796     fn send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()> {
797         let msg = IoTransferMsg::new(stream_id, offset, size);
798         seq_socket_send(&self.socket, msg.as_bytes())
799     }
800 
keep_rds(&self) -> Vec<RawDescriptor>801     fn keep_rds(&self) -> Vec<RawDescriptor> {
802         vec![
803             self.file.as_raw_descriptor(),
804             self.socket.as_raw_descriptor(),
805         ]
806     }
807 }
808 
809 /// Groups the parameters used to configure a stream prior to using it.
810 pub struct VioSStreamParams {
811     pub buffer_bytes: u32,
812     pub period_bytes: u32,
813     pub features: u32,
814     pub channels: u8,
815     pub format: u8,
816     pub rate: u8,
817 }
818 
819 impl From<(u32, VioSStreamParams)> for virtio_snd_pcm_set_params {
from(val: (u32, VioSStreamParams)) -> Self820     fn from(val: (u32, VioSStreamParams)) -> Self {
821         virtio_snd_pcm_set_params {
822             hdr: virtio_snd_pcm_hdr {
823                 hdr: virtio_snd_hdr {
824                     code: VIRTIO_SND_R_PCM_SET_PARAMS.into(),
825                 },
826                 stream_id: val.0.into(),
827             },
828             buffer_bytes: val.1.buffer_bytes.into(),
829             period_bytes: val.1.period_bytes.into(),
830             features: val.1.features.into(),
831             channels: val.1.channels,
832             format: val.1.format,
833             rate: val.1.rate,
834             padding: 0u8,
835         }
836     }
837 }
838 
send_cmd<T: Immutable + IntoBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()>839 fn send_cmd<T: Immutable + IntoBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()> {
840     seq_socket_send(control_socket, data.as_bytes())?;
841     recv_cmd_status(control_socket)
842 }
843 
recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()>844 fn recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()> {
845     let mut status: virtio_snd_hdr = Default::default();
846     control_socket
847         .recv(status.as_mut_bytes())
848         .map_err(Error::ServerIOError)?;
849     if status.code.to_native() == VIRTIO_SND_S_OK {
850         Ok(())
851     } else {
852         Err(Error::CommandFailed(status.code.to_native()))
853     }
854 }
855 
seq_socket_send(socket: &UnixSeqpacket, data: &[u8]) -> Result<()>856 fn seq_socket_send(socket: &UnixSeqpacket, data: &[u8]) -> Result<()> {
857     loop {
858         let send_res = socket.send(data);
859         if let Err(e) = send_res {
860             match e.kind() {
861                 // Retry if interrupted
862                 IOErrorKind::Interrupted => continue,
863                 _ => return Err(Error::ServerIOError(e)),
864             }
865         }
866         // Success
867         break;
868     }
869     Ok(())
870 }
871 
872 const VIOS_VERSION: u32 = 2;
873 
874 #[repr(C)]
875 #[derive(
876     Copy,
877     Clone,
878     Default,
879     FromBytes,
880     Immutable,
881     IntoBytes,
882     KnownLayout,
883     Serialize,
884     Deserialize,
885     PartialEq,
886     Eq,
887     Debug,
888 )]
889 struct VioSConfig {
890     version: u32,
891     jacks: u32,
892     streams: u32,
893     chmaps: u32,
894 }
895 
896 struct BufferReleaseMsg {
897     status: u32,
898     latency: u32,
899     consumed_len: usize,
900 }
901 
902 #[repr(C)]
903 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
904 struct IoTransferMsg {
905     io_xfer: virtio_snd_pcm_xfer,
906     buffer_offset: u32,
907     buffer_len: u32,
908 }
909 
910 impl IoTransferMsg {
new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg911     fn new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg {
912         IoTransferMsg {
913             io_xfer: virtio_snd_pcm_xfer {
914                 stream_id: stream_id.into(),
915             },
916             buffer_offset: buffer_offset as u32,
917             buffer_len: buffer_len as u32,
918         }
919     }
920 }
921 
922 #[repr(C)]
923 #[derive(Copy, Clone, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
924 struct IoStatusMsg {
925     status: virtio_snd_pcm_status,
926     buffer_offset: u32,
927     consumed_len: u32,
928 }
929