• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3 -i
2#
3# Copyright 2023-2025 The Khronos Group Inc.
4#
5# SPDX-License-Identifier: Apache-2.0
6
7import pickle
8import os
9import tempfile
10from vulkan_object import (VulkanObject,
11    Extension, Version, Handle, Param, Queues, CommandScope, Command,
12    EnumField, Enum, Flag, Bitmask, Member, Struct,
13    FormatComponent, FormatPlane, Format,
14    SyncSupport, SyncEquivalent, SyncStage, SyncAccess, SyncPipelineStage, SyncPipeline,
15    SpirvEnables, Spirv)
16
17# These live in the Vulkan-Docs repo, but are pulled in via the
18# Vulkan-Headers/registry folder
19from generator import OutputGenerator, GeneratorOptions, write
20from vkconventions import VulkanConventions
21
22# An API style convention object
23vulkanConventions = VulkanConventions()
24
25# Helpers to keep things cleaner
26def splitIfGet(elem, name):
27    return elem.get(name).split(',') if elem.get(name) is not None and elem.get(name) != '' else None
28
29def textIfFind(elem, name):
30    return elem.find(name).text if elem.find(name) is not None else None
31
32def intIfGet(elem, name):
33    return None if elem.get(name) is None else int(elem.get(name), 0)
34
35def boolGet(elem, name) -> bool:
36    return elem.get(name) is not None and elem.get(name) == "true"
37
38def getQueues(elem) -> Queues:
39    queues = 0
40    queues_list = splitIfGet(elem, 'queues')
41    if queues_list is not None:
42        queues |= Queues.TRANSFER if 'transfer' in queues_list else 0
43        queues |= Queues.GRAPHICS if 'graphics' in queues_list else 0
44        queues |= Queues.COMPUTE if 'compute' in queues_list else 0
45        queues |= Queues.PROTECTED if 'protected' in queues_list else 0
46        queues |= Queues.SPARSE_BINDING if 'sparse_binding' in queues_list else 0
47        queues |= Queues.OPTICAL_FLOW if 'opticalflow' in queues_list else 0
48        queues |= Queues.DECODE if 'decode' in queues_list else 0
49        queues |= Queues.ENCODE if 'encode' in queues_list else 0
50    return queues
51
52# Shared object used by Sync elements that do not have ones
53maxSyncSupport = SyncSupport(None, None, True)
54maxSyncEquivalent = SyncEquivalent(None, None, True)
55
56# Helpers to set GeneratorOptions options globally
57def SetOutputFileName(fileName: str) -> None:
58    global globalFileName
59    globalFileName = fileName
60
61def SetOutputDirectory(directory: str) -> None:
62    global globalDirectory
63    globalDirectory = directory
64
65def SetTargetApiName(apiname: str) -> None:
66    global globalApiName
67    globalApiName = apiname
68
69def SetMergedApiNames(names: str) -> None:
70    global mergedApiNames
71    mergedApiNames = names
72
73cachingEnabled = False
74def EnableCaching() -> None:
75    global cachingEnabled
76    cachingEnabled = True
77
78# This class is a container for any source code, data, or other behavior that is necessary to
79# customize the generator script for a specific target API variant (e.g. Vulkan SC). As such,
80# all of these API-specific interfaces and their use in the generator script are part of the
81# contract between this repository and its downstream users. Changing or removing any of these
82# interfaces or their use in the generator script will have downstream effects and thus
83# should be avoided unless absolutely necessary.
84class APISpecific:
85    # Version object factory method
86    @staticmethod
87    def createApiVersion(targetApiName: str, name: str) -> Version:
88        match targetApiName:
89
90            # Vulkan SC specific API version creation
91            case 'vulkansc':
92                nameApi = name.replace('VK_', 'VK_API_')
93                nameApi = nameApi.replace('VKSC_', 'VKSC_API_')
94                nameString = f'"{name}"'
95                return Version(name, nameString, nameApi)
96
97            # Vulkan specific API version creation
98            case 'vulkan':
99                nameApi = name.replace('VK_', 'VK_API_')
100                nameString = f'"{name}"'
101                return Version(name, nameString, nameApi)
102
103
104# This Generator Option is used across all generators.
105# After years of use, it has shown that most the options are unified across each generator (file)
106# as it is easier to modify things per-file that need the difference
107class BaseGeneratorOptions(GeneratorOptions):
108    def __init__(self,
109                 customFileName = None,
110                 customDirectory = None,
111                 customApiName = None):
112        GeneratorOptions.__init__(self,
113                conventions = vulkanConventions,
114                filename = customFileName if customFileName else globalFileName,
115                directory = customDirectory if customDirectory else globalDirectory,
116                apiname = customApiName if customApiName else globalApiName,
117                mergeApiNames = mergedApiNames,
118                defaultExtensions = customApiName if customApiName else globalApiName,
119                emitExtensions = '.*',
120                emitSpirv = '.*',
121                emitFormats = '.*')
122        # These are used by the generator.py script
123        self.apicall         = 'VKAPI_ATTR '
124        self.apientry        = 'VKAPI_CALL '
125        self.apientryp       = 'VKAPI_PTR *'
126        self.alignFuncParam  = 48
127
128#
129# This object handles all the parsing from reg.py generator scripts in the Vulkan-Headers
130# It will grab all the data and form it into a single object the rest of the generators will use
131class BaseGenerator(OutputGenerator):
132    def __init__(self):
133        OutputGenerator.__init__(self, None, None, None)
134        self.vk = VulkanObject()
135        self.targetApiName = globalApiName
136
137        # reg.py has a `self.featureName` but this is nicer because
138        # it will be either the Version or Extension object
139        self.currentExtension = None
140        self.currentVersion = None
141
142        # Will map alias to promoted name
143        #   ex. ['VK_FILTER_CUBIC_IMG' : 'VK_FILTER_CUBIC_EXT']
144        # When generating any code, there is no reason so use the old name
145        self.enumAliasMap = dict()
146        self.enumFieldAliasMap = dict()
147        self.bitmaskAliasMap = dict()
148        self.flagAliasMap = dict()
149        self.structAliasMap = dict()
150
151    def write(self, data):
152        # Prevents having to check before writing
153        if data is not None and data != "":
154            write(data, file=self.outFile)
155
156
157    def beginFile(self, genOpts):
158        OutputGenerator.beginFile(self, genOpts)
159        self.filename = genOpts.filename
160
161        # No gen*() command to get these, so do it manually
162        for platform in self.registry.tree.findall('platforms/platform'):
163            self.vk.platforms[platform.get('name')] = platform.get('protect')
164
165        for tags in self.registry.tree.findall('tags'):
166            for tag in tags.findall('tag'):
167                self.vk.vendorTags.append(tag.get('name'))
168
169        # No way known to get this from the XML
170        self.vk.queueBits[Queues.TRANSFER]       = 'VK_QUEUE_TRANSFER_BIT'
171        self.vk.queueBits[Queues.GRAPHICS]       = 'VK_QUEUE_GRAPHICS_BIT'
172        self.vk.queueBits[Queues.COMPUTE]        = 'VK_QUEUE_COMPUTE_BIT'
173        self.vk.queueBits[Queues.PROTECTED]      = 'VK_QUEUE_PROTECTED_BIT'
174        self.vk.queueBits[Queues.SPARSE_BINDING] = 'VK_QUEUE_SPARSE_BINDING_BIT'
175        self.vk.queueBits[Queues.OPTICAL_FLOW]   = 'VK_QUEUE_OPTICAL_FLOW_BIT_NV'
176        self.vk.queueBits[Queues.DECODE]         = 'VK_QUEUE_VIDEO_DECODE_BIT_KHR'
177        self.vk.queueBits[Queues.ENCODE]         = 'VK_QUEUE_VIDEO_ENCODE_BIT_KHR'
178
179    # This function should be overloaded
180    def generate(self):
181        print("WARNING: This should not be called from the child class")
182        return
183
184    # This function is dense, it does all the magic to set the right extensions dependencies!
185    #
186    # The issue is if 2 extension expose a command, genCmd() will only
187    # show one of the extension, at endFile() we can finally go through
188    # and update which things depend on which extensions
189    #
190    # self.featureDictionary is built for use in the reg.py framework
191    # Details found in Vulkan-Docs/scripts/scriptgenerator.py
192    def applyExtensionDependency(self):
193        for extension in self.vk.extensions.values():
194            # dict.key() can be None, so need to double loop
195            dict = self.featureDictionary[extension.name]['command']
196
197            # "required" == None
198            #         or
199            #  an additional feature dependency, which is a boolean expression of
200            #  one or more extension and/or core version names
201            for required in dict:
202                for commandName in dict[required]:
203                    # Skip commands removed in the target API
204                    # This check is needed because parts of the base generator code bypass the
205                    # dependency resolution logic in the registry tooling and thus the generator
206                    # may attempt to generate code for commands which are not supported in the
207                    # target API variant, thus this check needs to happen even if any specific
208                    # target API variant may not specifically need it
209                    if not commandName in self.vk.commands:
210                        continue
211
212                    command = self.vk.commands[commandName]
213                    # Make sure list is unique
214                    command.extensions.extend([extension] if extension not in command.extensions else [])
215                    extension.commands.extend([command] if command not in extension.commands else [])
216
217            # While genGroup() will call twice with aliased value, it does not provide all the information we need
218            dict = self.featureDictionary[extension.name]['enumconstant']
219            for required in dict:
220                # group can be a Enum or Bitmask
221                for group in dict[required]:
222                    if group in self.vk.enums:
223                        if group not in extension.enumFields:
224                            extension.enumFields[group] = [] # Dict needs init
225                        enum = self.vk.enums[group]
226                        # Need to convert all alias so they match what is in EnumField
227                        enumList = list(map(lambda x: x if x not in self.enumFieldAliasMap else self.enumFieldAliasMap[x], dict[required][group]))
228
229                        for enumField in [x for x in enum.fields if x.name in enumList]:
230                            # Make sure list is unique
231                            enum.fieldExtensions.extend([extension] if extension not in enum.fieldExtensions else [])
232                            enumField.extensions.extend([extension] if extension not in enumField.extensions else [])
233                            extension.enumFields[group].extend([enumField] if enumField not in extension.enumFields[group] else [])
234                    if group in self.vk.bitmasks:
235                        if group not in extension.flags:
236                            extension.flags[group] = [] # Dict needs init
237                        bitmask = self.vk.bitmasks[group]
238                        # Need to convert all alias so they match what is in Flags
239                        flagList = list(map(lambda x: x if x not in self.flagAliasMap else self.flagAliasMap[x], dict[required][group]))
240
241                        for flags in [x for x in bitmask.flags if x.name in flagList]:
242                            # Make sure list is unique
243                            bitmask.flagExtensions.extend([extension] if extension not in bitmask.flagExtensions else [])
244                            flags.extensions.extend([extension] if extension not in flags.extensions else [])
245                            extension.flags[group].extend([flags] if flags not in extension.flags[group] else [])
246
247        # Need to do 'enum'/'bitmask' after 'enumconstant' has applied everything so we can add implicit extensions
248        #
249        # Sometimes two extensions enable an Enum, but the newer extension version has extra flags allowed
250        # This information seems to be implicit, so need to update it here
251        # Go through each Flag and append the Enum extension to it
252        #
253        # ex. VkAccelerationStructureTypeKHR where GENERIC_KHR is not allowed with just VK_NV_ray_tracing
254        # This only works because the values are aliased as well, making the KHR a superset enum
255        for extension in self.vk.extensions.values():
256            dict = self.featureDictionary[extension.name]['enum']
257            for required in dict:
258                for group in dict[required]:
259                    for enumName in dict[required][group]:
260                        isAlias = enumName in self.enumAliasMap
261                        enumName = self.enumAliasMap[enumName] if isAlias else enumName
262                        if enumName in self.vk.enums:
263                            enum = self.vk.enums[enumName]
264                            enum.extensions.extend([extension] if extension not in enum.extensions else [])
265                            extension.enums.extend([enum] if enum not in extension.enums else [])
266                            # Update fields with implicit base extension
267                            if isAlias:
268                                continue
269                            enum.fieldExtensions.extend([extension] if extension not in enum.fieldExtensions else [])
270                            for enumField in [x for x in enum.fields if (not x.extensions or (x.extensions and all(e in enum.extensions for e in x.extensions)))]:
271                                enumField.extensions.extend([extension] if extension not in enumField.extensions else [])
272                                if enumName not in extension.enumFields:
273                                    extension.enumFields[enumName] = [] # Dict needs init
274                                extension.enumFields[enumName].extend([enumField] if enumField not in extension.enumFields[enumName] else [])
275
276            dict = self.featureDictionary[extension.name]['bitmask']
277            for required in dict:
278                for group in dict[required]:
279                    for bitmaskName in dict[required][group]:
280                        bitmaskName = bitmaskName.replace('Flags', 'FlagBits') # Works since Flags is not repeated in name
281                        isAlias = bitmaskName in self.bitmaskAliasMap
282                        bitmaskName = self.bitmaskAliasMap[bitmaskName] if isAlias else bitmaskName
283                        if bitmaskName in self.vk.bitmasks:
284                            bitmask = self.vk.bitmasks[bitmaskName]
285                            bitmask.extensions.extend([extension] if extension not in bitmask.extensions else [])
286                            extension.bitmasks.extend([bitmask] if bitmask not in extension.bitmasks else [])
287                            # Update flags with implicit base extension
288                            if isAlias:
289                                continue
290                            bitmask.flagExtensions.extend([extension] if extension not in bitmask.flagExtensions else [])
291                            for flag in [x for x in bitmask.flags if (not x.extensions or (x.extensions and all(e in bitmask.extensions for e in x.extensions)))]:
292                                flag.extensions.extend([extension] if extension not in flag.extensions else [])
293                                if bitmaskName not in extension.flags:
294                                    extension.flags[bitmaskName] = [] # Dict needs init
295                                extension.flags[bitmaskName].extend([flag] if flag not in extension.flags[bitmaskName] else [])
296
297        # Some structs (ex VkAttachmentSampleCountInfoAMD) can have multiple alias pointing to same extension
298        for extension in self.vk.extensions.values():
299            dict = self.featureDictionary[extension.name]['struct']
300            for required in dict:
301                for group in dict[required]:
302                    for structName in dict[required][group]:
303                        isAlias = structName in self.structAliasMap
304                        structName = self.structAliasMap[structName] if isAlias else structName
305                        # An EXT struct can alias a KHR struct,
306                        # that in turns aliaes a core struct
307                        # => Try to propagate aliasing, it can safely result in a no-op
308                        isAlias = structName in self.structAliasMap
309                        structName = self.structAliasMap[structName] if isAlias else structName
310                        if structName in self.vk.structs:
311                            struct = self.vk.structs[structName]
312                            struct.extensions.extend([extension] if extension not in struct.extensions else [])
313
314        # While we update struct alias inside other structs, the command itself might have the struct as a first level param.
315        # We use this time to update params to have the promoted name
316        # Example - https://github.com/KhronosGroup/Vulkan-ValidationLayers/issues/9322
317        for command in self.vk.commands.values():
318            for member in command.params:
319                if member.type in self.structAliasMap:
320                    member.type = self.structAliasMap[member.type]
321
322    def endFile(self):
323        # This is the point were reg.py has ran, everything is collected
324        # We do some post processing now
325        self.applyExtensionDependency()
326
327        # Use structs and commands to find which things are returnedOnly
328        for struct in [x for x in self.vk.structs.values() if not x.returnedOnly]:
329            for enum in [self.vk.enums[x.type] for x in struct.members if x.type in self.vk.enums]:
330                enum.returnedOnly = False
331            for bitmask in [self.vk.bitmasks[x.type] for x in struct.members if x.type in self.vk.bitmasks]:
332                bitmask.returnedOnly = False
333            for bitmask in [self.vk.bitmasks[x.type.replace('Flags', 'FlagBits')] for x in struct.members if x.type.replace('Flags', 'FlagBits') in self.vk.bitmasks]:
334                bitmask.returnedOnly = False
335        for command in self.vk.commands.values():
336            for enum in [self.vk.enums[x.type] for x in command.params if x.type in self.vk.enums]:
337                enum.returnedOnly = False
338            for bitmask in [self.vk.bitmasks[x.type] for x in command.params if x.type in self.vk.bitmasks]:
339                bitmask.returnedOnly = False
340            for bitmask in [self.vk.bitmasks[x.type.replace('Flags', 'FlagBits')] for x in command.params if x.type.replace('Flags', 'FlagBits') in self.vk.bitmasks]:
341                bitmask.returnedOnly = False
342
343        # Turn handle parents into pointers to classes
344        for handle in [x for x in self.vk.handles.values() if x.parent is not None]:
345            handle.parent = self.vk.handles[handle.parent]
346        # search up parent chain to see if instance or device
347        for handle in [x for x in self.vk.handles.values()]:
348            next_parent = handle.parent
349            while (not handle.instance and not handle.device):
350                handle.instance = next_parent.name == 'VkInstance'
351                handle.device = next_parent.name == 'VkDevice'
352                next_parent = next_parent.parent
353
354        maxSyncSupport.queues = Queues.ALL
355        maxSyncSupport.stages = self.vk.bitmasks['VkPipelineStageFlagBits2'].flags
356        maxSyncEquivalent.accesses = self.vk.bitmasks['VkAccessFlagBits2'].flags
357        maxSyncEquivalent.stages = self.vk.bitmasks['VkPipelineStageFlagBits2'].flags
358
359        # All inherited generators should run from here
360        self.generate()
361
362        if cachingEnabled:
363            cachePath = os.path.join(tempfile.gettempdir(), f'vkobject_{os.getpid()}')
364            if not os.path.isfile(cachePath):
365                cacheFile = open(cachePath, 'wb')
366                pickle.dump(self.vk, cacheFile)
367                cacheFile.close()
368
369        # This should not have to do anything but call into OutputGenerator
370        OutputGenerator.endFile(self)
371
372    #
373    # Bypass the entire processing and load in the VkObject data
374    # Still need to handle the beingFile/endFile for reg.py
375    def generateFromCache(self, cacheVkObjectData, genOpts):
376        OutputGenerator.beginFile(self, genOpts)
377        self.filename = genOpts.filename
378        self.vk = cacheVkObjectData
379        self.generate()
380        OutputGenerator.endFile(self)
381
382    #
383    # Processing point at beginning of each extension definition
384    def beginFeature(self, interface, emit):
385        OutputGenerator.beginFeature(self, interface, emit)
386        platform = interface.get('platform')
387        self.featureExtraProtec = self.vk.platforms[platform] if platform in self.vk.platforms else None
388        protect = self.vk.platforms[platform] if platform in self.vk.platforms else None
389        name = interface.get('name')
390
391        if interface.tag == 'extension':
392            instance = interface.get('type') == 'instance'
393            device = not instance
394            depends = interface.get('depends')
395            vendorTag = interface.get('author')
396            platform = interface.get('platform')
397            provisional = boolGet(interface, 'provisional')
398            promotedto = interface.get('promotedto')
399            deprecatedby = interface.get('deprecatedby')
400            obsoletedby = interface.get('obsoletedby')
401            specialuse = splitIfGet(interface, 'specialuse')
402            # Not sure if better way to get this info
403            specVersion = self.featureDictionary[name]['enumconstant'][None][None][0]
404            nameString = self.featureDictionary[name]['enumconstant'][None][None][1]
405
406            self.currentExtension = Extension(name, nameString, specVersion, instance, device, depends, vendorTag,
407                                            platform, protect, provisional, promotedto, deprecatedby,
408                                            obsoletedby, specialuse)
409            self.vk.extensions[name] = self.currentExtension
410        else: # version
411            number = interface.get('number')
412            if number != '1.0':
413                self.currentVersion = APISpecific.createApiVersion(self.targetApiName, name)
414                self.vk.versions[name] = self.currentVersion
415
416    def endFeature(self):
417        OutputGenerator.endFeature(self)
418        self.currentExtension = None
419        self.currentVersion = None
420
421    #
422    # All <command> from XML
423    def genCmd(self, cmdinfo, name, alias):
424        OutputGenerator.genCmd(self, cmdinfo, name, alias)
425
426        params = []
427        for param in cmdinfo.elem.findall('param'):
428            paramName = param.find('name').text
429            paramType = textIfFind(param, 'type')
430            paramAlias = param.get('alias')
431
432            cdecl = self.makeCParamDecl(param, 0)
433            pointer = '*' in cdecl or paramType.startswith('PFN_')
434            paramConst = 'const' in cdecl
435            fixedSizeArray = [x[:-1] for x in cdecl.split('[') if x.endswith(']')]
436
437            paramNoautovalidity = boolGet(param, 'noautovalidity')
438
439            nullTerminated = False
440            length = param.get('altlen') if param.get('altlen') is not None else param.get('len')
441            if length:
442                # we will either find it like "null-terminated" or "enabledExtensionCount,null-terminated"
443                # This finds both
444                nullTerminated = 'null-terminated' in length
445                length = length.replace(',null-terminated', '') if 'null-terminated' in length else length
446                length = None if length == 'null-terminated' else length
447
448            if fixedSizeArray and not length:
449                length = ','.join(fixedSizeArray)
450
451            # See Member::optional code for details of this
452            optionalValues = splitIfGet(param, 'optional')
453            optional = optionalValues is not None and optionalValues[0].lower() == "true"
454            optionalPointer = optionalValues is not None and len(optionalValues) > 1 and optionalValues[1].lower() == "true"
455
456            # externsync will be 'true' or expression
457            # if expression, it should be same as 'true'
458            externSync = boolGet(param, 'externsync')
459            externSyncPointer = None if externSync else splitIfGet(param, 'externsync')
460            if not externSync and externSyncPointer is not None:
461                externSync = True
462
463            params.append(Param(paramName, paramAlias, paramType, paramNoautovalidity,
464                                paramConst, length, nullTerminated, pointer, fixedSizeArray,
465                                optional, optionalPointer,
466                                externSync, externSyncPointer, cdecl))
467
468        attrib = cmdinfo.elem.attrib
469        alias = attrib.get('alias')
470        tasks = splitIfGet(attrib, 'tasks')
471
472        queues = getQueues(attrib)
473        successcodes = splitIfGet(attrib, 'successcodes')
474        errorcodes = splitIfGet(attrib, 'errorcodes')
475        cmdbufferlevel = attrib.get('cmdbufferlevel')
476        primary = cmdbufferlevel is not None and 'primary' in cmdbufferlevel
477        secondary = cmdbufferlevel is not None and 'secondary' in cmdbufferlevel
478
479        renderpass = attrib.get('renderpass')
480        renderpass = CommandScope.NONE if renderpass is None else getattr(CommandScope, renderpass.upper())
481        videocoding = attrib.get('videocoding')
482        videocoding = CommandScope.NONE if videocoding is None else getattr(CommandScope, videocoding.upper())
483
484        protoElem = cmdinfo.elem.find('proto')
485        returnType = textIfFind(protoElem, 'type')
486
487        decls = self.makeCDecls(cmdinfo.elem)
488        cPrototype = decls[0]
489        cFunctionPointer = decls[1]
490
491        protect = self.currentExtension.protect if self.currentExtension is not None else None
492
493        # These coammds have no way from the XML to detect they would be an instance command
494        specialInstanceCommand = ['vkCreateInstance', 'vkEnumerateInstanceExtensionProperties','vkEnumerateInstanceLayerProperties', 'vkEnumerateInstanceVersion']
495        instance = len(params) > 0 and (params[0].type == 'VkInstance' or params[0].type == 'VkPhysicalDevice' or name in specialInstanceCommand)
496        device = not instance
497
498        implicitElem = cmdinfo.elem.find('implicitexternsyncparams')
499        implicitExternSyncParams = [x.text for x in implicitElem.findall('param')] if implicitElem else []
500
501        self.vk.commands[name] = Command(name, alias, protect, [], self.currentVersion,
502                                         returnType, params, instance, device,
503                                         tasks, queues, successcodes, errorcodes,
504                                         primary, secondary, renderpass, videocoding,
505                                         implicitExternSyncParams, cPrototype, cFunctionPointer)
506
507    #
508    # List the enum for the commands
509    # TODO - Seems empty groups like `VkDeviceDeviceMemoryReportCreateInfoEXT` do not show up in here
510    def genGroup(self, groupinfo, groupName, alias):
511        # There can be case where the Enum/Bitmask is in a protect, but the individual
512        # fields also have their own protect
513        groupProtect = self.currentExtension.protect if hasattr(self.currentExtension, 'protect') and self.currentExtension.protect is not None else None
514        enumElem = groupinfo.elem
515        bitwidth = 32 if enumElem.get('bitwidth') is None else int(enumElem.get('bitwidth'))
516        fields = []
517        if enumElem.get('type') == "enum":
518            if alias is not None:
519                self.enumAliasMap[groupName] = alias
520                return
521
522            for elem in enumElem.findall('enum'):
523                fieldName = elem.get('name')
524
525                if elem.get('alias') is not None:
526                    self.enumFieldAliasMap[fieldName] = elem.get('alias')
527                    continue
528
529                negative = elem.get('dir') is not None
530                protect = elem.get('protect')
531
532                # Some values have multiple extensions (ex VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR)
533                # genGroup() lists them twice
534                if next((x for x in fields if x.name == fieldName), None) is None:
535                    fields.append(EnumField(fieldName, negative, protect, []))
536
537            self.vk.enums[groupName] = Enum(groupName, groupProtect, bitwidth, True, fields, [], [])
538
539        else: # "bitmask"
540            if alias is not None:
541                self.bitmaskAliasMap[groupName] = alias
542                return
543
544            for elem in enumElem.findall('enum'):
545                flagName = elem.get('name')
546
547                if elem.get('alias') is not None:
548                    self.flagAliasMap[flagName] = elem.get('alias')
549                    continue
550
551                flagMultiBit = False
552                flagZero = False
553                flagValue = intIfGet(elem, 'bitpos')
554                if flagValue is None:
555                    flagValue = intIfGet(elem, 'value')
556                    flagMultiBit = flagValue != 0
557                    flagZero = flagValue == 0
558                protect = elem.get('protect')
559
560                # Some values have multiple extensions (ex VK_TOOL_PURPOSE_DEBUG_REPORTING_BIT_EXT)
561                # genGroup() lists them twice
562                if next((x for x in fields if x.name == flagName), None) is None:
563                    fields.append(Flag(flagName, protect, flagValue, flagMultiBit, flagZero, []))
564
565            flagName = groupName.replace('FlagBits', 'Flags')
566            self.vk.bitmasks[groupName] = Bitmask(groupName, flagName, groupProtect, bitwidth, True, fields, [], [])
567
568    def genType(self, typeInfo, typeName, alias):
569        OutputGenerator.genType(self, typeInfo, typeName, alias)
570        typeElem = typeInfo.elem
571        protect = self.currentExtension.protect if hasattr(self.currentExtension, 'protect') and self.currentExtension.protect is not None else None
572        category = typeElem.get('category')
573        if (category == 'struct' or category == 'union'):
574            extension = [self.currentExtension] if self.currentExtension is not None else []
575            if alias is not None:
576                self.structAliasMap[typeName] = alias
577                return
578
579            union = category == 'union'
580
581            returnedOnly = boolGet(typeElem, 'returnedonly')
582            allowDuplicate = boolGet(typeElem, 'allowduplicate')
583
584            extends = splitIfGet(typeElem, 'structextends')
585            extendedBy = self.registry.validextensionstructs[typeName] if len(self.registry.validextensionstructs[typeName]) > 0 else None
586
587            membersElem = typeInfo.elem.findall('.//member')
588            members = []
589            sType = None
590
591            for member in membersElem:
592                for comment in member.findall('comment'):
593                    member.remove(comment)
594
595                name = textIfFind(member, 'name')
596                type = textIfFind(member, 'type')
597                sType = member.get('values') if member.get('values') is not None else sType
598                externSync = boolGet(member, 'externsync')
599                noautovalidity = boolGet(member, 'noautovalidity')
600                limittype = member.get('limittype')
601
602                nullTerminated = False
603                length = member.get('altlen') if member.get('altlen') is not None else member.get('len')
604                if length:
605                    # we will either find it like "null-terminated" or "enabledExtensionCount,null-terminated"
606                    # This finds both
607                    nullTerminated = 'null-terminated' in length
608                    length = length.replace(',null-terminated', '') if 'null-terminated' in length else length
609                    length = None if length == 'null-terminated' else length
610
611                cdecl = self.makeCParamDecl(member, 0)
612                pointer = '*' in cdecl or type.startswith('PFN_')
613                const = 'const' in cdecl
614                # Some structs like VkTransformMatrixKHR have a 2D array
615                fixedSizeArray = [x[:-1] for x in cdecl.split('[') if x.endswith(']')]
616
617                if fixedSizeArray and not length:
618                    length = ','.join(fixedSizeArray)
619
620                # if a pointer, this can be a something like:
621                #     optional="true,false" for ppGeometries
622                #     optional="false,true" for pPhysicalDeviceCount
623                # the first is if the variable itself is optional
624                # the second is the value of the pointer is optional;
625                optionalValues = splitIfGet(member, 'optional')
626                optional = optionalValues is not None and optionalValues[0].lower() == "true"
627                optionalPointer = optionalValues is not None and len(optionalValues) > 1 and optionalValues[1].lower() == "true"
628
629                members.append(Member(name, type, noautovalidity, limittype,
630                                      const, length, nullTerminated, pointer, fixedSizeArray,
631                                      optional, optionalPointer,
632                                      externSync, cdecl))
633
634            self.vk.structs[typeName] = Struct(typeName, extension, self.currentVersion, protect, members,
635                                               union, returnedOnly, sType, allowDuplicate, extends, extendedBy)
636
637        elif category == 'handle':
638            if alias is not None:
639                return
640            type = typeElem.get('objtypeenum')
641
642            # will resolve these later, the VulkanObjectType does not list things in dependent order
643            parent = typeElem.get('parent')
644            instance = typeName == 'VkInstance'
645            device = typeName == 'VkDevice'
646
647            dispatchable = typeElem.find('type').text == 'VK_DEFINE_HANDLE'
648
649            self.vk.handles[typeName] = Handle(typeName, type, protect, parent, instance, device, dispatchable)
650
651        elif category == 'define':
652            if typeName == 'VK_HEADER_VERSION':
653                self.vk.headerVersion = typeElem.find('name').tail.strip()
654
655        else:
656            # not all categories are used
657            #   'group'/'enum'/'bitmask' are routed to genGroup instead
658            #   'basetype'/'include' are only for headers
659            #   'funcpointer` ignore until needed
660            return
661
662    def genSpirv(self, spirvinfo, spirvName, alias):
663        OutputGenerator.genSpirv(self, spirvinfo, spirvName, alias)
664        spirvElem = spirvinfo.elem
665        name = spirvElem.get('name')
666        extension = True if spirvElem.tag == 'spirvextension' else False
667        capability = not extension
668
669        enables = []
670        for elem in spirvElem:
671            version = elem.attrib.get('version')
672            extensionEnable = elem.attrib.get('extension')
673            struct = elem.attrib.get('struct')
674            feature = elem.attrib.get('feature')
675            requires = elem.attrib.get('requires')
676            propertyEnable = elem.attrib.get('property')
677            member = elem.attrib.get('member')
678            value = elem.attrib.get('value')
679            enables.append(SpirvEnables(version, extensionEnable, struct, feature,
680                                        requires, propertyEnable, member, value))
681
682        self.vk.spirv.append(Spirv(name, extension, capability, enables))
683
684    def genFormat(self, format, formatinfo, alias):
685        OutputGenerator.genFormat(self, format, formatinfo, alias)
686        formatElem = format.elem
687        name = formatElem.get('name')
688
689        components = []
690        for component in formatElem.iterfind('component'):
691            type = component.get('name')
692            bits = component.get('bits')
693            numericFormat = component.get('numericFormat')
694            planeIndex = intIfGet(component, 'planeIndex')
695            components.append(FormatComponent(type, bits, numericFormat, planeIndex))
696
697        planes = []
698        for plane in formatElem.iterfind('plane'):
699            index = int(plane.get('index'))
700            widthDivisor = int(plane.get('widthDivisor'))
701            heightDivisor = int(plane.get('heightDivisor'))
702            compatible = plane.get('compatible')
703            planes.append(FormatPlane(index, widthDivisor, heightDivisor, compatible))
704
705        className = formatElem.get('class')
706        blockSize = int(formatElem.get('blockSize'))
707        texelsPerBlock = int(formatElem.get('texelsPerBlock'))
708        blockExtent = splitIfGet(formatElem, 'blockExtent')
709        packed = intIfGet(formatElem, 'packed')
710        chroma = formatElem.get('chroma')
711        compressed = formatElem.get('compressed')
712        spirvImageFormat = formatElem.find('spirvimageformat')
713        if spirvImageFormat is not None:
714            spirvImageFormat = spirvImageFormat.get('name')
715
716        self.vk.formats[name] = Format(name, className, blockSize, texelsPerBlock,
717                                       blockExtent, packed, chroma, compressed,
718                                       components, planes, spirvImageFormat)
719
720    def genSyncStage(self, sync):
721        OutputGenerator.genSyncStage(self, sync)
722        syncElem = sync.elem
723
724        support = maxSyncSupport
725        supportElem = syncElem.find('syncsupport')
726        if supportElem is not None:
727            queues = getQueues(supportElem)
728            stageNames = splitIfGet(supportElem, 'stage')
729            stages = [x for x in self.vk.bitmasks['VkPipelineStageFlagBits2'].flags if x.name in stageNames] if stageNames is not None else None
730            support = SyncSupport(queues, stages, False)
731
732        equivalent = maxSyncEquivalent
733        equivalentElem = syncElem.find('syncequivalent')
734        if equivalentElem is not None:
735            stageNames = splitIfGet(equivalentElem, 'stage')
736            stages = [x for x in self.vk.bitmasks['VkPipelineStageFlagBits2'].flags if x.name in stageNames] if stageNames is not None else None
737            accessNames = splitIfGet(equivalentElem, 'access')
738            accesses = [x for x in self.vk.bitmasks['VkAccessFlagBits2'].flags if x.name in accessNames] if accessNames is not None else None
739            equivalent = SyncEquivalent(stages, accesses, False)
740
741        flagName = syncElem.get('name')
742        flag = [x for x in self.vk.bitmasks['VkPipelineStageFlagBits2'].flags if x.name == flagName]
743        # This check is needed because not all API variants have VK_KHR_synchronization2
744        if flag:
745            self.vk.syncStage.append(SyncStage(flag[0], support, equivalent))
746
747    def genSyncAccess(self, sync):
748        OutputGenerator.genSyncAccess(self, sync)
749        syncElem = sync.elem
750
751        support = maxSyncSupport
752        supportElem = syncElem.find('syncsupport')
753        if supportElem is not None:
754            queues = getQueues(supportElem)
755            stageNames = splitIfGet(supportElem, 'stage')
756            stages = [x for x in self.vk.bitmasks['VkPipelineStageFlagBits2'].flags if x.name in stageNames] if stageNames is not None else None
757            support = SyncSupport(queues, stages, False)
758
759        equivalent = maxSyncEquivalent
760        equivalentElem = syncElem.find('syncequivalent')
761        if equivalentElem is not None:
762            stageNames = splitIfGet(equivalentElem, 'stage')
763            stages = [x for x in self.vk.bitmasks['VkPipelineStageFlagBits2'].flags if x.name in stageNames] if stageNames is not None else None
764            accessNames = splitIfGet(equivalentElem, 'access')
765            accesses = [x for x in self.vk.bitmasks['VkAccessFlagBits2'].flags if x.name in accessNames] if accessNames is not None else None
766            equivalent = SyncEquivalent(stages, accesses, False)
767
768        flagName = syncElem.get('name')
769        flag = [x for x in self.vk.bitmasks['VkAccessFlagBits2'].flags if x.name == flagName]
770        # This check is needed because not all API variants have VK_KHR_synchronization2
771        if flag:
772            self.vk.syncAccess.append(SyncAccess(flag[0], support, equivalent))
773
774    def genSyncPipeline(self, sync):
775        OutputGenerator.genSyncPipeline(self, sync)
776        syncElem = sync.elem
777        name = syncElem.get('name')
778        depends = splitIfGet(syncElem, 'depends')
779        stages = []
780        for stageElem in syncElem.findall('syncpipelinestage'):
781            order = stageElem.get('order')
782            before = stageElem.get('before')
783            after = stageElem.get('after')
784            value = stageElem.text
785            stages.append(SyncPipelineStage(order, before, after, value))
786
787        self.vk.syncPipeline.append(SyncPipeline(name, depends, stages))
788