1# Copyright 2021-2024 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19import asyncio 20import datetime 21import enum 22import functools 23from importlib import resources 24import json 25import os 26import logging 27import pathlib 28from typing import Optional, List, cast 29import weakref 30import struct 31 32import ctypes 33import wasmtime 34import wasmtime.loader 35import liblc3 # type: ignore 36import logging 37 38import click 39import aiohttp.web 40 41import bumble 42from bumble.core import AdvertisingData 43from bumble.colors import color 44from bumble.device import Device, DeviceConfiguration, AdvertisingParameters 45from bumble.transport import open_transport 46from bumble.profiles import bap 47from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket 48 49# ----------------------------------------------------------------------------- 50# Logging 51# ----------------------------------------------------------------------------- 52logger = logging.getLogger(__name__) 53 54# ----------------------------------------------------------------------------- 55# Constants 56# ----------------------------------------------------------------------------- 57DEFAULT_UI_PORT = 7654 58 59 60def _sink_pac_record() -> bap.PacRecord: 61 return bap.PacRecord( 62 coding_format=CodingFormat(CodecID.LC3), 63 codec_specific_capabilities=bap.CodecSpecificCapabilities( 64 supported_sampling_frequencies=( 65 bap.SupportedSamplingFrequency.FREQ_8000 66 | bap.SupportedSamplingFrequency.FREQ_16000 67 | bap.SupportedSamplingFrequency.FREQ_24000 68 | bap.SupportedSamplingFrequency.FREQ_32000 69 | bap.SupportedSamplingFrequency.FREQ_48000 70 ), 71 supported_frame_durations=( 72 bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED 73 ), 74 supported_audio_channel_count=[1, 2], 75 min_octets_per_codec_frame=26, 76 max_octets_per_codec_frame=240, 77 supported_max_codec_frames_per_sdu=2, 78 ), 79 ) 80 81 82def _source_pac_record() -> bap.PacRecord: 83 return bap.PacRecord( 84 coding_format=CodingFormat(CodecID.LC3), 85 codec_specific_capabilities=bap.CodecSpecificCapabilities( 86 supported_sampling_frequencies=( 87 bap.SupportedSamplingFrequency.FREQ_8000 88 | bap.SupportedSamplingFrequency.FREQ_16000 89 | bap.SupportedSamplingFrequency.FREQ_24000 90 | bap.SupportedSamplingFrequency.FREQ_32000 91 | bap.SupportedSamplingFrequency.FREQ_48000 92 ), 93 supported_frame_durations=( 94 bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED 95 ), 96 supported_audio_channel_count=[1], 97 min_octets_per_codec_frame=30, 98 max_octets_per_codec_frame=100, 99 supported_max_codec_frames_per_sdu=1, 100 ), 101 ) 102 103 104# ----------------------------------------------------------------------------- 105# WASM - liblc3 106# ----------------------------------------------------------------------------- 107store = wasmtime.loader.store 108_memory = cast(wasmtime.Memory, liblc3.memory) 109STACK_POINTER = _memory.data_len(store) 110_memory.grow(store, 1) 111# Mapping wasmtime memory to linear address 112memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address( 113 ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore 114) 115 116 117class Liblc3PcmFormat(enum.IntEnum): 118 S16 = 0 119 S24 = 1 120 S24_3LE = 2 121 FLOAT = 3 122 123 124MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000) 125MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000) 126 127DECODER_STACK_POINTER = STACK_POINTER 128ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2 129DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2 130ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192 131DEFAULT_PCM_SAMPLE_RATE = 48000 132DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16 133DEFAULT_PCM_BYTES_PER_SAMPLE = 2 134 135 136encoders: List[int] = [] 137decoders: List[int] = [] 138 139 140def setup_encoders( 141 sample_rate_hz: int, frame_duration_us: int, num_channels: int 142) -> None: 143 logger.info( 144 f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels" 145 ) 146 encoders[:num_channels] = [ 147 liblc3.lc3_setup_encoder( 148 frame_duration_us, 149 sample_rate_hz, 150 DEFAULT_PCM_SAMPLE_RATE, # Input sample rate 151 ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i, 152 ) 153 for i in range(num_channels) 154 ] 155 156 157def setup_decoders( 158 sample_rate_hz: int, frame_duration_us: int, num_channels: int 159) -> None: 160 logger.info( 161 f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels" 162 ) 163 decoders[:num_channels] = [ 164 liblc3.lc3_setup_decoder( 165 frame_duration_us, 166 sample_rate_hz, 167 DEFAULT_PCM_SAMPLE_RATE, # Output sample rate 168 DECODER_STACK_POINTER + MAX_DECODER_SIZE * i, 169 ) 170 for i in range(num_channels) 171 ] 172 173 174def decode( 175 frame_duration_us: int, 176 num_channels: int, 177 input_bytes: bytes, 178) -> bytes: 179 if not input_bytes: 180 return b'' 181 182 input_buffer_offset = DECODE_BUFFER_STACK_POINTER 183 input_buffer_size = len(input_bytes) 184 input_bytes_per_frame = input_buffer_size // num_channels 185 186 # Copy into wasm 187 memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore 188 189 output_buffer_offset = input_buffer_offset + input_buffer_size 190 output_buffer_size = ( 191 liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE) 192 * DEFAULT_PCM_BYTES_PER_SAMPLE 193 * num_channels 194 ) 195 196 for i in range(num_channels): 197 res = liblc3.lc3_decode( 198 decoders[i], 199 input_buffer_offset + input_bytes_per_frame * i, 200 input_bytes_per_frame, 201 DEFAULT_PCM_FORMAT, 202 output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE, 203 num_channels, # Stride 204 ) 205 206 if res != 0: 207 logging.error(f"Parsing failed, res={res}") 208 209 # Extract decoded data from the output buffer 210 return bytes( 211 memory[output_buffer_offset : output_buffer_offset + output_buffer_size] 212 ) 213 214 215def encode( 216 sdu_length: int, 217 num_channels: int, 218 stride: int, 219 input_bytes: bytes, 220) -> bytes: 221 if not input_bytes: 222 return b'' 223 224 input_buffer_offset = ENCODE_BUFFER_STACK_POINTER 225 input_buffer_size = len(input_bytes) 226 227 # Copy into wasm 228 memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore 229 230 output_buffer_offset = input_buffer_offset + input_buffer_size 231 output_buffer_size = sdu_length 232 output_frame_size = output_buffer_size // num_channels 233 234 for i in range(num_channels): 235 res = liblc3.lc3_encode( 236 encoders[i], 237 DEFAULT_PCM_FORMAT, 238 input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i, 239 stride, 240 output_frame_size, 241 output_buffer_offset + output_frame_size * i, 242 ) 243 244 if res != 0: 245 logging.error(f"Parsing failed, res={res}") 246 247 # Extract decoded data from the output buffer 248 return bytes( 249 memory[output_buffer_offset : output_buffer_offset + output_buffer_size] 250 ) 251 252 253async def lc3_source_task( 254 filename: str, 255 sdu_length: int, 256 frame_duration_us: int, 257 device: Device, 258 cis_handle: int, 259) -> None: 260 with open(filename, 'rb') as f: 261 header = f.read(44) 262 assert header[8:12] == b'WAVE' 263 264 pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = ( 265 struct.unpack("<HIIHH", header[22:36]) 266 ) 267 assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE 268 assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8 269 270 frame_bytes = ( 271 liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE) 272 * DEFAULT_PCM_BYTES_PER_SAMPLE 273 ) 274 packet_sequence_number = 0 275 276 while True: 277 next_round = datetime.datetime.now() + datetime.timedelta( 278 microseconds=frame_duration_us 279 ) 280 pcm_data = f.read(frame_bytes) 281 sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data) 282 283 iso_packet = HCI_IsoDataPacket( 284 connection_handle=cis_handle, 285 data_total_length=sdu_length + 4, 286 packet_sequence_number=packet_sequence_number, 287 pb_flag=0b10, 288 packet_status_flag=0, 289 iso_sdu_length=sdu_length, 290 iso_sdu_fragment=sdu, 291 ) 292 device.host.send_hci_packet(iso_packet) 293 packet_sequence_number += 1 294 sleep_time = next_round - datetime.datetime.now() 295 await asyncio.sleep(sleep_time.total_seconds()) 296 297 298# ----------------------------------------------------------------------------- 299class UiServer: 300 speaker: weakref.ReferenceType[Speaker] 301 port: int 302 303 def __init__(self, speaker: Speaker, port: int) -> None: 304 self.speaker = weakref.ref(speaker) 305 self.port = port 306 self.channel_socket = None 307 308 async def start_http(self) -> None: 309 """Start the UI HTTP server.""" 310 311 app = aiohttp.web.Application() 312 app.add_routes( 313 [ 314 aiohttp.web.get('/', self.get_static), 315 aiohttp.web.get('/index.html', self.get_static), 316 aiohttp.web.get('/channel', self.get_channel), 317 ] 318 ) 319 320 runner = aiohttp.web.AppRunner(app) 321 await runner.setup() 322 site = aiohttp.web.TCPSite(runner, 'localhost', self.port) 323 print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green')) 324 await site.start() 325 326 async def get_static(self, request): 327 path = request.path 328 if path == '/': 329 path = '/index.html' 330 if path.endswith('.html'): 331 content_type = 'text/html' 332 elif path.endswith('.js'): 333 content_type = 'text/javascript' 334 elif path.endswith('.css'): 335 content_type = 'text/css' 336 elif path.endswith('.svg'): 337 content_type = 'image/svg+xml' 338 else: 339 content_type = 'text/plain' 340 text = ( 341 resources.files("bumble.apps.lea_unicast") 342 .joinpath(pathlib.Path(path).relative_to('/')) 343 .read_text(encoding="utf-8") 344 ) 345 return aiohttp.web.Response(text=text, content_type=content_type) 346 347 async def get_channel(self, request): 348 ws = aiohttp.web.WebSocketResponse() 349 await ws.prepare(request) 350 351 # Process messages until the socket is closed. 352 self.channel_socket = ws 353 async for message in ws: 354 if message.type == aiohttp.WSMsgType.TEXT: 355 logger.debug(f'<<< received message: {message.data}') 356 await self.on_message(message.data) 357 elif message.type == aiohttp.WSMsgType.ERROR: 358 logger.debug( 359 f'channel connection closed with exception {ws.exception()}' 360 ) 361 362 self.channel_socket = None 363 logger.debug('--- channel connection closed') 364 365 return ws 366 367 async def on_message(self, message_str: str): 368 # Parse the message as JSON 369 message = json.loads(message_str) 370 371 # Dispatch the message 372 message_type = message['type'] 373 message_params = message.get('params', {}) 374 handler = getattr(self, f'on_{message_type}_message') 375 if handler: 376 await handler(**message_params) 377 378 async def on_hello_message(self): 379 await self.send_message( 380 'hello', 381 bumble_version=bumble.__version__, 382 codec=self.speaker().codec, 383 streamState=self.speaker().stream_state.name, 384 ) 385 if connection := self.speaker().connection: 386 await self.send_message( 387 'connection', 388 peer_address=connection.peer_address.to_string(False), 389 peer_name=connection.peer_name, 390 ) 391 392 async def send_message(self, message_type: str, **kwargs) -> None: 393 if self.channel_socket is None: 394 return 395 396 message = {'type': message_type, 'params': kwargs} 397 await self.channel_socket.send_json(message) 398 399 async def send_audio(self, data: bytes) -> None: 400 if self.channel_socket is None: 401 return 402 403 try: 404 await self.channel_socket.send_bytes(data) 405 except Exception as error: 406 logger.warning(f'exception while sending audio packet: {error}') 407 408 409# ----------------------------------------------------------------------------- 410class Speaker: 411 412 def __init__( 413 self, 414 device_config_path: Optional[str], 415 ui_port: int, 416 transport: str, 417 lc3_input_file_path: str, 418 ): 419 self.device_config_path = device_config_path 420 self.transport = transport 421 self.lc3_input_file_path = lc3_input_file_path 422 423 # Create an HTTP server for the UI 424 self.ui_server = UiServer(speaker=self, port=ui_port) 425 426 async def run(self) -> None: 427 await self.ui_server.start_http() 428 429 async with await open_transport(self.transport) as hci_transport: 430 # Create a device 431 if self.device_config_path: 432 device_config = DeviceConfiguration.from_file(self.device_config_path) 433 else: 434 device_config = DeviceConfiguration( 435 name="Bumble LE Headphone", 436 class_of_device=0x244418, 437 keystore="JsonKeyStore", 438 advertising_interval_min=25, 439 advertising_interval_max=25, 440 address=Address('F1:F2:F3:F4:F5:F6'), 441 ) 442 443 device_config.le_enabled = True 444 device_config.cis_enabled = True 445 self.device = Device.from_config_with_hci( 446 device_config, hci_transport.source, hci_transport.sink 447 ) 448 449 self.device.add_service( 450 bap.PublishedAudioCapabilitiesService( 451 supported_source_context=bap.ContextType(0xFFFF), 452 available_source_context=bap.ContextType(0xFFFF), 453 supported_sink_context=bap.ContextType(0xFFFF), # All context types 454 available_sink_context=bap.ContextType(0xFFFF), # All context types 455 sink_audio_locations=( 456 bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT 457 ), 458 sink_pac=[_sink_pac_record()], 459 source_audio_locations=bap.AudioLocation.FRONT_LEFT, 460 source_pac=[_source_pac_record()], 461 ) 462 ) 463 464 ascs = bap.AudioStreamControlService( 465 self.device, sink_ase_id=[1], source_ase_id=[2] 466 ) 467 self.device.add_service(ascs) 468 469 advertising_data = bytes( 470 AdvertisingData( 471 [ 472 ( 473 AdvertisingData.COMPLETE_LOCAL_NAME, 474 bytes(device_config.name, 'utf-8'), 475 ), 476 ( 477 AdvertisingData.FLAGS, 478 bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]), 479 ), 480 ( 481 AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, 482 bytes(bap.PublishedAudioCapabilitiesService.UUID), 483 ), 484 ] 485 ) 486 ) + bytes(bap.UnicastServerAdvertisingData()) 487 488 def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine): 489 codec_config = ase.codec_specific_configuration 490 assert isinstance(codec_config, bap.CodecSpecificConfiguration) 491 pcm = decode( 492 codec_config.frame_duration.us, 493 codec_config.audio_channel_allocation.channel_count, 494 pdu.iso_sdu_fragment, 495 ) 496 self.device.abort_on('disconnection', self.ui_server.send_audio(pcm)) 497 498 def on_ase_state_change(ase: bap.AseStateMachine) -> None: 499 if ase.state == bap.AseStateMachine.State.STREAMING: 500 codec_config = ase.codec_specific_configuration 501 assert isinstance(codec_config, bap.CodecSpecificConfiguration) 502 assert ase.cis_link 503 if ase.role == bap.AudioRole.SOURCE: 504 ase.cis_link.abort_on( 505 'disconnection', 506 lc3_source_task( 507 filename=self.lc3_input_file_path, 508 sdu_length=( 509 codec_config.codec_frames_per_sdu 510 * codec_config.octets_per_codec_frame 511 ), 512 frame_duration_us=codec_config.frame_duration.us, 513 device=self.device, 514 cis_handle=ase.cis_link.handle, 515 ), 516 ) 517 else: 518 ase.cis_link.sink = functools.partial(on_pdu, ase=ase) 519 elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED: 520 codec_config = ase.codec_specific_configuration 521 assert isinstance(codec_config, bap.CodecSpecificConfiguration) 522 if ase.role == bap.AudioRole.SOURCE: 523 setup_encoders( 524 codec_config.sampling_frequency.hz, 525 codec_config.frame_duration.us, 526 codec_config.audio_channel_allocation.channel_count, 527 ) 528 else: 529 setup_decoders( 530 codec_config.sampling_frequency.hz, 531 codec_config.frame_duration.us, 532 codec_config.audio_channel_allocation.channel_count, 533 ) 534 535 for ase in ascs.ase_state_machines.values(): 536 ase.on('state_change', functools.partial(on_ase_state_change, ase=ase)) 537 538 await self.device.power_on() 539 await self.device.create_advertising_set( 540 advertising_data=advertising_data, 541 auto_restart=True, 542 advertising_parameters=AdvertisingParameters( 543 primary_advertising_interval_min=100, 544 primary_advertising_interval_max=100, 545 ), 546 ) 547 548 await hci_transport.source.terminated 549 550 551@click.command() 552@click.option( 553 '--ui-port', 554 'ui_port', 555 metavar='HTTP_PORT', 556 default=DEFAULT_UI_PORT, 557 show_default=True, 558 help='HTTP port for the UI server', 559) 560@click.option('--device-config', metavar='FILENAME', help='Device configuration file') 561@click.argument('transport') 562@click.argument('lc3_file') 563def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None: 564 """Run the speaker.""" 565 566 asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run()) 567 568 569# ----------------------------------------------------------------------------- 570def main(): 571 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) 572 speaker() 573 574 575# ----------------------------------------------------------------------------- 576if __name__ == "__main__": 577 main() # pylint: disable=no-value-for-parameter 578