• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import grpc
17import logging
18import struct
19
20from avatar.bumble_server import utils
21from bumble.core import AdvertisingData
22from bumble.device import Connection, Connection as BumbleConnection, Device
23from bumble.gatt import (
24    GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
25    GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
26    GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
27    GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
28    GATT_ASHA_SERVICE,
29    GATT_ASHA_VOLUME_CHARACTERISTIC,
30    Characteristic,
31    CharacteristicValue,
32    TemplateService,
33)
34from bumble.l2cap import Channel
35from bumble.utils import AsyncRunner
36from google.protobuf.empty_pb2 import Empty  # pytype: disable=pyi-error
37from pandora_experimental.asha_grpc_aio import AshaServicer
38from pandora_experimental.asha_pb2 import CaptureAudioRequest, CaptureAudioResponse, RegisterRequest
39from typing import AsyncGenerator, List, Optional
40
41
42class AshaGattService(TemplateService):
43    # TODO: update bumble and remove this when complete
44    UUID = GATT_ASHA_SERVICE
45    OPCODE_START = 1
46    OPCODE_STOP = 2
47    OPCODE_STATUS = 3
48    PROTOCOL_VERSION = 0x01
49    RESERVED_FOR_FUTURE_USE = [00, 00]
50    FEATURE_MAP = [0x01]  # [LE CoC audio output streaming supported]
51    SUPPORTED_CODEC_ID = [0x02, 0x01]  # Codec IDs [G.722 at 16 kHz]
52    RENDER_DELAY = [00, 00]
53
54    def __init__(self, capability: int, hisyncid: List[int], device: Device, psm: int = 0) -> None:
55        self.hisyncid = hisyncid
56        self.capability = capability  # Device Capabilities [Left, Monaural]
57        self.device = device
58        self.audio_out_data = b""
59        self.psm: int = psm  # a non-zero psm is mainly for testing purpose
60
61        logger = logging.getLogger(__name__)
62
63        # Handler for volume control
64        def on_volume_write(connection: Connection, value: bytes) -> None:
65            logger.info(f"--- VOLUME Write:{value[0]}")
66            self.emit("volume", connection, value[0])
67
68        # Handler for audio control commands
69        def on_audio_control_point_write(connection: Connection, value: bytes) -> None:
70            logger.info(f"type {type(value)}")
71            logger.info(f"--- AUDIO CONTROL POINT Write:{value.hex()}")
72            opcode = value[0]
73            if opcode == AshaGattService.OPCODE_START:
74                # Start
75                audio_type = ("Unknown", "Ringtone", "Phone Call", "Media")[value[2]]
76                logger.info(
77                    f"### START: codec={value[1]}, "
78                    f"audio_type={audio_type}, "
79                    f"volume={value[3]}, "
80                    f"otherstate={value[4]}"
81                )
82                self.emit(
83                    "start",
84                    connection,
85                    {
86                        "codec": value[1],
87                        "audiotype": value[2],
88                        "volume": value[3],
89                        "otherstate": value[4],
90                    },
91                )
92            elif opcode == AshaGattService.OPCODE_STOP:
93                logger.info("### STOP")
94                self.emit("stop", connection)
95            elif opcode == AshaGattService.OPCODE_STATUS:
96                logger.info(f"### STATUS: connected={value[1]}")
97
98            # OPCODE_STATUS does not need audio status point update
99            if opcode != AshaGattService.OPCODE_STATUS:
100                AsyncRunner.spawn(device.notify_subscribers(self.audio_status_characteristic, force=True))  # type: ignore[no-untyped-call]
101
102        def on_read_only_properties_read(connection: Connection) -> bytes:
103            value = (
104                bytes(
105                    [
106                        AshaGattService.PROTOCOL_VERSION,  # Version
107                        self.capability,
108                    ]
109                )
110                + bytes(self.hisyncid)
111                + bytes(AshaGattService.FEATURE_MAP)
112                + bytes(AshaGattService.RENDER_DELAY)
113                + bytes(AshaGattService.RESERVED_FOR_FUTURE_USE)
114                + bytes(AshaGattService.SUPPORTED_CODEC_ID)
115            )
116            self.emit("read_only_properties", connection, value)
117            return value
118
119        def on_le_psm_out_read(connection: Connection) -> bytes:
120            self.emit("le_psm_out", connection, self.psm)
121            return struct.pack("<H", self.psm)
122
123        self.read_only_properties_characteristic = Characteristic(
124            GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
125            Characteristic.READ,
126            Characteristic.READABLE,
127            CharacteristicValue(read=on_read_only_properties_read),  # type: ignore[no-untyped-call]
128        )
129
130        self.audio_control_point_characteristic = Characteristic(
131            GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
132            Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
133            Characteristic.WRITEABLE,
134            CharacteristicValue(write=on_audio_control_point_write),  # type: ignore[no-untyped-call]
135        )
136        self.audio_status_characteristic = Characteristic(
137            GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
138            Characteristic.READ | Characteristic.NOTIFY,
139            Characteristic.READABLE,
140            bytes([0]),
141        )
142        self.volume_characteristic = Characteristic(
143            GATT_ASHA_VOLUME_CHARACTERISTIC,
144            Characteristic.WRITE_WITHOUT_RESPONSE,
145            Characteristic.WRITEABLE,
146            CharacteristicValue(write=on_volume_write),  # type: ignore[no-untyped-call]
147        )
148
149        # Register an L2CAP CoC server
150        def on_coc(channel: Channel) -> None:
151            def on_data(data: bytes) -> None:
152                logging.debug(f"data received:{data.hex()}")
153
154                self.emit("data", channel.connection, data)
155                self.audio_out_data += data
156
157            channel.sink = on_data  # type: ignore[no-untyped-call]
158
159        # let the server find a free PSM
160        self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)  # type: ignore[no-untyped-call]
161        self.le_psm_out_characteristic = Characteristic(
162            GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
163            Characteristic.READ,
164            Characteristic.READABLE,
165            CharacteristicValue(read=on_le_psm_out_read),  # type: ignore[no-untyped-call]
166        )
167
168        characteristics = [
169            self.read_only_properties_characteristic,
170            self.audio_control_point_characteristic,
171            self.audio_status_characteristic,
172            self.volume_characteristic,
173            self.le_psm_out_characteristic,
174        ]
175
176        super().__init__(characteristics)  # type: ignore[no-untyped-call]
177
178    def get_advertising_data(self) -> bytes:
179        # Advertisement only uses 4 least significant bytes of the HiSyncId.
180        return bytes(
181            AdvertisingData(
182                [
183                    (
184                        AdvertisingData.SERVICE_DATA_16_BIT_UUID,
185                        bytes(GATT_ASHA_SERVICE)
186                        + bytes(
187                            [
188                                AshaGattService.PROTOCOL_VERSION,
189                                self.capability,
190                            ]
191                        )
192                        + bytes(self.hisyncid[:4]),
193                    ),
194                ]
195            )
196        )
197
198
199class AshaService(AshaServicer):
200    device: Device
201    asha_service: Optional[AshaGattService]
202
203    def __init__(self, device: Device) -> None:
204        self.log = utils.BumbleServerLoggerAdapter(logging.getLogger(), {"service_name": "Asha", "device": device})
205        self.device = device
206        self.asha_service = None
207
208    @utils.rpc
209    async def Register(self, request: RegisterRequest, context: grpc.ServicerContext) -> Empty:
210        logging.info("Register")
211        # asha service from bumble profile
212        self.asha_service = AshaGattService(request.capability, request.hisyncid, self.device)
213        self.device.add_service(self.asha_service)  # type: ignore[no-untyped-call]
214        return Empty()
215
216    @utils.rpc
217    async def CaptureAudio(
218        self, request: CaptureAudioRequest, context: grpc.ServicerContext
219    ) -> AsyncGenerator[CaptureAudioResponse, None]:
220        connection_handle = int.from_bytes(request.connection.cookie.value, "big")
221        logging.info(f"CaptureAudioData connection_handle:{connection_handle}")
222
223        if not (connection := self.device.lookup_connection(connection_handle)):
224            raise RuntimeError(f"Unknown connection for connection_handle:{connection_handle}")
225
226        queue: asyncio.Queue[bytes] = asyncio.Queue()
227
228        def on_data(asha_connection: BumbleConnection, data: bytes) -> None:
229            if asha_connection == connection:
230                queue.put_nowait(data)
231
232        self.asha_service.on("data", on_data)  # type: ignore
233
234        try:
235            while data := await queue.get():
236                yield CaptureAudioResponse(data=data)
237        finally:
238            self.asha_service.remove_listener("data", on_data)  # type: ignore
239