1#!/usr/bin/python3 2# 3# Copyright (c) 2019 Collabora, Ltd. 4# 5# SPDX-License-Identifier: Apache-2.0 6# 7# Author(s): Ryan Pavlik <ryan.pavlik@collabora.com> 8# 9# Purpose: This script checks some "business logic" in the XML registry. 10 11import re 12import sys 13from pathlib import Path 14 15from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase 16from reg import Registry 17from spec_tools.consistency_tools import XMLChecker 18from spec_tools.util import findNamedElem, getElemName, getElemType 19from vkconventions import VulkanConventions as APIConventions 20 21# These are extensions which do not follow the usual naming conventions, 22# specifying the alternate convention they follow 23EXTENSION_ENUM_NAME_SPELLING_CHANGE = { 24 'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE', 25} 26 27# These are extensions whose names *look* like they end in version numbers, 28# but don't 29EXTENSION_NAME_VERSION_EXCEPTIONS = ( 30 'VK_AMD_gpu_shader_int16', 31 'VK_EXT_index_type_uint8', 32 'VK_EXT_shader_image_atomic_int64', 33 'VK_EXT_video_decode_h264', 34 'VK_EXT_video_decode_h265', 35 'VK_EXT_video_encode_h264', 36 'VK_EXT_video_encode_h265', 37 'VK_KHR_external_fence_win32', 38 'VK_KHR_external_memory_win32', 39 'VK_KHR_external_semaphore_win32', 40 'VK_KHR_shader_atomic_int64', 41 'VK_KHR_shader_float16_int8', 42 'VK_KHR_spirv_1_4', 43 'VK_NV_external_memory_win32', 44 'VK_RESERVED_do_not_use_146', 45 'VK_RESERVED_do_not_use_94', 46) 47 48# Exceptions to pointer parameter naming rules 49# Keyed by (entity name, type, name). 50CHECK_PARAM_POINTER_NAME_EXCEPTIONS = { 51 ('vkGetDrmDisplayEXT', 'VkDisplayKHR', 'display') : None, 52} 53 54# Exceptions to pNext member requiring an optional attribute 55CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS = ( 56 'VkVideoEncodeInfoKHR', 57) 58 59def get_extension_commands(reg): 60 extension_cmds = set() 61 for ext in reg.extensions: 62 for cmd in ext.findall("./require/command[@name]"): 63 extension_cmds.add(cmd.get("name")) 64 return extension_cmds 65 66 67def get_enum_value_names(reg, enum_type): 68 names = set() 69 result_elem = reg.groupdict[enum_type].elem 70 for val in result_elem.findall("./enum[@name]"): 71 names.add(val.get("name")) 72 return names 73 74 75# Regular expression matching an extension name ending in a (possible) version number 76EXTNAME_RE = re.compile(r'(?P<base>(\w+[A-Za-z]))(?P<version>\d+)') 77 78DESTROY_PREFIX = "vkDestroy" 79TYPEENUM = "VkStructureType" 80 81SPECIFICATION_DIR = Path(__file__).parent.parent 82REVISION_RE = re.compile(r' *[*] Revision (?P<num>[1-9][0-9]*),.*') 83 84 85def get_extension_source(extname): 86 fn = '{}.txt'.format(extname) 87 return str(SPECIFICATION_DIR / 'appendices' / fn) 88 89 90class EntityDatabase(OrigEntityDatabase): 91 92 # Override base class method to not exclude 'disabled' extensions 93 def getExclusionSet(self): 94 """Return a set of "support=" attribute strings that should not be included in the database. 95 96 Called only during construction.""" 97 98 return set(()) 99 100 def makeRegistry(self): 101 try: 102 import lxml.etree as etree 103 HAS_LXML = True 104 except ImportError: 105 HAS_LXML = False 106 if not HAS_LXML: 107 return super().makeRegistry() 108 109 registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml') 110 registry = Registry() 111 registry.filename = registryFile 112 registry.loadElementTree(etree.parse(registryFile)) 113 return registry 114 115 116class Checker(XMLChecker): 117 def __init__(self): 118 manual_types_to_codes = { 119 # These are hard-coded "manual" return codes: 120 # the codes of the value (string, list, or tuple) 121 # are available for a command if-and-only-if 122 # the key type is passed as an input. 123 "VkFormat": "VK_ERROR_FORMAT_NOT_SUPPORTED" 124 } 125 forward_only = { 126 # Like the above, but these are only valid in the 127 # "type implies return code" direction 128 } 129 reverse_only = { 130 # like the above, but these are only valid in the 131 # "return code implies type or its descendant" direction 132 # "XrDuration": "XR_TIMEOUT_EXPIRED" 133 } 134 # Some return codes are related in that only one of a set 135 # may be returned by a command 136 # (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING) 137 self.exclusive_return_code_sets = tuple( 138 # set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")), 139 ) 140 # Map of extension number -> [ list of extension names ] 141 self.extension_number_reservations = { 142 } 143 144 # This is used to report collisions. 145 conventions = APIConventions() 146 db = EntityDatabase() 147 148 self.extension_cmds = get_extension_commands(db.registry) 149 self.return_codes = get_enum_value_names(db.registry, 'VkResult') 150 self.structure_types = get_enum_value_names(db.registry, TYPEENUM) 151 152 # Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:") 153 # Keys are entity names, values are tuples or lists of message text to suppress. 154 suppressions = {} 155 156 # Initialize superclass 157 super().__init__(entity_db=db, conventions=conventions, 158 manual_types_to_codes=manual_types_to_codes, 159 forward_only_types_to_codes=forward_only, 160 reverse_only_types_to_codes=reverse_only, 161 suppressions=suppressions) 162 163 def check_command_return_codes_basic(self, name, info, 164 successcodes, errorcodes): 165 """Check a command's return codes for consistency. 166 167 Called on every command.""" 168 # Check that all extension commands can return the code associated 169 # with trying to use an extension that wasn't enabled. 170 # if name in self.extension_cmds and UNSUPPORTED not in errorcodes: 171 # self.record_error("Missing expected return code", 172 # UNSUPPORTED, 173 # "implied due to being an extension command") 174 175 codes = successcodes.union(errorcodes) 176 177 # Check that all return codes are recognized. 178 unrecognized = codes - self.return_codes 179 if unrecognized: 180 self.record_error("Unrecognized return code(s):", 181 unrecognized) 182 183 elem = info.elem 184 params = [(getElemName(elt), elt) for elt in elem.findall('param')] 185 186 def is_count_output(name, elt): 187 # Must end with Count or Size, 188 # not be const, 189 # and be a pointer (detected by naming convention) 190 return (name.endswith('Count') or name.endswith('Size')) \ 191 and (elt.tail is None or 'const' not in elt.tail) \ 192 and (name.startswith('p')) 193 194 countParams = [elt 195 for name, elt in params 196 if is_count_output(name, elt)] 197 if countParams: 198 assert(len(countParams) == 1) 199 if 'VK_INCOMPLETE' not in successcodes: 200 self.record_error( 201 "Apparent enumeration of an array without VK_INCOMPLETE in successcodes.") 202 203 elif 'VK_INCOMPLETE' in successcodes: 204 self.record_error( 205 "VK_INCOMPLETE in successcodes of command that is apparently not an array enumeration.") 206 207 def check_param(self, param): 208 """Check a member of a struct or a param of a function. 209 210 Called from check_params.""" 211 super().check_param(param) 212 213 if not self.is_api_type(param): 214 return 215 216 param_text = "".join(param.itertext()) 217 param_name = getElemName(param) 218 219 # Make sure the number of leading "p" matches the pointer count. 220 pointercount = param.find('type').tail 221 if pointercount: 222 pointercount = pointercount.count('*') 223 if pointercount: 224 prefix = 'p' * pointercount 225 if not param_name.startswith(prefix): 226 param_type = param.find('type').text 227 message = "Apparently incorrect pointer-related name prefix for {} - expected it to start with '{}'".format( 228 param_text, prefix) 229 if (self.entity, param_type, param_name) in CHECK_PARAM_POINTER_NAME_EXCEPTIONS: 230 self.record_warning('(Allowed exception)', message, elem=param) 231 else: 232 self.record_error(message, elem=param) 233 234 # Make sure pNext members have optional="true" attributes 235 if param_name == self.conventions.nextpointer_member_name: 236 optional = param.get('optional') 237 if optional is None or optional != 'true': 238 message = '{}.pNext member is missing \'optional="true"\' attribute'.format(self.entity) 239 if self.entity in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS: 240 self.record_warning('(Allowed exception)', message, elem=param) 241 else: 242 self.record_error(message, elem=param) 243 244 def check_type(self, name, info, category): 245 """Check a type's XML data for consistency. 246 247 Called from check.""" 248 249 elem = info.elem 250 type_elts = [elt 251 for elt in elem.findall("member") 252 if getElemType(elt) == TYPEENUM] 253 if category == 'struct' and type_elts: 254 if len(type_elts) > 1: 255 self.record_error( 256 "Have more than one member of type", TYPEENUM) 257 else: 258 type_elt = type_elts[0] 259 val = type_elt.get('values') 260 if val and val not in self.structure_types: 261 self.record_error("Unknown structure type constant", val) 262 263 # Check the pointer chain member, if present. 264 next_name = self.conventions.nextpointer_member_name 265 next_member = findNamedElem(info.elem.findall('member'), next_name) 266 if next_member is not None: 267 # Ensure that the 'optional' attribute is set to 'true' 268 optional = next_member.get('optional') 269 if optional is None or optional != 'true': 270 message = '{}.{} member is missing \'optional="true"\' attribute'.format(name, next_name) 271 if name in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS: 272 self.record_warning('(Allowed exception)', message) 273 else: 274 self.record_error(message) 275 276 elif category == "bitmask": 277 if 'Flags' in name: 278 expected_require = name.replace('Flags', 'FlagBits') 279 require = info.elem.get('require') 280 if require is not None and expected_require != require: 281 self.record_error("Unexpected require attribute value:", 282 "got", require, 283 "but expected", expected_require) 284 super().check_type(name, info, category) 285 286 def check_extension(self, name, info): 287 """Check an extension's XML data for consistency. 288 289 Called from check.""" 290 elem = info.elem 291 enums = elem.findall('./require/enum[@name]') 292 293 # Look for other extensions using that number 294 # Keep track of this extension number reservation 295 ext_number = elem.get('number') 296 if ext_number in self.extension_number_reservations: 297 conflicts = self.extension_number_reservations[ext_number] 298 self.record_error('Extension number {} has more than one reservation: {}, {}'.format( 299 ext_number, name, ', '.join(conflicts))) 300 self.extension_number_reservations[ext_number].append(name) 301 else: 302 self.extension_number_reservations[ext_number] = [ name ] 303 304 # If extension name is not on the exception list and matches the 305 # versioned-extension pattern, map the extension name to the version 306 # name with the version as a separate word. Otherwise just map it to 307 # the upper-case version of the extension name. 308 309 matches = EXTNAME_RE.fullmatch(name) 310 ext_versioned_name = False 311 if name in EXTENSION_ENUM_NAME_SPELLING_CHANGE: 312 ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(name) 313 elif matches is None or name in EXTENSION_NAME_VERSION_EXCEPTIONS: 314 # This is the usual case, either a name that doesn't look 315 # versioned, or one that does but is on the exception list. 316 ext_enum_name = name.upper() 317 else: 318 # This is a versioned extension name. 319 # Treat the version number as a separate word. 320 base = matches.group('base') 321 version = matches.group('version') 322 ext_enum_name = base.upper() + '_' + version 323 # Keep track of this case 324 ext_versioned_name = True 325 326 # Look for the expected SPEC_VERSION token name 327 version_name = "{}_SPEC_VERSION".format(ext_enum_name) 328 version_elem = findNamedElem(enums, version_name) 329 330 if version_elem is None: 331 # Did not find a SPEC_VERSION enum matching the extension name 332 if ext_versioned_name: 333 suffix = '\n\ 334 Make sure that trailing version numbers in extension names are treated\n\ 335 as separate words in extension enumerant names. If this is an extension\n\ 336 whose name ends in a number which is not a version, such as "...h264"\n\ 337 or "...int16", add it to EXTENSION_NAME_VERSION_EXCEPTIONS in\n\ 338 scripts/xml_consistency.py.' 339 else: 340 suffix = '' 341 self.record_error('Missing version enum {}{}'.format(version_name, suffix)) 342 elif info.elem.get('supported') == self.conventions.xml_api_name: 343 # Skip unsupported / disabled extensions for these checks 344 345 fn = get_extension_source(name) 346 revisions = [] 347 with open(fn, 'r', encoding='utf-8') as fp: 348 for line in fp: 349 line = line.rstrip() 350 match = REVISION_RE.match(line) 351 if match: 352 revisions.append(int(match.group('num'))) 353 ver_from_xml = version_elem.get('value') 354 if revisions: 355 ver_from_text = str(max(revisions)) 356 if ver_from_xml != ver_from_text: 357 self.record_error("Version enum mismatch: spec text indicates", ver_from_text, 358 "but XML says", ver_from_xml) 359 else: 360 if ver_from_xml == '1': 361 self.record_warning( 362 "Cannot find version history in spec text - make sure it has lines starting exactly like '* Revision 1, ....'", 363 filename=fn) 364 else: 365 self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml, 366 " - make sure the spec text has lines starting exactly like '* Revision 1, ....'", 367 filename=fn) 368 369 name_define = "{}_EXTENSION_NAME".format(ext_enum_name) 370 name_elem = findNamedElem(enums, name_define) 371 if name_elem is None: 372 self.record_error("Missing name enum", name_define) 373 else: 374 # Note: etree handles the XML entities here and turns " back into " 375 expected_name = '"{}"'.format(name) 376 name_val = name_elem.get('value') 377 if name_val != expected_name: 378 self.record_error("Incorrect name enum: expected", expected_name, 379 "got", name_val) 380 381 super().check_extension(name, elem) 382 383 384if __name__ == "__main__": 385 386 ckr = Checker() 387 ckr.check() 388 389 if ckr.fail: 390 sys.exit(1) 391