• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright (c) 2019 Collabora, Ltd.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Author(s):    Ryan Pavlik <ryan.pavlik@collabora.com>
8#
9# Purpose:      This script checks some "business logic" in the XML registry.
10
11import re
12import sys
13from pathlib import Path
14
15from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase
16from reg import Registry
17from spec_tools.consistency_tools import XMLChecker
18from spec_tools.util import findNamedElem, getElemName, getElemType
19from apiconventions import APIConventions
20
21# These are extensions which do not follow the usual naming conventions,
22# specifying the alternate convention they follow
23EXTENSION_ENUM_NAME_SPELLING_CHANGE = {
24    'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE',
25}
26
27# These are extensions whose names *look* like they end in version numbers,
28# but do not
29EXTENSION_NAME_VERSION_EXCEPTIONS = (
30    'VK_AMD_gpu_shader_int16',
31    'VK_EXT_index_type_uint8',
32    'VK_EXT_shader_image_atomic_int64',
33    'VK_EXT_video_decode_h264',
34    'VK_EXT_video_decode_h265',
35    'VK_EXT_video_encode_h264',
36    'VK_EXT_video_encode_h265',
37    'VK_KHR_external_fence_win32',
38    'VK_KHR_external_memory_win32',
39    'VK_KHR_external_semaphore_win32',
40    'VK_KHR_shader_atomic_int64',
41    'VK_KHR_shader_float16_int8',
42    'VK_KHR_spirv_1_4',
43    'VK_NV_external_memory_win32',
44    'VK_RESERVED_do_not_use_146',
45    'VK_RESERVED_do_not_use_94',
46)
47
48# Exceptions to pointer parameter naming rules
49# Keyed by (entity name, type, name).
50CHECK_PARAM_POINTER_NAME_EXCEPTIONS = {
51    ('vkGetDrmDisplayEXT', 'VkDisplayKHR', 'display') : None,
52}
53
54# Exceptions to pNext member requiring an optional attribute
55CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS = (
56    'VkVideoEncodeInfoKHR',
57    'VkVideoEncodeRateControlLayerInfoKHR',
58)
59
60def get_extension_commands(reg):
61    extension_cmds = set()
62    for ext in reg.extensions:
63        for cmd in ext.findall("./require/command[@name]"):
64            extension_cmds.add(cmd.get("name"))
65    return extension_cmds
66
67
68def get_enum_value_names(reg, enum_type):
69    names = set()
70    result_elem = reg.groupdict[enum_type].elem
71    for val in result_elem.findall("./enum[@name]"):
72        names.add(val.get("name"))
73    return names
74
75
76# Regular expression matching an extension name ending in a (possible) version number
77EXTNAME_RE = re.compile(r'(?P<base>(\w+[A-Za-z]))(?P<version>\d+)')
78
79DESTROY_PREFIX = "vkDestroy"
80TYPEENUM = "VkStructureType"
81
82SPECIFICATION_DIR = Path(__file__).parent.parent
83REVISION_RE = re.compile(r' *[*] Revision (?P<num>[1-9][0-9]*),.*')
84
85
86def get_extension_source(extname):
87    fn = '{}.txt'.format(extname)
88    return str(SPECIFICATION_DIR / 'appendices' / fn)
89
90
91class EntityDatabase(OrigEntityDatabase):
92
93    # Override base class method to not exclude 'disabled' extensions
94    def getExclusionSet(self):
95        """Return a set of "support=" attribute strings that should not be included in the database.
96
97        Called only during construction."""
98
99        return set(())
100
101    def makeRegistry(self):
102        try:
103            import lxml.etree as etree
104            HAS_LXML = True
105        except ImportError:
106            HAS_LXML = False
107        if not HAS_LXML:
108            return super().makeRegistry()
109
110        registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml')
111        registry = Registry()
112        registry.filename = registryFile
113        registry.loadElementTree(etree.parse(registryFile))
114        return registry
115
116
117class Checker(XMLChecker):
118    def __init__(self):
119        manual_types_to_codes = {
120            # These are hard-coded "manual" return codes:
121            # the codes of the value (string, list, or tuple)
122            # are available for a command if-and-only-if
123            # the key type is passed as an input.
124            "VkFormat": "VK_ERROR_FORMAT_NOT_SUPPORTED"
125        }
126        forward_only = {
127            # Like the above, but these are only valid in the
128            # "type implies return code" direction
129        }
130        reverse_only = {
131            # like the above, but these are only valid in the
132            # "return code implies type or its descendant" direction
133            # "XrDuration": "XR_TIMEOUT_EXPIRED"
134        }
135        # Some return codes are related in that only one of a set
136        # may be returned by a command
137        # (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING)
138        self.exclusive_return_code_sets = tuple(
139            # set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")),
140        )
141        # Map of extension number -> [ list of extension names ]
142        self.extension_number_reservations = {
143        }
144
145        # This is used to report collisions.
146        conventions = APIConventions()
147        db = EntityDatabase()
148
149        self.extension_cmds = get_extension_commands(db.registry)
150        self.return_codes = get_enum_value_names(db.registry, 'VkResult')
151        self.structure_types = get_enum_value_names(db.registry, TYPEENUM)
152
153        # Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:")
154        # Keys are entity names, values are tuples or lists of message text to suppress.
155        suppressions = {}
156
157        # Initialize superclass
158        super().__init__(entity_db=db, conventions=conventions,
159                         manual_types_to_codes=manual_types_to_codes,
160                         forward_only_types_to_codes=forward_only,
161                         reverse_only_types_to_codes=reverse_only,
162                         suppressions=suppressions)
163
164    def check_command_return_codes_basic(self, name, info,
165                                         successcodes, errorcodes):
166        """Check a command's return codes for consistency.
167
168        Called on every command."""
169        # Check that all extension commands can return the code associated
170        # with trying to use an extension that was not enabled.
171        # if name in self.extension_cmds and UNSUPPORTED not in errorcodes:
172        #     self.record_error("Missing expected return code",
173        #                       UNSUPPORTED,
174        #                       "implied due to being an extension command")
175
176        codes = successcodes.union(errorcodes)
177
178        # Check that all return codes are recognized.
179        unrecognized = codes - self.return_codes
180        if unrecognized:
181            self.record_error("Unrecognized return code(s):",
182                              unrecognized)
183
184        elem = info.elem
185        params = [(getElemName(elt), elt) for elt in elem.findall('param')]
186
187        def is_count_output(name, elt):
188            # Must end with Count or Size,
189            # not be const,
190            # and be a pointer (detected by naming convention)
191            return (name.endswith('Count') or name.endswith('Size')) \
192                and (elt.tail is None or 'const' not in elt.tail) \
193                and (name.startswith('p'))
194
195        countParams = [elt
196                       for name, elt in params
197                       if is_count_output(name, elt)]
198        if countParams:
199            assert(len(countParams) == 1)
200            if 'VK_INCOMPLETE' not in successcodes:
201                self.record_error(
202                    "Apparent enumeration of an array without VK_INCOMPLETE in successcodes.")
203
204        elif 'VK_INCOMPLETE' in successcodes:
205            self.record_error(
206                "VK_INCOMPLETE in successcodes of command that is apparently not an array enumeration.")
207
208    def check_param(self, param):
209        """Check a member of a struct or a param of a function.
210
211        Called from check_params."""
212        super().check_param(param)
213
214        if not self.is_api_type(param):
215            return
216
217        param_text = "".join(param.itertext())
218        param_name = getElemName(param)
219
220        # Make sure the number of leading "p" matches the pointer count.
221        pointercount = param.find('type').tail
222        if pointercount:
223            pointercount = pointercount.count('*')
224        if pointercount:
225            prefix = 'p' * pointercount
226            if not param_name.startswith(prefix):
227                param_type = param.find('type').text
228                message = "Apparently incorrect pointer-related name prefix for {} - expected it to start with '{}'".format(
229                    param_text, prefix)
230                if (self.entity, param_type, param_name) in CHECK_PARAM_POINTER_NAME_EXCEPTIONS:
231                    self.record_warning('(Allowed exception)', message, elem=param)
232                else:
233                    self.record_error(message, elem=param)
234
235        # Make sure pNext members have optional="true" attributes
236        if param_name == self.conventions.nextpointer_member_name:
237            optional = param.get('optional')
238            if optional is None or optional != 'true':
239                message = '{}.pNext member is missing \'optional="true"\' attribute'.format(self.entity)
240                if self.entity in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
241                    self.record_warning('(Allowed exception)', message, elem=param)
242                else:
243                    self.record_error(message, elem=param)
244
245    def check_type(self, name, info, category):
246        """Check a type's XML data for consistency.
247
248        Called from check."""
249
250        elem = info.elem
251        type_elts = [elt
252                     for elt in elem.findall("member")
253                     if getElemType(elt) == TYPEENUM]
254        if category == 'struct' and type_elts:
255            if len(type_elts) > 1:
256                self.record_error(
257                    "Have more than one member of type", TYPEENUM)
258            else:
259                type_elt = type_elts[0]
260                val = type_elt.get('values')
261                if val and val not in self.structure_types:
262                    self.record_error("Unknown structure type constant", val)
263
264            # Check the pointer chain member, if present.
265            next_name = self.conventions.nextpointer_member_name
266            next_member = findNamedElem(info.elem.findall('member'), next_name)
267            if next_member is not None:
268                # Ensure that the 'optional' attribute is set to 'true'
269                optional = next_member.get('optional')
270                if optional is None or optional != 'true':
271                    message = '{}.{} member is missing \'optional="true"\' attribute'.format(name, next_name)
272                    if name in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
273                        self.record_warning('(Allowed exception)', message)
274                    else:
275                        self.record_error(message)
276
277        elif category == "bitmask":
278            if 'Flags' in name:
279                expected_require = name.replace('Flags', 'FlagBits')
280                require = info.elem.get('require')
281                if require is not None and expected_require != require:
282                    self.record_error("Unexpected require attribute value:",
283                                      "got", require,
284                                      "but expected", expected_require)
285        super().check_type(name, info, category)
286
287    def check_extension(self, name, info):
288        """Check an extension's XML data for consistency.
289
290        Called from check."""
291        elem = info.elem
292        enums = elem.findall('./require/enum[@name]')
293
294        # Look for other extensions using that number
295        # Keep track of this extension number reservation
296        ext_number = elem.get('number')
297        if ext_number in self.extension_number_reservations:
298            conflicts = self.extension_number_reservations[ext_number]
299            self.record_error('Extension number {} has more than one reservation: {}, {}'.format(
300                ext_number, name, ', '.join(conflicts)))
301            self.extension_number_reservations[ext_number].append(name)
302        else:
303            self.extension_number_reservations[ext_number] = [ name ]
304
305        # If extension name is not on the exception list and matches the
306        # versioned-extension pattern, map the extension name to the version
307        # name with the version as a separate word. Otherwise just map it to
308        # the upper-case version of the extension name.
309
310        matches = EXTNAME_RE.fullmatch(name)
311        ext_versioned_name = False
312        if name in EXTENSION_ENUM_NAME_SPELLING_CHANGE:
313            ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(name)
314        elif matches is None or name in EXTENSION_NAME_VERSION_EXCEPTIONS:
315            # This is the usual case, either a name that does not look
316            # versioned, or one that does but is on the exception list.
317            ext_enum_name = name.upper()
318        else:
319            # This is a versioned extension name.
320            # Treat the version number as a separate word.
321            base = matches.group('base')
322            version = matches.group('version')
323            ext_enum_name = base.upper() + '_' + version
324            # Keep track of this case
325            ext_versioned_name = True
326
327        # Look for the expected SPEC_VERSION token name
328        version_name = "{}_SPEC_VERSION".format(ext_enum_name)
329        version_elem = findNamedElem(enums, version_name)
330
331        if version_elem is None:
332            # Did not find a SPEC_VERSION enum matching the extension name
333            if ext_versioned_name:
334                suffix = '\n\
335    Make sure that trailing version numbers in extension names are treated\n\
336    as separate words in extension enumerant names. If this is an extension\n\
337    whose name ends in a number which is not a version, such as "...h264"\n\
338    or "...int16", add it to EXTENSION_NAME_VERSION_EXCEPTIONS in\n\
339    scripts/xml_consistency.py.'
340            else:
341                suffix = ''
342            self.record_error('Missing version enum {}{}'.format(version_name, suffix))
343        elif self.conventions.xml_api_name in info.elem.get('supported').split(','):
344            # Skip unsupported / disabled extensions for these checks
345
346            fn = get_extension_source(name)
347            revisions = []
348            with open(fn, 'r', encoding='utf-8') as fp:
349                for line in fp:
350                    line = line.rstrip()
351                    match = REVISION_RE.match(line)
352                    if match:
353                        revisions.append(int(match.group('num')))
354            ver_from_xml = version_elem.get('value')
355            if revisions:
356                ver_from_text = str(max(revisions))
357                if ver_from_xml != ver_from_text:
358                    self.record_error("Version enum mismatch: spec text indicates", ver_from_text,
359                                      "but XML says", ver_from_xml)
360            else:
361                if ver_from_xml == '1':
362                    self.record_warning(
363                        "Cannot find version history in spec text - make sure it has lines starting exactly like '* Revision 1, ....'",
364                        filename=fn)
365                else:
366                    self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml,
367                                        " - make sure the spec text has lines starting exactly like '* Revision 1, ....'",
368                                        filename=fn)
369
370        name_define = "{}_EXTENSION_NAME".format(ext_enum_name)
371        name_elem = findNamedElem(enums, name_define)
372        if name_elem is None:
373            self.record_error("Missing name enum", name_define)
374        else:
375            # Note: etree handles the XML entities here and turns &quot; back into "
376            expected_name = '"{}"'.format(name)
377            name_val = name_elem.get('value')
378            if name_val != expected_name:
379                self.record_error("Incorrect name enum: expected", expected_name,
380                                  "got", name_val)
381
382        super().check_extension(name, elem)
383
384    def check_format(self):
385        """Check an extension's XML data for consistency.
386
387        Called from check."""
388
389        astc3d_formats = [
390                'VK_FORMAT_ASTC_3x3x3_UNORM_BLOCK_EXT',
391                'VK_FORMAT_ASTC_3x3x3_SRGB_BLOCK_EXT',
392                'VK_FORMAT_ASTC_3x3x3_SFLOAT_BLOCK_EXT',
393                'VK_FORMAT_ASTC_4x3x3_UNORM_BLOCK_EXT',
394                'VK_FORMAT_ASTC_4x3x3_SRGB_BLOCK_EXT',
395                'VK_FORMAT_ASTC_4x3x3_SFLOAT_BLOCK_EXT',
396                'VK_FORMAT_ASTC_4x4x3_UNORM_BLOCK_EXT',
397                'VK_FORMAT_ASTC_4x4x3_SRGB_BLOCK_EXT',
398                'VK_FORMAT_ASTC_4x4x3_SFLOAT_BLOCK_EXT',
399                'VK_FORMAT_ASTC_4x4x4_UNORM_BLOCK_EXT',
400                'VK_FORMAT_ASTC_4x4x4_SRGB_BLOCK_EXT',
401                'VK_FORMAT_ASTC_4x4x4_SFLOAT_BLOCK_EXT',
402                'VK_FORMAT_ASTC_5x4x4_UNORM_BLOCK_EXT',
403                'VK_FORMAT_ASTC_5x4x4_SRGB_BLOCK_EXT',
404                'VK_FORMAT_ASTC_5x4x4_SFLOAT_BLOCK_EXT',
405                'VK_FORMAT_ASTC_5x5x4_UNORM_BLOCK_EXT',
406                'VK_FORMAT_ASTC_5x5x4_SRGB_BLOCK_EXT',
407                'VK_FORMAT_ASTC_5x5x4_SFLOAT_BLOCK_EXT',
408                'VK_FORMAT_ASTC_5x5x5_UNORM_BLOCK_EXT',
409                'VK_FORMAT_ASTC_5x5x5_SRGB_BLOCK_EXT',
410                'VK_FORMAT_ASTC_5x5x5_SFLOAT_BLOCK_EXT',
411                'VK_FORMAT_ASTC_6x5x5_UNORM_BLOCK_EXT',
412                'VK_FORMAT_ASTC_6x5x5_SRGB_BLOCK_EXT',
413                'VK_FORMAT_ASTC_6x5x5_SFLOAT_BLOCK_EXT',
414                'VK_FORMAT_ASTC_6x6x5_UNORM_BLOCK_EXT',
415                'VK_FORMAT_ASTC_6x6x5_SRGB_BLOCK_EXT',
416                'VK_FORMAT_ASTC_6x6x5_SFLOAT_BLOCK_EXT',
417                'VK_FORMAT_ASTC_6x6x6_UNORM_BLOCK_EXT',
418                'VK_FORMAT_ASTC_6x6x6_SRGB_BLOCK_EXT',
419                'VK_FORMAT_ASTC_6x6x6_SFLOAT_BLOCK_EXT'
420        ]
421
422        # Need to build list of formats from rest of <enums>
423        enum_formats = []
424        for enum in self.reg.groupdict["VkFormat"].elem:
425            if enum.get("alias") is None and enum.get("name") != "VK_FORMAT_UNDEFINED":
426                enum_formats.append(enum.get("name"))
427
428        found_formats = []
429        for name, info in self.reg.formatsdict.items():
430            found_formats.append(name)
431            self.set_error_context(entity=name, elem=info.elem)
432
433            if name not in enum_formats:
434                self.record_error("The <format> has no matching <enum> for", name)
435
436            # Check never just 1 plane
437            plane_elems = info.elem.findall("plane")
438            if len(plane_elems) == 1:
439                self.record_error("The <format> has only 1 <plane> for", name)
440
441            valid_chroma = ["420", "422", "444"]
442            if info.elem.get("chroma") and info.elem.get("chroma") not in valid_chroma:
443                self.record_error("The <format> has chroma is not a valid value for", name)
444
445        # Re-loop to check the other way if the <format> is missing
446        for enum in self.reg.groupdict["VkFormat"].elem:
447            name = enum.get("name")
448            if enum.get("alias") is None and name != "VK_FORMAT_UNDEFINED":
449                if name not in found_formats and name not in astc3d_formats:
450                    self.set_error_context(entity=name, elem=enum)
451                    self.record_error("The <enum> has no matching <format> for ", name)
452
453        super().check_format()
454
455
456if __name__ == "__main__":
457
458    ckr = Checker()
459    ckr.check()
460
461    if ckr.fail:
462        sys.exit(1)
463