• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15from typing import Dict, Optional, List, Set, Union
16from xml.etree.ElementTree import Element
17
18from generator import noneStr
19
20from copy import copy
21from string import whitespace
22
23# Holds information about core Vulkan objects
24# and the API calls that are used to create/destroy each one.
25class HandleInfo(object):
26    def __init__(self, name, createApis, destroyApis):
27        self.name = name
28        self.createApis = createApis
29        self.destroyApis = destroyApis
30
31    def isCreateApi(self, apiName):
32        return apiName == self.createApis or (apiName in self.createApis)
33
34    def isDestroyApi(self, apiName):
35        if self.destroyApis is None:
36            return False
37        return apiName == self.destroyApis or (apiName in self.destroyApis)
38
39DISPATCHABLE_HANDLE_TYPES = [
40    "VkInstance",
41    "VkPhysicalDevice",
42    "VkDevice",
43    "VkQueue",
44    "VkCommandBuffer",
45]
46
47NON_DISPATCHABLE_HANDLE_TYPES = [
48    "VkDeviceMemory",
49    "VkBuffer",
50    "VkBufferView",
51    "VkImage",
52    "VkImageView",
53    "VkShaderModule",
54    "VkDescriptorPool",
55    "VkDescriptorSetLayout",
56    "VkDescriptorSet",
57    "VkSampler",
58    "VkPipeline",
59    "VkPipelineLayout",
60    "VkRenderPass",
61    "VkFramebuffer",
62    "VkPipelineCache",
63    "VkCommandPool",
64    "VkFence",
65    "VkSemaphore",
66    "VkEvent",
67    "VkQueryPool",
68    "VkSamplerYcbcrConversion",
69    "VkSamplerYcbcrConversionKHR",
70    "VkDescriptorUpdateTemplate",
71    "VkSurfaceKHR",
72    "VkSwapchainKHR",
73    "VkDisplayKHR",
74    "VkDisplayModeKHR",
75    "VkObjectTableNVX",
76    "VkIndirectCommandsLayoutNVX",
77    "VkValidationCacheEXT",
78    "VkDebugReportCallbackEXT",
79    "VkDebugUtilsMessengerEXT",
80    "VkAccelerationStructureNV",
81    "VkIndirectCommandsLayoutNV",
82    "VkAccelerationStructureKHR",
83]
84
85CUSTOM_HANDLE_CREATE_TYPES = [
86    "VkPhysicalDevice",
87    "VkQueue",
88    "VkPipeline",
89    "VkDeviceMemory",
90    "VkDescriptorSet",
91    "VkCommandBuffer",
92    "VkRenderPass",
93]
94
95HANDLE_TYPES = list(sorted(list(set(DISPATCHABLE_HANDLE_TYPES +
96                                    NON_DISPATCHABLE_HANDLE_TYPES + CUSTOM_HANDLE_CREATE_TYPES))))
97
98HANDLE_INFO = {}
99
100for h in HANDLE_TYPES:
101    if h in CUSTOM_HANDLE_CREATE_TYPES:
102        if h == "VkPhysicalDevice":
103            HANDLE_INFO[h] = \
104                HandleInfo(
105                    "VkPhysicalDevice",
106                    "vkEnumeratePhysicalDevices", None)
107        if h == "VkQueue":
108            HANDLE_INFO[h] = \
109                HandleInfo(
110                    "VkQueue",
111                    "vkGetDeviceQueue", None)
112        if h == "VkPipeline":
113            HANDLE_INFO[h] = \
114                HandleInfo(
115                    "VkPipeline",
116                    ["vkCreateGraphicsPipelines", "vkCreateComputePipelines"],
117                    "vkDestroyPipeline")
118        if h == "VkDeviceMemory":
119            HANDLE_INFO[h] = \
120                HandleInfo("VkDeviceMemory",
121                           "vkAllocateMemory", ["vkFreeMemory", "vkFreeMemorySyncGOOGLE"])
122        if h == "VkDescriptorSet":
123            HANDLE_INFO[h] = \
124                HandleInfo("VkDescriptorSet", "vkAllocateDescriptorSets",
125                           "vkFreeDescriptorSets")
126        if h == "VkCommandBuffer":
127            HANDLE_INFO[h] = \
128                HandleInfo("VkCommandBuffer", "vkAllocateCommandBuffers",
129                           "vkFreeCommandBuffers")
130        if h == "VkRenderPass":
131            HANDLE_INFO[h] = \
132                HandleInfo(
133                    "VkRenderPass",
134                    ["vkCreateRenderPass", "vkCreateRenderPass2", "vkCreateRenderPass2KHR"],
135                    "vkDestroyRenderPass")
136    else:
137        HANDLE_INFO[h] = \
138            HandleInfo(h, "vkCreate" + h[2:], "vkDestroy" + h[2:])
139
140EXCLUDED_APIS = [
141    "vkEnumeratePhysicalDeviceGroups",
142]
143
144EXPLICITLY_ABI_PORTABLE_TYPES = [
145    "VkResult",
146    "VkBool32",
147    "VkSampleMask",
148    "VkFlags",
149    "VkDeviceSize",
150]
151
152EXPLICITLY_ABI_NON_PORTABLE_TYPES = [
153    "size_t"
154]
155
156NON_ABI_PORTABLE_TYPE_CATEGORIES = [
157    "handle",
158    "funcpointer",
159]
160
161DEVICE_MEMORY_INFO_KEYS = [
162    "devicememoryhandle",
163    "devicememoryoffset",
164    "devicememorysize",
165    "devicememorytypeindex",
166    "devicememorytypebits",
167]
168
169TRIVIAL_TRANSFORMED_TYPES = [
170    "VkPhysicalDeviceExternalImageFormatInfo",
171    "VkPhysicalDeviceExternalBufferInfo",
172    "VkExternalMemoryImageCreateInfo",
173    "VkExternalMemoryBufferCreateInfo",
174    "VkExportMemoryAllocateInfo",
175    "VkExternalImageFormatProperties",
176    "VkExternalBufferProperties",
177]
178
179NON_TRIVIAL_TRANSFORMED_TYPES = [
180    "VkExternalMemoryProperties",
181    "VkImageCreateInfo",
182]
183
184TRANSFORMED_TYPES = TRIVIAL_TRANSFORMED_TYPES + NON_TRIVIAL_TRANSFORMED_TYPES
185
186# Holds information about a Vulkan type instance (i.e., not a type definition).
187# Type instances are used as struct field definitions or function parameters,
188# to be later fed to code generation.
189# VulkanType instances can be constructed in two ways:
190# 1. From an XML tag with <type> / <param> tags in vk.xml,
191#    using makeVulkanTypeFromXMLTag
192# 2. User-defined instances with makeVulkanTypeSimple.
193class VulkanType(object):
194
195    def __init__(self):
196        self.parent: Optional[VulkanType] = None
197        self.typeName: str = ""
198
199        self.isTransformed = False
200
201        self.paramName: Optional[str] = None
202
203        self.lenExpr: Optional[str] = None  # Value of the `len` attribute in the spec
204        self.isOptional: bool = False
205        self.optionalStr: Optional[str] = None  # Value of the `optional` attribute in the spec
206
207        self.isConst = False
208
209        # "" means it's not a static array, otherwise this is the total size of
210        # all elements. e.g. staticArrExpr of "x[3][2][8]" will be "((3)*(2)*(8))".
211        self.staticArrExpr = ""
212        # "" means it's not a static array, otherwise it's the raw expression
213        # of static array size, which can be one-dimensional or multi-dimensional.
214        self.rawStaticArrExpr = ""
215
216        self.pointerIndirectionLevels = 0  # 0 means not pointer
217        self.isPointerToConstPointer = False
218
219        self.primitiveEncodingSize = None
220
221        self.deviceMemoryInfoParameterIndices = None
222
223        # Annotations
224        # Environment annotation for binding current
225        # variables to sub-structures
226        self.binds = {}
227
228        # Device memory annotations
229
230        # self.deviceMemoryAttrib/Val stores
231        # device memory info attributes from the XML.
232        # devicememoryhandle
233        # devicememoryoffset
234        # devicememorysize
235        # devicememorytypeindex
236        # devicememorytypebits
237        self.deviceMemoryAttrib = None
238        self.deviceMemoryVal = None
239
240        # Filter annotations
241        self.filterVar = None
242        self.filterVals = None
243        self.filterFunc = None
244        self.filterOtherwise = None
245
246        # Stream feature
247        self.streamFeature = None
248
249        # All other annotations
250        self.attribs = {}
251
252        self.nonDispatchableHandleCreate = False
253        self.nonDispatchableHandleDestroy = False
254        self.dispatchHandle = False
255        self.dispatchableHandleCreate = False
256        self.dispatchableHandleDestroy = False
257
258
259    def __str__(self,):
260        return ("(vulkantype %s %s paramName %s len %s optional? %s "
261                "staticArrExpr %s)") % (
262            self.typeName + ("*" * self.pointerIndirectionLevels) +
263            ("ptr2constptr" if self.isPointerToConstPointer else ""), "const"
264            if self.isConst else "nonconst", self.paramName, self.lenExpr,
265            self.isOptional, self.staticArrExpr)
266
267    def isString(self):
268        return self.pointerIndirectionLevels == 1 and (self.typeName == "char")
269
270    def isArrayOfStrings(self):
271        return self.isPointerToConstPointer and (self.typeName == "char")
272
273    def primEncodingSize(self):
274        return self.primitiveEncodingSize
275
276    # Utility functions to make codegen life easier.
277    # This method derives the correct "count" expression if possible.
278    # Otherwise, returns None or "null-terminated" if a string.
279    def getLengthExpression(self):
280        if self.staticArrExpr != "":
281            return self.staticArrExpr
282        if self.lenExpr:
283            return self.lenExpr
284        return None
285
286    # Can we just pass this to functions expecting T*
287    def accessibleAsPointer(self):
288        if self.staticArrExpr != "":
289            return True
290        if self.pointerIndirectionLevels > 0:
291            return True
292        return False
293
294    # Rough attempt to infer where a type could be an output.
295    # Good for inferring which things need to be marshaled in
296    # versus marshaled out for Vulkan API calls
297    def possiblyOutput(self,):
298        return self.pointerIndirectionLevels > 0 and (not self.isConst)
299
300    def isVoidWithNoSize(self,):
301        return self.typeName == "void" and self.pointerIndirectionLevels == 0
302
303    def getCopy(self,):
304        return copy(self)
305
306    def getTransformed(self, isConstChoice=None, ptrIndirectionChoice=None):
307        res = self.getCopy()
308
309        if isConstChoice is not None:
310            res.isConst = isConstChoice
311        if ptrIndirectionChoice is not None:
312            res.pointerIndirectionLevels = ptrIndirectionChoice
313
314        return res
315
316    def getWithCustomName(self):
317        return self.getTransformed(
318            ptrIndirectionChoice=self.pointerIndirectionLevels + 1)
319
320    def getForAddressAccess(self):
321        return self.getTransformed(
322            ptrIndirectionChoice=self.pointerIndirectionLevels + 1)
323
324    def getForValueAccess(self):
325        if self.typeName == "void" and self.pointerIndirectionLevels == 1:
326            asUint8Type = self.getCopy()
327            asUint8Type.typeName = "uint8_t"
328            return asUint8Type.getForValueAccess()
329        return self.getTransformed(
330            ptrIndirectionChoice=self.pointerIndirectionLevels - 1)
331
332    def getForNonConstAccess(self):
333        return self.getTransformed(isConstChoice=False)
334
335    def withModifiedName(self, newName):
336        res = self.getCopy()
337        res.paramName = newName
338        return res
339
340    def isNextPointer(self):
341        return self.paramName == "pNext"
342
343    def isSigned(self):
344        return self.typeName in ["int", "int8_t", "int16_t", "int32_t", "int64_t"]
345
346    def isEnum(self, typeInfo):
347        return typeInfo.categoryOf(self.typeName) == "enum"
348
349    def isBitmask(self, typeInfo):
350        return typeInfo.categoryOf(self.typeName) == "enum"
351
352    # Only deals with 'core' handle types here.
353    def isDispatchableHandleType(self):
354        return self.typeName in DISPATCHABLE_HANDLE_TYPES
355
356    def isNonDispatchableHandleType(self):
357        return self.typeName in NON_DISPATCHABLE_HANDLE_TYPES
358
359    def isHandleType(self):
360        return self.isDispatchableHandleType() or \
361               self.isNonDispatchableHandleType()
362
363    def isCreatedBy(self, api):
364        if self.typeName in HANDLE_INFO.keys():
365            nonKhrRes = HANDLE_INFO[self.typeName].isCreateApi(api.name)
366            if nonKhrRes:
367                return True
368            if len(api.name) > 3 and "KHR" == api.name[-3:]:
369                return HANDLE_INFO[self.typeName].isCreateApi(api.name[:-3])
370
371        if self.typeName == "VkImage" and api.name == "vkCreateImageWithRequirementsGOOGLE":
372            return True
373
374        if self.typeName == "VkBuffer" and api.name == "vkCreateBufferWithRequirementsGOOGLE":
375            return True
376
377        return False
378
379    def isDestroyedBy(self, api):
380        if self.typeName in HANDLE_INFO.keys():
381            nonKhrRes = HANDLE_INFO[self.typeName].isDestroyApi(api.name)
382            if nonKhrRes:
383                return True
384            if len(api.name) > 3 and "KHR" == api.name[-3:]:
385                return HANDLE_INFO[self.typeName].isDestroyApi(api.name[:-3])
386
387        return False
388
389    def isSimpleValueType(self, typeInfo):
390        if typeInfo.isCompoundType(self.typeName):
391            return False
392        if self.isString() or self.isArrayOfStrings():
393            return False
394        if self.staticArrExpr or self.pointerIndirectionLevels > 0:
395            return False
396        return True
397
398    def getStructEnumExpr(self,):
399        return None
400
401    def getPrintFormatSpecifier(self):
402        kKnownTypePrintFormatSpecifiers = {
403            'float': '%f',
404            'int': '%d',
405            'int32_t': '%d',
406            'size_t': '%ld',
407            'uint16_t': '%d',
408            'uint32_t': '%d',
409            'uint64_t': '%ld',
410            'VkBool32': '%d',
411            'VkDeviceSize': '%ld',
412            'VkFormat': '%d',
413            'VkImageLayout': '%d',
414        }
415
416        if self.pointerIndirectionLevels > 0 or self.isHandleType():
417            return '%p'
418
419        if self.typeName in kKnownTypePrintFormatSpecifiers:
420            return kKnownTypePrintFormatSpecifiers[self.typeName]
421
422        if self.typeName.endswith('Flags'):
423            # Based on `typedef uint32_t VkFlags;`
424            return '%d'
425
426        return None
427    def isOptionalPointer(self) -> bool:
428        return self.isOptional and \
429               self.pointerIndirectionLevels > 0 and \
430               (not self.isNextPointer())
431
432
433# Is an S-expression w/ the following spec:
434# From https://gist.github.com/pib/240957
435class Atom(object):
436    def __init__(self, name):
437        self.name = name
438    def __repr__(self,):
439        return self.name
440
441def parse_sexp(sexp):
442    atom_end = set('()"\'') | set(whitespace)
443    stack, i, length = [[]], 0, len(sexp)
444    while i < length:
445        c = sexp[i]
446
447        reading = type(stack[-1])
448        if reading == list:
449            if   c == '(': stack.append([])
450            elif c == ')':
451                stack[-2].append(stack.pop())
452                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
453            elif c == '"': stack.append('')
454            elif c == "'": stack.append([('quote',)])
455            elif c in whitespace: pass
456            else: stack.append(Atom(c))
457        elif reading == str:
458            if   c == '"':
459                stack[-2].append(stack.pop())
460                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
461            elif c == '\\':
462                i += 1
463                stack[-1] += sexp[i]
464            else: stack[-1] += c
465        elif reading == Atom:
466            if c in atom_end:
467                atom = stack.pop()
468                if atom.name[0].isdigit(): stack[-1].append(eval(atom.name))
469                else: stack[-1].append(atom)
470                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
471                continue
472            else: stack[-1] = Atom(stack[-1].name + c)
473        i += 1
474
475    return stack.pop()
476
477class FuncExprVal(object):
478    def __init__(self, val):
479        self.val = val
480    def __repr__(self,):
481        return self.val.__repr__()
482
483class FuncExpr(object):
484    def __init__(self, name, args):
485        self.name = name
486        self.args = args
487    def __repr__(self,):
488        if len(self.args) == 0:
489            return "(%s)" % (self.name.__repr__())
490        else:
491            return "(%s %s)" % (self.name.__repr__(), " ".join(map(lambda x: x.__repr__(), self.args)))
492
493class FuncLambda(object):
494    def __init__(self, vs, body):
495        self.vs = vs
496        self.body = body
497    def __repr__(self,):
498        return "(L (%s) %s)" % (" ".join(map(lambda x: x.__repr__(), self.vs)), self.body.__repr__())
499
500class FuncLambdaParam(object):
501    def __init__(self, name, typ):
502        self.name = name
503        self.typ = typ
504    def __repr__(self,):
505        return "%s : %s" % (self.name, self.typ)
506
507def parse_func_expr(parsed_sexp):
508    if len(parsed_sexp) != 1:
509        print("Error: parsed # expressions != 1: %d" % (len(parsed_sexp)))
510        raise
511
512    e = parsed_sexp[0]
513
514    def parse_lambda_param(e):
515        return FuncLambdaParam(e[0].name, e[1].name)
516
517    def parse_one(exp):
518        if list == type(exp):
519            if "lambda" == exp[0].__repr__():
520                return FuncLambda(list(map(parse_lambda_param, exp[1])), parse_one(exp[2]))
521            else:
522                return FuncExpr(exp[0], list(map(parse_one, exp[1:])))
523        else:
524            return FuncExprVal(exp)
525
526    return parse_one(e)
527
528def parseFilterFuncExpr(expr):
529    res = parse_func_expr(parse_sexp(expr))
530    print("parseFilterFuncExpr: parsed %s" % res)
531    return res
532
533def parseLetBodyExpr(expr):
534    res = parse_func_expr(parse_sexp(expr))
535    print("parseLetBodyExpr: parsed %s" % res)
536    return res
537
538
539def makeVulkanTypeFromXMLTag(typeInfo, tag: Element) -> VulkanType:
540    res = VulkanType()
541
542    # Process the length expression
543
544    if tag.attrib.get("len") is not None:
545        lengths = tag.attrib.get("len").split(",")
546        res.lenExpr = lengths[0]
547
548    # Calculate static array expression
549
550    nametag = tag.find("name")
551    enumtag = tag.find("enum")
552
553    if enumtag is not None:
554        res.staticArrExpr = enumtag.text
555    elif nametag is not None:
556        res.rawStaticArrExpr = noneStr(nametag.tail)
557
558        dimensions = res.rawStaticArrExpr.count('[')
559        if dimensions == 1:
560            res.staticArrExpr = res.rawStaticArrExpr[1:-1]
561        elif dimensions > 1:
562            arraySizes = res.rawStaticArrExpr[1:-1].split('][')
563            res.staticArrExpr = '(' + \
564                '*'.join(f'({size})' for size in arraySizes) + ')'
565
566    # Determine const
567
568    beforeTypePart = noneStr(tag.text)
569
570    if "const" in beforeTypePart:
571        res.isConst = True
572
573    # Calculate type and pointer info
574    for elem in tag:
575        if elem.tag == "name":
576            res.paramName = elem.text
577        if elem.tag == "type":
578            duringTypePart = noneStr(elem.text)
579            afterTypePart = noneStr(elem.tail)
580            # Now we know enough to fill some stuff in
581            res.typeName = duringTypePart
582
583            if res.typeName in TRANSFORMED_TYPES:
584                res.isTransformed = True
585
586            # This only handles pointerIndirectionLevels == 2
587            # along with optional constant pointer for the inner part.
588            for c in afterTypePart:
589                if c == "*":
590                    res.pointerIndirectionLevels += 1
591            if "const" in afterTypePart and res.pointerIndirectionLevels == 2:
592                res.isPointerToConstPointer = True
593
594            # If void*, treat like it's not a pointer
595            # if duringTypePart == "void":
596            # res.pointerIndirectionLevels -= 1
597
598    # Calculate optionality (based on validitygenerator.py)
599    if tag.attrib.get("optional") is not None:
600        res.isOptional = True
601        res.optionalStr = tag.attrib.get("optional")
602
603    # If no validity is being generated, it usually means that
604    # validity is complex and not absolute, so let's say yes.
605    if tag.attrib.get("noautovalidity") is not None:
606        res.isOptional = True
607
608    # If this is a structure extension, it is optional.
609    if tag.attrib.get("structextends") is not None:
610        res.isOptional = True
611
612    # If this is a pNext pointer, it is optional.
613    if res.paramName == "pNext":
614        res.isOptional = True
615
616    res.primitiveEncodingSize = typeInfo.getPrimitiveEncodingSize(res.typeName)
617
618    # Annotations: Environment binds
619    if tag.attrib.get("binds") is not None:
620        bindPairs = map(lambda x: x.strip(), tag.attrib.get("binds").split(","))
621        bindPairsSplit = map(lambda p: p.split(":"), bindPairs)
622        res.binds = dict(map(lambda sp: (sp[0].strip(), sp[1].strip()), bindPairsSplit))
623
624    # Annotations: Device memory
625    for k in DEVICE_MEMORY_INFO_KEYS:
626        if tag.attrib.get(k) is not None:
627            res.deviceMemoryAttrib = k
628            res.deviceMemoryVal = tag.attrib.get(k)
629            break
630
631    # Annotations: Filters
632    if tag.attrib.get("filterVar") is not None:
633        res.filterVar = tag.attrib.get("filterVar").strip()
634
635    if tag.attrib.get("filterVals") is not None:
636        res.filterVals = \
637            list(map(lambda v: v.strip(),
638                    tag.attrib.get("filterVals").strip().split(",")))
639        print("Filtervals: %s" % res.filterVals)
640
641    if tag.attrib.get("filterFunc") is not None:
642        res.filterFunc = parseFilterFuncExpr(tag.attrib.get("filterFunc"))
643
644    if tag.attrib.get("filterOtherwise") is not None:
645        res.Otherwise = tag.attrib.get("filterOtherwise")
646
647    # store all other attribs here
648    res.attribs = dict(tag.attrib)
649
650    return res
651
652
653def makeVulkanTypeSimple(isConst,
654                         typeName,
655                         ptrIndirectionLevels,
656                         paramName=None):
657    res = VulkanType()
658
659    res.typeName = typeName
660    res.isConst = isConst
661    res.pointerIndirectionLevels = ptrIndirectionLevels
662    res.isPointerToConstPointer = False
663    res.paramName = paramName
664    res.primitiveEncodingSize = None
665
666    return res
667
668# A class for holding the parameter indices corresponding to various
669# attributes about a VkDeviceMemory, such as the handle, size, offset, etc.
670class DeviceMemoryInfoParameterIndices(object):
671    def __init__(self, handle, offset, size, typeIndex, typeBits):
672        self.handle = handle
673        self.offset = offset
674        self.size = size
675        self.typeIndex = typeIndex
676        self.typeBits = typeBits
677
678# initializes DeviceMemoryInfoParameterIndices for each
679# abstract VkDeviceMemory encountered over |parameters|
680def initDeviceMemoryInfoParameterIndices(parameters):
681
682    use = False
683    deviceMemoryInfoById = {}
684
685    for (i, p) in enumerate(parameters):
686        a = p.deviceMemoryAttrib
687        if not a:
688            continue
689
690        if a in DEVICE_MEMORY_INFO_KEYS:
691            use = True
692            deviceMemoryInfoById[p.deviceMemoryVal] =  DeviceMemoryInfoParameterIndices(
693                        None, None, None, None, None)
694
695    for (i, p) in enumerate(parameters):
696        a = p.deviceMemoryAttrib
697        if not a:
698            continue
699
700        info = deviceMemoryInfoById[p.deviceMemoryVal]
701
702        if a == "devicememoryhandle":
703            info.handle = i
704        if a == "devicememoryoffset":
705            info.offset = i
706        if a == "devicememorysize":
707            info.size = i
708        if a == "devicememorytypeindex":
709            info.typeIndex = i
710        if a == "devicememorytypebits":
711            info.typeBits = i
712
713    if not use:
714        return None
715
716    return deviceMemoryInfoById
717
718# Classes for describing aggregate types (unions, structs) and API calls.
719class VulkanCompoundType(object):
720
721    def __init__(self, name: str, members: List[VulkanType], isUnion=False, structEnumExpr=None, structExtendsExpr=None, feature=None, initialEnv={}, optional=None):
722        self.name: str = name
723        self.typeName: str = name
724        self.members: List[VulkanType] = members
725        self.environment = initialEnv
726        self.isUnion = isUnion
727        self.structEnumExpr = structEnumExpr
728        self.structExtendsExpr = structExtendsExpr
729        self.feature = feature
730        self.deviceMemoryInfoParameterIndices = initDeviceMemoryInfoParameterIndices(self.members)
731        self.isTransformed = name in TRANSFORMED_TYPES
732        self.copy = None
733        self.optionalStr = optional
734
735    def initCopies(self):
736        self.copy = self
737
738        for m in self.members:
739            m.parent = self.copy
740
741    def getMember(self, memberName) -> Optional[VulkanType]:
742        for m in self.members:
743            if m.paramName == memberName:
744                return m
745        return None
746
747    def getStructEnumExpr(self,):
748        return self.structEnumExpr
749
750class VulkanAPI(object):
751
752    def __init__(self, name: str, retType: VulkanType, parameters: list[VulkanType], origName=None):
753        self.name: str = name
754        self.origName = name
755        self.retType: VulkanType = retType
756        self.parameters: List[VulkanType] = parameters
757
758        self.deviceMemoryInfoParameterIndices = initDeviceMemoryInfoParameterIndices(self.parameters)
759
760        self.copy = None
761
762        self.isTransformed = name in TRANSFORMED_TYPES
763
764        if origName:
765            self.origName = origName
766
767    def initCopies(self):
768        self.copy = self
769
770        for m in self.parameters:
771            m.parent = self.copy
772
773    def getCopy(self,):
774        return copy(self)
775
776    def getParameter(self, parameterName):
777        for p in self.parameters:
778            if p.paramName == parameterName:
779                return p
780        return None
781
782    def withModifiedName(self, newName):
783        res = VulkanAPI(newName, self.retType, self.parameters)
784        return res
785
786    def getRetVarExpr(self):
787        if self.retType.typeName == "void":
788            return None
789        return "%s_%s_return" % (self.name, self.retType.typeName)
790
791    def getRetTypeExpr(self):
792        return self.retType.typeName
793
794    def withCustomParameters(self, customParams):
795        res = self.getCopy()
796        res.parameters = customParams
797        return res
798
799    def withCustomReturnType(self, retType):
800        res = self.getCopy()
801        res.retType = retType
802        return res
803
804# Whether or not special handling of virtual elements
805# such as VkDeviceMemory is needed.
806def vulkanTypeNeedsTransform(structOrApi):
807    return structOrApi.deviceMemoryInfoParameterIndices != None
808
809def vulkanTypeGetNeededTransformTypes(structOrApi):
810    res = []
811    if structOrApi.deviceMemoryInfoParameterIndices != None:
812        res.append("devicememory")
813    return res
814
815def vulkanTypeforEachSubType(structOrApi, f):
816    toLoop = None
817    if type(structOrApi) == VulkanCompoundType:
818        toLoop = structOrApi.members
819    if type(structOrApi) == VulkanAPI:
820        toLoop = structOrApi.parameters
821
822    for (i, x) in enumerate(toLoop):
823        f(i, x)
824
825# Parses everything about Vulkan types into a Python readable format.
826class VulkanTypeInfo(object):
827
828    def __init__(self, generator):
829        self.generator = generator
830        self.categories: Set[str] = set([])
831
832        # Tracks what Vulkan type is part of what category.
833        self.typeCategories: Dict[str, str] = {}
834
835        # Tracks the primitive encoding size for each type, if applicable.
836        self.encodingSizes: Dict[str, Optional[int]] = {}
837
838        self.structs: Dict[str, VulkanCompoundType] = {}
839        self.apis: Dict[str, VulkanAPI] = {}
840
841        # Maps bitmask types to the enum type used for the flags
842        # E.g. "VkImageAspectFlags" -> "VkImageAspectFlagBits"
843        self.bitmasks: Dict[str, str] = {}
844
845        # Maps all enum names to their values.
846        # For aliases, the value is the name of the canonical enum
847        self.enumValues: Dict[str, Union[int, str]] = {}
848
849        self.feature = None
850
851    def initType(self, name: str, category: str):
852        self.categories.add(category)
853        self.typeCategories[name] = category
854        self.encodingSizes[name] = self.setPrimitiveEncodingSize(name)
855
856    def categoryOf(self, name):
857        return self.typeCategories[name]
858
859    def getPrimitiveEncodingSize(self, name):
860        return self.encodingSizes[name]
861
862    # Queries relating to categories of Vulkan types.
863    def isHandleType(self, name):
864        return self.typeCategories.get(name) == "handle"
865
866    def isCompoundType(self, name: str):
867        return self.typeCategories.get(name) in ["struct", "union"]
868
869    # Gets the best size in bytes
870    # for encoding/decoding a particular Vulkan type.
871    # If not applicable, returns None.
872    def setPrimitiveEncodingSize(self, name: str) -> Optional[int]:
873        baseEncodingSizes = {
874            "void": 8,
875            "char": 1,
876            "float": 4,
877            "uint8_t": 1,
878            "uint16_t": 2,
879            "uint32_t": 4,
880            "uint64_t": 8,
881            "int": 4,
882            "int8_t": 1,
883            "int16_t": 2,
884            "int32_t": 4,
885            "int64_t": 8,
886            "size_t": 8,
887            "ssize_t": 8,
888            "VkBool32": 4,
889        }
890
891        if name in baseEncodingSizes:
892            return baseEncodingSizes[name]
893
894        category = self.typeCategories[name]
895
896        if category in [None, "api", "include", "define", "struct", "union"]:
897            return None
898
899        # Handles are pointers so they must be 8 bytes. Basetype includes VkDeviceSize which is 8 bytes.
900        if category in ["handle", "basetype", "funcpointer"]:
901            return 8
902
903        if category in ["enum", "bitmask"]:
904            return 4
905
906    def isNonAbiPortableType(self, typeName):
907        if typeName in EXPLICITLY_ABI_PORTABLE_TYPES:
908            return False
909
910        if typeName in EXPLICITLY_ABI_NON_PORTABLE_TYPES:
911            return True
912
913        category = self.typeCategories[typeName]
914        return category in NON_ABI_PORTABLE_TYPE_CATEGORIES
915
916    def onBeginFeature(self, featureName, featureType):
917        self.feature = featureName
918
919    def onEndFeature(self):
920        self.feature = None
921
922    def onGenType(self, typeinfo, name, alias):
923        category = typeinfo.elem.get("category")
924        self.initType(name, category)
925
926        if category in ["struct", "union"]:
927            self.onGenStruct(typeinfo, name, alias)
928
929        if category == "bitmask":
930            self.bitmasks[name] = typeinfo.elem.get("requires")
931
932    def onGenStruct(self, typeinfo, typeName, alias):
933        if not alias:
934            members: List[VulkanType] = []
935
936            structExtendsExpr = typeinfo.elem.get("structextends")
937
938            structEnumExpr = None
939
940            initialEnv = {}
941            envStr = typeinfo.elem.get("exists")
942            if envStr != None:
943                comma_separated = envStr.split(",")
944                name_type_pairs = map(lambda cs: tuple(map(lambda t: t.strip(), cs.split(":"))), comma_separated)
945                for (name, typ) in name_type_pairs:
946                    initialEnv[name] = {
947                        "type" : typ,
948                        "binding" : None,
949                        "structmember" : False,
950                        "body" : None,
951                    }
952
953            letenvStr = typeinfo.elem.get("let")
954            if letenvStr != None:
955                comma_separated = letenvStr.split(",")
956                name_body_pairs = map(lambda cs: tuple(map(lambda t: t.strip(), cs.split(":"))), comma_separated)
957                for (name, body) in name_body_pairs:
958                    initialEnv[name] = {
959                        "type" : "uint32_t",
960                        "binding" : name,
961                        "structmember" : False,
962                        "body" : parseLetBodyExpr(body)
963                    }
964
965            for member in typeinfo.elem.findall(".//member"):
966                vulkanType = makeVulkanTypeFromXMLTag(self, member)
967                initialEnv[vulkanType.paramName] = {
968                    "type": vulkanType.typeName,
969                    "binding": vulkanType.paramName,
970                    "structmember": True,
971                    "body": None,
972                }
973                members.append(vulkanType)
974                if vulkanType.typeName == "VkStructureType" and \
975                   member.get("values"):
976                   structEnumExpr = member.get("values")
977
978            self.structs[typeName] = \
979                VulkanCompoundType(
980                    typeName,
981                    members,
982                    isUnion = self.categoryOf(typeName) == "union",
983                    structEnumExpr = structEnumExpr,
984                    structExtendsExpr = structExtendsExpr,
985                    feature = self.feature,
986                    initialEnv = initialEnv,
987                    optional = typeinfo.elem.get("optional", None))
988            self.structs[typeName].initCopies()
989
990    def onGenGroup(self, groupinfo, groupName, _alias=None):
991        self.initType(groupName, "enum")
992        enums = groupinfo.elem.findall("enum")
993        for enum in enums:
994            intVal, strVal = self.generator.enumToValue(enum, True)
995            self.enumValues[enum.get('name')] = intVal if intVal is not None else strVal
996
997
998    def onGenEnum(self, enuminfo, name: str, alias):
999        self.initType(name, "enum")
1000        value: str = enuminfo.elem.get("value")
1001        if value and value.isdigit():
1002            self.enumValues[name] = int(value)
1003        elif value and value[0] == '"' and value[-1] == '"':
1004            self.enumValues[name] = value[1:-1]
1005        elif alias is not None:
1006            self.enumValues[name] = alias
1007        else:
1008            # There's about a dozen cases of using the bitwise NOT operator (e.g.: `(~0U)`, `(~0ULL)`)
1009            # to concisely represent large values. Just ignore them for now.
1010            # In the future, we can add a lookup table to convert these to int
1011            return
1012
1013    def onGenCmd(self, cmdinfo, name, _alias):
1014        self.initType(name, "api")
1015
1016        proto = cmdinfo.elem.find("proto")
1017        params = cmdinfo.elem.findall("param")
1018
1019        self.apis[name] = \
1020            VulkanAPI(
1021                name,
1022                makeVulkanTypeFromXMLTag(self, proto),
1023                list(map(lambda p: makeVulkanTypeFromXMLTag(self, p),
1024                         params)))
1025        self.apis[name].initCopies()
1026
1027    def onEnd(self):
1028        pass
1029
1030def hasNullOptionalStringFeature(forEachType):
1031    return (hasattr(forEachType, "onCheckWithNullOptionalStringFeature")) and \
1032           (hasattr(forEachType, "endCheckWithNullOptionalStringFeature")) and \
1033           (hasattr(forEachType, "finalCheckWithNullOptionalStringFeature"))
1034
1035
1036# General function to iterate over a vulkan type and call code that processes
1037# each of its sub-components, if any.
1038def iterateVulkanType(typeInfo: VulkanTypeInfo, vulkanType: VulkanType, forEachType):
1039    if not vulkanType.isArrayOfStrings():
1040        if vulkanType.isPointerToConstPointer:
1041            return False
1042
1043    forEachType.registerTypeInfo(typeInfo)
1044
1045    needCheck = vulkanType.isOptionalPointer()
1046
1047    if typeInfo.isCompoundType(vulkanType.typeName) and not vulkanType.isNextPointer():
1048
1049        if needCheck:
1050            forEachType.onCheck(vulkanType)
1051
1052        forEachType.onCompoundType(vulkanType)
1053
1054        if needCheck:
1055            forEachType.endCheck(vulkanType)
1056
1057    else:
1058        if vulkanType.isString():
1059            if needCheck and hasNullOptionalStringFeature(forEachType):
1060                forEachType.onCheckWithNullOptionalStringFeature(vulkanType)
1061                forEachType.onString(vulkanType)
1062                forEachType.endCheckWithNullOptionalStringFeature(vulkanType)
1063                forEachType.onString(vulkanType)
1064                forEachType.finalCheckWithNullOptionalStringFeature(vulkanType)
1065            elif needCheck:
1066                forEachType.onCheck(vulkanType)
1067                forEachType.onString(vulkanType)
1068                forEachType.endCheck(vulkanType)
1069            else:
1070                forEachType.onString(vulkanType)
1071
1072        elif vulkanType.isArrayOfStrings():
1073            forEachType.onStringArray(vulkanType)
1074
1075        elif vulkanType.staticArrExpr:
1076            forEachType.onStaticArr(vulkanType)
1077
1078        elif vulkanType.isNextPointer():
1079            if needCheck:
1080                forEachType.onCheck(vulkanType)
1081            forEachType.onStructExtension(vulkanType)
1082            if needCheck:
1083                forEachType.endCheck(vulkanType)
1084
1085        elif vulkanType.pointerIndirectionLevels > 0:
1086            if needCheck:
1087                forEachType.onCheck(vulkanType)
1088            forEachType.onPointer(vulkanType)
1089            if needCheck:
1090                forEachType.endCheck(vulkanType)
1091
1092        else:
1093            forEachType.onValue(vulkanType)
1094
1095    return True
1096
1097class VulkanTypeIterator(object):
1098    def __init__(self,):
1099        self.typeInfo = None
1100
1101    def registerTypeInfo(self, typeInfo):
1102        self.typeInfo = typeInfo
1103
1104def vulkanTypeGetStructFieldLengthInfo(structInfo, vulkanType):
1105    def getSpecialCaseVulkanStructFieldLength(structInfo, vulkanType):
1106        cases = [
1107            {
1108                "structName": "VkShaderModuleCreateInfo",
1109                "field": "pCode",
1110                "lenExpr": "codeSize",
1111                "postprocess": lambda expr: "(%s / 4)" % expr
1112            },
1113            {
1114                "structName": "VkPipelineMultisampleStateCreateInfo",
1115                "field": "pSampleMask",
1116                "lenExpr": "rasterizationSamples",
1117                "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
1118            },
1119        ]
1120
1121        for c in cases:
1122            if (structInfo.name, vulkanType.paramName) == (c["structName"], c["field"]):
1123                return c
1124
1125        return None
1126
1127    specialCaseAccess = getSpecialCaseVulkanStructFieldLength(structInfo, vulkanType)
1128
1129    if specialCaseAccess is not None:
1130        return specialCaseAccess
1131
1132    lenExpr = vulkanType.getLengthExpression()
1133
1134    if lenExpr is None:
1135        return None
1136
1137    return {
1138        "structName": structInfo.name,
1139        "field": vulkanType.typeName,
1140        "lenExpr": lenExpr,
1141        "postprocess": lambda expr: expr}
1142
1143
1144class VulkanTypeProtobufInfo(object):
1145    def __init__(self, typeInfo, structInfo, vulkanType):
1146        self.needsMessage = typeInfo.isCompoundType(vulkanType.typeName)
1147        self.isRepeatedString = vulkanType.isArrayOfStrings()
1148        self.isString = vulkanType.isString() or (
1149            vulkanType.typeName == "char" and (vulkanType.staticArrExpr != ""))
1150
1151        if structInfo is not None:
1152            self.lengthInfo = vulkanTypeGetStructFieldLengthInfo(
1153                structInfo, vulkanType)
1154        else:
1155            self.lengthInfo = vulkanType.getLengthExpression()
1156
1157        self.protobufType = None
1158        self.origTypeCategory = typeInfo.categoryOf(vulkanType.typeName)
1159
1160        self.isExtensionStruct = \
1161            vulkanType.typeName == "void" and \
1162            vulkanType.pointerIndirectionLevels > 0 and \
1163            vulkanType.paramName == "pNext"
1164
1165        if self.needsMessage:
1166            return
1167
1168        if typeInfo.categoryOf(vulkanType.typeName) in ["enum", "bitmask"]:
1169            self.protobufType = "uint32"
1170
1171        if typeInfo.categoryOf(vulkanType.typeName) in ["funcpointer", "handle", "define"]:
1172            self.protobufType = "uint64"
1173
1174        if typeInfo.categoryOf(vulkanType.typeName) in ["basetype"]:
1175            baseTypeMapping = {
1176                "VkFlags" : "uint32",
1177                "VkBool32" : "uint32",
1178                "VkDeviceSize" : "uint64",
1179                "VkSampleMask" : "uint32",
1180            }
1181            self.protobufType = baseTypeMapping[vulkanType.typeName]
1182
1183        if typeInfo.categoryOf(vulkanType.typeName) == None:
1184
1185            otherTypeMapping = {
1186                "void" : "uint64",
1187                "char" : "uint8",
1188                "size_t" : "uint64",
1189                "float" : "float",
1190                "uint8_t" : "uint32",
1191                "uint16_t" : "uint32",
1192                "int32_t" : "int32",
1193                "uint32_t" : "uint32",
1194                "uint64_t" : "uint64",
1195                "VkDeviceSize" : "uint64",
1196                "VkSampleMask" : "uint32",
1197            }
1198
1199            if vulkanType.typeName in otherTypeMapping:
1200                self.protobufType = otherTypeMapping[vulkanType.typeName]
1201            else:
1202                self.protobufType = "uint64"
1203
1204
1205        protobufCTypeMapping = {
1206            "uint8" : "uint8_t",
1207            "uint32" : "uint32_t",
1208            "int32" : "int32_t",
1209            "uint64" : "uint64_t",
1210            "float" : "float",
1211            "string" : "const char*",
1212        }
1213
1214        self.protobufCType = protobufCTypeMapping[self.protobufType]
1215
1216