• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import logging
20import struct
21from typing import Dict, List, Type
22
23from . import core
24from .colors import color
25from .core import InvalidStateError
26from .hci import HCI_Object, name_or_number, key_with_value
27
28# -----------------------------------------------------------------------------
29# Logging
30# -----------------------------------------------------------------------------
31logger = logging.getLogger(__name__)
32
33
34# -----------------------------------------------------------------------------
35# Constants
36# -----------------------------------------------------------------------------
37# fmt: off
38# pylint: disable=line-too-long
39
40SDP_CONTINUATION_WATCHDOG = 64  # Maximum number of continuations we're willing to do
41
42SDP_PSM = 0x0001
43
44SDP_ERROR_RESPONSE                    = 0x01
45SDP_SERVICE_SEARCH_REQUEST            = 0x02
46SDP_SERVICE_SEARCH_RESPONSE           = 0x03
47SDP_SERVICE_ATTRIBUTE_REQUEST         = 0x04
48SDP_SERVICE_ATTRIBUTE_RESPONSE        = 0x05
49SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST  = 0x06
50SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE = 0x07
51
52SDP_PDU_NAMES = {
53    SDP_ERROR_RESPONSE:                    'SDP_ERROR_RESPONSE',
54    SDP_SERVICE_SEARCH_REQUEST:            'SDP_SERVICE_SEARCH_REQUEST',
55    SDP_SERVICE_SEARCH_RESPONSE:           'SDP_SERVICE_SEARCH_RESPONSE',
56    SDP_SERVICE_ATTRIBUTE_REQUEST:         'SDP_SERVICE_ATTRIBUTE_REQUEST',
57    SDP_SERVICE_ATTRIBUTE_RESPONSE:        'SDP_SERVICE_ATTRIBUTE_RESPONSE',
58    SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST:  'SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST',
59    SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE: 'SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE'
60}
61
62SDP_INVALID_SDP_VERSION_ERROR                       = 0x0001
63SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR             = 0x0002
64SDP_INVALID_REQUEST_SYNTAX_ERROR                    = 0x0003
65SDP_INVALID_PDU_SIZE_ERROR                          = 0x0004
66SDP_INVALID_CONTINUATION_STATE_ERROR                = 0x0005
67SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR = 0x0006
68
69SDP_ERROR_NAMES = {
70    SDP_INVALID_SDP_VERSION_ERROR:                       'SDP_INVALID_SDP_VERSION_ERROR',
71    SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR:             'SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR',
72    SDP_INVALID_REQUEST_SYNTAX_ERROR:                    'SDP_INVALID_REQUEST_SYNTAX_ERROR',
73    SDP_INVALID_PDU_SIZE_ERROR:                          'SDP_INVALID_PDU_SIZE_ERROR',
74    SDP_INVALID_CONTINUATION_STATE_ERROR:                'SDP_INVALID_CONTINUATION_STATE_ERROR',
75    SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR: 'SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR'
76}
77
78SDP_SERVICE_NAME_ATTRIBUTE_ID_OFFSET        = 0x0000
79SDP_SERVICE_DESCRIPTION_ATTRIBUTE_ID_OFFSET = 0x0001
80SDP_PROVIDER_NAME_ATTRIBUTE_ID_OFFSET       = 0x0002
81
82SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID               = 0X0000
83SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID               = 0X0001
84SDP_SERVICE_RECORD_STATE_ATTRIBUTE_ID                = 0X0002
85SDP_SERVICE_ID_ATTRIBUTE_ID                          = 0X0003
86SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID            = 0X0004
87SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID                   = 0X0005
88SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID     = 0X0006
89SDP_SERVICE_INFO_TIME_TO_LIVE_ATTRIBUTE_ID           = 0X0007
90SDP_SERVICE_AVAILABILITY_ATTRIBUTE_ID                = 0X0008
91SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID   = 0X0009
92SDP_DOCUMENTATION_URL_ATTRIBUTE_ID                   = 0X000A
93SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID               = 0X000B
94SDP_ICON_URL_ATTRIBUTE_ID                            = 0X000C
95SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
96
97SDP_ATTRIBUTE_ID_NAMES = {
98    SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID:               'SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID',
99    SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:               'SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID',
100    SDP_SERVICE_RECORD_STATE_ATTRIBUTE_ID:                'SDP_SERVICE_RECORD_STATE_ATTRIBUTE_ID',
101    SDP_SERVICE_ID_ATTRIBUTE_ID:                          'SDP_SERVICE_ID_ATTRIBUTE_ID',
102    SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:            'SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
103    SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID:                   'SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID',
104    SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID:     'SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID',
105    SDP_SERVICE_INFO_TIME_TO_LIVE_ATTRIBUTE_ID:           'SDP_SERVICE_INFO_TIME_TO_LIVE_ATTRIBUTE_ID',
106    SDP_SERVICE_AVAILABILITY_ATTRIBUTE_ID:                'SDP_SERVICE_AVAILABILITY_ATTRIBUTE_ID',
107    SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID:   'SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID',
108    SDP_DOCUMENTATION_URL_ATTRIBUTE_ID:                   'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
109    SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID:               'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
110    SDP_ICON_URL_ATTRIBUTE_ID:                            'SDP_ICON_URL_ATTRIBUTE_ID',
111    SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID'
112}
113
114SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
115
116# To be used in searches where an attribute ID list allows a range to be specified
117SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4)  # Express this as tuple so we can convey the desired encoding size
118
119# fmt: on
120# pylint: enable=line-too-long
121# pylint: disable=invalid-name
122
123
124# -----------------------------------------------------------------------------
125class DataElement:
126    NIL = 0
127    UNSIGNED_INTEGER = 1
128    SIGNED_INTEGER = 2
129    UUID = 3
130    TEXT_STRING = 4
131    BOOLEAN = 5
132    SEQUENCE = 6
133    ALTERNATIVE = 7
134    URL = 8
135
136    TYPE_NAMES = {
137        NIL: 'NIL',
138        UNSIGNED_INTEGER: 'UNSIGNED_INTEGER',
139        SIGNED_INTEGER: 'SIGNED_INTEGER',
140        UUID: 'UUID',
141        TEXT_STRING: 'TEXT_STRING',
142        BOOLEAN: 'BOOLEAN',
143        SEQUENCE: 'SEQUENCE',
144        ALTERNATIVE: 'ALTERNATIVE',
145        URL: 'URL',
146    }
147
148    type_constructors = {
149        NIL: lambda x: DataElement(DataElement.NIL, None),
150        UNSIGNED_INTEGER: lambda x, y: DataElement(
151            DataElement.UNSIGNED_INTEGER,
152            DataElement.unsigned_integer_from_bytes(x),
153            value_size=y,
154        ),
155        SIGNED_INTEGER: lambda x, y: DataElement(
156            DataElement.SIGNED_INTEGER,
157            DataElement.signed_integer_from_bytes(x),
158            value_size=y,
159        ),
160        UUID: lambda x: DataElement(
161            DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
162        ),
163        TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
164        BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
165        SEQUENCE: lambda x: DataElement(
166            DataElement.SEQUENCE, DataElement.list_from_bytes(x)
167        ),
168        ALTERNATIVE: lambda x: DataElement(
169            DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
170        ),
171        URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
172    }
173
174    def __init__(self, element_type, value, value_size=None):
175        self.type = element_type
176        self.value = value
177        self.value_size = value_size
178        # Used as a cache when parsing from bytes so we can emit a byte-for-byte replica
179        self.bytes = None
180        if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
181            if value_size is None:
182                raise ValueError('integer types must have a value size specified')
183
184    @staticmethod
185    def nil() -> DataElement:
186        return DataElement(DataElement.NIL, None)
187
188    @staticmethod
189    def unsigned_integer(value: int, value_size: int) -> DataElement:
190        return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
191
192    @staticmethod
193    def unsigned_integer_8(value: int) -> DataElement:
194        return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
195
196    @staticmethod
197    def unsigned_integer_16(value: int) -> DataElement:
198        return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
199
200    @staticmethod
201    def unsigned_integer_32(value: int) -> DataElement:
202        return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
203
204    @staticmethod
205    def signed_integer(value: int, value_size: int) -> DataElement:
206        return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
207
208    @staticmethod
209    def signed_integer_8(value: int) -> DataElement:
210        return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
211
212    @staticmethod
213    def signed_integer_16(value: int) -> DataElement:
214        return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
215
216    @staticmethod
217    def signed_integer_32(value: int) -> DataElement:
218        return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
219
220    @staticmethod
221    def uuid(value: core.UUID) -> DataElement:
222        return DataElement(DataElement.UUID, value)
223
224    @staticmethod
225    def text_string(value: str) -> DataElement:
226        return DataElement(DataElement.TEXT_STRING, value)
227
228    @staticmethod
229    def boolean(value: bool) -> DataElement:
230        return DataElement(DataElement.BOOLEAN, value)
231
232    @staticmethod
233    def sequence(value: List[DataElement]) -> DataElement:
234        return DataElement(DataElement.SEQUENCE, value)
235
236    @staticmethod
237    def alternative(value: List[DataElement]) -> DataElement:
238        return DataElement(DataElement.ALTERNATIVE, value)
239
240    @staticmethod
241    def url(value: str) -> DataElement:
242        return DataElement(DataElement.URL, value)
243
244    @staticmethod
245    def unsigned_integer_from_bytes(data):
246        if len(data) == 1:
247            return data[0]
248
249        if len(data) == 2:
250            return struct.unpack('>H', data)[0]
251
252        if len(data) == 4:
253            return struct.unpack('>I', data)[0]
254
255        if len(data) == 8:
256            return struct.unpack('>Q', data)[0]
257
258        raise ValueError(f'invalid integer length {len(data)}')
259
260    @staticmethod
261    def signed_integer_from_bytes(data):
262        if len(data) == 1:
263            return struct.unpack('b', data)[0]
264
265        if len(data) == 2:
266            return struct.unpack('>h', data)[0]
267
268        if len(data) == 4:
269            return struct.unpack('>i', data)[0]
270
271        if len(data) == 8:
272            return struct.unpack('>q', data)[0]
273
274        raise ValueError(f'invalid integer length {len(data)}')
275
276    @staticmethod
277    def list_from_bytes(data):
278        elements = []
279        while data:
280            element = DataElement.from_bytes(data)
281            elements.append(element)
282            data = data[len(bytes(element)) :]
283        return elements
284
285    @staticmethod
286    def parse_from_bytes(data, offset):
287        element = DataElement.from_bytes(data[offset:])
288        return offset + len(bytes(element)), element
289
290    @staticmethod
291    def from_bytes(data):
292        element_type = data[0] >> 3
293        size_index = data[0] & 7
294        value_offset = 0
295        if size_index == 0:
296            if element_type == DataElement.NIL:
297                value_size = 0
298            else:
299                value_size = 1
300        elif size_index == 1:
301            value_size = 2
302        elif size_index == 2:
303            value_size = 4
304        elif size_index == 3:
305            value_size = 8
306        elif size_index == 4:
307            value_size = 16
308        elif size_index == 5:
309            value_size = data[1]
310            value_offset = 1
311        elif size_index == 6:
312            value_size = struct.unpack('>H', data[1:3])[0]
313            value_offset = 2
314        else:  # size_index == 7
315            value_size = struct.unpack('>I', data[1:5])[0]
316            value_offset = 4
317
318        value_data = data[1 + value_offset : 1 + value_offset + value_size]
319        constructor = DataElement.type_constructors.get(element_type)
320        if constructor:
321            if element_type in (
322                DataElement.UNSIGNED_INTEGER,
323                DataElement.SIGNED_INTEGER,
324            ):
325                result = constructor(value_data, value_size)
326            else:
327                result = constructor(value_data)
328        else:
329            result = DataElement(element_type, value_data)
330        result.bytes = data[
331            : 1 + value_offset + value_size
332        ]  # Keep a copy so we can re-serialize to an exact replica
333        return result
334
335    def to_bytes(self):
336        return bytes(self)
337
338    def __bytes__(self):
339        # Return early if we have a cache
340        if self.bytes:
341            return self.bytes
342
343        if self.type == DataElement.NIL:
344            data = b''
345        elif self.type == DataElement.UNSIGNED_INTEGER:
346            if self.value < 0:
347                raise ValueError('UNSIGNED_INTEGER cannot be negative')
348
349            if self.value_size == 1:
350                data = struct.pack('B', self.value)
351            elif self.value_size == 2:
352                data = struct.pack('>H', self.value)
353            elif self.value_size == 4:
354                data = struct.pack('>I', self.value)
355            elif self.value_size == 8:
356                data = struct.pack('>Q', self.value)
357            else:
358                raise ValueError('invalid value_size')
359        elif self.type == DataElement.SIGNED_INTEGER:
360            if self.value_size == 1:
361                data = struct.pack('b', self.value)
362            elif self.value_size == 2:
363                data = struct.pack('>h', self.value)
364            elif self.value_size == 4:
365                data = struct.pack('>i', self.value)
366            elif self.value_size == 8:
367                data = struct.pack('>q', self.value)
368            else:
369                raise ValueError('invalid value_size')
370        elif self.type == DataElement.UUID:
371            data = bytes(reversed(bytes(self.value)))
372        elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
373            data = self.value.encode('utf8')
374        elif self.type == DataElement.BOOLEAN:
375            data = bytes([1 if self.value else 0])
376        elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
377            data = b''.join([bytes(element) for element in self.value])
378        else:
379            data = self.value
380
381        size = len(data)
382        size_bytes = b''
383        if self.type == DataElement.NIL:
384            if size != 0:
385                raise ValueError('NIL must be empty')
386            size_index = 0
387        elif self.type in (
388            DataElement.UNSIGNED_INTEGER,
389            DataElement.SIGNED_INTEGER,
390            DataElement.UUID,
391        ):
392            if size <= 1:
393                size_index = 0
394            elif size == 2:
395                size_index = 1
396            elif size == 4:
397                size_index = 2
398            elif size == 8:
399                size_index = 3
400            elif size == 16:
401                size_index = 4
402            else:
403                raise ValueError('invalid data size')
404        elif self.type in (
405            DataElement.TEXT_STRING,
406            DataElement.SEQUENCE,
407            DataElement.ALTERNATIVE,
408            DataElement.URL,
409        ):
410            if size <= 0xFF:
411                size_index = 5
412                size_bytes = bytes([size])
413            elif size <= 0xFFFF:
414                size_index = 6
415                size_bytes = struct.pack('>H', size)
416            elif size <= 0xFFFFFFFF:
417                size_index = 7
418                size_bytes = struct.pack('>I', size)
419            else:
420                raise ValueError('invalid data size')
421        elif self.type == DataElement.BOOLEAN:
422            if size != 1:
423                raise ValueError('boolean must be 1 byte')
424            size_index = 0
425
426        self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
427        return self.bytes
428
429    def to_string(self, pretty=False, indentation=0):
430        prefix = '  ' * indentation
431        type_name = name_or_number(self.TYPE_NAMES, self.type)
432        if self.type == DataElement.NIL:
433            value_string = ''
434        elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
435            container_separator = '\n' if pretty else ''
436            element_separator = '\n' if pretty else ','
437            elements = [
438                element.to_string(pretty, indentation + 1 if pretty else 0)
439                for element in self.value
440            ]
441            value_string = (
442                f'[{container_separator}'
443                f'{element_separator.join(elements)}'
444                f'{container_separator}{prefix}]'
445            )
446        elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
447            value_string = f'{self.value}#{self.value_size}'
448        elif isinstance(self.value, DataElement):
449            value_string = self.value.to_string(pretty, indentation)
450        else:
451            value_string = str(self.value)
452        return f'{prefix}{type_name}({value_string})'
453
454    def __str__(self):
455        return self.to_string()
456
457
458# -----------------------------------------------------------------------------
459class ServiceAttribute:
460    def __init__(self, attribute_id: int, value: DataElement) -> None:
461        self.id = attribute_id
462        self.value = value
463
464    @staticmethod
465    def list_from_data_elements(elements):
466        attribute_list = []
467        for i in range(0, len(elements) // 2):
468            attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
469            if attribute_id.type != DataElement.UNSIGNED_INTEGER:
470                logger.warning('attribute ID element is not an integer')
471                continue
472            attribute_list.append(ServiceAttribute(attribute_id.value, attribute_value))
473
474        return attribute_list
475
476    @staticmethod
477    def find_attribute_in_list(attribute_list, attribute_id):
478        return next(
479            (
480                attribute.value
481                for attribute in attribute_list
482                if attribute.id == attribute_id
483            ),
484            None,
485        )
486
487    @staticmethod
488    def id_name(id_code):
489        return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
490
491    @staticmethod
492    def is_uuid_in_value(uuid, value):
493        # Find if a uuid matches a value, either directly or recursing into sequences
494        if value.type == DataElement.UUID:
495            return value.value == uuid
496
497        if value.type == DataElement.SEQUENCE:
498            for element in value.value:
499                if ServiceAttribute.is_uuid_in_value(uuid, element):
500                    return True
501            return False
502
503        return False
504
505    def to_string(self, with_colors=False):
506        if with_colors:
507            return (
508                f'Attribute(id={color(self.id_name(self.id),"magenta")},'
509                f'value={self.value})'
510            )
511
512        return f'Attribute(id={self.id_name(self.id)},value={self.value})'
513
514    def __str__(self):
515        return self.to_string()
516
517
518# -----------------------------------------------------------------------------
519class SDP_PDU:
520    '''
521    See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
522    '''
523
524    sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {}
525    name = None
526    pdu_id = 0
527
528    @staticmethod
529    def from_bytes(pdu):
530        pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
531
532        cls = SDP_PDU.sdp_pdu_classes.get(pdu_id)
533        if cls is None:
534            instance = SDP_PDU(pdu)
535            instance.name = SDP_PDU.pdu_name(pdu_id)
536            instance.pdu_id = pdu_id
537            instance.transaction_id = transaction_id
538            return instance
539        self = cls.__new__(cls)
540        SDP_PDU.__init__(self, pdu, transaction_id)
541        if hasattr(self, 'fields'):
542            self.init_from_bytes(pdu, 5)
543        return self
544
545    @staticmethod
546    def parse_service_record_handle_list_preceded_by_count(data, offset):
547        count = struct.unpack_from('>H', data, offset - 2)[0]
548        handle_list = [
549            struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
550        ]
551        return offset + count * 4, handle_list
552
553    @staticmethod
554    def parse_bytes_preceded_by_length(data, offset):
555        length = struct.unpack_from('>H', data, offset - 2)[0]
556        return offset + length, data[offset : offset + length]
557
558    @staticmethod
559    def error_name(error_code):
560        return name_or_number(SDP_ERROR_NAMES, error_code)
561
562    @staticmethod
563    def pdu_name(code):
564        return name_or_number(SDP_PDU_NAMES, code)
565
566    @staticmethod
567    def subclass(fields):
568        def inner(cls):
569            name = cls.__name__
570
571            # add a _ character before every uppercase letter, except the SDP_ prefix
572            location = len(name) - 1
573            while location > 4:
574                if not name[location].isupper():
575                    location -= 1
576                    continue
577                name = name[:location] + '_' + name[location:]
578                location -= 1
579
580            cls.name = name.upper()
581            cls.pdu_id = key_with_value(SDP_PDU_NAMES, cls.name)
582            if cls.pdu_id is None:
583                raise KeyError(f'PDU name {cls.name} not found in SDP_PDU_NAMES')
584            cls.fields = fields
585
586            # Register a factory for this class
587            SDP_PDU.sdp_pdu_classes[cls.pdu_id] = cls
588
589            return cls
590
591        return inner
592
593    def __init__(self, pdu=None, transaction_id=0, **kwargs):
594        if hasattr(self, 'fields') and kwargs:
595            HCI_Object.init_from_fields(self, self.fields, kwargs)
596        if pdu is None:
597            parameters = HCI_Object.dict_to_bytes(kwargs, self.fields)
598            pdu = (
599                struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters))
600                + parameters
601            )
602        self.pdu = pdu
603        self.transaction_id = transaction_id
604
605    def init_from_bytes(self, pdu, offset):
606        return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
607
608    def to_bytes(self):
609        return self.pdu
610
611    def __bytes__(self):
612        return self.to_bytes()
613
614    def __str__(self):
615        result = f'{color(self.name, "blue")} [TID={self.transaction_id}]'
616        if fields := getattr(self, 'fields', None):
617            result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, '  ')
618        elif len(self.pdu) > 1:
619            result += f': {self.pdu.hex()}'
620        return result
621
622
623# -----------------------------------------------------------------------------
624@SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})])
625class SDP_ErrorResponse(SDP_PDU):
626    '''
627    See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
628    '''
629
630
631# -----------------------------------------------------------------------------
632@SDP_PDU.subclass(
633    [
634        ('service_search_pattern', DataElement.parse_from_bytes),
635        ('maximum_service_record_count', '>2'),
636        ('continuation_state', '*'),
637    ]
638)
639class SDP_ServiceSearchRequest(SDP_PDU):
640    '''
641    See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
642    '''
643
644
645# -----------------------------------------------------------------------------
646@SDP_PDU.subclass(
647    [
648        ('total_service_record_count', '>2'),
649        ('current_service_record_count', '>2'),
650        (
651            'service_record_handle_list',
652            SDP_PDU.parse_service_record_handle_list_preceded_by_count,
653        ),
654        ('continuation_state', '*'),
655    ]
656)
657class SDP_ServiceSearchResponse(SDP_PDU):
658    '''
659    See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
660    '''
661
662
663# -----------------------------------------------------------------------------
664@SDP_PDU.subclass(
665    [
666        ('service_record_handle', '>4'),
667        ('maximum_attribute_byte_count', '>2'),
668        ('attribute_id_list', DataElement.parse_from_bytes),
669        ('continuation_state', '*'),
670    ]
671)
672class SDP_ServiceAttributeRequest(SDP_PDU):
673    '''
674    See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
675    '''
676
677
678# -----------------------------------------------------------------------------
679@SDP_PDU.subclass(
680    [
681        ('attribute_list_byte_count', '>2'),
682        ('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
683        ('continuation_state', '*'),
684    ]
685)
686class SDP_ServiceAttributeResponse(SDP_PDU):
687    '''
688    See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
689    '''
690
691
692# -----------------------------------------------------------------------------
693@SDP_PDU.subclass(
694    [
695        ('service_search_pattern', DataElement.parse_from_bytes),
696        ('maximum_attribute_byte_count', '>2'),
697        ('attribute_id_list', DataElement.parse_from_bytes),
698        ('continuation_state', '*'),
699    ]
700)
701class SDP_ServiceSearchAttributeRequest(SDP_PDU):
702    '''
703    See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
704    '''
705
706
707# -----------------------------------------------------------------------------
708@SDP_PDU.subclass(
709    [
710        ('attribute_lists_byte_count', '>2'),
711        ('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
712        ('continuation_state', '*'),
713    ]
714)
715class SDP_ServiceSearchAttributeResponse(SDP_PDU):
716    '''
717    See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
718    '''
719
720
721# -----------------------------------------------------------------------------
722class Client:
723    def __init__(self, device):
724        self.device = device
725        self.pending_request = None
726        self.channel = None
727
728    async def connect(self, connection):
729        result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
730        self.channel = result
731
732    async def disconnect(self):
733        if self.channel:
734            await self.channel.disconnect()
735            self.channel = None
736
737    async def search_services(self, uuids):
738        if self.pending_request is not None:
739            raise InvalidStateError('request already pending')
740
741        service_search_pattern = DataElement.sequence(
742            [DataElement.uuid(uuid) for uuid in uuids]
743        )
744
745        # Request and accumulate until there's no more continuation
746        service_record_handle_list = []
747        continuation_state = bytes([0])
748        watchdog = SDP_CONTINUATION_WATCHDOG
749        while watchdog > 0:
750            response_pdu = await self.channel.send_request(
751                SDP_ServiceSearchRequest(
752                    transaction_id=0,  # Transaction ID TODO: pick a real value
753                    service_search_pattern=service_search_pattern,
754                    maximum_service_record_count=0xFFFF,
755                    continuation_state=continuation_state,
756                )
757            )
758            response = SDP_PDU.from_bytes(response_pdu)
759            logger.debug(f'<<< Response: {response}')
760            service_record_handle_list += response.service_record_handle_list
761            continuation_state = response.continuation_state
762            if len(continuation_state) == 1 and continuation_state[0] == 0:
763                break
764            logger.debug(f'continuation: {continuation_state.hex()}')
765            watchdog -= 1
766
767        return service_record_handle_list
768
769    async def search_attributes(self, uuids, attribute_ids):
770        if self.pending_request is not None:
771            raise InvalidStateError('request already pending')
772
773        service_search_pattern = DataElement.sequence(
774            [DataElement.uuid(uuid) for uuid in uuids]
775        )
776        attribute_id_list = DataElement.sequence(
777            [
778                DataElement.unsigned_integer(
779                    attribute_id[0], value_size=attribute_id[1]
780                )
781                if isinstance(attribute_id, tuple)
782                else DataElement.unsigned_integer_16(attribute_id)
783                for attribute_id in attribute_ids
784            ]
785        )
786
787        # Request and accumulate until there's no more continuation
788        accumulator = b''
789        continuation_state = bytes([0])
790        watchdog = SDP_CONTINUATION_WATCHDOG
791        while watchdog > 0:
792            response_pdu = await self.channel.send_request(
793                SDP_ServiceSearchAttributeRequest(
794                    transaction_id=0,  # Transaction ID TODO: pick a real value
795                    service_search_pattern=service_search_pattern,
796                    maximum_attribute_byte_count=0xFFFF,
797                    attribute_id_list=attribute_id_list,
798                    continuation_state=continuation_state,
799                )
800            )
801            response = SDP_PDU.from_bytes(response_pdu)
802            logger.debug(f'<<< Response: {response}')
803            accumulator += response.attribute_lists
804            continuation_state = response.continuation_state
805            if len(continuation_state) == 1 and continuation_state[0] == 0:
806                break
807            logger.debug(f'continuation: {continuation_state.hex()}')
808            watchdog -= 1
809
810        # Parse the result into attribute lists
811        attribute_lists_sequences = DataElement.from_bytes(accumulator)
812        if attribute_lists_sequences.type != DataElement.SEQUENCE:
813            logger.warning('unexpected data type')
814            return []
815
816        return [
817            ServiceAttribute.list_from_data_elements(sequence.value)
818            for sequence in attribute_lists_sequences.value
819            if sequence.type == DataElement.SEQUENCE
820        ]
821
822    async def get_attributes(self, service_record_handle, attribute_ids):
823        if self.pending_request is not None:
824            raise InvalidStateError('request already pending')
825
826        attribute_id_list = DataElement.sequence(
827            [
828                DataElement.unsigned_integer(
829                    attribute_id[0], value_size=attribute_id[1]
830                )
831                if isinstance(attribute_id, tuple)
832                else DataElement.unsigned_integer_16(attribute_id)
833                for attribute_id in attribute_ids
834            ]
835        )
836
837        # Request and accumulate until there's no more continuation
838        accumulator = b''
839        continuation_state = bytes([0])
840        watchdog = SDP_CONTINUATION_WATCHDOG
841        while watchdog > 0:
842            response_pdu = await self.channel.send_request(
843                SDP_ServiceAttributeRequest(
844                    transaction_id=0,  # Transaction ID TODO: pick a real value
845                    service_record_handle=service_record_handle,
846                    maximum_attribute_byte_count=0xFFFF,
847                    attribute_id_list=attribute_id_list,
848                    continuation_state=continuation_state,
849                )
850            )
851            response = SDP_PDU.from_bytes(response_pdu)
852            logger.debug(f'<<< Response: {response}')
853            accumulator += response.attribute_list
854            continuation_state = response.continuation_state
855            if len(continuation_state) == 1 and continuation_state[0] == 0:
856                break
857            logger.debug(f'continuation: {continuation_state.hex()}')
858            watchdog -= 1
859
860        # Parse the result into a list of attributes
861        attribute_list_sequence = DataElement.from_bytes(accumulator)
862        if attribute_list_sequence.type != DataElement.SEQUENCE:
863            logger.warning('unexpected data type')
864            return []
865
866        return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
867
868
869# -----------------------------------------------------------------------------
870class Server:
871    CONTINUATION_STATE = bytes([0x01, 0x43])
872
873    def __init__(self, device):
874        self.device = device
875        self.service_records = {}  # Service records maps, by record handle
876        self.channel = None
877        self.current_response = None
878
879    def register(self, l2cap_channel_manager):
880        l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
881
882    def send_response(self, response):
883        logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
884        self.channel.send_pdu(response)
885
886    def match_services(self, search_pattern):
887        # Find the services for which the attributes in the pattern is a subset of the
888        # service's attribute values (NOTE: the value search recurses into sequences)
889        matching_services = {}
890        for handle, service in self.service_records.items():
891            for uuid in search_pattern.value:
892                found = False
893                for attribute in service:
894                    if ServiceAttribute.is_uuid_in_value(uuid.value, attribute.value):
895                        found = True
896                        break
897                if found:
898                    matching_services[handle] = service
899                    break
900
901        return matching_services
902
903    def on_connection(self, channel):
904        self.channel = channel
905        self.channel.sink = self.on_pdu
906
907    def on_pdu(self, pdu):
908        try:
909            sdp_pdu = SDP_PDU.from_bytes(pdu)
910        except Exception as error:
911            logger.warning(color(f'failed to parse SDP Request PDU: {error}', 'red'))
912            self.send_response(
913                SDP_ErrorResponse(
914                    transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
915                )
916            )
917
918        logger.debug(f'{color("<<< Received SDP Request", "green")}: {sdp_pdu}')
919
920        # Find the handler method
921        handler_name = f'on_{sdp_pdu.name.lower()}'
922        handler = getattr(self, handler_name, None)
923        if handler:
924            try:
925                handler(sdp_pdu)
926            except Exception as error:
927                logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
928                self.send_response(
929                    SDP_ErrorResponse(
930                        transaction_id=sdp_pdu.transaction_id,
931                        error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR,
932                    )
933                )
934        else:
935            logger.error(color('SDP Request not handled???', 'red'))
936            self.send_response(
937                SDP_ErrorResponse(
938                    transaction_id=sdp_pdu.transaction_id,
939                    error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR,
940                )
941            )
942
943    def get_next_response_payload(self, maximum_size):
944        if len(self.current_response) > maximum_size:
945            payload = self.current_response[:maximum_size]
946            continuation_state = Server.CONTINUATION_STATE
947            self.current_response = self.current_response[maximum_size:]
948        else:
949            payload = self.current_response
950            continuation_state = bytes([0])
951            self.current_response = None
952
953        return (payload, continuation_state)
954
955    @staticmethod
956    def get_service_attributes(service, attribute_ids):
957        attributes = []
958        for attribute_id in attribute_ids:
959            if attribute_id.value_size == 4:
960                # Attribute ID range
961                id_range_start = attribute_id.value >> 16
962                id_range_end = attribute_id.value & 0xFFFF
963            else:
964                id_range_start = attribute_id.value
965                id_range_end = attribute_id.value
966            attributes += [
967                attribute
968                for attribute in service
969                if attribute.id >= id_range_start and attribute.id <= id_range_end
970            ]
971
972        # Return the matching attributes, sorted by attribute id
973        attributes.sort(key=lambda x: x.id)
974        attribute_list = DataElement.sequence([])
975        for attribute in attributes:
976            attribute_list.value.append(DataElement.unsigned_integer_16(attribute.id))
977            attribute_list.value.append(attribute.value)
978
979        return attribute_list
980
981    def on_sdp_service_search_request(self, request):
982        # Check if this is a continuation
983        if len(request.continuation_state) > 1:
984            if not self.current_response:
985                self.send_response(
986                    SDP_ErrorResponse(
987                        transaction_id=request.transaction_id,
988                        error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
989                    )
990                )
991                return
992        else:
993            # Cleanup any partial response leftover
994            self.current_response = None
995
996            # Find the matching services
997            matching_services = self.match_services(request.service_search_pattern)
998            service_record_handles = list(matching_services.keys())
999
1000            # Only return up to the maximum requested
1001            service_record_handles_subset = service_record_handles[
1002                : request.maximum_service_record_count
1003            ]
1004
1005            # Serialize to a byte array, and remember the total count
1006            logger.debug(f'Service Record Handles: {service_record_handles}')
1007            self.current_response = (
1008                len(service_record_handles),
1009                service_record_handles_subset,
1010            )
1011
1012        # Respond, keeping any unsent handles for later
1013        service_record_handles = self.current_response[1][
1014            : request.maximum_service_record_count
1015        ]
1016        self.current_response = (
1017            self.current_response[0],
1018            self.current_response[1][request.maximum_service_record_count :],
1019        )
1020        continuation_state = (
1021            Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
1022        )
1023        service_record_handle_list = b''.join(
1024            [struct.pack('>I', handle) for handle in service_record_handles]
1025        )
1026        self.send_response(
1027            SDP_ServiceSearchResponse(
1028                transaction_id=request.transaction_id,
1029                total_service_record_count=self.current_response[0],
1030                current_service_record_count=len(service_record_handles),
1031                service_record_handle_list=service_record_handle_list,
1032                continuation_state=continuation_state,
1033            )
1034        )
1035
1036    def on_sdp_service_attribute_request(self, request):
1037        # Check if this is a continuation
1038        if len(request.continuation_state) > 1:
1039            if not self.current_response:
1040                self.send_response(
1041                    SDP_ErrorResponse(
1042                        transaction_id=request.transaction_id,
1043                        error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
1044                    )
1045                )
1046                return
1047        else:
1048            # Cleanup any partial response leftover
1049            self.current_response = None
1050
1051            # Check that the service exists
1052            service = self.service_records.get(request.service_record_handle)
1053            if service is None:
1054                self.send_response(
1055                    SDP_ErrorResponse(
1056                        transaction_id=request.transaction_id,
1057                        error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR,
1058                    )
1059                )
1060                return
1061
1062            # Get the attributes for the service
1063            attribute_list = Server.get_service_attributes(
1064                service, request.attribute_id_list.value
1065            )
1066
1067            # Serialize to a byte array
1068            logger.debug(f'Attributes: {attribute_list}')
1069            self.current_response = bytes(attribute_list)
1070
1071        # Respond, keeping any pending chunks for later
1072        attribute_list, continuation_state = self.get_next_response_payload(
1073            request.maximum_attribute_byte_count
1074        )
1075        self.send_response(
1076            SDP_ServiceAttributeResponse(
1077                transaction_id=request.transaction_id,
1078                attribute_list_byte_count=len(attribute_list),
1079                attribute_list=attribute_list,
1080                continuation_state=continuation_state,
1081            )
1082        )
1083
1084    def on_sdp_service_search_attribute_request(self, request):
1085        # Check if this is a continuation
1086        if len(request.continuation_state) > 1:
1087            if not self.current_response:
1088                self.send_response(
1089                    SDP_ErrorResponse(
1090                        transaction_id=request.transaction_id,
1091                        error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
1092                    )
1093                )
1094        else:
1095            # Cleanup any partial response leftover
1096            self.current_response = None
1097
1098            # Find the matching services
1099            matching_services = self.match_services(
1100                request.service_search_pattern
1101            ).values()
1102
1103            # Filter the required attributes
1104            attribute_lists = DataElement.sequence([])
1105            for service in matching_services:
1106                attribute_list = Server.get_service_attributes(
1107                    service, request.attribute_id_list.value
1108                )
1109                if attribute_list.value:
1110                    attribute_lists.value.append(attribute_list)
1111
1112            # Serialize to a byte array
1113            logger.debug(f'Search response: {attribute_lists}')
1114            self.current_response = bytes(attribute_lists)
1115
1116        # Respond, keeping any pending chunks for later
1117        attribute_lists, continuation_state = self.get_next_response_payload(
1118            request.maximum_attribute_byte_count
1119        )
1120        self.send_response(
1121            SDP_ServiceSearchAttributeResponse(
1122                transaction_id=request.transaction_id,
1123                attribute_lists_byte_count=len(attribute_lists),
1124                attribute_lists=attribute_lists,
1125                continuation_state=continuation_state,
1126            )
1127        )
1128