• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import asyncio
20import datetime
21import enum
22import functools
23from importlib import resources
24import json
25import os
26import logging
27import pathlib
28from typing import Optional, List, cast
29import weakref
30import struct
31
32import ctypes
33import wasmtime
34import wasmtime.loader
35import liblc3  # type: ignore
36import logging
37
38import click
39import aiohttp.web
40
41import bumble
42from bumble.core import AdvertisingData
43from bumble.colors import color
44from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
45from bumble.transport import open_transport
46from bumble.profiles import bap
47from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
48
49# -----------------------------------------------------------------------------
50# Logging
51# -----------------------------------------------------------------------------
52logger = logging.getLogger(__name__)
53
54# -----------------------------------------------------------------------------
55# Constants
56# -----------------------------------------------------------------------------
57DEFAULT_UI_PORT = 7654
58
59
60def _sink_pac_record() -> bap.PacRecord:
61    return bap.PacRecord(
62        coding_format=CodingFormat(CodecID.LC3),
63        codec_specific_capabilities=bap.CodecSpecificCapabilities(
64            supported_sampling_frequencies=(
65                bap.SupportedSamplingFrequency.FREQ_8000
66                | bap.SupportedSamplingFrequency.FREQ_16000
67                | bap.SupportedSamplingFrequency.FREQ_24000
68                | bap.SupportedSamplingFrequency.FREQ_32000
69                | bap.SupportedSamplingFrequency.FREQ_48000
70            ),
71            supported_frame_durations=(
72                bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
73            ),
74            supported_audio_channel_count=[1, 2],
75            min_octets_per_codec_frame=26,
76            max_octets_per_codec_frame=240,
77            supported_max_codec_frames_per_sdu=2,
78        ),
79    )
80
81
82def _source_pac_record() -> bap.PacRecord:
83    return bap.PacRecord(
84        coding_format=CodingFormat(CodecID.LC3),
85        codec_specific_capabilities=bap.CodecSpecificCapabilities(
86            supported_sampling_frequencies=(
87                bap.SupportedSamplingFrequency.FREQ_8000
88                | bap.SupportedSamplingFrequency.FREQ_16000
89                | bap.SupportedSamplingFrequency.FREQ_24000
90                | bap.SupportedSamplingFrequency.FREQ_32000
91                | bap.SupportedSamplingFrequency.FREQ_48000
92            ),
93            supported_frame_durations=(
94                bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
95            ),
96            supported_audio_channel_count=[1],
97            min_octets_per_codec_frame=30,
98            max_octets_per_codec_frame=100,
99            supported_max_codec_frames_per_sdu=1,
100        ),
101    )
102
103
104# -----------------------------------------------------------------------------
105# WASM - liblc3
106# -----------------------------------------------------------------------------
107store = wasmtime.loader.store
108_memory = cast(wasmtime.Memory, liblc3.memory)
109STACK_POINTER = _memory.data_len(store)
110_memory.grow(store, 1)
111# Mapping wasmtime memory to linear address
112memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
113    ctypes.addressof(_memory.data_ptr(store).contents)  # type: ignore
114)
115
116
117class Liblc3PcmFormat(enum.IntEnum):
118    S16 = 0
119    S24 = 1
120    S24_3LE = 2
121    FLOAT = 3
122
123
124MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
125MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
126
127DECODER_STACK_POINTER = STACK_POINTER
128ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
129DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
130ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
131DEFAULT_PCM_SAMPLE_RATE = 48000
132DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
133DEFAULT_PCM_BYTES_PER_SAMPLE = 2
134
135
136encoders: List[int] = []
137decoders: List[int] = []
138
139
140def setup_encoders(
141    sample_rate_hz: int, frame_duration_us: int, num_channels: int
142) -> None:
143    logger.info(
144        f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
145    )
146    encoders[:num_channels] = [
147        liblc3.lc3_setup_encoder(
148            frame_duration_us,
149            sample_rate_hz,
150            DEFAULT_PCM_SAMPLE_RATE,  # Input sample rate
151            ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
152        )
153        for i in range(num_channels)
154    ]
155
156
157def setup_decoders(
158    sample_rate_hz: int, frame_duration_us: int, num_channels: int
159) -> None:
160    logger.info(
161        f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
162    )
163    decoders[:num_channels] = [
164        liblc3.lc3_setup_decoder(
165            frame_duration_us,
166            sample_rate_hz,
167            DEFAULT_PCM_SAMPLE_RATE,  # Output sample rate
168            DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
169        )
170        for i in range(num_channels)
171    ]
172
173
174def decode(
175    frame_duration_us: int,
176    num_channels: int,
177    input_bytes: bytes,
178) -> bytes:
179    if not input_bytes:
180        return b''
181
182    input_buffer_offset = DECODE_BUFFER_STACK_POINTER
183    input_buffer_size = len(input_bytes)
184    input_bytes_per_frame = input_buffer_size // num_channels
185
186    # Copy into wasm
187    memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes  # type: ignore
188
189    output_buffer_offset = input_buffer_offset + input_buffer_size
190    output_buffer_size = (
191        liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
192        * DEFAULT_PCM_BYTES_PER_SAMPLE
193        * num_channels
194    )
195
196    for i in range(num_channels):
197        res = liblc3.lc3_decode(
198            decoders[i],
199            input_buffer_offset + input_bytes_per_frame * i,
200            input_bytes_per_frame,
201            DEFAULT_PCM_FORMAT,
202            output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
203            num_channels,  # Stride
204        )
205
206        if res != 0:
207            logging.error(f"Parsing failed, res={res}")
208
209    # Extract decoded data from the output buffer
210    return bytes(
211        memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
212    )
213
214
215def encode(
216    sdu_length: int,
217    num_channels: int,
218    stride: int,
219    input_bytes: bytes,
220) -> bytes:
221    if not input_bytes:
222        return b''
223
224    input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
225    input_buffer_size = len(input_bytes)
226
227    # Copy into wasm
228    memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes  # type: ignore
229
230    output_buffer_offset = input_buffer_offset + input_buffer_size
231    output_buffer_size = sdu_length
232    output_frame_size = output_buffer_size // num_channels
233
234    for i in range(num_channels):
235        res = liblc3.lc3_encode(
236            encoders[i],
237            DEFAULT_PCM_FORMAT,
238            input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
239            stride,
240            output_frame_size,
241            output_buffer_offset + output_frame_size * i,
242        )
243
244        if res != 0:
245            logging.error(f"Parsing failed, res={res}")
246
247    # Extract decoded data from the output buffer
248    return bytes(
249        memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
250    )
251
252
253async def lc3_source_task(
254    filename: str,
255    sdu_length: int,
256    frame_duration_us: int,
257    device: Device,
258    cis_handle: int,
259) -> None:
260    with open(filename, 'rb') as f:
261        header = f.read(44)
262        assert header[8:12] == b'WAVE'
263
264        pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
265            struct.unpack("<HIIHH", header[22:36])
266        )
267        assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
268        assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
269
270        frame_bytes = (
271            liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
272            * DEFAULT_PCM_BYTES_PER_SAMPLE
273        )
274        packet_sequence_number = 0
275
276        while True:
277            next_round = datetime.datetime.now() + datetime.timedelta(
278                microseconds=frame_duration_us
279            )
280            pcm_data = f.read(frame_bytes)
281            sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
282
283            iso_packet = HCI_IsoDataPacket(
284                connection_handle=cis_handle,
285                data_total_length=sdu_length + 4,
286                packet_sequence_number=packet_sequence_number,
287                pb_flag=0b10,
288                packet_status_flag=0,
289                iso_sdu_length=sdu_length,
290                iso_sdu_fragment=sdu,
291            )
292            device.host.send_hci_packet(iso_packet)
293            packet_sequence_number += 1
294            sleep_time = next_round - datetime.datetime.now()
295            await asyncio.sleep(sleep_time.total_seconds())
296
297
298# -----------------------------------------------------------------------------
299class UiServer:
300    speaker: weakref.ReferenceType[Speaker]
301    port: int
302
303    def __init__(self, speaker: Speaker, port: int) -> None:
304        self.speaker = weakref.ref(speaker)
305        self.port = port
306        self.channel_socket = None
307
308    async def start_http(self) -> None:
309        """Start the UI HTTP server."""
310
311        app = aiohttp.web.Application()
312        app.add_routes(
313            [
314                aiohttp.web.get('/', self.get_static),
315                aiohttp.web.get('/index.html', self.get_static),
316                aiohttp.web.get('/channel', self.get_channel),
317            ]
318        )
319
320        runner = aiohttp.web.AppRunner(app)
321        await runner.setup()
322        site = aiohttp.web.TCPSite(runner, 'localhost', self.port)
323        print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
324        await site.start()
325
326    async def get_static(self, request):
327        path = request.path
328        if path == '/':
329            path = '/index.html'
330        if path.endswith('.html'):
331            content_type = 'text/html'
332        elif path.endswith('.js'):
333            content_type = 'text/javascript'
334        elif path.endswith('.css'):
335            content_type = 'text/css'
336        elif path.endswith('.svg'):
337            content_type = 'image/svg+xml'
338        else:
339            content_type = 'text/plain'
340        text = (
341            resources.files("bumble.apps.lea_unicast")
342            .joinpath(pathlib.Path(path).relative_to('/'))
343            .read_text(encoding="utf-8")
344        )
345        return aiohttp.web.Response(text=text, content_type=content_type)
346
347    async def get_channel(self, request):
348        ws = aiohttp.web.WebSocketResponse()
349        await ws.prepare(request)
350
351        # Process messages until the socket is closed.
352        self.channel_socket = ws
353        async for message in ws:
354            if message.type == aiohttp.WSMsgType.TEXT:
355                logger.debug(f'<<< received message: {message.data}')
356                await self.on_message(message.data)
357            elif message.type == aiohttp.WSMsgType.ERROR:
358                logger.debug(
359                    f'channel connection closed with exception {ws.exception()}'
360                )
361
362        self.channel_socket = None
363        logger.debug('--- channel connection closed')
364
365        return ws
366
367    async def on_message(self, message_str: str):
368        # Parse the message as JSON
369        message = json.loads(message_str)
370
371        # Dispatch the message
372        message_type = message['type']
373        message_params = message.get('params', {})
374        handler = getattr(self, f'on_{message_type}_message')
375        if handler:
376            await handler(**message_params)
377
378    async def on_hello_message(self):
379        await self.send_message(
380            'hello',
381            bumble_version=bumble.__version__,
382            codec=self.speaker().codec,
383            streamState=self.speaker().stream_state.name,
384        )
385        if connection := self.speaker().connection:
386            await self.send_message(
387                'connection',
388                peer_address=connection.peer_address.to_string(False),
389                peer_name=connection.peer_name,
390            )
391
392    async def send_message(self, message_type: str, **kwargs) -> None:
393        if self.channel_socket is None:
394            return
395
396        message = {'type': message_type, 'params': kwargs}
397        await self.channel_socket.send_json(message)
398
399    async def send_audio(self, data: bytes) -> None:
400        if self.channel_socket is None:
401            return
402
403        try:
404            await self.channel_socket.send_bytes(data)
405        except Exception as error:
406            logger.warning(f'exception while sending audio packet: {error}')
407
408
409# -----------------------------------------------------------------------------
410class Speaker:
411
412    def __init__(
413        self,
414        device_config_path: Optional[str],
415        ui_port: int,
416        transport: str,
417        lc3_input_file_path: str,
418    ):
419        self.device_config_path = device_config_path
420        self.transport = transport
421        self.lc3_input_file_path = lc3_input_file_path
422
423        # Create an HTTP server for the UI
424        self.ui_server = UiServer(speaker=self, port=ui_port)
425
426    async def run(self) -> None:
427        await self.ui_server.start_http()
428
429        async with await open_transport(self.transport) as hci_transport:
430            # Create a device
431            if self.device_config_path:
432                device_config = DeviceConfiguration.from_file(self.device_config_path)
433            else:
434                device_config = DeviceConfiguration(
435                    name="Bumble LE Headphone",
436                    class_of_device=0x244418,
437                    keystore="JsonKeyStore",
438                    advertising_interval_min=25,
439                    advertising_interval_max=25,
440                    address=Address('F1:F2:F3:F4:F5:F6'),
441                )
442
443            device_config.le_enabled = True
444            device_config.cis_enabled = True
445            self.device = Device.from_config_with_hci(
446                device_config, hci_transport.source, hci_transport.sink
447            )
448
449            self.device.add_service(
450                bap.PublishedAudioCapabilitiesService(
451                    supported_source_context=bap.ContextType(0xFFFF),
452                    available_source_context=bap.ContextType(0xFFFF),
453                    supported_sink_context=bap.ContextType(0xFFFF),  # All context types
454                    available_sink_context=bap.ContextType(0xFFFF),  # All context types
455                    sink_audio_locations=(
456                        bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT
457                    ),
458                    sink_pac=[_sink_pac_record()],
459                    source_audio_locations=bap.AudioLocation.FRONT_LEFT,
460                    source_pac=[_source_pac_record()],
461                )
462            )
463
464            ascs = bap.AudioStreamControlService(
465                self.device, sink_ase_id=[1], source_ase_id=[2]
466            )
467            self.device.add_service(ascs)
468
469            advertising_data = bytes(
470                AdvertisingData(
471                    [
472                        (
473                            AdvertisingData.COMPLETE_LOCAL_NAME,
474                            bytes(device_config.name, 'utf-8'),
475                        ),
476                        (
477                            AdvertisingData.FLAGS,
478                            bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
479                        ),
480                        (
481                            AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
482                            bytes(bap.PublishedAudioCapabilitiesService.UUID),
483                        ),
484                    ]
485                )
486            ) + bytes(bap.UnicastServerAdvertisingData())
487
488            def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine):
489                codec_config = ase.codec_specific_configuration
490                assert isinstance(codec_config, bap.CodecSpecificConfiguration)
491                pcm = decode(
492                    codec_config.frame_duration.us,
493                    codec_config.audio_channel_allocation.channel_count,
494                    pdu.iso_sdu_fragment,
495                )
496                self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
497
498            def on_ase_state_change(ase: bap.AseStateMachine) -> None:
499                if ase.state == bap.AseStateMachine.State.STREAMING:
500                    codec_config = ase.codec_specific_configuration
501                    assert isinstance(codec_config, bap.CodecSpecificConfiguration)
502                    assert ase.cis_link
503                    if ase.role == bap.AudioRole.SOURCE:
504                        ase.cis_link.abort_on(
505                            'disconnection',
506                            lc3_source_task(
507                                filename=self.lc3_input_file_path,
508                                sdu_length=(
509                                    codec_config.codec_frames_per_sdu
510                                    * codec_config.octets_per_codec_frame
511                                ),
512                                frame_duration_us=codec_config.frame_duration.us,
513                                device=self.device,
514                                cis_handle=ase.cis_link.handle,
515                            ),
516                        )
517                    else:
518                        ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
519                elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED:
520                    codec_config = ase.codec_specific_configuration
521                    assert isinstance(codec_config, bap.CodecSpecificConfiguration)
522                    if ase.role == bap.AudioRole.SOURCE:
523                        setup_encoders(
524                            codec_config.sampling_frequency.hz,
525                            codec_config.frame_duration.us,
526                            codec_config.audio_channel_allocation.channel_count,
527                        )
528                    else:
529                        setup_decoders(
530                            codec_config.sampling_frequency.hz,
531                            codec_config.frame_duration.us,
532                            codec_config.audio_channel_allocation.channel_count,
533                        )
534
535            for ase in ascs.ase_state_machines.values():
536                ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
537
538            await self.device.power_on()
539            await self.device.create_advertising_set(
540                advertising_data=advertising_data,
541                auto_restart=True,
542                advertising_parameters=AdvertisingParameters(
543                    primary_advertising_interval_min=100,
544                    primary_advertising_interval_max=100,
545                ),
546            )
547
548            await hci_transport.source.terminated
549
550
551@click.command()
552@click.option(
553    '--ui-port',
554    'ui_port',
555    metavar='HTTP_PORT',
556    default=DEFAULT_UI_PORT,
557    show_default=True,
558    help='HTTP port for the UI server',
559)
560@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
561@click.argument('transport')
562@click.argument('lc3_file')
563def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None:
564    """Run the speaker."""
565
566    asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run())
567
568
569# -----------------------------------------------------------------------------
570def main():
571    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
572    speaker()
573
574
575# -----------------------------------------------------------------------------
576if __name__ == "__main__":
577    main()  # pylint: disable=no-value-for-parameter
578