1 // Copyright 2020 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::collections::HashMap;
6 use std::collections::VecDeque;
7 use std::fs::File;
8 use std::io::Error as IOError;
9 use std::io::ErrorKind as IOErrorKind;
10 use std::io::Seek;
11 use std::io::SeekFrom;
12 use std::path::Path;
13 use std::path::PathBuf;
14 use std::sync::mpsc::channel;
15 use std::sync::mpsc::Receiver;
16 use std::sync::mpsc::RecvError;
17 use std::sync::mpsc::Sender;
18 use std::sync::Arc;
19
20 use base::error;
21 use base::AsRawDescriptor;
22 use base::Error as BaseError;
23 use base::Event;
24 use base::EventToken;
25 use base::FromRawDescriptor;
26 use base::IntoRawDescriptor;
27 use base::MemoryMapping;
28 use base::MemoryMappingBuilder;
29 use base::MmapError;
30 use base::RawDescriptor;
31 use base::SafeDescriptor;
32 use base::ScmSocket;
33 use base::UnixSeqpacket;
34 use base::VolatileMemory;
35 use base::VolatileMemoryError;
36 use base::VolatileSlice;
37 use base::WaitContext;
38 use base::WorkerThread;
39 use remain::sorted;
40 use serde::Deserialize;
41 use serde::Serialize;
42 use sync::Mutex;
43 use thiserror::Error as ThisError;
44 use zerocopy::FromBytes;
45 use zerocopy::Immutable;
46 use zerocopy::IntoBytes;
47 use zerocopy::KnownLayout;
48
49 use crate::virtio::snd::constants::*;
50 use crate::virtio::snd::layout::*;
51 use crate::virtio::snd::vios_backend::streams::StreamState;
52
53 pub type Result<T> = std::result::Result<T, Error>;
54
55 #[sorted]
56 #[derive(ThisError, Debug)]
57 pub enum Error {
58 #[error("Error memory mapping client_shm: {0}")]
59 BaseMmapError(BaseError),
60 #[error("Sender was dropped without sending buffer status, the recv thread may have exited")]
61 BufferStatusSenderLost(RecvError),
62 #[error("Command failed with status {0}")]
63 CommandFailed(u32),
64 #[error("Error duplicating file descriptor: {0}")]
65 DupError(BaseError),
66 #[error("Failed to create Recv event: {0}")]
67 EventCreateError(BaseError),
68 #[error("Failed to dup Recv event: {0}")]
69 EventDupError(BaseError),
70 #[error("Failed to signal event: {0}")]
71 EventWriteError(BaseError),
72 #[error("Failed to get size of tx shared memory: {0}")]
73 FileSizeError(IOError),
74 #[error("Error accessing guest's shared memory: {0}")]
75 GuestMmapError(MmapError),
76 #[error("No jack with id {0}")]
77 InvalidJackId(u32),
78 #[error("No stream with id {0}")]
79 InvalidStreamId(u32),
80 #[error("IO buffer operation failed: status = {0}")]
81 IOBufferError(u32),
82 #[error("No PCM streams available")]
83 NoStreamsAvailable,
84 #[error("Insuficient space for the new buffer in the queue's buffer area")]
85 OutOfSpace,
86 #[error("Platform not supported")]
87 PlatformNotSupported,
88 #[error("{0}")]
89 ProtocolError(ProtocolErrorKind),
90 #[error("Failed to connect to VioS server {1}: {0:?}")]
91 ServerConnectionError(IOError, PathBuf),
92 #[error("Failed to communicate with VioS server: {0:?}")]
93 ServerError(IOError),
94 #[error("Failed to communicate with VioS server: {0:?}")]
95 ServerIOError(IOError),
96 #[error("Error accessing VioS server's shared memory: {0}")]
97 ServerMmapError(MmapError),
98 #[error("Failed to duplicate UnixSeqpacket: {0}")]
99 UnixSeqpacketDupError(IOError),
100 #[error("Unsupported frame rate: {0}")]
101 UnsupportedFrameRate(u32),
102 #[error("Error accessing volatile memory: {0}")]
103 VolatileMemoryError(VolatileMemoryError),
104 #[error("Failed to create Recv thread's WaitContext: {0}")]
105 WaitContextCreateError(BaseError),
106 #[error("Error waiting for events")]
107 WaitError(BaseError),
108 #[error("Invalid operation for stream direction: {0}")]
109 WrongDirection(u8),
110 #[error("Set saved params should only be used while restoring the device")]
111 WrongSetParams,
112 }
113
114 #[derive(ThisError, Debug)]
115 pub enum ProtocolErrorKind {
116 #[error("The server sent a config of the wrong size: {0}")]
117 UnexpectedConfigSize(usize),
118 #[error("Received {1} file descriptors from the server, expected {0}")]
119 UnexpectedNumberOfFileDescriptors(usize, usize), // expected, received
120 #[error("Server's version ({0}) doesn't match client's")]
121 VersionMismatch(u32),
122 #[error("Received a msg with an unexpected size: expected {0}, received {1}")]
123 UnexpectedMessageSize(usize, usize), // expected, received
124 }
125
126 /// The client for the VioS backend
127 ///
128 /// Uses a protocol equivalent to virtio-snd over a shared memory file and a unix socket for
129 /// notifications. It's thread safe, it can be encapsulated in an Arc smart pointer and shared
130 /// between threads.
131 pub struct VioSClient {
132 // These mutexes should almost never be held simultaneously. If at some point they have to the
133 // locking order should match the order in which they are declared here.
134 config: VioSConfig,
135 jacks: Vec<virtio_snd_jack_info>,
136 streams: Vec<virtio_snd_pcm_info>,
137 chmaps: Vec<virtio_snd_chmap_info>,
138 // The control socket is used from multiple threads to send and wait for a reply, which needs
139 // to happen atomically, hence the need for a mutex instead of just sharing clones of the
140 // socket.
141 control_socket: Mutex<UnixSeqpacket>,
142 event_socket: UnixSeqpacket,
143 // These are thread safe and don't require locking
144 tx: IoBufferQueue,
145 rx: IoBufferQueue,
146 // This is accessed by the recv_thread and whatever thread processes the events
147 events: Arc<Mutex<VecDeque<virtio_snd_event>>>,
148 event_notifier: Event,
149 // These are accessed by the recv_thread and the stream threads
150 tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
151 rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
152 recv_thread_state: Arc<Mutex<ThreadFlags>>,
153 recv_thread: Mutex<Option<WorkerThread<Result<()>>>>,
154 // Params are required to be stored for snapshot/restore. On restore, we don't have the params
155 // locally available as the VM is started anew, so they need to be restored.
156 params: HashMap<u32, virtio_snd_pcm_set_params>,
157 }
158
159 #[derive(Serialize, Deserialize)]
160 pub struct VioSClientSnapshot {
161 config: VioSConfig,
162 jacks: Vec<virtio_snd_jack_info>,
163 streams: Vec<virtio_snd_pcm_info>,
164 chmaps: Vec<virtio_snd_chmap_info>,
165 params: HashMap<u32, virtio_snd_pcm_set_params>,
166 }
167
168 impl VioSClient {
169 /// Create a new client given the path to the audio server's socket.
try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient>170 pub fn try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient> {
171 let client_socket = ScmSocket::try_from(
172 UnixSeqpacket::connect(server.as_ref())
173 .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?,
174 )
175 .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?;
176 let mut config: VioSConfig = Default::default();
177 const NUM_FDS: usize = 5;
178 let (recv_size, mut safe_fds) = client_socket
179 .recv_with_fds(config.as_mut_bytes(), NUM_FDS)
180 .map_err(Error::ServerError)?;
181
182 if recv_size != std::mem::size_of::<VioSConfig>() {
183 return Err(Error::ProtocolError(
184 ProtocolErrorKind::UnexpectedConfigSize(recv_size),
185 ));
186 }
187
188 if config.version != VIOS_VERSION {
189 return Err(Error::ProtocolError(ProtocolErrorKind::VersionMismatch(
190 config.version,
191 )));
192 }
193
194 fn pop<T: FromRawDescriptor>(
195 safe_fds: &mut Vec<SafeDescriptor>,
196 expected: usize,
197 received: usize,
198 ) -> Result<T> {
199 // SAFETY:
200 // Safe because we transfer ownership from the SafeDescriptor to T
201 unsafe {
202 Ok(T::from_raw_descriptor(
203 safe_fds
204 .pop()
205 .ok_or(Error::ProtocolError(
206 ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(
207 expected, received,
208 ),
209 ))?
210 .into_raw_descriptor(),
211 ))
212 }
213 }
214
215 let fd_count = safe_fds.len();
216 let rx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
217 let tx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
218 let rx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
219 let tx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
220 let event_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
221
222 if !safe_fds.is_empty() {
223 return Err(Error::ProtocolError(
224 ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(NUM_FDS, fd_count),
225 ));
226 }
227
228 let tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
229 Arc::new(Mutex::new(HashMap::new()));
230 let rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
231 Arc::new(Mutex::new(HashMap::new()));
232 let recv_thread_state = Arc::new(Mutex::new(ThreadFlags {
233 reporting_events: false,
234 }));
235
236 let mut client = VioSClient {
237 config,
238 jacks: Vec::new(),
239 streams: Vec::new(),
240 chmaps: Vec::new(),
241 control_socket: Mutex::new(client_socket.into_inner()),
242 event_socket,
243 tx: IoBufferQueue::new(tx_socket, tx_shm_file)?,
244 rx: IoBufferQueue::new(rx_socket, rx_shm_file)?,
245 events: Arc::new(Mutex::new(VecDeque::new())),
246 event_notifier: Event::new().map_err(Error::EventCreateError)?,
247 tx_subscribers,
248 rx_subscribers,
249 recv_thread_state,
250 recv_thread: Mutex::new(None),
251 params: HashMap::new(),
252 };
253 client.request_and_cache_info()?;
254 Ok(client)
255 }
256
257 /// Get the number of jacks
num_jacks(&self) -> u32258 pub fn num_jacks(&self) -> u32 {
259 self.config.jacks
260 }
261
262 /// Get the number of pcm streams
num_streams(&self) -> u32263 pub fn num_streams(&self) -> u32 {
264 self.config.streams
265 }
266
267 /// Get the number of channel maps
num_chmaps(&self) -> u32268 pub fn num_chmaps(&self) -> u32 {
269 self.config.chmaps
270 }
271
272 /// Get the configuration information on a jack
jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info>273 pub fn jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info> {
274 self.jacks.get(idx as usize).copied()
275 }
276
277 /// Get the configuration information on a pcm stream
stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info>278 pub fn stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info> {
279 self.streams.get(idx as usize).cloned()
280 }
281
282 /// Get the configuration information on a channel map
chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info>283 pub fn chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info> {
284 self.chmaps.get(idx as usize).copied()
285 }
286
287 /// Starts the background thread that receives release messages from the server. If the thread
288 /// was already started this function does nothing.
289 /// This thread must be started prior to attempting any stream IO operation or the calling
290 /// thread would block.
start_bg_thread(&self) -> Result<()>291 pub fn start_bg_thread(&self) -> Result<()> {
292 if self.recv_thread.lock().is_some() {
293 return Ok(());
294 }
295 let tx_socket = self.tx.try_clone_socket()?;
296 let rx_socket = self.rx.try_clone_socket()?;
297 let event_socket = self
298 .event_socket
299 .try_clone()
300 .map_err(Error::UnixSeqpacketDupError)?;
301 let mut opt = self.recv_thread.lock();
302 // The lock on recv_thread was released above to avoid holding more than one lock at a time
303 // while duplicating the fds. So we have to check the condition again.
304 if opt.is_none() {
305 *opt = Some(spawn_recv_thread(
306 self.tx_subscribers.clone(),
307 self.rx_subscribers.clone(),
308 self.event_notifier
309 .try_clone()
310 .map_err(Error::EventDupError)?,
311 self.events.clone(),
312 self.recv_thread_state.clone(),
313 tx_socket,
314 rx_socket,
315 event_socket,
316 ));
317 }
318 Ok(())
319 }
320
321 /// Stops the background thread.
stop_bg_thread(&self) -> Result<()>322 pub fn stop_bg_thread(&self) -> Result<()> {
323 if let Some(recv_thread) = self.recv_thread.lock().take() {
324 recv_thread.stop()?;
325 }
326 Ok(())
327 }
328
329 /// Gets an Event object that will trigger every time an event is received from the server
get_event_notifier(&self) -> Result<Event>330 pub fn get_event_notifier(&self) -> Result<Event> {
331 // Let the background thread know that there is at least one consumer of events
332 self.recv_thread_state.lock().reporting_events = true;
333 self.event_notifier
334 .try_clone()
335 .map_err(Error::EventDupError)
336 }
337
338 /// Retrieves one event. Callers should have received a notification through the event notifier
339 /// before calling this function.
pop_event(&self) -> Option<virtio_snd_event>340 pub fn pop_event(&self) -> Option<virtio_snd_event> {
341 self.events.lock().pop_front()
342 }
343
344 /// Remap a jack. This should only be called if the jack announces support for the operation
345 /// through the features field in the corresponding virtio_snd_jack_info struct.
remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()>346 pub fn remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()> {
347 if jack_id >= self.config.jacks {
348 return Err(Error::InvalidJackId(jack_id));
349 }
350 let msg = virtio_snd_jack_remap {
351 hdr: virtio_snd_jack_hdr {
352 hdr: virtio_snd_hdr {
353 code: VIRTIO_SND_R_JACK_REMAP.into(),
354 },
355 jack_id: jack_id.into(),
356 },
357 association: association.into(),
358 sequence: sequence.into(),
359 };
360 let control_socket_lock = self.control_socket.lock();
361 send_cmd(&control_socket_lock, msg)
362 }
363
364 /// Configures a stream with the given parameters.
set_stream_parameters( &mut self, stream_id: u32, params: VioSStreamParams, ) -> Result<()>365 pub fn set_stream_parameters(
366 &mut self,
367 stream_id: u32,
368 params: VioSStreamParams,
369 ) -> Result<()> {
370 self.streams
371 .get(stream_id as usize)
372 .ok_or(Error::InvalidStreamId(stream_id))?;
373 let raw_params: virtio_snd_pcm_set_params = (stream_id, params).into();
374 // Old value is not needed and is dropped
375 let _ = self.params.insert(stream_id, raw_params);
376 let control_socket_lock = self.control_socket.lock();
377 send_cmd(&control_socket_lock, raw_params)
378 }
379
380 /// Configures a stream with the given parameters.
set_stream_parameters_raw( &mut self, raw_params: virtio_snd_pcm_set_params, ) -> Result<()>381 pub fn set_stream_parameters_raw(
382 &mut self,
383 raw_params: virtio_snd_pcm_set_params,
384 ) -> Result<()> {
385 let stream_id = raw_params.hdr.stream_id.to_native();
386 // Old value is not needed and is dropped
387 let _ = self.params.insert(stream_id, raw_params);
388 self.streams
389 .get(stream_id as usize)
390 .ok_or(Error::InvalidStreamId(stream_id))?;
391 let control_socket_lock = self.control_socket.lock();
392 send_cmd(&control_socket_lock, raw_params)
393 }
394
395 /// Send the PREPARE_STREAM command to the server.
prepare_stream(&self, stream_id: u32) -> Result<()>396 pub fn prepare_stream(&self, stream_id: u32) -> Result<()> {
397 self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_PREPARE)
398 }
399
400 /// Send the RELEASE_STREAM command to the server.
release_stream(&self, stream_id: u32) -> Result<()>401 pub fn release_stream(&self, stream_id: u32) -> Result<()> {
402 self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_RELEASE)
403 }
404
405 /// Send the START_STREAM command to the server.
start_stream(&self, stream_id: u32) -> Result<()>406 pub fn start_stream(&self, stream_id: u32) -> Result<()> {
407 self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_START)
408 }
409
410 /// Send the STOP_STREAM command to the server.
stop_stream(&self, stream_id: u32) -> Result<()>411 pub fn stop_stream(&self, stream_id: u32) -> Result<()> {
412 self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_STOP)
413 }
414
415 /// Send audio frames to the server. Blocks the calling thread until the server acknowledges
416 /// the data.
inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>( &self, stream_id: u32, size: usize, callback: Cb, ) -> Result<(u32, R)>417 pub fn inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>(
418 &self,
419 stream_id: u32,
420 size: usize,
421 callback: Cb,
422 ) -> Result<(u32, R)> {
423 if self
424 .streams
425 .get(stream_id as usize)
426 .ok_or(Error::InvalidStreamId(stream_id))?
427 .direction
428 != VIRTIO_SND_D_OUTPUT
429 {
430 return Err(Error::WrongDirection(VIRTIO_SND_D_OUTPUT));
431 }
432 self.streams
433 .get(stream_id as usize)
434 .ok_or(Error::InvalidStreamId(stream_id))?;
435 let dst_offset = self.tx.allocate_buffer(size)?;
436 let buffer_slice = self.tx.buffer_at(dst_offset, size)?;
437 let ret = callback(buffer_slice);
438 // Register to receive the status before sending the buffer to the server
439 let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
440 self.tx_subscribers.lock().insert(dst_offset, sender);
441 self.tx.send_buffer(stream_id, dst_offset, size)?;
442 let (_, latency) = await_status(receiver)?;
443 Ok((latency, ret))
444 }
445
446 /// Request audio frames from the server. It blocks until the data is available.
request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>( &self, stream_id: u32, size: usize, callback: Cb, ) -> Result<(u32, R)>447 pub fn request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>(
448 &self,
449 stream_id: u32,
450 size: usize,
451 callback: Cb,
452 ) -> Result<(u32, R)> {
453 if self
454 .streams
455 .get(stream_id as usize)
456 .ok_or(Error::InvalidStreamId(stream_id))?
457 .direction
458 != VIRTIO_SND_D_INPUT
459 {
460 return Err(Error::WrongDirection(VIRTIO_SND_D_INPUT));
461 }
462 let src_offset = self.rx.allocate_buffer(size)?;
463 // Register to receive the status before sending the buffer to the server
464 let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
465 self.rx_subscribers.lock().insert(src_offset, sender);
466 self.rx.send_buffer(stream_id, src_offset, size)?;
467 // Make sure no mutexes are held while awaiting for the buffer to be written to
468 let (recv_size, latency) = await_status(receiver)?;
469 let buffer_slice = self.rx.buffer_at(src_offset, recv_size)?;
470 Ok((latency, callback(&buffer_slice)))
471 }
472
473 /// Get a list of file descriptors used by the implementation.
keep_rds(&self) -> Vec<RawDescriptor>474 pub fn keep_rds(&self) -> Vec<RawDescriptor> {
475 let control_desc = self.control_socket.lock().as_raw_descriptor();
476 let event_desc = self.event_socket.as_raw_descriptor();
477 let event_notifier = self.event_notifier.as_raw_descriptor();
478 let mut ret = vec![control_desc, event_desc, event_notifier];
479 ret.append(&mut self.tx.keep_rds());
480 ret.append(&mut self.rx.keep_rds());
481 ret
482 }
483
common_stream_op(&self, stream_id: u32, op: u32) -> Result<()>484 fn common_stream_op(&self, stream_id: u32, op: u32) -> Result<()> {
485 self.streams
486 .get(stream_id as usize)
487 .ok_or(Error::InvalidStreamId(stream_id))?;
488 let msg = virtio_snd_pcm_hdr {
489 hdr: virtio_snd_hdr { code: op.into() },
490 stream_id: stream_id.into(),
491 };
492 let control_socket_lock = self.control_socket.lock();
493 send_cmd(&control_socket_lock, msg)
494 }
495
request_and_cache_info(&mut self) -> Result<()>496 fn request_and_cache_info(&mut self) -> Result<()> {
497 self.request_and_cache_jacks_info()?;
498 self.request_and_cache_streams_info()?;
499 self.request_and_cache_chmaps_info()?;
500 Ok(())
501 }
502
request_info<T: IntoBytes + FromBytes + Default + Copy + Clone>( &self, req_code: u32, count: usize, ) -> Result<Vec<T>>503 fn request_info<T: IntoBytes + FromBytes + Default + Copy + Clone>(
504 &self,
505 req_code: u32,
506 count: usize,
507 ) -> Result<Vec<T>> {
508 let info_size = std::mem::size_of::<T>();
509 let status_size = std::mem::size_of::<virtio_snd_hdr>();
510 let req = virtio_snd_query_info {
511 hdr: virtio_snd_hdr {
512 code: req_code.into(),
513 },
514 start_id: 0u32.into(),
515 count: (count as u32).into(),
516 size: (std::mem::size_of::<virtio_snd_query_info>() as u32).into(),
517 };
518 let control_socket_lock = self.control_socket.lock();
519 seq_socket_send(&control_socket_lock, req.as_bytes())?;
520 let reply = control_socket_lock
521 .recv_as_vec()
522 .map_err(Error::ServerIOError)?;
523 let mut status: virtio_snd_hdr = Default::default();
524 status
525 .as_mut_bytes()
526 .copy_from_slice(&reply[0..status_size]);
527 if status.code.to_native() != VIRTIO_SND_S_OK {
528 return Err(Error::CommandFailed(status.code.to_native()));
529 }
530 if reply.len() != status_size + count * info_size {
531 return Err(Error::ProtocolError(
532 ProtocolErrorKind::UnexpectedMessageSize(count * info_size, reply.len()),
533 ));
534 }
535 Ok(reply[status_size..]
536 .chunks(info_size)
537 .map(|info_buffer| T::read_from_bytes(info_buffer).unwrap())
538 .collect())
539 }
540
request_and_cache_jacks_info(&mut self) -> Result<()>541 fn request_and_cache_jacks_info(&mut self) -> Result<()> {
542 let num_jacks = self.config.jacks as usize;
543 if num_jacks == 0 {
544 return Ok(());
545 }
546 self.jacks = self.request_info(VIRTIO_SND_R_JACK_INFO, num_jacks)?;
547 Ok(())
548 }
549
request_and_cache_streams_info(&mut self) -> Result<()>550 fn request_and_cache_streams_info(&mut self) -> Result<()> {
551 let num_streams = self.config.streams as usize;
552 if num_streams == 0 {
553 return Ok(());
554 }
555 self.streams = self.request_info(VIRTIO_SND_R_PCM_INFO, num_streams)?;
556 Ok(())
557 }
558
request_and_cache_chmaps_info(&mut self) -> Result<()>559 fn request_and_cache_chmaps_info(&mut self) -> Result<()> {
560 let num_chmaps = self.config.chmaps as usize;
561 if num_chmaps == 0 {
562 return Ok(());
563 }
564 self.chmaps = self.request_info(VIRTIO_SND_R_CHMAP_INFO, num_chmaps)?;
565 Ok(())
566 }
567
snapshot(&self) -> VioSClientSnapshot568 pub fn snapshot(&self) -> VioSClientSnapshot {
569 VioSClientSnapshot {
570 config: self.config,
571 jacks: self.jacks.clone(),
572 streams: self.streams.clone(),
573 chmaps: self.chmaps.clone(),
574 params: self.params.clone(),
575 }
576 }
577
578 // Function called `restore` to signify it will happen as part of the snapshot/restore flow. No
579 // data is actually restored in the case of VioSClient.
restore(&mut self, data: VioSClientSnapshot) -> anyhow::Result<()>580 pub fn restore(&mut self, data: VioSClientSnapshot) -> anyhow::Result<()> {
581 anyhow::ensure!(
582 data.config == self.config,
583 "config doesn't match on restore: expected: {:?}, got: {:?}",
584 data.config,
585 self.config
586 );
587 self.jacks = data.jacks;
588 self.streams = data.streams;
589 self.chmaps = data.chmaps;
590 self.params = data.params;
591 Ok(())
592 }
593
restore_stream(&mut self, stream_id: u32, state: StreamState) -> Result<()>594 pub fn restore_stream(&mut self, stream_id: u32, state: StreamState) -> Result<()> {
595 if let Some(params) = self.params.get(&stream_id).cloned() {
596 self.set_stream_parameters_raw(params)?;
597 }
598 match state {
599 StreamState::Started => {
600 // If state != prepared, start will always fail.
601 // As such, it is fine to only print the first error without returning, as the
602 // second action will then fail.
603 if let Err(e) = self.prepare_stream(stream_id) {
604 error!("failed to prepare stream: {}", e);
605 };
606 self.start_stream(stream_id)
607 }
608 StreamState::Prepared => self.prepare_stream(stream_id),
609 // Nothing to do here
610 _ => Ok(()),
611 }
612 }
613 }
614
615 #[derive(Clone, Copy)]
616 struct ThreadFlags {
617 reporting_events: bool,
618 }
619
620 #[derive(EventToken)]
621 enum Token {
622 Notification,
623 TxBufferMsg,
624 RxBufferMsg,
625 EventMsg,
626 }
627
recv_buffer_status_msg( socket: &UnixSeqpacket, subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, ) -> Result<()>628 fn recv_buffer_status_msg(
629 socket: &UnixSeqpacket,
630 subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
631 ) -> Result<()> {
632 let mut msg: IoStatusMsg = Default::default();
633 let size = socket
634 .recv(msg.as_mut_bytes())
635 .map_err(Error::ServerIOError)?;
636 if size != std::mem::size_of::<IoStatusMsg>() {
637 return Err(Error::ProtocolError(
638 ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<IoStatusMsg>(), size),
639 ));
640 }
641 let mut status = msg.status.status.into();
642 if status == u32::MAX {
643 // Anyone waiting for this would continue to wait for as long as status is
644 // u32::MAX
645 status -= 1;
646 }
647 let latency = msg.status.latency_bytes.into();
648 let offset = msg.buffer_offset as usize;
649 let consumed_len = msg.consumed_len as usize;
650 let promise_opt = subscribers.lock().remove(&offset);
651 match promise_opt {
652 None => error!(
653 "Received an unexpected buffer status message: {}. This is a BUG!!",
654 offset
655 ),
656 Some(sender) => {
657 if let Err(e) = sender.send(BufferReleaseMsg {
658 status,
659 latency,
660 consumed_len,
661 }) {
662 error!("Failed to notify waiting thread: {:?}", e);
663 }
664 }
665 }
666 Ok(())
667 }
668
recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event>669 fn recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event> {
670 let mut msg: virtio_snd_event = Default::default();
671 let size = socket
672 .recv(msg.as_mut_bytes())
673 .map_err(Error::ServerIOError)?;
674 if size != std::mem::size_of::<virtio_snd_event>() {
675 return Err(Error::ProtocolError(
676 ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<virtio_snd_event>(), size),
677 ));
678 }
679 Ok(msg)
680 }
681
spawn_recv_thread( tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, event_notifier: Event, event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>, state: Arc<Mutex<ThreadFlags>>, tx_socket: UnixSeqpacket, rx_socket: UnixSeqpacket, event_socket: UnixSeqpacket, ) -> WorkerThread<Result<()>>682 fn spawn_recv_thread(
683 tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
684 rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
685 event_notifier: Event,
686 event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>,
687 state: Arc<Mutex<ThreadFlags>>,
688 tx_socket: UnixSeqpacket,
689 rx_socket: UnixSeqpacket,
690 event_socket: UnixSeqpacket,
691 ) -> WorkerThread<Result<()>> {
692 WorkerThread::start("shm_vios", move |event| {
693 let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
694 (&tx_socket, Token::TxBufferMsg),
695 (&rx_socket, Token::RxBufferMsg),
696 (&event_socket, Token::EventMsg),
697 (&event, Token::Notification),
698 ])
699 .map_err(Error::WaitContextCreateError)?;
700 let mut running = true;
701 while running {
702 let events = wait_ctx.wait().map_err(Error::WaitError)?;
703 for evt in events {
704 match evt.token {
705 Token::TxBufferMsg => recv_buffer_status_msg(&tx_socket, &tx_subscribers)?,
706 Token::RxBufferMsg => recv_buffer_status_msg(&rx_socket, &rx_subscribers)?,
707 Token::EventMsg => {
708 let evt = recv_event(&event_socket)?;
709 let state_cpy = *state.lock();
710 if state_cpy.reporting_events {
711 event_queue.lock().push_back(evt);
712 event_notifier.signal().map_err(Error::EventWriteError)?;
713 } // else just drop the events
714 }
715 Token::Notification => {
716 // Just consume the notification and check for termination on the next
717 // iteration
718 if let Err(e) = event.wait() {
719 error!("Failed to consume notification from recv thread: {:?}", e);
720 }
721 running = false;
722 }
723 }
724 }
725 }
726 Ok(())
727 })
728 }
729
await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)>730 fn await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)> {
731 let BufferReleaseMsg {
732 status,
733 latency,
734 consumed_len,
735 } = promise.recv().map_err(Error::BufferStatusSenderLost)?;
736 if status == VIRTIO_SND_S_OK {
737 Ok((consumed_len, latency))
738 } else {
739 Err(Error::IOBufferError(status))
740 }
741 }
742
743 struct IoBufferQueue {
744 socket: UnixSeqpacket,
745 file: File,
746 mmap: MemoryMapping,
747 size: usize,
748 next: Mutex<usize>,
749 }
750
751 impl IoBufferQueue {
new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue>752 fn new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue> {
753 let size = file.seek(SeekFrom::End(0)).map_err(Error::FileSizeError)? as usize;
754
755 let mmap = MemoryMappingBuilder::new(size)
756 .from_file(&file)
757 .build()
758 .map_err(Error::ServerMmapError)?;
759
760 Ok(IoBufferQueue {
761 socket,
762 file,
763 mmap,
764 size,
765 next: Mutex::new(0),
766 })
767 }
768
allocate_buffer(&self, size: usize) -> Result<usize>769 fn allocate_buffer(&self, size: usize) -> Result<usize> {
770 if size > self.size {
771 return Err(Error::OutOfSpace);
772 }
773 let mut next_lock = self.next.lock();
774 let offset = if size > self.size - *next_lock {
775 // Can't fit the new buffer at the end of the area, so put it at the beginning
776 0
777 } else {
778 *next_lock
779 };
780 *next_lock = offset + size;
781 Ok(offset)
782 }
783
buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice>784 fn buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice> {
785 self.mmap
786 .get_slice(offset, len)
787 .map_err(Error::VolatileMemoryError)
788 }
789
try_clone_socket(&self) -> Result<UnixSeqpacket>790 fn try_clone_socket(&self) -> Result<UnixSeqpacket> {
791 self.socket
792 .try_clone()
793 .map_err(Error::UnixSeqpacketDupError)
794 }
795
send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()>796 fn send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()> {
797 let msg = IoTransferMsg::new(stream_id, offset, size);
798 seq_socket_send(&self.socket, msg.as_bytes())
799 }
800
keep_rds(&self) -> Vec<RawDescriptor>801 fn keep_rds(&self) -> Vec<RawDescriptor> {
802 vec![
803 self.file.as_raw_descriptor(),
804 self.socket.as_raw_descriptor(),
805 ]
806 }
807 }
808
809 /// Groups the parameters used to configure a stream prior to using it.
810 pub struct VioSStreamParams {
811 pub buffer_bytes: u32,
812 pub period_bytes: u32,
813 pub features: u32,
814 pub channels: u8,
815 pub format: u8,
816 pub rate: u8,
817 }
818
819 impl From<(u32, VioSStreamParams)> for virtio_snd_pcm_set_params {
from(val: (u32, VioSStreamParams)) -> Self820 fn from(val: (u32, VioSStreamParams)) -> Self {
821 virtio_snd_pcm_set_params {
822 hdr: virtio_snd_pcm_hdr {
823 hdr: virtio_snd_hdr {
824 code: VIRTIO_SND_R_PCM_SET_PARAMS.into(),
825 },
826 stream_id: val.0.into(),
827 },
828 buffer_bytes: val.1.buffer_bytes.into(),
829 period_bytes: val.1.period_bytes.into(),
830 features: val.1.features.into(),
831 channels: val.1.channels,
832 format: val.1.format,
833 rate: val.1.rate,
834 padding: 0u8,
835 }
836 }
837 }
838
send_cmd<T: Immutable + IntoBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()>839 fn send_cmd<T: Immutable + IntoBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()> {
840 seq_socket_send(control_socket, data.as_bytes())?;
841 recv_cmd_status(control_socket)
842 }
843
recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()>844 fn recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()> {
845 let mut status: virtio_snd_hdr = Default::default();
846 control_socket
847 .recv(status.as_mut_bytes())
848 .map_err(Error::ServerIOError)?;
849 if status.code.to_native() == VIRTIO_SND_S_OK {
850 Ok(())
851 } else {
852 Err(Error::CommandFailed(status.code.to_native()))
853 }
854 }
855
seq_socket_send(socket: &UnixSeqpacket, data: &[u8]) -> Result<()>856 fn seq_socket_send(socket: &UnixSeqpacket, data: &[u8]) -> Result<()> {
857 loop {
858 let send_res = socket.send(data);
859 if let Err(e) = send_res {
860 match e.kind() {
861 // Retry if interrupted
862 IOErrorKind::Interrupted => continue,
863 _ => return Err(Error::ServerIOError(e)),
864 }
865 }
866 // Success
867 break;
868 }
869 Ok(())
870 }
871
872 const VIOS_VERSION: u32 = 2;
873
874 #[repr(C)]
875 #[derive(
876 Copy,
877 Clone,
878 Default,
879 FromBytes,
880 Immutable,
881 IntoBytes,
882 KnownLayout,
883 Serialize,
884 Deserialize,
885 PartialEq,
886 Eq,
887 Debug,
888 )]
889 struct VioSConfig {
890 version: u32,
891 jacks: u32,
892 streams: u32,
893 chmaps: u32,
894 }
895
896 struct BufferReleaseMsg {
897 status: u32,
898 latency: u32,
899 consumed_len: usize,
900 }
901
902 #[repr(C)]
903 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
904 struct IoTransferMsg {
905 io_xfer: virtio_snd_pcm_xfer,
906 buffer_offset: u32,
907 buffer_len: u32,
908 }
909
910 impl IoTransferMsg {
new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg911 fn new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg {
912 IoTransferMsg {
913 io_xfer: virtio_snd_pcm_xfer {
914 stream_id: stream_id.into(),
915 },
916 buffer_offset: buffer_offset as u32,
917 buffer_len: buffer_len as u32,
918 }
919 }
920 }
921
922 #[repr(C)]
923 #[derive(Copy, Clone, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
924 struct IoStatusMsg {
925 status: virtio_snd_pcm_status,
926 buffer_offset: u32,
927 consumed_len: u32,
928 }
929