• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "host/frontend/webrtc/audio_handler.h"
18 
19 #include <algorithm>
20 #include <chrono>
21 
22 #include <android-base/logging.h>
23 #include <rtc_base/time_utils.h>
24 
25 namespace cuttlefish {
26 namespace {
27 
28 const virtio_snd_jack_info JACKS[] = {};
29 constexpr uint32_t NUM_JACKS = sizeof(JACKS) / sizeof(JACKS[0]);
30 
31 const virtio_snd_chmap_info CHMAPS[] = {{
32     .hdr = { .hda_fn_nid = Le32(0), },
33     .direction = (uint8_t) AudioStreamDirection::VIRTIO_SND_D_OUTPUT,
34     .channels = 2,
35     .positions = {
36         (uint8_t) AudioChannelMap::VIRTIO_SND_CHMAP_FL,
37         (uint8_t) AudioChannelMap::VIRTIO_SND_CHMAP_FR
38     },
39 }, {
40     .hdr = { .hda_fn_nid = Le32(0), },
41     .direction = (uint8_t) AudioStreamDirection::VIRTIO_SND_D_INPUT,
42     .channels = 2,
43     .positions = {
44         (uint8_t) AudioChannelMap::VIRTIO_SND_CHMAP_FL,
45         (uint8_t) AudioChannelMap::VIRTIO_SND_CHMAP_FR
46     },
47 }};
48 constexpr uint32_t NUM_CHMAPS = sizeof(CHMAPS) / sizeof(CHMAPS[0]);
49 
GetVirtioSndPcmInfo(AudioStreamDirection direction,int streamId)50 virtio_snd_pcm_info GetVirtioSndPcmInfo(AudioStreamDirection direction,
51                                         int streamId) {
52   return {
53       .hdr =
54           {
55               .hda_fn_nid = Le32(streamId),
56           },
57       .features = Le32(0),
58       // webrtc's api is quite primitive and doesn't allow for many different
59       // formats: It only takes the bits_per_sample as a parameter and assumes
60       // the underlying format to be one of the following:
61       .formats = Le64((((uint64_t)1)
62                        << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S8) |
63                       (((uint64_t)1)
64                        << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S16) |
65                       (((uint64_t)1)
66                        << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S24) |
67                       (((uint64_t)1)
68                        << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S32)),
69       .rates = Le64((((uint64_t)1)
70                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_5512) |
71                     (((uint64_t)1)
72                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_8000) |
73                     (((uint64_t)1)
74                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_11025) |
75                     (((uint64_t)1)
76                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_16000) |
77                     (((uint64_t)1)
78                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_22050) |
79                     (((uint64_t)1)
80                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_32000) |
81                     (((uint64_t)1)
82                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_44100) |
83                     (((uint64_t)1)
84                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_48000) |
85                     (((uint64_t)1)
86                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_64000) |
87                     (((uint64_t)1)
88                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_88200) |
89                     (((uint64_t)1)
90                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_96000) |
91                     (((uint64_t)1)
92                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_176400) |
93                     (((uint64_t)1)
94                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_192000) |
95                     (((uint64_t)1)
96                      << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_384000)),
97       .direction = (uint8_t)direction,
98       .channels_min = 1,
99       .channels_max = 2,
100   };
101 }
102 
103 constexpr uint32_t NUM_INPUT_STREAMS = 1;
104 
105 class CvdAudioFrameBuffer : public webrtc_streaming::AudioFrameBuffer {
106  public:
CvdAudioFrameBuffer(const uint8_t * buffer,int bits_per_sample,int sample_rate,int channels,int frames)107   CvdAudioFrameBuffer(const uint8_t* buffer, int bits_per_sample,
108                       int sample_rate, int channels, int frames)
109       : buffer_(buffer),
110         bits_per_sample_(bits_per_sample),
111         sample_rate_(sample_rate),
112         channels_(channels),
113         frames_(frames) {}
114 
bits_per_sample() const115   int bits_per_sample() const override { return bits_per_sample_; }
116 
sample_rate() const117   int sample_rate() const override { return sample_rate_; }
118 
channels() const119   int channels() const override { return channels_; }
120 
frames() const121   int frames() const override { return frames_; }
122 
data() const123   const uint8_t* data() const override { return buffer_; }
124 
125  private:
126   const uint8_t* buffer_;
127   int bits_per_sample_;
128   int sample_rate_;
129   int channels_;
130   int frames_;
131 };
132 
BitsPerSample(uint8_t virtio_format)133 int BitsPerSample(uint8_t virtio_format) {
134   switch (virtio_format) {
135     /* analog formats (width / physical width) */
136     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_IMA_ADPCM:
137       /*  4 /  4 bits */
138       return 4;
139     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_MU_LAW:
140       /*  8 /  8 bits */
141       return 8;
142     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_A_LAW:
143       /*  8 /  8 bits */
144       return 8;
145     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S8:
146       /*  8 /  8 bits */
147       return 8;
148     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U8:
149       /*  8 /  8 bits */
150       return 8;
151     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S16:
152       /* 16 / 16 bits */
153       return 16;
154     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U16:
155       /* 16 / 16 bits */
156       return 16;
157     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S18_3:
158       /* 18 / 24 bits */
159       return 24;
160     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U18_3:
161       /* 18 / 24 bits */
162       return 24;
163     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S20_3:
164       /* 20 / 24 bits */
165       return 24;
166     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U20_3:
167       /* 20 / 24 bits */
168       return 24;
169     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S24_3:
170       /* 24 / 24 bits */
171       return 24;
172     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U24_3:
173       /* 24 / 24 bits */
174       return 24;
175     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S20:
176       /* 20 / 32 bits */
177       return 32;
178     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U20:
179       /* 20 / 32 bits */
180       return 32;
181     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S24:
182       /* 24 / 32 bits */
183       return 32;
184     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U24:
185       /* 24 / 32 bits */
186       return 32;
187     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S32:
188       /* 32 / 32 bits */
189       return 32;
190     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U32:
191       /* 32 / 32 bits */
192       return 32;
193     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_FLOAT:
194       /* 32 / 32 bits */
195       return 32;
196     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_FLOAT64:
197       /* 64 / 64 bits */
198       return 64;
199     /* digital formats (width / physical width) */
200     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_DSD_U8:
201       /*  8 /  8 bits */
202       return 8;
203     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_DSD_U16:
204       /* 16 / 16 bits */
205       return 16;
206     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_DSD_U32:
207       /* 32 / 32 bits */
208       return 32;
209     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_IEC958_SUBFRAME:
210       /* 32 / 32 bits */
211       return 32;
212     default:
213       LOG(ERROR) << "Unknown virtio-snd audio format: " << virtio_format;
214       return -1;
215   }
216 }
217 
SampleRate(uint8_t virtio_rate)218 int SampleRate(uint8_t virtio_rate) {
219   switch (virtio_rate) {
220     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_5512:
221       return 5512;
222     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_8000:
223       return 8000;
224     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_11025:
225       return 11025;
226     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_16000:
227       return 16000;
228     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_22050:
229       return 22050;
230     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_32000:
231       return 32000;
232     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_44100:
233       return 44100;
234     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_48000:
235       return 48000;
236     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_64000:
237       return 64000;
238     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_88200:
239       return 88200;
240     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_96000:
241       return 96000;
242     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_176400:
243       return 176400;
244     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_192000:
245       return 192000;
246     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_384000:
247       return 384000;
248     default:
249       LOG(ERROR) << "Unknown virtio-snd sample rate: " << virtio_rate;
250       return -1;
251   }
252 }
253 
254 }  // namespace
255 
AudioHandler(std::unique_ptr<AudioServer> audio_server,std::vector<std::shared_ptr<webrtc_streaming::AudioSink>> audio_sinks,std::shared_ptr<webrtc_streaming::AudioSource> audio_source)256 AudioHandler::AudioHandler(
257     std::unique_ptr<AudioServer> audio_server,
258     std::vector<std::shared_ptr<webrtc_streaming::AudioSink>> audio_sinks,
259     std::shared_ptr<webrtc_streaming::AudioSource> audio_source)
260     : audio_sinks_(std::move(audio_sinks)),
261       audio_server_(std::move(audio_server)),
262       stream_descs_(audio_sinks_.size() + NUM_INPUT_STREAMS),
263       audio_source_(audio_source) {
264   streams_ = std::vector<virtio_snd_pcm_info>(stream_descs_.size());
265   streams_[0] =
266       GetVirtioSndPcmInfo(AudioStreamDirection::VIRTIO_SND_D_INPUT, 0);
267   for (int i = 0; i < audio_sinks_.size(); i++) {
268     int stream_id = NUM_INPUT_STREAMS + i;
269     streams_[stream_id] =
270         GetVirtioSndPcmInfo(AudioStreamDirection::VIRTIO_SND_D_OUTPUT, i);
271   }
272 }
273 
Start()274 void AudioHandler::Start() {
275   server_thread_ = std::thread([this]() { Loop(); });
276 }
277 
Loop()278 [[noreturn]] void AudioHandler::Loop() {
279   for (;;) {
280     auto audio_client = audio_server_->AcceptClient(
281         streams_.size(), NUM_JACKS, NUM_CHMAPS, 262144 /* tx_shm_len */,
282         262144 /* rx_shm_len */);
283     CHECK(audio_client) << "Failed to create audio client connection instance";
284 
285     std::thread playback_thread([this, &audio_client]() {
286       while (audio_client->ReceivePlayback(*this)) {
287       }
288     });
289     std::thread capture_thread([this, &audio_client]() {
290       while (audio_client->ReceiveCapture(*this)) {
291       }
292     });
293     // Wait for the client to do something
294     while (audio_client->ReceiveCommands(*this)) {
295     }
296     playback_thread.join();
297     capture_thread.join();
298   }
299 }
300 
StreamsInfo(StreamInfoCommand & cmd)301 void AudioHandler::StreamsInfo(StreamInfoCommand& cmd) {
302   if (cmd.start_id() >= streams_.size() ||
303       cmd.start_id() + cmd.count() > streams_.size()) {
304     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG, {});
305     return;
306   }
307   std::vector<virtio_snd_pcm_info> stream_info(
308       &streams_[cmd.start_id()], &streams_[0] + cmd.start_id() + cmd.count());
309   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK, stream_info);
310 }
311 
SetStreamParameters(StreamSetParamsCommand & cmd)312 void AudioHandler::SetStreamParameters(StreamSetParamsCommand& cmd) {
313   if (cmd.stream_id() >= streams_.size()) {
314     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
315     return;
316   }
317   const auto& stream_info = streams_[cmd.stream_id()];
318   auto bits_per_sample = BitsPerSample(cmd.format());
319   auto sample_rate = SampleRate(cmd.rate());
320   auto channels = cmd.channels();
321   if (bits_per_sample < 0 || sample_rate < 0 ||
322       channels < stream_info.channels_min ||
323       channels > stream_info.channels_max) {
324     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
325     return;
326   }
327   {
328     std::lock_guard<std::mutex> lock(stream_descs_[cmd.stream_id()].mtx);
329     stream_descs_[cmd.stream_id()].bits_per_sample = bits_per_sample;
330     stream_descs_[cmd.stream_id()].sample_rate = sample_rate;
331     stream_descs_[cmd.stream_id()].channels = channels;
332     auto len10ms = (channels * (sample_rate / 100) * bits_per_sample) / 8;
333     stream_descs_[cmd.stream_id()].buffer.Reset(len10ms);
334   }
335   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
336 }
337 
PrepareStream(StreamControlCommand & cmd)338 void AudioHandler::PrepareStream(StreamControlCommand& cmd) {
339   if (cmd.stream_id() >= streams_.size()) {
340     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
341     return;
342   }
343   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
344 }
345 
ReleaseStream(StreamControlCommand & cmd)346 void AudioHandler::ReleaseStream(StreamControlCommand& cmd) {
347   if (cmd.stream_id() >= streams_.size()) {
348     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
349     return;
350   }
351   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
352 }
353 
StartStream(StreamControlCommand & cmd)354 void AudioHandler::StartStream(StreamControlCommand& cmd) {
355   if (cmd.stream_id() >= streams_.size()) {
356     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
357     return;
358   }
359   stream_descs_[cmd.stream_id()].active = true;
360   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
361 }
362 
StopStream(StreamControlCommand & cmd)363 void AudioHandler::StopStream(StreamControlCommand& cmd) {
364   if (cmd.stream_id() >= streams_.size()) {
365     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
366     return;
367   }
368   stream_descs_[cmd.stream_id()].active = false;
369   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
370 }
371 
ChmapsInfo(ChmapInfoCommand & cmd)372 void AudioHandler::ChmapsInfo(ChmapInfoCommand& cmd) {
373   if (cmd.start_id() >= NUM_CHMAPS ||
374       cmd.start_id() + cmd.count() > NUM_CHMAPS) {
375     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG, {});
376     return;
377   }
378   std::vector<virtio_snd_chmap_info> chmap_info(
379       &CHMAPS[cmd.start_id()], &CHMAPS[cmd.start_id()] + cmd.count());
380   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK, chmap_info);
381 }
382 
JacksInfo(JackInfoCommand & cmd)383 void AudioHandler::JacksInfo(JackInfoCommand& cmd) {
384   if (cmd.start_id() >= NUM_JACKS ||
385       cmd.start_id() + cmd.count() > NUM_JACKS) {
386     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG, {});
387     return;
388   }
389   std::vector<virtio_snd_jack_info> jack_info(
390       &JACKS[cmd.start_id()], &JACKS[cmd.start_id()] + cmd.count());
391   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK, jack_info);
392 }
393 
OnPlaybackBuffer(TxBuffer buffer)394 void AudioHandler::OnPlaybackBuffer(TxBuffer buffer) {
395   auto stream_id = buffer.stream_id();
396   auto& stream_desc = stream_descs_[stream_id];
397   {
398     std::lock_guard<std::mutex> lock(stream_desc.mtx);
399     auto& holding_buffer = stream_descs_[stream_id].buffer;
400     // Invalid or capture streams shouldn't send tx buffers
401     if (stream_id >= streams_.size() || IsCapture(stream_id)) {
402       buffer.SendStatus(AudioStatus::VIRTIO_SND_S_BAD_MSG, 0, 0);
403       return;
404     }
405     // A buffer may be received for an inactive stream if we were slow to
406     // process it and the other side stopped the stream. Quietly ignore it in
407     // that case
408     if (!stream_desc.active) {
409       buffer.SendStatus(AudioStatus::VIRTIO_SND_S_OK, 0, buffer.len());
410       return;
411     }
412     auto sink_id = stream_id - NUM_INPUT_STREAMS;
413     if (sink_id >= audio_sinks_.size()) {
414       LOG(ERROR) << "Audio sink for stream id " << stream_id
415                  << " does not exist";
416       buffer.SendStatus(AudioStatus::VIRTIO_SND_S_BAD_MSG, 0, 0);
417       return;
418     }
419     auto audio_sink = audio_sinks_[sink_id];
420     // Webrtc will silently ignore any buffer with a length different than 10ms,
421     // so we must split any buffer bigger than that and temporarily store any
422     // remaining frames that are less than that size.
423     auto current_time = rtc::TimeMillis();
424     // The timestamp of the first 10ms chunk to be sent so that the last one
425     // will have the current time
426     auto base_time =
427         current_time - ((buffer.len() - 1) / holding_buffer.buffer.size()) * 10;
428     // number of frames in a 10 ms buffer
429     const int frames = stream_desc.sample_rate / 100;
430     size_t pos = 0;
431     while (pos < buffer.len()) {
432       if (holding_buffer.empty() &&
433           buffer.len() - pos >= holding_buffer.buffer.size()) {
434         // Avoid the extra copy into holding buffer
435         // This casts away volatility of the pointer, necessary because the
436         // webrtc api doesn't expect volatile memory. This should be safe though
437         // because webrtc will use the contents of the buffer before returning
438         // and only then we release it.
439         CvdAudioFrameBuffer audio_frame_buffer(
440             const_cast<const uint8_t*>(&buffer.get()[pos]),
441             stream_desc.bits_per_sample, stream_desc.sample_rate,
442             stream_desc.channels, frames);
443         // Multiple output streams are mixed on the client side.
444         audio_sink->OnFrame(audio_frame_buffer, base_time);
445         pos += holding_buffer.buffer.size();
446       } else {
447         pos += holding_buffer.Add(buffer.get() + pos, buffer.len() - pos);
448         if (holding_buffer.full()) {
449           auto buffer_ptr = const_cast<const uint8_t*>(holding_buffer.data());
450           CvdAudioFrameBuffer audio_frame_buffer(
451               buffer_ptr, stream_desc.bits_per_sample, stream_desc.sample_rate,
452               stream_desc.channels, frames);
453           audio_sink->OnFrame(audio_frame_buffer, base_time);
454           holding_buffer.count = 0;
455         }
456       }
457       base_time += 10;
458     }
459   }
460   buffer.SendStatus(AudioStatus::VIRTIO_SND_S_OK, 0, buffer.len());
461 }
462 
OnCaptureBuffer(RxBuffer buffer)463 void AudioHandler::OnCaptureBuffer(RxBuffer buffer) {
464   auto stream_id = buffer.stream_id();
465   auto& stream_desc = stream_descs_[stream_id];
466   {
467     std::lock_guard<std::mutex> lock(stream_desc.mtx);
468     // Invalid or playback streams shouldn't send rx buffers
469     if (stream_id >= streams_.size() || !IsCapture(stream_id)) {
470       LOG(ERROR) << "Received capture buffers on playback stream " << stream_id;
471       buffer.SendStatus(AudioStatus::VIRTIO_SND_S_BAD_MSG, 0, 0);
472       return;
473     }
474     // A buffer may be received for an inactive stream if we were slow to
475     // process it and the other side stopped the stream. Quietly ignore it in
476     // that case
477     if (!stream_desc.active) {
478       buffer.SendStatus(AudioStatus::VIRTIO_SND_S_OK, 0, buffer.len());
479       return;
480     }
481     const auto bytes_per_sample = stream_desc.bits_per_sample / 8;
482     const auto samples_per_channel = stream_desc.sample_rate / 100;
483     const auto bytes_per_request =
484         samples_per_channel * bytes_per_sample * stream_desc.channels;
485     bool muted = false;
486     size_t bytes_read = 0;
487     auto& holding_buffer = stream_descs_[stream_id].buffer;
488     auto rx_buffer = const_cast<uint8_t*>(buffer.get());
489     if (!holding_buffer.empty()) {
490       // Consume any bytes remaining from previous requests
491       bytes_read += holding_buffer.Take(rx_buffer + bytes_read,
492                                         buffer.len() - bytes_read);
493     }
494     while (buffer.len() - bytes_read >= bytes_per_request) {
495       // Skip the holding buffer in as many reads as possible to avoid the extra
496       // copies
497       auto write_pos = rx_buffer + bytes_read;
498       auto res = audio_source_->GetMoreAudioData(
499           write_pos, bytes_per_sample, samples_per_channel,
500           stream_desc.channels, stream_desc.sample_rate, muted);
501       if (res < 0) {
502         // This is likely a recoverable error, log the error but don't let the
503         // VMM know about it so that it doesn't crash.
504         LOG(ERROR) << "Failed to receive audio data from client";
505         break;
506       }
507       if (muted) {
508         // The source is muted, just fill the buffer with zeros and return
509         memset(rx_buffer + bytes_read, 0, buffer.len() - bytes_read);
510         bytes_read = buffer.len();
511         break;
512       }
513       auto bytes_received = res * bytes_per_sample * stream_desc.channels;
514       bytes_read += bytes_received;
515     }
516     if (bytes_read < buffer.len()) {
517       // There is some buffer left to fill, but it's less than 10ms, read into
518       // holding buffer to ensure the remainder is kept around for future reads
519       auto write_pos = holding_buffer.data();
520       // Holding buffer is the exact size we need to read into and is emptied
521       // before we try to read into it.
522       CHECK(holding_buffer.freeCapacity() >= bytes_per_request)
523           << "Buffer too small for receiving audio";
524       auto res = audio_source_->GetMoreAudioData(
525           write_pos, bytes_per_sample, samples_per_channel,
526           stream_desc.channels, stream_desc.sample_rate, muted);
527       if (res < 0) {
528         // This is likely a recoverable error, log the error but don't let the
529         // VMM know about it so that it doesn't crash.
530         LOG(ERROR) << "Failed to receive audio data from client";
531       } else if (muted) {
532         // The source is muted, just fill the buffer with zeros and return
533         memset(rx_buffer + bytes_read, 0, buffer.len() - bytes_read);
534         bytes_read = buffer.len();
535       } else {
536         auto bytes_received = res * bytes_per_sample * stream_desc.channels;
537         holding_buffer.count += bytes_received;
538         bytes_read += holding_buffer.Take(rx_buffer + bytes_read,
539                                           buffer.len() - bytes_read);
540         // If the entire buffer is not full by now there is a bug above
541         // somewhere
542         CHECK(bytes_read == buffer.len()) << "Failed to read entire buffer";
543       }
544     }
545   }
546   buffer.SendStatus(AudioStatus::VIRTIO_SND_S_OK, 0, buffer.len());
547 }
548 
Reset(size_t size)549 void AudioHandler::HoldingBuffer::Reset(size_t size) {
550   buffer.resize(size);
551   count = 0;
552 }
553 
Add(const volatile uint8_t * data,size_t max_len)554 size_t AudioHandler::HoldingBuffer::Add(const volatile uint8_t* data,
555                                         size_t max_len) {
556   auto added_len = std::min(max_len, buffer.size() - count);
557   std::copy(data, data + added_len, &buffer[count]);
558   count += added_len;
559   return added_len;
560 }
561 
Take(uint8_t * dst,size_t len)562 size_t AudioHandler::HoldingBuffer::Take(uint8_t* dst, size_t len) {
563   auto n = std::min(len, count);
564   std::copy(buffer.begin(), buffer.begin() + n, dst);
565   std::copy(buffer.begin() + n, buffer.begin() + count, buffer.begin());
566   count -= n;
567   return n;
568 }
569 
empty() const570 bool AudioHandler::HoldingBuffer::empty() const { return count == 0; }
571 
full() const572 bool AudioHandler::HoldingBuffer::full() const {
573   return count == buffer.size();
574 }
575 
freeCapacity() const576 size_t AudioHandler::HoldingBuffer::freeCapacity() const {
577   return buffer.size() - count;
578 }
579 
data()580 uint8_t* AudioHandler::HoldingBuffer::data() { return buffer.data(); }
581 
IsCapture(uint32_t stream_id) const582 bool AudioHandler::IsCapture(uint32_t stream_id) const {
583   CHECK(stream_id < streams_.size()) << "Invalid stream id: " << stream_id;
584   return streams_[stream_id].direction ==
585          (uint8_t)AudioStreamDirection::VIRTIO_SND_D_INPUT;
586 }
587 
588 }  // namespace cuttlefish
589