1#!/usr/bin/python3 2# 3# Copyright (c) 2019 Collabora, Ltd. 4# 5# SPDX-License-Identifier: Apache-2.0 6# 7# Author(s): Ryan Pavlik <ryan.pavlik@collabora.com> 8# 9# Purpose: This script checks some "business logic" in the XML registry. 10 11import re 12import sys 13from pathlib import Path 14 15from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase 16from reg import Registry 17from spec_tools.consistency_tools import XMLChecker 18from spec_tools.util import findNamedElem, getElemName, getElemType 19from apiconventions import APIConventions 20 21# These are extensions which do not follow the usual naming conventions, 22# specifying the alternate convention they follow 23EXTENSION_ENUM_NAME_SPELLING_CHANGE = { 24 'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE', 25} 26 27# These are extensions whose names *look* like they end in version numbers, 28# but do not 29EXTENSION_NAME_VERSION_EXCEPTIONS = ( 30 'VK_AMD_gpu_shader_int16', 31 'VK_EXT_index_type_uint8', 32 'VK_EXT_shader_image_atomic_int64', 33 'VK_EXT_video_decode_h264', 34 'VK_EXT_video_decode_h265', 35 'VK_EXT_video_encode_h264', 36 'VK_EXT_video_encode_h265', 37 'VK_KHR_external_fence_win32', 38 'VK_KHR_external_memory_win32', 39 'VK_KHR_external_semaphore_win32', 40 'VK_KHR_shader_atomic_int64', 41 'VK_KHR_shader_float16_int8', 42 'VK_KHR_spirv_1_4', 43 'VK_NV_external_memory_win32', 44 'VK_RESERVED_do_not_use_146', 45 'VK_RESERVED_do_not_use_94', 46) 47 48# Exceptions to pointer parameter naming rules 49# Keyed by (entity name, type, name). 50CHECK_PARAM_POINTER_NAME_EXCEPTIONS = { 51 ('vkGetDrmDisplayEXT', 'VkDisplayKHR', 'display') : None, 52} 53 54# Exceptions to pNext member requiring an optional attribute 55CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS = ( 56 'VkVideoEncodeInfoKHR', 57 'VkVideoEncodeRateControlLayerInfoKHR', 58) 59 60def get_extension_commands(reg): 61 extension_cmds = set() 62 for ext in reg.extensions: 63 for cmd in ext.findall("./require/command[@name]"): 64 extension_cmds.add(cmd.get("name")) 65 return extension_cmds 66 67 68def get_enum_value_names(reg, enum_type): 69 names = set() 70 result_elem = reg.groupdict[enum_type].elem 71 for val in result_elem.findall("./enum[@name]"): 72 names.add(val.get("name")) 73 return names 74 75 76# Regular expression matching an extension name ending in a (possible) version number 77EXTNAME_RE = re.compile(r'(?P<base>(\w+[A-Za-z]))(?P<version>\d+)') 78 79DESTROY_PREFIX = "vkDestroy" 80TYPEENUM = "VkStructureType" 81 82SPECIFICATION_DIR = Path(__file__).parent.parent 83REVISION_RE = re.compile(r' *[*] Revision (?P<num>[1-9][0-9]*),.*') 84 85 86def get_extension_source(extname): 87 fn = '{}.txt'.format(extname) 88 return str(SPECIFICATION_DIR / 'appendices' / fn) 89 90 91class EntityDatabase(OrigEntityDatabase): 92 93 # Override base class method to not exclude 'disabled' extensions 94 def getExclusionSet(self): 95 """Return a set of "support=" attribute strings that should not be included in the database. 96 97 Called only during construction.""" 98 99 return set(()) 100 101 def makeRegistry(self): 102 try: 103 import lxml.etree as etree 104 HAS_LXML = True 105 except ImportError: 106 HAS_LXML = False 107 if not HAS_LXML: 108 return super().makeRegistry() 109 110 registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml') 111 registry = Registry() 112 registry.filename = registryFile 113 registry.loadElementTree(etree.parse(registryFile)) 114 return registry 115 116 117class Checker(XMLChecker): 118 def __init__(self): 119 manual_types_to_codes = { 120 # These are hard-coded "manual" return codes: 121 # the codes of the value (string, list, or tuple) 122 # are available for a command if-and-only-if 123 # the key type is passed as an input. 124 "VkFormat": "VK_ERROR_FORMAT_NOT_SUPPORTED" 125 } 126 forward_only = { 127 # Like the above, but these are only valid in the 128 # "type implies return code" direction 129 } 130 reverse_only = { 131 # like the above, but these are only valid in the 132 # "return code implies type or its descendant" direction 133 # "XrDuration": "XR_TIMEOUT_EXPIRED" 134 } 135 # Some return codes are related in that only one of a set 136 # may be returned by a command 137 # (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING) 138 self.exclusive_return_code_sets = tuple( 139 # set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")), 140 ) 141 # Map of extension number -> [ list of extension names ] 142 self.extension_number_reservations = { 143 } 144 145 # This is used to report collisions. 146 conventions = APIConventions() 147 db = EntityDatabase() 148 149 self.extension_cmds = get_extension_commands(db.registry) 150 self.return_codes = get_enum_value_names(db.registry, 'VkResult') 151 self.structure_types = get_enum_value_names(db.registry, TYPEENUM) 152 153 # Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:") 154 # Keys are entity names, values are tuples or lists of message text to suppress. 155 suppressions = {} 156 157 # Initialize superclass 158 super().__init__(entity_db=db, conventions=conventions, 159 manual_types_to_codes=manual_types_to_codes, 160 forward_only_types_to_codes=forward_only, 161 reverse_only_types_to_codes=reverse_only, 162 suppressions=suppressions) 163 164 def check_command_return_codes_basic(self, name, info, 165 successcodes, errorcodes): 166 """Check a command's return codes for consistency. 167 168 Called on every command.""" 169 # Check that all extension commands can return the code associated 170 # with trying to use an extension that was not enabled. 171 # if name in self.extension_cmds and UNSUPPORTED not in errorcodes: 172 # self.record_error("Missing expected return code", 173 # UNSUPPORTED, 174 # "implied due to being an extension command") 175 176 codes = successcodes.union(errorcodes) 177 178 # Check that all return codes are recognized. 179 unrecognized = codes - self.return_codes 180 if unrecognized: 181 self.record_error("Unrecognized return code(s):", 182 unrecognized) 183 184 elem = info.elem 185 params = [(getElemName(elt), elt) for elt in elem.findall('param')] 186 187 def is_count_output(name, elt): 188 # Must end with Count or Size, 189 # not be const, 190 # and be a pointer (detected by naming convention) 191 return (name.endswith('Count') or name.endswith('Size')) \ 192 and (elt.tail is None or 'const' not in elt.tail) \ 193 and (name.startswith('p')) 194 195 countParams = [elt 196 for name, elt in params 197 if is_count_output(name, elt)] 198 if countParams: 199 assert(len(countParams) == 1) 200 if 'VK_INCOMPLETE' not in successcodes: 201 self.record_error( 202 "Apparent enumeration of an array without VK_INCOMPLETE in successcodes.") 203 204 elif 'VK_INCOMPLETE' in successcodes: 205 self.record_error( 206 "VK_INCOMPLETE in successcodes of command that is apparently not an array enumeration.") 207 208 def check_param(self, param): 209 """Check a member of a struct or a param of a function. 210 211 Called from check_params.""" 212 super().check_param(param) 213 214 if not self.is_api_type(param): 215 return 216 217 param_text = "".join(param.itertext()) 218 param_name = getElemName(param) 219 220 # Make sure the number of leading "p" matches the pointer count. 221 pointercount = param.find('type').tail 222 if pointercount: 223 pointercount = pointercount.count('*') 224 if pointercount: 225 prefix = 'p' * pointercount 226 if not param_name.startswith(prefix): 227 param_type = param.find('type').text 228 message = "Apparently incorrect pointer-related name prefix for {} - expected it to start with '{}'".format( 229 param_text, prefix) 230 if (self.entity, param_type, param_name) in CHECK_PARAM_POINTER_NAME_EXCEPTIONS: 231 self.record_warning('(Allowed exception)', message, elem=param) 232 else: 233 self.record_error(message, elem=param) 234 235 # Make sure pNext members have optional="true" attributes 236 if param_name == self.conventions.nextpointer_member_name: 237 optional = param.get('optional') 238 if optional is None or optional != 'true': 239 message = '{}.pNext member is missing \'optional="true"\' attribute'.format(self.entity) 240 if self.entity in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS: 241 self.record_warning('(Allowed exception)', message, elem=param) 242 else: 243 self.record_error(message, elem=param) 244 245 def check_type(self, name, info, category): 246 """Check a type's XML data for consistency. 247 248 Called from check.""" 249 250 elem = info.elem 251 type_elts = [elt 252 for elt in elem.findall("member") 253 if getElemType(elt) == TYPEENUM] 254 if category == 'struct' and type_elts: 255 if len(type_elts) > 1: 256 self.record_error( 257 "Have more than one member of type", TYPEENUM) 258 else: 259 type_elt = type_elts[0] 260 val = type_elt.get('values') 261 if val and val not in self.structure_types: 262 self.record_error("Unknown structure type constant", val) 263 264 # Check the pointer chain member, if present. 265 next_name = self.conventions.nextpointer_member_name 266 next_member = findNamedElem(info.elem.findall('member'), next_name) 267 if next_member is not None: 268 # Ensure that the 'optional' attribute is set to 'true' 269 optional = next_member.get('optional') 270 if optional is None or optional != 'true': 271 message = '{}.{} member is missing \'optional="true"\' attribute'.format(name, next_name) 272 if name in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS: 273 self.record_warning('(Allowed exception)', message) 274 else: 275 self.record_error(message) 276 277 elif category == "bitmask": 278 if 'Flags' in name: 279 expected_require = name.replace('Flags', 'FlagBits') 280 require = info.elem.get('require') 281 if require is not None and expected_require != require: 282 self.record_error("Unexpected require attribute value:", 283 "got", require, 284 "but expected", expected_require) 285 super().check_type(name, info, category) 286 287 def check_extension(self, name, info): 288 """Check an extension's XML data for consistency. 289 290 Called from check.""" 291 elem = info.elem 292 enums = elem.findall('./require/enum[@name]') 293 294 # Look for other extensions using that number 295 # Keep track of this extension number reservation 296 ext_number = elem.get('number') 297 if ext_number in self.extension_number_reservations: 298 conflicts = self.extension_number_reservations[ext_number] 299 self.record_error('Extension number {} has more than one reservation: {}, {}'.format( 300 ext_number, name, ', '.join(conflicts))) 301 self.extension_number_reservations[ext_number].append(name) 302 else: 303 self.extension_number_reservations[ext_number] = [ name ] 304 305 # If extension name is not on the exception list and matches the 306 # versioned-extension pattern, map the extension name to the version 307 # name with the version as a separate word. Otherwise just map it to 308 # the upper-case version of the extension name. 309 310 matches = EXTNAME_RE.fullmatch(name) 311 ext_versioned_name = False 312 if name in EXTENSION_ENUM_NAME_SPELLING_CHANGE: 313 ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(name) 314 elif matches is None or name in EXTENSION_NAME_VERSION_EXCEPTIONS: 315 # This is the usual case, either a name that does not look 316 # versioned, or one that does but is on the exception list. 317 ext_enum_name = name.upper() 318 else: 319 # This is a versioned extension name. 320 # Treat the version number as a separate word. 321 base = matches.group('base') 322 version = matches.group('version') 323 ext_enum_name = base.upper() + '_' + version 324 # Keep track of this case 325 ext_versioned_name = True 326 327 # Look for the expected SPEC_VERSION token name 328 version_name = "{}_SPEC_VERSION".format(ext_enum_name) 329 version_elem = findNamedElem(enums, version_name) 330 331 if version_elem is None: 332 # Did not find a SPEC_VERSION enum matching the extension name 333 if ext_versioned_name: 334 suffix = '\n\ 335 Make sure that trailing version numbers in extension names are treated\n\ 336 as separate words in extension enumerant names. If this is an extension\n\ 337 whose name ends in a number which is not a version, such as "...h264"\n\ 338 or "...int16", add it to EXTENSION_NAME_VERSION_EXCEPTIONS in\n\ 339 scripts/xml_consistency.py.' 340 else: 341 suffix = '' 342 self.record_error('Missing version enum {}{}'.format(version_name, suffix)) 343 elif self.conventions.xml_api_name in info.elem.get('supported').split(','): 344 # Skip unsupported / disabled extensions for these checks 345 346 fn = get_extension_source(name) 347 revisions = [] 348 with open(fn, 'r', encoding='utf-8') as fp: 349 for line in fp: 350 line = line.rstrip() 351 match = REVISION_RE.match(line) 352 if match: 353 revisions.append(int(match.group('num'))) 354 ver_from_xml = version_elem.get('value') 355 if revisions: 356 ver_from_text = str(max(revisions)) 357 if ver_from_xml != ver_from_text: 358 self.record_error("Version enum mismatch: spec text indicates", ver_from_text, 359 "but XML says", ver_from_xml) 360 else: 361 if ver_from_xml == '1': 362 self.record_warning( 363 "Cannot find version history in spec text - make sure it has lines starting exactly like '* Revision 1, ....'", 364 filename=fn) 365 else: 366 self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml, 367 " - make sure the spec text has lines starting exactly like '* Revision 1, ....'", 368 filename=fn) 369 370 name_define = "{}_EXTENSION_NAME".format(ext_enum_name) 371 name_elem = findNamedElem(enums, name_define) 372 if name_elem is None: 373 self.record_error("Missing name enum", name_define) 374 else: 375 # Note: etree handles the XML entities here and turns " back into " 376 expected_name = '"{}"'.format(name) 377 name_val = name_elem.get('value') 378 if name_val != expected_name: 379 self.record_error("Incorrect name enum: expected", expected_name, 380 "got", name_val) 381 382 super().check_extension(name, elem) 383 384 def check_format(self): 385 """Check an extension's XML data for consistency. 386 387 Called from check.""" 388 389 astc3d_formats = [ 390 'VK_FORMAT_ASTC_3x3x3_UNORM_BLOCK_EXT', 391 'VK_FORMAT_ASTC_3x3x3_SRGB_BLOCK_EXT', 392 'VK_FORMAT_ASTC_3x3x3_SFLOAT_BLOCK_EXT', 393 'VK_FORMAT_ASTC_4x3x3_UNORM_BLOCK_EXT', 394 'VK_FORMAT_ASTC_4x3x3_SRGB_BLOCK_EXT', 395 'VK_FORMAT_ASTC_4x3x3_SFLOAT_BLOCK_EXT', 396 'VK_FORMAT_ASTC_4x4x3_UNORM_BLOCK_EXT', 397 'VK_FORMAT_ASTC_4x4x3_SRGB_BLOCK_EXT', 398 'VK_FORMAT_ASTC_4x4x3_SFLOAT_BLOCK_EXT', 399 'VK_FORMAT_ASTC_4x4x4_UNORM_BLOCK_EXT', 400 'VK_FORMAT_ASTC_4x4x4_SRGB_BLOCK_EXT', 401 'VK_FORMAT_ASTC_4x4x4_SFLOAT_BLOCK_EXT', 402 'VK_FORMAT_ASTC_5x4x4_UNORM_BLOCK_EXT', 403 'VK_FORMAT_ASTC_5x4x4_SRGB_BLOCK_EXT', 404 'VK_FORMAT_ASTC_5x4x4_SFLOAT_BLOCK_EXT', 405 'VK_FORMAT_ASTC_5x5x4_UNORM_BLOCK_EXT', 406 'VK_FORMAT_ASTC_5x5x4_SRGB_BLOCK_EXT', 407 'VK_FORMAT_ASTC_5x5x4_SFLOAT_BLOCK_EXT', 408 'VK_FORMAT_ASTC_5x5x5_UNORM_BLOCK_EXT', 409 'VK_FORMAT_ASTC_5x5x5_SRGB_BLOCK_EXT', 410 'VK_FORMAT_ASTC_5x5x5_SFLOAT_BLOCK_EXT', 411 'VK_FORMAT_ASTC_6x5x5_UNORM_BLOCK_EXT', 412 'VK_FORMAT_ASTC_6x5x5_SRGB_BLOCK_EXT', 413 'VK_FORMAT_ASTC_6x5x5_SFLOAT_BLOCK_EXT', 414 'VK_FORMAT_ASTC_6x6x5_UNORM_BLOCK_EXT', 415 'VK_FORMAT_ASTC_6x6x5_SRGB_BLOCK_EXT', 416 'VK_FORMAT_ASTC_6x6x5_SFLOAT_BLOCK_EXT', 417 'VK_FORMAT_ASTC_6x6x6_UNORM_BLOCK_EXT', 418 'VK_FORMAT_ASTC_6x6x6_SRGB_BLOCK_EXT', 419 'VK_FORMAT_ASTC_6x6x6_SFLOAT_BLOCK_EXT' 420 ] 421 422 # Need to build list of formats from rest of <enums> 423 enum_formats = [] 424 for enum in self.reg.groupdict["VkFormat"].elem: 425 if enum.get("alias") is None and enum.get("name") != "VK_FORMAT_UNDEFINED": 426 enum_formats.append(enum.get("name")) 427 428 found_formats = [] 429 for name, info in self.reg.formatsdict.items(): 430 found_formats.append(name) 431 self.set_error_context(entity=name, elem=info.elem) 432 433 if name not in enum_formats: 434 self.record_error("The <format> has no matching <enum> for", name) 435 436 # Check never just 1 plane 437 plane_elems = info.elem.findall("plane") 438 if len(plane_elems) == 1: 439 self.record_error("The <format> has only 1 <plane> for", name) 440 441 valid_chroma = ["420", "422", "444"] 442 if info.elem.get("chroma") and info.elem.get("chroma") not in valid_chroma: 443 self.record_error("The <format> has chroma is not a valid value for", name) 444 445 # Re-loop to check the other way if the <format> is missing 446 for enum in self.reg.groupdict["VkFormat"].elem: 447 name = enum.get("name") 448 if enum.get("alias") is None and name != "VK_FORMAT_UNDEFINED": 449 if name not in found_formats and name not in astc3d_formats: 450 self.set_error_context(entity=name, elem=enum) 451 self.record_error("The <enum> has no matching <format> for ", name) 452 453 super().check_format() 454 455 456if __name__ == "__main__": 457 458 ckr = Checker() 459 ckr.check() 460 461 if ckr.fail: 462 sys.exit(1) 463