• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright (c) 2019 Collabora, Ltd.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Author(s):    Ryan Pavlik <ryan.pavlik@collabora.com>
8#
9# Purpose:      This script checks some "business logic" in the XML registry.
10
11import re
12import sys
13from pathlib import Path
14
15from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase
16from reg import Registry
17from spec_tools.consistency_tools import XMLChecker
18from spec_tools.util import findNamedElem, getElemName, getElemType
19from vkconventions import VulkanConventions as APIConventions
20
21# Most extensions have theier meta-enums named with just an uppercase version of their name,
22# but some are weird.
23EXTENSION_ENUM_NAME_SPELLING_CHANGE = {
24    'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE',
25    'VK_KHR_get_physical_device_properties2': 'VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2',
26    'VK_KHR_get_display_properties2': 'VK_KHR_GET_DISPLAY_PROPERTIES_2',
27    'VK_KHR_get_surface_capabilities2': 'VK_KHR_GET_SURFACE_CAPABILITIES_2',
28    'VK_KHR_create_renderpass2': 'VK_KHR_CREATE_RENDERPASS_2',
29    'VK_KHR_bind_memory2': 'VK_KHR_BIND_MEMORY_2',
30    'VK_KHR_get_memory_requirements2': 'VK_KHR_GET_MEMORY_REQUIREMENTS_2',
31    'VK_AMD_shader_core_properties2': 'VK_AMD_SHADER_CORE_PROPERTIES_2',
32    'VK_INTEL_shader_integer_functions2': 'VK_INTEL_SHADER_INTEGER_FUNCTIONS_2'
33}
34
35
36def get_extension_commands(reg):
37    extension_cmds = set()
38    for ext in reg.extensions:
39        for cmd in ext.findall("./require/command[@name]"):
40            extension_cmds.add(cmd.get("name"))
41    return extension_cmds
42
43
44def get_enum_value_names(reg, enum_type):
45    names = set()
46    result_elem = reg.groupdict[enum_type].elem
47    for val in result_elem.findall("./enum[@name]"):
48        names.add(val.get("name"))
49    return names
50
51
52DESTROY_PREFIX = "vkDestroy"
53TYPEENUM = "VkStructureType"
54
55
56SPECIFICATION_DIR = Path(__file__).parent.parent
57REVISION_RE = re.compile(r' *[*] Revision (?P<num>[1-9][0-9]*),.*')
58
59
60def get_extension_source(extname):
61    fn = '{}.txt'.format(extname)
62    return str(SPECIFICATION_DIR / 'appendices' / fn)
63
64
65class EntityDatabase(OrigEntityDatabase):
66
67    def makeRegistry(self):
68        try:
69            import lxml.etree as etree
70            HAS_LXML = True
71        except ImportError:
72            HAS_LXML = False
73        if not HAS_LXML:
74            return super().makeRegistry()
75
76        registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml')
77        registry = Registry()
78        registry.filename = registryFile
79        registry.loadElementTree(etree.parse(registryFile))
80        return registry
81
82
83class Checker(XMLChecker):
84    def __init__(self):
85        manual_types_to_codes = {
86            # These are hard-coded "manual" return codes:
87            # the codes of the value (string, list, or tuple)
88            # are available for a command if-and-only-if
89            # the key type is passed as an input.
90            "VkFormat": "VK_ERROR_FORMAT_NOT_SUPPORTED"
91        }
92        forward_only = {
93            # Like the above, but these are only valid in the
94            # "type implies return code" direction
95        }
96        reverse_only = {
97            # like the above, but these are only valid in the
98            # "return code implies type or its descendant" direction
99            # "XrDuration": "XR_TIMEOUT_EXPIRED"
100        }
101        # Some return codes are related in that only one of a set
102        # may be returned by a command
103        # (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING)
104        self.exclusive_return_code_sets = tuple(
105            # set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")),
106        )
107
108        conventions = APIConventions()
109        db = EntityDatabase()
110
111        self.extension_cmds = get_extension_commands(db.registry)
112        self.return_codes = get_enum_value_names(db.registry, 'VkResult')
113        self.structure_types = get_enum_value_names(db.registry, TYPEENUM)
114
115        # Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:")
116        # Keys are entity names, values are tuples or lists of message text to suppress.
117        suppressions = {}
118
119        # Initialize superclass
120        super().__init__(entity_db=db, conventions=conventions,
121                         manual_types_to_codes=manual_types_to_codes,
122                         forward_only_types_to_codes=forward_only,
123                         reverse_only_types_to_codes=reverse_only,
124                         suppressions=suppressions)
125
126    def check_command_return_codes_basic(self, name, info,
127                                         successcodes, errorcodes):
128        """Check a command's return codes for consistency.
129
130        Called on every command."""
131        # Check that all extension commands can return the code associated
132        # with trying to use an extension that wasn't enabled.
133        # if name in self.extension_cmds and UNSUPPORTED not in errorcodes:
134        #     self.record_error("Missing expected return code",
135        #                       UNSUPPORTED,
136        #                       "implied due to being an extension command")
137
138        codes = successcodes.union(errorcodes)
139
140        # Check that all return codes are recognized.
141        unrecognized = codes - self.return_codes
142        if unrecognized:
143            self.record_error("Unrecognized return code(s):",
144                              unrecognized)
145
146        elem = info.elem
147        params = [(getElemName(elt), elt) for elt in elem.findall('param')]
148
149        def is_count_output(name, elt):
150            # Must end with Count or Size,
151            # not be const,
152            # and be a pointer (detected by naming convention)
153            return (name.endswith('Count') or name.endswith('Size')) \
154                and (elt.tail is None or 'const' not in elt.tail) \
155                and (name.startswith('p'))
156
157        countParams = [elt
158                       for name, elt in params
159                       if is_count_output(name, elt)]
160        if countParams:
161            assert(len(countParams) == 1)
162            if 'VK_INCOMPLETE' not in successcodes:
163                self.record_error(
164                    "Apparent enumeration of an array without VK_INCOMPLETE in successcodes.")
165
166        elif 'VK_INCOMPLETE' in successcodes:
167            self.record_error(
168                "VK_INCOMPLETE in successcodes of command that is apparently not an array enumeration.")
169
170    def check_param(self, param):
171        """Check a member of a struct or a param of a function.
172
173        Called from check_params."""
174        super().check_param(param)
175
176        if not self.is_api_type(param):
177            return
178
179        param_text = "".join(param.itertext())
180        param_name = getElemName(param)
181
182        # Make sure the number of leading "p" matches the pointer count.
183        pointercount = param.find('type').tail
184        if pointercount:
185            pointercount = pointercount.count('*')
186        if pointercount:
187            prefix = 'p' * pointercount
188            if not param_name.startswith(prefix):
189                self.record_error("Apparently incorrect pointer-related name prefix for",
190                                  param_text, "- expected it to start with", prefix,
191                                  elem=param)
192
193    def check_type(self, name, info, category):
194        """Check a type's XML data for consistency.
195
196        Called from check."""
197
198        elem = info.elem
199        type_elts = [elt
200                     for elt in elem.findall("member")
201                     if getElemType(elt) == TYPEENUM]
202        if category == 'struct' and type_elts:
203            if len(type_elts) > 1:
204                self.record_error(
205                    "Have more than one member of type", TYPEENUM)
206            else:
207                type_elt = type_elts[0]
208                val = type_elt.get('values')
209                if val and val not in self.structure_types:
210                    self.record_error("Unknown structure type constant", val)
211        elif category == "bitmask":
212            if 'Flags' in name:
213                expected_require = name.replace('Flags', 'FlagBits')
214                require = info.elem.get('require')
215                if require is not None and expected_require != require:
216                    self.record_error("Unexpected require attribute value:",
217                                      "got", require,
218                                      "but expected", expected_require)
219        super().check_type(name, info, category)
220
221    def check_extension(self, name, info):
222        """Check an extension's XML data for consistency.
223
224        Called from check."""
225        elem = info.elem
226        enums = elem.findall('./require/enum[@name]')
227
228        # Get the way it's spelling in enum names
229        ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(
230            name, name.upper())
231        version_name = "{}_SPEC_VERSION".format(ext_enum_name)
232        version_elem = findNamedElem(enums, version_name)
233        if version_elem is None:
234            self.record_error("Missing version enum", version_name)
235        else:
236            fn = get_extension_source(name)
237            revisions = []
238            with open(fn, 'r', encoding='utf-8') as fp:
239                for line in fp:
240                    line = line.rstrip()
241                    match = REVISION_RE.match(line)
242                    if match:
243                        revisions.append(int(match.group('num')))
244            ver_from_xml = version_elem.get('value')
245            if revisions:
246                ver_from_text = str(max(revisions))
247                if ver_from_xml != ver_from_text:
248                    self.record_error("Version enum mismatch: spec text indicates", ver_from_text,
249                                      "but XML says", ver_from_xml)
250            else:
251                if ver_from_xml == '1':
252                    self.record_warning(
253                        "Cannot find version history in spec text - make sure it has lines starting exactly like '* Revision 1, ....'",
254                        filename=fn)
255                else:
256                    self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml,
257                                        " - make sure the spec text has lines starting exactly like '* Revision 1, ....'",
258                                        filename=fn)
259
260        name_define = "{}_EXTENSION_NAME".format(ext_enum_name)
261        name_elem = findNamedElem(enums, name_define)
262        if name_elem is None:
263            self.record_error("Missing name enum", name_define)
264        else:
265            # Note: etree handles the XML entities here and turns &quot; back into "
266            expected_name = '"{}"'.format(name)
267            name_val = name_elem.get('value')
268            if name_val != expected_name:
269                self.record_error("Incorrect name enum: expected", expected_name,
270                                  "got", name_val)
271
272        super().check_extension(name, elem)
273
274
275if __name__ == "__main__":
276
277    ckr = Checker()
278    ckr.check()
279
280    if ckr.fail:
281        sys.exit(1)
282