• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright (c) 2019 Collabora, Ltd.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Author(s):    Ryan Pavlik <ryan.pavlik@collabora.com>
8#
9# Purpose:      This script checks some "business logic" in the XML registry.
10
11import re
12import sys
13from pathlib import Path
14
15from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase
16from reg import Registry
17from spec_tools.consistency_tools import XMLChecker
18from spec_tools.util import findNamedElem, getElemName, getElemType
19from vkconventions import VulkanConventions as APIConventions
20
21# These are extensions which do not follow the usual naming conventions,
22# specifying the alternate convention they follow
23EXTENSION_ENUM_NAME_SPELLING_CHANGE = {
24    'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE',
25}
26
27# These are extensions whose names *look* like they end in version numbers,
28# but don't
29EXTENSION_NAME_VERSION_EXCEPTIONS = (
30    'VK_AMD_gpu_shader_int16',
31    'VK_EXT_index_type_uint8',
32    'VK_EXT_shader_image_atomic_int64',
33    'VK_EXT_video_decode_h264',
34    'VK_EXT_video_decode_h265',
35    'VK_EXT_video_encode_h264',
36    'VK_EXT_video_encode_h265',
37    'VK_KHR_external_fence_win32',
38    'VK_KHR_external_memory_win32',
39    'VK_KHR_external_semaphore_win32',
40    'VK_KHR_shader_atomic_int64',
41    'VK_KHR_shader_float16_int8',
42    'VK_KHR_spirv_1_4',
43    'VK_NV_external_memory_win32',
44    'VK_RESERVED_do_not_use_146',
45    'VK_RESERVED_do_not_use_94',
46)
47
48# Exceptions to pointer parameter naming rules
49# Keyed by (entity name, type, name).
50CHECK_PARAM_POINTER_NAME_EXCEPTIONS = {
51    ('vkGetDrmDisplayEXT', 'VkDisplayKHR', 'display') : None,
52}
53
54# Exceptions to pNext member requiring an optional attribute
55CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS = (
56    'VkVideoEncodeInfoKHR',
57)
58
59def get_extension_commands(reg):
60    extension_cmds = set()
61    for ext in reg.extensions:
62        for cmd in ext.findall("./require/command[@name]"):
63            extension_cmds.add(cmd.get("name"))
64    return extension_cmds
65
66
67def get_enum_value_names(reg, enum_type):
68    names = set()
69    result_elem = reg.groupdict[enum_type].elem
70    for val in result_elem.findall("./enum[@name]"):
71        names.add(val.get("name"))
72    return names
73
74
75# Regular expression matching an extension name ending in a (possible) version number
76EXTNAME_RE = re.compile(r'(?P<base>(\w+[A-Za-z]))(?P<version>\d+)')
77
78DESTROY_PREFIX = "vkDestroy"
79TYPEENUM = "VkStructureType"
80
81SPECIFICATION_DIR = Path(__file__).parent.parent
82REVISION_RE = re.compile(r' *[*] Revision (?P<num>[1-9][0-9]*),.*')
83
84
85def get_extension_source(extname):
86    fn = '{}.txt'.format(extname)
87    return str(SPECIFICATION_DIR / 'appendices' / fn)
88
89
90class EntityDatabase(OrigEntityDatabase):
91
92    # Override base class method to not exclude 'disabled' extensions
93    def getExclusionSet(self):
94        """Return a set of "support=" attribute strings that should not be included in the database.
95
96        Called only during construction."""
97
98        return set(())
99
100    def makeRegistry(self):
101        try:
102            import lxml.etree as etree
103            HAS_LXML = True
104        except ImportError:
105            HAS_LXML = False
106        if not HAS_LXML:
107            return super().makeRegistry()
108
109        registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml')
110        registry = Registry()
111        registry.filename = registryFile
112        registry.loadElementTree(etree.parse(registryFile))
113        return registry
114
115
116class Checker(XMLChecker):
117    def __init__(self):
118        manual_types_to_codes = {
119            # These are hard-coded "manual" return codes:
120            # the codes of the value (string, list, or tuple)
121            # are available for a command if-and-only-if
122            # the key type is passed as an input.
123            "VkFormat": "VK_ERROR_FORMAT_NOT_SUPPORTED"
124        }
125        forward_only = {
126            # Like the above, but these are only valid in the
127            # "type implies return code" direction
128        }
129        reverse_only = {
130            # like the above, but these are only valid in the
131            # "return code implies type or its descendant" direction
132            # "XrDuration": "XR_TIMEOUT_EXPIRED"
133        }
134        # Some return codes are related in that only one of a set
135        # may be returned by a command
136        # (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING)
137        self.exclusive_return_code_sets = tuple(
138            # set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")),
139        )
140        # Map of extension number -> [ list of extension names ]
141        self.extension_number_reservations = {
142        }
143
144        # This is used to report collisions.
145        conventions = APIConventions()
146        db = EntityDatabase()
147
148        self.extension_cmds = get_extension_commands(db.registry)
149        self.return_codes = get_enum_value_names(db.registry, 'VkResult')
150        self.structure_types = get_enum_value_names(db.registry, TYPEENUM)
151
152        # Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:")
153        # Keys are entity names, values are tuples or lists of message text to suppress.
154        suppressions = {}
155
156        # Initialize superclass
157        super().__init__(entity_db=db, conventions=conventions,
158                         manual_types_to_codes=manual_types_to_codes,
159                         forward_only_types_to_codes=forward_only,
160                         reverse_only_types_to_codes=reverse_only,
161                         suppressions=suppressions)
162
163    def check_command_return_codes_basic(self, name, info,
164                                         successcodes, errorcodes):
165        """Check a command's return codes for consistency.
166
167        Called on every command."""
168        # Check that all extension commands can return the code associated
169        # with trying to use an extension that wasn't enabled.
170        # if name in self.extension_cmds and UNSUPPORTED not in errorcodes:
171        #     self.record_error("Missing expected return code",
172        #                       UNSUPPORTED,
173        #                       "implied due to being an extension command")
174
175        codes = successcodes.union(errorcodes)
176
177        # Check that all return codes are recognized.
178        unrecognized = codes - self.return_codes
179        if unrecognized:
180            self.record_error("Unrecognized return code(s):",
181                              unrecognized)
182
183        elem = info.elem
184        params = [(getElemName(elt), elt) for elt in elem.findall('param')]
185
186        def is_count_output(name, elt):
187            # Must end with Count or Size,
188            # not be const,
189            # and be a pointer (detected by naming convention)
190            return (name.endswith('Count') or name.endswith('Size')) \
191                and (elt.tail is None or 'const' not in elt.tail) \
192                and (name.startswith('p'))
193
194        countParams = [elt
195                       for name, elt in params
196                       if is_count_output(name, elt)]
197        if countParams:
198            assert(len(countParams) == 1)
199            if 'VK_INCOMPLETE' not in successcodes:
200                self.record_error(
201                    "Apparent enumeration of an array without VK_INCOMPLETE in successcodes.")
202
203        elif 'VK_INCOMPLETE' in successcodes:
204            self.record_error(
205                "VK_INCOMPLETE in successcodes of command that is apparently not an array enumeration.")
206
207    def check_param(self, param):
208        """Check a member of a struct or a param of a function.
209
210        Called from check_params."""
211        super().check_param(param)
212
213        if not self.is_api_type(param):
214            return
215
216        param_text = "".join(param.itertext())
217        param_name = getElemName(param)
218
219        # Make sure the number of leading "p" matches the pointer count.
220        pointercount = param.find('type').tail
221        if pointercount:
222            pointercount = pointercount.count('*')
223        if pointercount:
224            prefix = 'p' * pointercount
225            if not param_name.startswith(prefix):
226                param_type = param.find('type').text
227                message = "Apparently incorrect pointer-related name prefix for {} - expected it to start with '{}'".format(
228                    param_text, prefix)
229                if (self.entity, param_type, param_name) in CHECK_PARAM_POINTER_NAME_EXCEPTIONS:
230                    self.record_warning('(Allowed exception)', message, elem=param)
231                else:
232                    self.record_error(message, elem=param)
233
234        # Make sure pNext members have optional="true" attributes
235        if param_name == self.conventions.nextpointer_member_name:
236            optional = param.get('optional')
237            if optional is None or optional != 'true':
238                message = '{}.pNext member is missing \'optional="true"\' attribute'.format(self.entity)
239                if self.entity in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
240                    self.record_warning('(Allowed exception)', message, elem=param)
241                else:
242                    self.record_error(message, elem=param)
243
244    def check_type(self, name, info, category):
245        """Check a type's XML data for consistency.
246
247        Called from check."""
248
249        elem = info.elem
250        type_elts = [elt
251                     for elt in elem.findall("member")
252                     if getElemType(elt) == TYPEENUM]
253        if category == 'struct' and type_elts:
254            if len(type_elts) > 1:
255                self.record_error(
256                    "Have more than one member of type", TYPEENUM)
257            else:
258                type_elt = type_elts[0]
259                val = type_elt.get('values')
260                if val and val not in self.structure_types:
261                    self.record_error("Unknown structure type constant", val)
262
263            # Check the pointer chain member, if present.
264            next_name = self.conventions.nextpointer_member_name
265            next_member = findNamedElem(info.elem.findall('member'), next_name)
266            if next_member is not None:
267                # Ensure that the 'optional' attribute is set to 'true'
268                optional = next_member.get('optional')
269                if optional is None or optional != 'true':
270                    message = '{}.{} member is missing \'optional="true"\' attribute'.format(name, next_name)
271                    if name in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
272                        self.record_warning('(Allowed exception)', message)
273                    else:
274                        self.record_error(message)
275
276        elif category == "bitmask":
277            if 'Flags' in name:
278                expected_require = name.replace('Flags', 'FlagBits')
279                require = info.elem.get('require')
280                if require is not None and expected_require != require:
281                    self.record_error("Unexpected require attribute value:",
282                                      "got", require,
283                                      "but expected", expected_require)
284        super().check_type(name, info, category)
285
286    def check_extension(self, name, info):
287        """Check an extension's XML data for consistency.
288
289        Called from check."""
290        elem = info.elem
291        enums = elem.findall('./require/enum[@name]')
292
293        # Look for other extensions using that number
294        # Keep track of this extension number reservation
295        ext_number = elem.get('number')
296        if ext_number in self.extension_number_reservations:
297            conflicts = self.extension_number_reservations[ext_number]
298            self.record_error('Extension number {} has more than one reservation: {}, {}'.format(
299                ext_number, name, ', '.join(conflicts)))
300            self.extension_number_reservations[ext_number].append(name)
301        else:
302            self.extension_number_reservations[ext_number] = [ name ]
303
304        # If extension name is not on the exception list and matches the
305        # versioned-extension pattern, map the extension name to the version
306        # name with the version as a separate word. Otherwise just map it to
307        # the upper-case version of the extension name.
308
309        matches = EXTNAME_RE.fullmatch(name)
310        ext_versioned_name = False
311        if name in EXTENSION_ENUM_NAME_SPELLING_CHANGE:
312            ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(name)
313        elif matches is None or name in EXTENSION_NAME_VERSION_EXCEPTIONS:
314            # This is the usual case, either a name that doesn't look
315            # versioned, or one that does but is on the exception list.
316            ext_enum_name = name.upper()
317        else:
318            # This is a versioned extension name.
319            # Treat the version number as a separate word.
320            base = matches.group('base')
321            version = matches.group('version')
322            ext_enum_name = base.upper() + '_' + version
323            # Keep track of this case
324            ext_versioned_name = True
325
326        # Look for the expected SPEC_VERSION token name
327        version_name = "{}_SPEC_VERSION".format(ext_enum_name)
328        version_elem = findNamedElem(enums, version_name)
329
330        if version_elem is None:
331            # Did not find a SPEC_VERSION enum matching the extension name
332            if ext_versioned_name:
333                suffix = '\n\
334    Make sure that trailing version numbers in extension names are treated\n\
335    as separate words in extension enumerant names. If this is an extension\n\
336    whose name ends in a number which is not a version, such as "...h264"\n\
337    or "...int16", add it to EXTENSION_NAME_VERSION_EXCEPTIONS in\n\
338    scripts/xml_consistency.py.'
339            else:
340                suffix = ''
341            self.record_error('Missing version enum {}{}'.format(version_name, suffix))
342        elif info.elem.get('supported') == self.conventions.xml_api_name:
343            # Skip unsupported / disabled extensions for these checks
344
345            fn = get_extension_source(name)
346            revisions = []
347            with open(fn, 'r', encoding='utf-8') as fp:
348                for line in fp:
349                    line = line.rstrip()
350                    match = REVISION_RE.match(line)
351                    if match:
352                        revisions.append(int(match.group('num')))
353            ver_from_xml = version_elem.get('value')
354            if revisions:
355                ver_from_text = str(max(revisions))
356                if ver_from_xml != ver_from_text:
357                    self.record_error("Version enum mismatch: spec text indicates", ver_from_text,
358                                      "but XML says", ver_from_xml)
359            else:
360                if ver_from_xml == '1':
361                    self.record_warning(
362                        "Cannot find version history in spec text - make sure it has lines starting exactly like '* Revision 1, ....'",
363                        filename=fn)
364                else:
365                    self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml,
366                                        " - make sure the spec text has lines starting exactly like '* Revision 1, ....'",
367                                        filename=fn)
368
369        name_define = "{}_EXTENSION_NAME".format(ext_enum_name)
370        name_elem = findNamedElem(enums, name_define)
371        if name_elem is None:
372            self.record_error("Missing name enum", name_define)
373        else:
374            # Note: etree handles the XML entities here and turns &quot; back into "
375            expected_name = '"{}"'.format(name)
376            name_val = name_elem.get('value')
377            if name_val != expected_name:
378                self.record_error("Incorrect name enum: expected", expected_name,
379                                  "got", name_val)
380
381        super().check_extension(name, elem)
382
383
384if __name__ == "__main__":
385
386    ckr = Checker()
387    ckr.check()
388
389    if ckr.fail:
390        sys.exit(1)
391