• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::collections::HashMap;
6 use std::collections::VecDeque;
7 use std::fs::File;
8 use std::io::Error as IOError;
9 use std::io::ErrorKind as IOErrorKind;
10 use std::io::IoSliceMut;
11 use std::io::Seek;
12 use std::io::SeekFrom;
13 use std::os::unix::io::RawFd;
14 use std::path::Path;
15 use std::path::PathBuf;
16 use std::sync::mpsc::channel;
17 use std::sync::mpsc::Receiver;
18 use std::sync::mpsc::RecvError;
19 use std::sync::mpsc::Sender;
20 use std::sync::Arc;
21 
22 use base::error;
23 use base::AsRawDescriptor;
24 use base::Error as BaseError;
25 use base::Event;
26 use base::EventToken;
27 use base::FromRawDescriptor;
28 use base::IntoRawDescriptor;
29 use base::MemoryMapping;
30 use base::MemoryMappingBuilder;
31 use base::MmapError;
32 use base::RawDescriptor;
33 use base::SafeDescriptor;
34 use base::ScmSocket;
35 use base::UnixSeqpacket;
36 use base::WaitContext;
37 use base::WorkerThread;
38 use data_model::VolatileMemory;
39 use data_model::VolatileMemoryError;
40 use data_model::VolatileSlice;
41 use remain::sorted;
42 use sync::Mutex;
43 use thiserror::Error as ThisError;
44 use zerocopy::AsBytes;
45 use zerocopy::FromBytes;
46 
47 use crate::virtio::snd::constants::*;
48 use crate::virtio::snd::layout::*;
49 
50 pub type Result<T> = std::result::Result<T, Error>;
51 
52 #[sorted]
53 #[derive(ThisError, Debug)]
54 pub enum Error {
55     #[error("Error memory mapping client_shm: {0}")]
56     BaseMmapError(BaseError),
57     #[error("Sender was dropped without sending buffer status, the recv thread may have exited")]
58     BufferStatusSenderLost(RecvError),
59     #[error("Command failed with status {0}")]
60     CommandFailed(u32),
61     #[error("Error duplicating file descriptor: {0}")]
62     DupError(BaseError),
63     #[error("Failed to create Recv event: {0}")]
64     EventCreateError(BaseError),
65     #[error("Failed to dup Recv event: {0}")]
66     EventDupError(BaseError),
67     #[error("Failed to signal event: {0}")]
68     EventWriteError(BaseError),
69     #[error("Failed to get size of tx shared memory: {0}")]
70     FileSizeError(IOError),
71     #[error("Error accessing guest's shared memory: {0}")]
72     GuestMmapError(MmapError),
73     #[error("No jack with id {0}")]
74     InvalidJackId(u32),
75     #[error("No stream with id {0}")]
76     InvalidStreamId(u32),
77     #[error("IO buffer operation failed: status = {0}")]
78     IOBufferError(u32),
79     #[error("No PCM streams available")]
80     NoStreamsAvailable,
81     #[error("Insuficient space for the new buffer in the queue's buffer area")]
82     OutOfSpace,
83     #[error("Platform not supported")]
84     PlatformNotSupported,
85     #[error("{0}")]
86     ProtocolError(ProtocolErrorKind),
87     #[error("Failed to connect to VioS server {1}: {0:?}")]
88     ServerConnectionError(IOError, PathBuf),
89     #[error("Failed to communicate with VioS server: {0:?}")]
90     ServerError(BaseError),
91     #[error("Failed to communicate with VioS server: {0:?}")]
92     ServerIOError(IOError),
93     #[error("Error accessing VioS server's shared memory: {0}")]
94     ServerMmapError(MmapError),
95     #[error("Failed to duplicate UnixSeqpacket: {0}")]
96     UnixSeqpacketDupError(IOError),
97     #[error("Unsupported frame rate: {0}")]
98     UnsupportedFrameRate(u32),
99     #[error("Error accessing volatile memory: {0}")]
100     VolatileMemoryError(VolatileMemoryError),
101     #[error("Failed to create Recv thread's WaitContext: {0}")]
102     WaitContextCreateError(BaseError),
103     #[error("Error waiting for events")]
104     WaitError(BaseError),
105     #[error("Invalid operation for stream direction: {0}")]
106     WrongDirection(u8),
107 }
108 
109 #[derive(ThisError, Debug)]
110 pub enum ProtocolErrorKind {
111     #[error("The server sent a config of the wrong size: {0}")]
112     UnexpectedConfigSize(usize),
113     #[error("Received {1} file descriptors from the server, expected {0}")]
114     UnexpectedNumberOfFileDescriptors(usize, usize), // expected, received
115     #[error("Server's version ({0}) doesn't match client's")]
116     VersionMismatch(u32),
117     #[error("Received a msg with an unexpected size: expected {0}, received {1}")]
118     UnexpectedMessageSize(usize, usize), // expected, received
119 }
120 
121 /// The client for the VioS backend
122 ///
123 /// Uses a protocol equivalent to virtio-snd over a shared memory file and a unix socket for
124 /// notifications. It's thread safe, it can be encapsulated in an Arc smart pointer and shared
125 /// between threads.
126 pub struct VioSClient {
127     // These mutexes should almost never be held simultaneously. If at some point they have to the
128     // locking order should match the order in which they are declared here.
129     config: VioSConfig,
130     jacks: Vec<virtio_snd_jack_info>,
131     streams: Vec<virtio_snd_pcm_info>,
132     chmaps: Vec<virtio_snd_chmap_info>,
133     // The control socket is used from multiple threads to send and wait for a reply, which needs
134     // to happen atomically, hence the need for a mutex instead of just sharing clones of the
135     // socket.
136     control_socket: Mutex<UnixSeqpacket>,
137     event_socket: UnixSeqpacket,
138     // These are thread safe and don't require locking
139     tx: IoBufferQueue,
140     rx: IoBufferQueue,
141     // This is accessed by the recv_thread and whatever thread processes the events
142     events: Arc<Mutex<VecDeque<virtio_snd_event>>>,
143     event_notifier: Event,
144     // These are accessed by the recv_thread and the stream threads
145     tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
146     rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
147     recv_thread_state: Arc<Mutex<ThreadFlags>>,
148     recv_thread: Mutex<Option<WorkerThread<Result<()>>>>,
149 }
150 
151 impl VioSClient {
152     /// Create a new client given the path to the audio server's socket.
try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient>153     pub fn try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient> {
154         let client_socket = UnixSeqpacket::connect(server.as_ref())
155             .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?;
156         let mut config: VioSConfig = Default::default();
157         let mut fds: Vec<RawFd> = Vec::new();
158         const NUM_FDS: usize = 5;
159         fds.resize(NUM_FDS, 0);
160         let (recv_size, fd_count) = client_socket
161             .recv_with_fds(IoSliceMut::new(config.as_bytes_mut()), &mut fds)
162             .map_err(Error::ServerError)?;
163 
164         // Resize the vector to the actual number of file descriptors received and wrap them in
165         // SafeDescriptors to prevent leaks
166         fds.resize(fd_count, -1);
167         let mut safe_fds: Vec<SafeDescriptor> = fds
168             .into_iter()
169             .map(|fd| unsafe {
170                 // safe because the SafeDescriptor object completely assumes ownership of the fd.
171                 SafeDescriptor::from_raw_descriptor(fd)
172             })
173             .collect();
174 
175         if recv_size != std::mem::size_of::<VioSConfig>() {
176             return Err(Error::ProtocolError(
177                 ProtocolErrorKind::UnexpectedConfigSize(recv_size),
178             ));
179         }
180 
181         if config.version != VIOS_VERSION {
182             return Err(Error::ProtocolError(ProtocolErrorKind::VersionMismatch(
183                 config.version,
184             )));
185         }
186 
187         fn pop<T: FromRawDescriptor>(
188             safe_fds: &mut Vec<SafeDescriptor>,
189             expected: usize,
190             received: usize,
191         ) -> Result<T> {
192             unsafe {
193                 // Safe because we transfer ownership from the SafeDescriptor to T
194                 Ok(T::from_raw_descriptor(
195                     safe_fds
196                         .pop()
197                         .ok_or(Error::ProtocolError(
198                             ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(
199                                 expected, received,
200                             ),
201                         ))?
202                         .into_raw_descriptor(),
203                 ))
204             }
205         }
206 
207         let rx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
208         let tx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
209         let rx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
210         let tx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
211         let event_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
212 
213         if !safe_fds.is_empty() {
214             return Err(Error::ProtocolError(
215                 ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(NUM_FDS, fd_count),
216             ));
217         }
218 
219         let tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
220             Arc::new(Mutex::new(HashMap::new()));
221         let rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
222             Arc::new(Mutex::new(HashMap::new()));
223         let recv_thread_state = Arc::new(Mutex::new(ThreadFlags {
224             reporting_events: false,
225         }));
226 
227         let mut client = VioSClient {
228             config,
229             jacks: Vec::new(),
230             streams: Vec::new(),
231             chmaps: Vec::new(),
232             control_socket: Mutex::new(client_socket),
233             event_socket,
234             tx: IoBufferQueue::new(tx_socket, tx_shm_file)?,
235             rx: IoBufferQueue::new(rx_socket, rx_shm_file)?,
236             events: Arc::new(Mutex::new(VecDeque::new())),
237             event_notifier: Event::new().map_err(Error::EventCreateError)?,
238             tx_subscribers,
239             rx_subscribers,
240             recv_thread_state,
241             recv_thread: Mutex::new(None),
242         };
243         client.request_and_cache_info()?;
244         Ok(client)
245     }
246 
247     /// Get the number of jacks
num_jacks(&self) -> u32248     pub fn num_jacks(&self) -> u32 {
249         self.config.jacks
250     }
251 
252     /// Get the number of pcm streams
num_streams(&self) -> u32253     pub fn num_streams(&self) -> u32 {
254         self.config.streams
255     }
256 
257     /// Get the number of channel maps
num_chmaps(&self) -> u32258     pub fn num_chmaps(&self) -> u32 {
259         self.config.chmaps
260     }
261 
262     /// Get the configuration information on a jack
jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info>263     pub fn jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info> {
264         self.jacks.get(idx as usize).copied()
265     }
266 
267     /// Get the configuration information on a pcm stream
stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info>268     pub fn stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info> {
269         self.streams.get(idx as usize).cloned()
270     }
271 
272     /// Get the configuration information on a channel map
chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info>273     pub fn chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info> {
274         self.chmaps.get(idx as usize).copied()
275     }
276 
277     /// Starts the background thread that receives release messages from the server. If the thread
278     /// was already started this function does nothing.
279     /// This thread must be started prior to attempting any stream IO operation or the calling
280     /// thread would block.
start_bg_thread(&self) -> Result<()>281     pub fn start_bg_thread(&self) -> Result<()> {
282         if self.recv_thread.lock().is_some() {
283             return Ok(());
284         }
285         let tx_socket = self.tx.try_clone_socket()?;
286         let rx_socket = self.rx.try_clone_socket()?;
287         let event_socket = self
288             .event_socket
289             .try_clone()
290             .map_err(Error::UnixSeqpacketDupError)?;
291         let mut opt = self.recv_thread.lock();
292         // The lock on recv_thread was released above to avoid holding more than one lock at a time
293         // while duplicating the fds. So we have to check the condition again.
294         if opt.is_none() {
295             *opt = Some(spawn_recv_thread(
296                 self.tx_subscribers.clone(),
297                 self.rx_subscribers.clone(),
298                 self.event_notifier
299                     .try_clone()
300                     .map_err(Error::EventDupError)?,
301                 self.events.clone(),
302                 self.recv_thread_state.clone(),
303                 tx_socket,
304                 rx_socket,
305                 event_socket,
306             ));
307         }
308         Ok(())
309     }
310 
311     /// Stops the background thread.
stop_bg_thread(&self) -> Result<()>312     pub fn stop_bg_thread(&self) -> Result<()> {
313         if let Some(recv_thread) = self.recv_thread.lock().take() {
314             recv_thread.stop()?;
315         }
316         Ok(())
317     }
318 
319     /// Gets an Event object that will trigger every time an event is received from the server
get_event_notifier(&self) -> Result<Event>320     pub fn get_event_notifier(&self) -> Result<Event> {
321         // Let the background thread know that there is at least one consumer of events
322         self.recv_thread_state.lock().reporting_events = true;
323         self.event_notifier
324             .try_clone()
325             .map_err(Error::EventDupError)
326     }
327 
328     /// Retrieves one event. Callers should have received a notification through the event notifier
329     /// before calling this function.
pop_event(&self) -> Option<virtio_snd_event>330     pub fn pop_event(&self) -> Option<virtio_snd_event> {
331         self.events.lock().pop_front()
332     }
333 
334     /// Remap a jack. This should only be called if the jack announces support for the operation
335     /// through the features field in the corresponding virtio_snd_jack_info struct.
remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()>336     pub fn remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()> {
337         if jack_id >= self.config.jacks {
338             return Err(Error::InvalidJackId(jack_id));
339         }
340         let msg = virtio_snd_jack_remap {
341             hdr: virtio_snd_jack_hdr {
342                 hdr: virtio_snd_hdr {
343                     code: VIRTIO_SND_R_JACK_REMAP.into(),
344                 },
345                 jack_id: jack_id.into(),
346             },
347             association: association.into(),
348             sequence: sequence.into(),
349         };
350         let control_socket_lock = self.control_socket.lock();
351         send_cmd(&control_socket_lock, msg)
352     }
353 
354     /// Configures a stream with the given parameters.
set_stream_parameters(&self, stream_id: u32, params: VioSStreamParams) -> Result<()>355     pub fn set_stream_parameters(&self, stream_id: u32, params: VioSStreamParams) -> Result<()> {
356         self.streams
357             .get(stream_id as usize)
358             .ok_or(Error::InvalidStreamId(stream_id))?;
359         let raw_params: virtio_snd_pcm_set_params = (stream_id, params).into();
360         let control_socket_lock = self.control_socket.lock();
361         send_cmd(&control_socket_lock, raw_params)
362     }
363 
364     /// Configures a stream with the given parameters.
set_stream_parameters_raw(&self, raw_params: virtio_snd_pcm_set_params) -> Result<()>365     pub fn set_stream_parameters_raw(&self, raw_params: virtio_snd_pcm_set_params) -> Result<()> {
366         let stream_id = raw_params.hdr.stream_id.to_native();
367         self.streams
368             .get(stream_id as usize)
369             .ok_or(Error::InvalidStreamId(stream_id))?;
370         let control_socket_lock = self.control_socket.lock();
371         send_cmd(&control_socket_lock, raw_params)
372     }
373 
374     /// Send the PREPARE_STREAM command to the server.
prepare_stream(&self, stream_id: u32) -> Result<()>375     pub fn prepare_stream(&self, stream_id: u32) -> Result<()> {
376         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_PREPARE)
377     }
378 
379     /// Send the RELEASE_STREAM command to the server.
release_stream(&self, stream_id: u32) -> Result<()>380     pub fn release_stream(&self, stream_id: u32) -> Result<()> {
381         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_RELEASE)
382     }
383 
384     /// Send the START_STREAM command to the server.
start_stream(&self, stream_id: u32) -> Result<()>385     pub fn start_stream(&self, stream_id: u32) -> Result<()> {
386         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_START)
387     }
388 
389     /// Send the STOP_STREAM command to the server.
stop_stream(&self, stream_id: u32) -> Result<()>390     pub fn stop_stream(&self, stream_id: u32) -> Result<()> {
391         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_STOP)
392     }
393 
394     /// Send audio frames to the server. Blocks the calling thread until the server acknowledges
395     /// the data.
inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>( &self, stream_id: u32, size: usize, callback: Cb, ) -> Result<(u32, R)>396     pub fn inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>(
397         &self,
398         stream_id: u32,
399         size: usize,
400         callback: Cb,
401     ) -> Result<(u32, R)> {
402         if self
403             .streams
404             .get(stream_id as usize)
405             .ok_or(Error::InvalidStreamId(stream_id))?
406             .direction
407             != VIRTIO_SND_D_OUTPUT
408         {
409             return Err(Error::WrongDirection(VIRTIO_SND_D_OUTPUT));
410         }
411         self.streams
412             .get(stream_id as usize)
413             .ok_or(Error::InvalidStreamId(stream_id))?;
414         let dst_offset = self.tx.allocate_buffer(size)?;
415         let buffer_slice = self.tx.buffer_at(dst_offset, size)?;
416         let ret = callback(buffer_slice);
417         // Register to receive the status before sending the buffer to the server
418         let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
419         self.tx_subscribers.lock().insert(dst_offset, sender);
420         self.tx.send_buffer(stream_id, dst_offset, size)?;
421         let (_, latency) = await_status(receiver)?;
422         Ok((latency, ret))
423     }
424 
425     /// Request audio frames from the server. It blocks until the data is available.
request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>( &self, stream_id: u32, size: usize, callback: Cb, ) -> Result<(u32, R)>426     pub fn request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>(
427         &self,
428         stream_id: u32,
429         size: usize,
430         callback: Cb,
431     ) -> Result<(u32, R)> {
432         if self
433             .streams
434             .get(stream_id as usize)
435             .ok_or(Error::InvalidStreamId(stream_id))?
436             .direction
437             != VIRTIO_SND_D_INPUT
438         {
439             return Err(Error::WrongDirection(VIRTIO_SND_D_INPUT));
440         }
441         let src_offset = self.rx.allocate_buffer(size)?;
442         // Register to receive the status before sending the buffer to the server
443         let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
444         self.rx_subscribers.lock().insert(src_offset, sender);
445         self.rx.send_buffer(stream_id, src_offset, size)?;
446         // Make sure no mutexes are held while awaiting for the buffer to be written to
447         let (recv_size, latency) = await_status(receiver)?;
448         let buffer_slice = self.rx.buffer_at(src_offset, recv_size)?;
449         Ok((latency, callback(&buffer_slice)))
450     }
451 
452     /// Get a list of file descriptors used by the implementation.
keep_rds(&self) -> Vec<RawDescriptor>453     pub fn keep_rds(&self) -> Vec<RawDescriptor> {
454         let control_desc = self.control_socket.lock().as_raw_descriptor();
455         let event_desc = self.event_socket.as_raw_descriptor();
456         let event_notifier = self.event_notifier.as_raw_descriptor();
457         let mut ret = vec![control_desc, event_desc, event_notifier];
458         ret.append(&mut self.tx.keep_rds());
459         ret.append(&mut self.rx.keep_rds());
460         ret
461     }
462 
common_stream_op(&self, stream_id: u32, op: u32) -> Result<()>463     fn common_stream_op(&self, stream_id: u32, op: u32) -> Result<()> {
464         self.streams
465             .get(stream_id as usize)
466             .ok_or(Error::InvalidStreamId(stream_id))?;
467         let msg = virtio_snd_pcm_hdr {
468             hdr: virtio_snd_hdr { code: op.into() },
469             stream_id: stream_id.into(),
470         };
471         let control_socket_lock = self.control_socket.lock();
472         send_cmd(&control_socket_lock, msg)
473     }
474 
request_and_cache_info(&mut self) -> Result<()>475     fn request_and_cache_info(&mut self) -> Result<()> {
476         self.request_and_cache_jacks_info()?;
477         self.request_and_cache_streams_info()?;
478         self.request_and_cache_chmaps_info()?;
479         Ok(())
480     }
481 
request_info<T: AsBytes + FromBytes + Default + Copy + Clone>( &self, req_code: u32, count: usize, ) -> Result<Vec<T>>482     fn request_info<T: AsBytes + FromBytes + Default + Copy + Clone>(
483         &self,
484         req_code: u32,
485         count: usize,
486     ) -> Result<Vec<T>> {
487         let info_size = std::mem::size_of::<T>();
488         let status_size = std::mem::size_of::<virtio_snd_hdr>();
489         let req = virtio_snd_query_info {
490             hdr: virtio_snd_hdr {
491                 code: req_code.into(),
492             },
493             start_id: 0u32.into(),
494             count: (count as u32).into(),
495             size: (std::mem::size_of::<virtio_snd_query_info>() as u32).into(),
496         };
497         let control_socket_lock = self.control_socket.lock();
498         seq_socket_send(&control_socket_lock, req)?;
499         let reply = control_socket_lock
500             .recv_as_vec()
501             .map_err(Error::ServerIOError)?;
502         let mut status: virtio_snd_hdr = Default::default();
503         status
504             .as_bytes_mut()
505             .copy_from_slice(&reply[0..status_size]);
506         if status.code.to_native() != VIRTIO_SND_S_OK {
507             return Err(Error::CommandFailed(status.code.to_native()));
508         }
509         if reply.len() != status_size + count * info_size {
510             return Err(Error::ProtocolError(
511                 ProtocolErrorKind::UnexpectedMessageSize(count * info_size, reply.len()),
512             ));
513         }
514         Ok(reply[status_size..]
515             .chunks(info_size)
516             .map(|info_buffer| T::read_from(info_buffer).unwrap())
517             .collect())
518     }
519 
request_and_cache_jacks_info(&mut self) -> Result<()>520     fn request_and_cache_jacks_info(&mut self) -> Result<()> {
521         let num_jacks = self.config.jacks as usize;
522         if num_jacks == 0 {
523             return Ok(());
524         }
525         self.jacks = self.request_info(VIRTIO_SND_R_JACK_INFO, num_jacks)?;
526         Ok(())
527     }
528 
request_and_cache_streams_info(&mut self) -> Result<()>529     fn request_and_cache_streams_info(&mut self) -> Result<()> {
530         let num_streams = self.config.streams as usize;
531         if num_streams == 0 {
532             return Ok(());
533         }
534         self.streams = self.request_info(VIRTIO_SND_R_PCM_INFO, num_streams)?;
535         Ok(())
536     }
537 
request_and_cache_chmaps_info(&mut self) -> Result<()>538     fn request_and_cache_chmaps_info(&mut self) -> Result<()> {
539         let num_chmaps = self.config.chmaps as usize;
540         if num_chmaps == 0 {
541             return Ok(());
542         }
543         self.chmaps = self.request_info(VIRTIO_SND_R_CHMAP_INFO, num_chmaps)?;
544         Ok(())
545     }
546 }
547 
548 #[derive(Clone, Copy)]
549 struct ThreadFlags {
550     reporting_events: bool,
551 }
552 
553 #[derive(EventToken)]
554 enum Token {
555     Notification,
556     TxBufferMsg,
557     RxBufferMsg,
558     EventMsg,
559 }
560 
recv_buffer_status_msg( socket: &UnixSeqpacket, subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, ) -> Result<()>561 fn recv_buffer_status_msg(
562     socket: &UnixSeqpacket,
563     subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
564 ) -> Result<()> {
565     let mut msg: IoStatusMsg = Default::default();
566     let size = socket
567         .recv(msg.as_bytes_mut())
568         .map_err(Error::ServerIOError)?;
569     if size != std::mem::size_of::<IoStatusMsg>() {
570         return Err(Error::ProtocolError(
571             ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<IoStatusMsg>(), size),
572         ));
573     }
574     let mut status = msg.status.status.into();
575     if status == u32::MAX {
576         // Anyone waiting for this would continue to wait for as long as status is
577         // u32::MAX
578         status -= 1;
579     }
580     let latency = msg.status.latency_bytes.into();
581     let offset = msg.buffer_offset as usize;
582     let consumed_len = msg.consumed_len as usize;
583     let promise_opt = subscribers.lock().remove(&offset);
584     match promise_opt {
585         None => error!(
586             "Received an unexpected buffer status message: {}. This is a BUG!!",
587             offset
588         ),
589         Some(sender) => {
590             if let Err(e) = sender.send(BufferReleaseMsg {
591                 status,
592                 latency,
593                 consumed_len,
594             }) {
595                 error!("Failed to notify waiting thread: {:?}", e);
596             }
597         }
598     }
599     Ok(())
600 }
601 
recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event>602 fn recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event> {
603     let mut msg: virtio_snd_event = Default::default();
604     let size = socket
605         .recv(msg.as_bytes_mut())
606         .map_err(Error::ServerIOError)?;
607     if size != std::mem::size_of::<virtio_snd_event>() {
608         return Err(Error::ProtocolError(
609             ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<virtio_snd_event>(), size),
610         ));
611     }
612     Ok(msg)
613 }
614 
spawn_recv_thread( tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, event_notifier: Event, event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>, state: Arc<Mutex<ThreadFlags>>, tx_socket: UnixSeqpacket, rx_socket: UnixSeqpacket, event_socket: UnixSeqpacket, ) -> WorkerThread<Result<()>>615 fn spawn_recv_thread(
616     tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
617     rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
618     event_notifier: Event,
619     event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>,
620     state: Arc<Mutex<ThreadFlags>>,
621     tx_socket: UnixSeqpacket,
622     rx_socket: UnixSeqpacket,
623     event_socket: UnixSeqpacket,
624 ) -> WorkerThread<Result<()>> {
625     WorkerThread::start("shm_vios", move |event| {
626         let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
627             (&tx_socket, Token::TxBufferMsg),
628             (&rx_socket, Token::RxBufferMsg),
629             (&event_socket, Token::EventMsg),
630             (&event, Token::Notification),
631         ])
632         .map_err(Error::WaitContextCreateError)?;
633         let mut running = true;
634         while running {
635             let events = wait_ctx.wait().map_err(Error::WaitError)?;
636             for evt in events {
637                 match evt.token {
638                     Token::TxBufferMsg => recv_buffer_status_msg(&tx_socket, &tx_subscribers)?,
639                     Token::RxBufferMsg => recv_buffer_status_msg(&rx_socket, &rx_subscribers)?,
640                     Token::EventMsg => {
641                         let evt = recv_event(&event_socket)?;
642                         let state_cpy = *state.lock();
643                         if state_cpy.reporting_events {
644                             event_queue.lock().push_back(evt);
645                             event_notifier.signal().map_err(Error::EventWriteError)?;
646                         } // else just drop the events
647                     }
648                     Token::Notification => {
649                         // Just consume the notification and check for termination on the next
650                         // iteration
651                         if let Err(e) = event.wait() {
652                             error!("Failed to consume notification from recv thread: {:?}", e);
653                         }
654                         running = false;
655                     }
656                 }
657             }
658         }
659         Ok(())
660     })
661 }
662 
await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)>663 fn await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)> {
664     let BufferReleaseMsg {
665         status,
666         latency,
667         consumed_len,
668     } = promise.recv().map_err(Error::BufferStatusSenderLost)?;
669     if status == VIRTIO_SND_S_OK {
670         Ok((consumed_len, latency))
671     } else {
672         Err(Error::IOBufferError(status))
673     }
674 }
675 
676 struct IoBufferQueue {
677     socket: UnixSeqpacket,
678     file: File,
679     mmap: MemoryMapping,
680     size: usize,
681     next: Mutex<usize>,
682 }
683 
684 impl IoBufferQueue {
new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue>685     fn new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue> {
686         let size = file.seek(SeekFrom::End(0)).map_err(Error::FileSizeError)? as usize;
687 
688         let mmap = MemoryMappingBuilder::new(size)
689             .from_file(&file)
690             .build()
691             .map_err(Error::ServerMmapError)?;
692 
693         Ok(IoBufferQueue {
694             socket,
695             file,
696             mmap,
697             size,
698             next: Mutex::new(0),
699         })
700     }
701 
allocate_buffer(&self, size: usize) -> Result<usize>702     fn allocate_buffer(&self, size: usize) -> Result<usize> {
703         if size > self.size {
704             return Err(Error::OutOfSpace);
705         }
706         let mut next_lock = self.next.lock();
707         let offset = if size > self.size - *next_lock {
708             // Can't fit the new buffer at the end of the area, so put it at the beginning
709             0
710         } else {
711             *next_lock
712         };
713         *next_lock = offset + size;
714         Ok(offset)
715     }
716 
buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice>717     fn buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice> {
718         self.mmap
719             .get_slice(offset, len)
720             .map_err(Error::VolatileMemoryError)
721     }
722 
try_clone_socket(&self) -> Result<UnixSeqpacket>723     fn try_clone_socket(&self) -> Result<UnixSeqpacket> {
724         self.socket
725             .try_clone()
726             .map_err(Error::UnixSeqpacketDupError)
727     }
728 
send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()>729     fn send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()> {
730         let msg = IoTransferMsg::new(stream_id, offset, size);
731         seq_socket_send(&self.socket, msg)
732     }
733 
keep_rds(&self) -> Vec<RawDescriptor>734     fn keep_rds(&self) -> Vec<RawDescriptor> {
735         vec![
736             self.file.as_raw_descriptor(),
737             self.socket.as_raw_descriptor(),
738         ]
739     }
740 }
741 
742 /// Groups the parameters used to configure a stream prior to using it.
743 pub struct VioSStreamParams {
744     pub buffer_bytes: u32,
745     pub period_bytes: u32,
746     pub features: u32,
747     pub channels: u8,
748     pub format: u8,
749     pub rate: u8,
750 }
751 
752 impl From<(u32, VioSStreamParams)> for virtio_snd_pcm_set_params {
from(val: (u32, VioSStreamParams)) -> Self753     fn from(val: (u32, VioSStreamParams)) -> Self {
754         virtio_snd_pcm_set_params {
755             hdr: virtio_snd_pcm_hdr {
756                 hdr: virtio_snd_hdr {
757                     code: VIRTIO_SND_R_PCM_SET_PARAMS.into(),
758                 },
759                 stream_id: val.0.into(),
760             },
761             buffer_bytes: val.1.buffer_bytes.into(),
762             period_bytes: val.1.period_bytes.into(),
763             features: val.1.features.into(),
764             channels: val.1.channels,
765             format: val.1.format,
766             rate: val.1.rate,
767             padding: 0u8,
768         }
769     }
770 }
771 
send_cmd<T: AsBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()>772 fn send_cmd<T: AsBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()> {
773     seq_socket_send(control_socket, data)?;
774     recv_cmd_status(control_socket)
775 }
776 
recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()>777 fn recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()> {
778     let mut status: virtio_snd_hdr = Default::default();
779     control_socket
780         .recv(status.as_bytes_mut())
781         .map_err(Error::ServerIOError)?;
782     if status.code.to_native() == VIRTIO_SND_S_OK {
783         Ok(())
784     } else {
785         Err(Error::CommandFailed(status.code.to_native()))
786     }
787 }
788 
seq_socket_send<T: AsBytes>(socket: &UnixSeqpacket, data: T) -> Result<()>789 fn seq_socket_send<T: AsBytes>(socket: &UnixSeqpacket, data: T) -> Result<()> {
790     loop {
791         let send_res = socket.send(data.as_bytes());
792         if let Err(e) = send_res {
793             match e.kind() {
794                 // Retry if interrupted
795                 IOErrorKind::Interrupted => continue,
796                 _ => return Err(Error::ServerIOError(e)),
797             }
798         }
799         // Success
800         break;
801     }
802     Ok(())
803 }
804 
805 const VIOS_VERSION: u32 = 2;
806 
807 #[repr(C)]
808 #[derive(Copy, Clone, Default, AsBytes, FromBytes)]
809 struct VioSConfig {
810     version: u32,
811     jacks: u32,
812     streams: u32,
813     chmaps: u32,
814 }
815 
816 struct BufferReleaseMsg {
817     status: u32,
818     latency: u32,
819     consumed_len: usize,
820 }
821 
822 #[repr(C)]
823 #[derive(Copy, Clone, AsBytes, FromBytes)]
824 struct IoTransferMsg {
825     io_xfer: virtio_snd_pcm_xfer,
826     buffer_offset: u32,
827     buffer_len: u32,
828 }
829 
830 impl IoTransferMsg {
new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg831     fn new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg {
832         IoTransferMsg {
833             io_xfer: virtio_snd_pcm_xfer {
834                 stream_id: stream_id.into(),
835             },
836             buffer_offset: buffer_offset as u32,
837             buffer_len: buffer_len as u32,
838         }
839     }
840 }
841 
842 #[repr(C)]
843 #[derive(Copy, Clone, Default, AsBytes, FromBytes)]
844 struct IoStatusMsg {
845     status: virtio_snd_pcm_status,
846     buffer_offset: u32,
847     consumed_len: u32,
848 }
849