• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15from typing import Dict, Optional, List, Set, Union
16from xml.etree.ElementTree import Element
17
18from generator import noneStr
19
20from copy import copy
21from string import whitespace
22
23# Holds information about core Vulkan objects
24# and the API calls that are used to create/destroy each one.
25class HandleInfo(object):
26    def __init__(self, name, createApis, destroyApis):
27        self.name = name
28        self.createApis = createApis
29        self.destroyApis = destroyApis
30
31    def isCreateApi(self, apiName):
32        return apiName == self.createApis or (apiName in self.createApis)
33
34    def isDestroyApi(self, apiName):
35        if self.destroyApis is None:
36            return False
37        return apiName == self.destroyApis or (apiName in self.destroyApis)
38
39DISPATCHABLE_HANDLE_TYPES = [
40    "VkInstance",
41    "VkPhysicalDevice",
42    "VkDevice",
43    "VkQueue",
44    "VkCommandBuffer",
45]
46
47NON_DISPATCHABLE_HANDLE_TYPES = [
48    "VkDeviceMemory",
49    "VkBuffer",
50    "VkBufferView",
51    "VkImage",
52    "VkImageView",
53    "VkShaderModule",
54    "VkDescriptorPool",
55    "VkDescriptorSetLayout",
56    "VkDescriptorSet",
57    "VkSampler",
58    "VkPipeline",
59    "VkPipelineLayout",
60    "VkRenderPass",
61    "VkFramebuffer",
62    "VkPipelineCache",
63    "VkCommandPool",
64    "VkFence",
65    "VkSemaphore",
66    "VkEvent",
67    "VkQueryPool",
68    "VkSamplerYcbcrConversion",
69    "VkSamplerYcbcrConversionKHR",
70    "VkDescriptorUpdateTemplate",
71    "VkSurfaceKHR",
72    "VkSwapchainKHR",
73    "VkDisplayKHR",
74    "VkDisplayModeKHR",
75    "VkObjectTableNVX",
76    "VkIndirectCommandsLayoutNVX",
77    "VkValidationCacheEXT",
78    "VkDebugReportCallbackEXT",
79    "VkDebugUtilsMessengerEXT",
80    "VkAccelerationStructureNV",
81    "VkIndirectCommandsLayoutNV",
82    "VkAccelerationStructureKHR",
83]
84
85CUSTOM_HANDLE_CREATE_TYPES = [
86    "VkPhysicalDevice",
87    "VkQueue",
88    "VkPipeline",
89    "VkDeviceMemory",
90    "VkDescriptorSet",
91    "VkCommandBuffer",
92    "VkRenderPass",
93]
94
95HANDLE_TYPES = list(sorted(list(set(DISPATCHABLE_HANDLE_TYPES +
96                                    NON_DISPATCHABLE_HANDLE_TYPES + CUSTOM_HANDLE_CREATE_TYPES))))
97
98HANDLE_INFO = {}
99
100for h in HANDLE_TYPES:
101    if h in CUSTOM_HANDLE_CREATE_TYPES:
102        if h == "VkPhysicalDevice":
103            HANDLE_INFO[h] = \
104                HandleInfo(
105                    "VkPhysicalDevice",
106                    "vkEnumeratePhysicalDevices", None)
107        if h == "VkQueue":
108            HANDLE_INFO[h] = \
109                HandleInfo(
110                    "VkQueue",
111                    ["vkGetDeviceQueue", "vkGetDeviceQueue2"],
112                    None)
113        if h == "VkPipeline":
114            HANDLE_INFO[h] = \
115                HandleInfo(
116                    "VkPipeline",
117                    ["vkCreateGraphicsPipelines", "vkCreateComputePipelines"],
118                    "vkDestroyPipeline")
119        if h == "VkDeviceMemory":
120            HANDLE_INFO[h] = \
121                HandleInfo("VkDeviceMemory",
122                           "vkAllocateMemory", ["vkFreeMemory", "vkFreeMemorySyncGOOGLE"])
123        if h == "VkDescriptorSet":
124            HANDLE_INFO[h] = \
125                HandleInfo("VkDescriptorSet", "vkAllocateDescriptorSets",
126                           "vkFreeDescriptorSets")
127        if h == "VkCommandBuffer":
128            HANDLE_INFO[h] = \
129                HandleInfo("VkCommandBuffer", "vkAllocateCommandBuffers",
130                           "vkFreeCommandBuffers")
131        if h == "VkRenderPass":
132            HANDLE_INFO[h] = \
133                HandleInfo(
134                    "VkRenderPass",
135                    ["vkCreateRenderPass", "vkCreateRenderPass2", "vkCreateRenderPass2KHR"],
136                    "vkDestroyRenderPass")
137    else:
138        HANDLE_INFO[h] = \
139            HandleInfo(h, "vkCreate" + h[2:], "vkDestroy" + h[2:])
140
141EXCLUDED_APIS = [
142    "vkEnumeratePhysicalDeviceGroups",
143]
144
145EXPLICITLY_ABI_PORTABLE_TYPES = [
146    "VkResult",
147    "VkBool32",
148    "VkSampleMask",
149    "VkFlags",
150    "VkDeviceSize",
151]
152
153EXPLICITLY_ABI_NON_PORTABLE_TYPES = [
154    "size_t"
155]
156
157NON_ABI_PORTABLE_TYPE_CATEGORIES = [
158    "handle",
159    "funcpointer",
160]
161
162DEVICE_MEMORY_INFO_KEYS = [
163    "devicememoryhandle",
164    "devicememoryoffset",
165    "devicememorysize",
166    "devicememorytypeindex",
167    "devicememorytypebits",
168]
169
170TRIVIAL_TRANSFORMED_TYPES = [
171    "VkPhysicalDeviceExternalImageFormatInfo",
172    "VkPhysicalDeviceExternalBufferInfo",
173    "VkExternalMemoryImageCreateInfo",
174    "VkExternalMemoryBufferCreateInfo",
175    "VkExportMemoryAllocateInfo",
176    "VkExternalImageFormatProperties",
177    "VkExternalBufferProperties",
178]
179
180NON_TRIVIAL_TRANSFORMED_TYPES = [
181    "VkExternalMemoryProperties",
182    "VkImageCreateInfo",
183]
184
185TRANSFORMED_TYPES = TRIVIAL_TRANSFORMED_TYPES + NON_TRIVIAL_TRANSFORMED_TYPES
186
187# Holds information about a Vulkan type instance (i.e., not a type definition).
188# Type instances are used as struct field definitions or function parameters,
189# to be later fed to code generation.
190# VulkanType instances can be constructed in two ways:
191# 1. From an XML tag with <type> / <param> tags in vk.xml,
192#    using makeVulkanTypeFromXMLTag
193# 2. User-defined instances with makeVulkanTypeSimple.
194class VulkanType(object):
195
196    def __init__(self):
197        self.parent: Optional[VulkanType] = None
198        self.typeName: str = ""
199
200        self.isTransformed = False
201
202        self.paramName: Optional[str] = None
203
204        self.lenExpr: Optional[str] = None  # Value of the `len` attribute in the spec
205        self.isOptional: bool = False
206        self.optionalStr: Optional[str] = None  # Value of the `optional` attribute in the spec
207
208        self.isConst = False
209
210        # "" means it's not a static array, otherwise this is the total size of
211        # all elements. e.g. staticArrExpr of "x[3][2][8]" will be "((3)*(2)*(8))".
212        self.staticArrExpr = ""
213        # "" means it's not a static array, otherwise it's the raw expression
214        # of static array size, which can be one-dimensional or multi-dimensional.
215        self.rawStaticArrExpr = ""
216
217        self.pointerIndirectionLevels = 0  # 0 means not pointer
218        self.isPointerToConstPointer = False
219
220        self.primitiveEncodingSize = None
221
222        self.deviceMemoryInfoParameterIndices = None
223
224        # Annotations
225        # Environment annotation for binding current
226        # variables to sub-structures
227        self.binds = {}
228
229        # Device memory annotations
230
231        # self.deviceMemoryAttrib/Val stores
232        # device memory info attributes from the XML.
233        # devicememoryhandle
234        # devicememoryoffset
235        # devicememorysize
236        # devicememorytypeindex
237        # devicememorytypebits
238        self.deviceMemoryAttrib = None
239        self.deviceMemoryVal = None
240
241        # Filter annotations
242        self.filterVar = None
243        self.filterVals = None
244        self.filterFunc = None
245        self.filterOtherwise = None
246
247        # Stream feature
248        self.streamFeature = None
249
250        # All other annotations
251        self.attribs = {}
252
253        self.nonDispatchableHandleCreate = False
254        self.nonDispatchableHandleDestroy = False
255        self.dispatchHandle = False
256        self.dispatchableHandleCreate = False
257        self.dispatchableHandleDestroy = False
258
259
260    def __str__(self,):
261        return ("(vulkantype %s %s paramName %s len %s optional? %s "
262                "staticArrExpr %s)") % (
263            self.typeName + ("*" * self.pointerIndirectionLevels) +
264            ("ptr2constptr" if self.isPointerToConstPointer else ""), "const"
265            if self.isConst else "nonconst", self.paramName, self.lenExpr,
266            self.isOptional, self.staticArrExpr)
267
268    def isString(self):
269        return self.pointerIndirectionLevels == 1 and (self.typeName == "char")
270
271    def isArrayOfStrings(self):
272        return self.isPointerToConstPointer and (self.typeName == "char")
273
274    def primEncodingSize(self):
275        return self.primitiveEncodingSize
276
277    # Utility functions to make codegen life easier.
278    # This method derives the correct "count" expression if possible.
279    # Otherwise, returns None or "null-terminated" if a string.
280    def getLengthExpression(self):
281        if self.staticArrExpr != "":
282            return self.staticArrExpr
283        if self.lenExpr:
284            return self.lenExpr
285        return None
286
287    # Can we just pass this to functions expecting T*
288    def accessibleAsPointer(self):
289        if self.staticArrExpr != "":
290            return True
291        if self.pointerIndirectionLevels > 0:
292            return True
293        return False
294
295    # Rough attempt to infer where a type could be an output.
296    # Good for inferring which things need to be marshaled in
297    # versus marshaled out for Vulkan API calls
298    def possiblyOutput(self,):
299        return self.pointerIndirectionLevels > 0 and (not self.isConst)
300
301    def isVoidWithNoSize(self,):
302        return self.typeName == "void" and self.pointerIndirectionLevels == 0
303
304    def getCopy(self,):
305        return copy(self)
306
307    def getTransformed(self, isConstChoice=None, ptrIndirectionChoice=None):
308        res = self.getCopy()
309
310        if isConstChoice is not None:
311            res.isConst = isConstChoice
312        if ptrIndirectionChoice is not None:
313            res.pointerIndirectionLevels = ptrIndirectionChoice
314
315        return res
316
317    def getWithCustomName(self):
318        return self.getTransformed(
319            ptrIndirectionChoice=self.pointerIndirectionLevels + 1)
320
321    def getForAddressAccess(self):
322        return self.getTransformed(
323            ptrIndirectionChoice=self.pointerIndirectionLevels + 1)
324
325    def getForValueAccess(self):
326        if self.typeName == "void" and self.pointerIndirectionLevels == 1:
327            asUint8Type = self.getCopy()
328            asUint8Type.typeName = "uint8_t"
329            return asUint8Type.getForValueAccess()
330        return self.getTransformed(
331            ptrIndirectionChoice=self.pointerIndirectionLevels - 1)
332
333    def getForNonConstAccess(self):
334        return self.getTransformed(isConstChoice=False)
335
336    def withModifiedName(self, newName):
337        res = self.getCopy()
338        res.paramName = newName
339        return res
340
341    def isNextPointer(self):
342        return self.paramName == "pNext"
343
344    def isSigned(self):
345        return self.typeName in ["int", "int8_t", "int16_t", "int32_t", "int64_t"]
346
347    def isEnum(self, typeInfo):
348        return typeInfo.categoryOf(self.typeName) == "enum"
349
350    def isBitmask(self, typeInfo):
351        return typeInfo.categoryOf(self.typeName) == "enum"
352
353    # Only deals with 'core' handle types here.
354    def isDispatchableHandleType(self):
355        return self.typeName in DISPATCHABLE_HANDLE_TYPES
356
357    def isNonDispatchableHandleType(self):
358        return self.typeName in NON_DISPATCHABLE_HANDLE_TYPES
359
360    def isHandleType(self):
361        return self.isDispatchableHandleType() or \
362               self.isNonDispatchableHandleType()
363
364    def isCreatedBy(self, api):
365        if self.typeName in HANDLE_INFO.keys():
366            nonKhrRes = HANDLE_INFO[self.typeName].isCreateApi(api.name)
367            if nonKhrRes:
368                return True
369            if len(api.name) > 3 and "KHR" == api.name[-3:]:
370                return HANDLE_INFO[self.typeName].isCreateApi(api.name[:-3])
371
372        if self.typeName == "VkImage" and api.name == "vkCreateImageWithRequirementsGOOGLE":
373            return True
374
375        if self.typeName == "VkBuffer" and api.name == "vkCreateBufferWithRequirementsGOOGLE":
376            return True
377
378        return False
379
380    def isDestroyedBy(self, api):
381        if self.typeName in HANDLE_INFO.keys():
382            nonKhrRes = HANDLE_INFO[self.typeName].isDestroyApi(api.name)
383            if nonKhrRes:
384                return True
385            if len(api.name) > 3 and "KHR" == api.name[-3:]:
386                return HANDLE_INFO[self.typeName].isDestroyApi(api.name[:-3])
387
388        return False
389
390    def isSimpleValueType(self, typeInfo):
391        if typeInfo.isCompoundType(self.typeName):
392            return False
393        if self.isString() or self.isArrayOfStrings():
394            return False
395        if self.staticArrExpr or self.pointerIndirectionLevels > 0:
396            return False
397        return True
398
399    def getStructEnumExpr(self,):
400        return None
401
402    def getPrintFormatSpecifier(self):
403        kKnownTypePrintFormatSpecifiers = {
404            'float': '%f',
405            'int': '%d',
406            'int32_t': '%d',
407            'size_t': '%ld',
408            'uint16_t': '%d',
409            'uint32_t': '%d',
410            'uint64_t': '%ld',
411            'VkBool32': '%d',
412            'VkDeviceSize': '%ld',
413            'VkFormat': '%d',
414            'VkImageLayout': '%d',
415        }
416
417        if self.pointerIndirectionLevels > 0 or self.isHandleType():
418            return '%p'
419
420        if self.typeName in kKnownTypePrintFormatSpecifiers:
421            return kKnownTypePrintFormatSpecifiers[self.typeName]
422
423        if self.typeName.endswith('Flags'):
424            # Based on `typedef uint32_t VkFlags;`
425            return '%d'
426
427        return None
428    def isOptionalPointer(self) -> bool:
429        return self.isOptional and \
430               self.pointerIndirectionLevels > 0 and \
431               (not self.isNextPointer())
432
433
434# Is an S-expression w/ the following spec:
435# From https://gist.github.com/pib/240957
436class Atom(object):
437    def __init__(self, name):
438        self.name = name
439    def __repr__(self,):
440        return self.name
441
442def parse_sexp(sexp):
443    atom_end = set('()"\'') | set(whitespace)
444    stack, i, length = [[]], 0, len(sexp)
445    while i < length:
446        c = sexp[i]
447
448        reading = type(stack[-1])
449        if reading == list:
450            if   c == '(': stack.append([])
451            elif c == ')':
452                stack[-2].append(stack.pop())
453                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
454            elif c == '"': stack.append('')
455            elif c == "'": stack.append([('quote',)])
456            elif c in whitespace: pass
457            else: stack.append(Atom(c))
458        elif reading == str:
459            if   c == '"':
460                stack[-2].append(stack.pop())
461                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
462            elif c == '\\':
463                i += 1
464                stack[-1] += sexp[i]
465            else: stack[-1] += c
466        elif reading == Atom:
467            if c in atom_end:
468                atom = stack.pop()
469                if atom.name[0].isdigit(): stack[-1].append(eval(atom.name))
470                else: stack[-1].append(atom)
471                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
472                continue
473            else: stack[-1] = Atom(stack[-1].name + c)
474        i += 1
475
476    return stack.pop()
477
478class FuncExprVal(object):
479    def __init__(self, val):
480        self.val = val
481    def __repr__(self,):
482        return self.val.__repr__()
483
484class FuncExpr(object):
485    def __init__(self, name, args):
486        self.name = name
487        self.args = args
488    def __repr__(self,):
489        if len(self.args) == 0:
490            return "(%s)" % (self.name.__repr__())
491        else:
492            return "(%s %s)" % (self.name.__repr__(), " ".join(map(lambda x: x.__repr__(), self.args)))
493
494class FuncLambda(object):
495    def __init__(self, vs, body):
496        self.vs = vs
497        self.body = body
498    def __repr__(self,):
499        return "(L (%s) %s)" % (" ".join(map(lambda x: x.__repr__(), self.vs)), self.body.__repr__())
500
501class FuncLambdaParam(object):
502    def __init__(self, name, typ):
503        self.name = name
504        self.typ = typ
505    def __repr__(self,):
506        return "%s : %s" % (self.name, self.typ)
507
508def parse_func_expr(parsed_sexp):
509    if len(parsed_sexp) != 1:
510        print("Error: parsed # expressions != 1: %d" % (len(parsed_sexp)))
511        raise
512
513    e = parsed_sexp[0]
514
515    def parse_lambda_param(e):
516        return FuncLambdaParam(e[0].name, e[1].name)
517
518    def parse_one(exp):
519        if list == type(exp):
520            if "lambda" == exp[0].__repr__():
521                return FuncLambda(list(map(parse_lambda_param, exp[1])), parse_one(exp[2]))
522            else:
523                return FuncExpr(exp[0], list(map(parse_one, exp[1:])))
524        else:
525            return FuncExprVal(exp)
526
527    return parse_one(e)
528
529def parseFilterFuncExpr(expr):
530    res = parse_func_expr(parse_sexp(expr))
531    print("parseFilterFuncExpr: parsed %s" % res)
532    return res
533
534def parseLetBodyExpr(expr):
535    res = parse_func_expr(parse_sexp(expr))
536    print("parseLetBodyExpr: parsed %s" % res)
537    return res
538
539
540def makeVulkanTypeFromXMLTag(typeInfo, tag: Element) -> VulkanType:
541    res = VulkanType()
542
543    # Process the length expression
544
545    if tag.attrib.get("len") is not None:
546        lengths = tag.attrib.get("len").split(",")
547        res.lenExpr = lengths[0]
548
549    # Calculate static array expression
550
551    nametag = tag.find("name")
552    enumtag = tag.find("enum")
553
554    if enumtag is not None:
555        res.staticArrExpr = enumtag.text
556    elif nametag is not None:
557        res.rawStaticArrExpr = noneStr(nametag.tail)
558
559        dimensions = res.rawStaticArrExpr.count('[')
560        if dimensions == 1:
561            res.staticArrExpr = res.rawStaticArrExpr[1:-1]
562        elif dimensions > 1:
563            arraySizes = res.rawStaticArrExpr[1:-1].split('][')
564            res.staticArrExpr = '(' + \
565                '*'.join(f'({size})' for size in arraySizes) + ')'
566
567    # Determine const
568
569    beforeTypePart = noneStr(tag.text)
570
571    if "const" in beforeTypePart:
572        res.isConst = True
573
574    # Calculate type and pointer info
575    for elem in tag:
576        if elem.tag == "name":
577            res.paramName = elem.text
578        if elem.tag == "type":
579            duringTypePart = noneStr(elem.text)
580            afterTypePart = noneStr(elem.tail)
581            # Now we know enough to fill some stuff in
582            res.typeName = duringTypePart
583
584            if res.typeName in TRANSFORMED_TYPES:
585                res.isTransformed = True
586
587            # This only handles pointerIndirectionLevels == 2
588            # along with optional constant pointer for the inner part.
589            for c in afterTypePart:
590                if c == "*":
591                    res.pointerIndirectionLevels += 1
592            if "const" in afterTypePart and res.pointerIndirectionLevels == 2:
593                res.isPointerToConstPointer = True
594
595            # If void*, treat like it's not a pointer
596            # if duringTypePart == "void":
597            # res.pointerIndirectionLevels -= 1
598
599    # Calculate optionality (based on validitygenerator.py)
600    if tag.attrib.get("optional") is not None:
601        res.isOptional = True
602        res.optionalStr = tag.attrib.get("optional")
603
604    # If no validity is being generated, it usually means that
605    # validity is complex and not absolute, so let's say yes.
606    if tag.attrib.get("noautovalidity") is not None:
607        res.isOptional = True
608
609    # If this is a structure extension, it is optional.
610    if tag.attrib.get("structextends") is not None:
611        res.isOptional = True
612
613    # If this is a pNext pointer, it is optional.
614    if res.paramName == "pNext":
615        res.isOptional = True
616
617    res.primitiveEncodingSize = typeInfo.getPrimitiveEncodingSize(res.typeName)
618
619    # Annotations: Environment binds
620    if tag.attrib.get("binds") is not None:
621        bindPairs = map(lambda x: x.strip(), tag.attrib.get("binds").split(","))
622        bindPairsSplit = map(lambda p: p.split(":"), bindPairs)
623        res.binds = dict(map(lambda sp: (sp[0].strip(), sp[1].strip()), bindPairsSplit))
624
625    # Annotations: Device memory
626    for k in DEVICE_MEMORY_INFO_KEYS:
627        if tag.attrib.get(k) is not None:
628            res.deviceMemoryAttrib = k
629            res.deviceMemoryVal = tag.attrib.get(k)
630            break
631
632    # Annotations: Filters
633    if tag.attrib.get("filterVar") is not None:
634        res.filterVar = tag.attrib.get("filterVar").strip()
635
636    if tag.attrib.get("filterVals") is not None:
637        res.filterVals = \
638            list(map(lambda v: v.strip(),
639                    tag.attrib.get("filterVals").strip().split(",")))
640        print("Filtervals: %s" % res.filterVals)
641
642    if tag.attrib.get("filterFunc") is not None:
643        res.filterFunc = parseFilterFuncExpr(tag.attrib.get("filterFunc"))
644
645    if tag.attrib.get("filterOtherwise") is not None:
646        res.Otherwise = tag.attrib.get("filterOtherwise")
647
648    # store all other attribs here
649    res.attribs = dict(tag.attrib)
650
651    return res
652
653
654def makeVulkanTypeSimple(isConst,
655                         typeName,
656                         ptrIndirectionLevels,
657                         paramName=None):
658    res = VulkanType()
659
660    res.typeName = typeName
661    res.isConst = isConst
662    res.pointerIndirectionLevels = ptrIndirectionLevels
663    res.isPointerToConstPointer = False
664    res.paramName = paramName
665    res.primitiveEncodingSize = None
666
667    return res
668
669# A class for holding the parameter indices corresponding to various
670# attributes about a VkDeviceMemory, such as the handle, size, offset, etc.
671class DeviceMemoryInfoParameterIndices(object):
672    def __init__(self, handle, offset, size, typeIndex, typeBits):
673        self.handle = handle
674        self.offset = offset
675        self.size = size
676        self.typeIndex = typeIndex
677        self.typeBits = typeBits
678
679# initializes DeviceMemoryInfoParameterIndices for each
680# abstract VkDeviceMemory encountered over |parameters|
681def initDeviceMemoryInfoParameterIndices(parameters):
682
683    use = False
684    deviceMemoryInfoById = {}
685
686    for (i, p) in enumerate(parameters):
687        a = p.deviceMemoryAttrib
688        if not a:
689            continue
690
691        if a in DEVICE_MEMORY_INFO_KEYS:
692            use = True
693            deviceMemoryInfoById[p.deviceMemoryVal] =  DeviceMemoryInfoParameterIndices(
694                        None, None, None, None, None)
695
696    for (i, p) in enumerate(parameters):
697        a = p.deviceMemoryAttrib
698        if not a:
699            continue
700
701        info = deviceMemoryInfoById[p.deviceMemoryVal]
702
703        if a == "devicememoryhandle":
704            info.handle = i
705        if a == "devicememoryoffset":
706            info.offset = i
707        if a == "devicememorysize":
708            info.size = i
709        if a == "devicememorytypeindex":
710            info.typeIndex = i
711        if a == "devicememorytypebits":
712            info.typeBits = i
713
714    if not use:
715        return None
716
717    return deviceMemoryInfoById
718
719# Classes for describing aggregate types (unions, structs) and API calls.
720class VulkanCompoundType(object):
721
722    def __init__(self, name: str, members: List[VulkanType], isUnion=False, structEnumExpr=None, structExtendsExpr=None, feature=None, initialEnv={}, optional=None):
723        self.name: str = name
724        self.typeName: str = name
725        self.members: List[VulkanType] = members
726        self.environment = initialEnv
727        self.isUnion = isUnion
728        self.structEnumExpr = structEnumExpr
729        self.structExtendsExpr = structExtendsExpr
730        self.feature = feature
731        self.deviceMemoryInfoParameterIndices = initDeviceMemoryInfoParameterIndices(self.members)
732        self.isTransformed = name in TRANSFORMED_TYPES
733        self.copy = None
734        self.optionalStr = optional
735
736    def initCopies(self):
737        self.copy = self
738
739        for m in self.members:
740            m.parent = self.copy
741
742    def getMember(self, memberName) -> Optional[VulkanType]:
743        for m in self.members:
744            if m.paramName == memberName:
745                return m
746        return None
747
748    def getStructEnumExpr(self,):
749        return self.structEnumExpr
750
751class VulkanAPI(object):
752
753    def __init__(self, name: str, retType: VulkanType, parameters, origName=None):
754        self.name: str = name
755        self.origName = name
756        self.retType: VulkanType = retType
757        self.parameters: List[VulkanType] = parameters
758
759        self.deviceMemoryInfoParameterIndices = initDeviceMemoryInfoParameterIndices(self.parameters)
760
761        self.copy = None
762
763        self.isTransformed = name in TRANSFORMED_TYPES
764
765        if origName:
766            self.origName = origName
767
768    def initCopies(self):
769        self.copy = self
770
771        for m in self.parameters:
772            m.parent = self.copy
773
774    def getCopy(self,):
775        return copy(self)
776
777    def getParameter(self, parameterName):
778        for p in self.parameters:
779            if p.paramName == parameterName:
780                return p
781        return None
782
783    def withModifiedName(self, newName):
784        res = VulkanAPI(newName, self.retType, self.parameters)
785        return res
786
787    def getRetVarExpr(self):
788        if self.retType.typeName == "void":
789            return None
790        return "%s_%s_return" % (self.name, self.retType.typeName)
791
792    def getRetTypeExpr(self):
793        return self.retType.typeName
794
795    def withCustomParameters(self, customParams):
796        res = self.getCopy()
797        res.parameters = customParams
798        return res
799
800    def withCustomReturnType(self, retType):
801        res = self.getCopy()
802        res.retType = retType
803        return res
804
805# Whether or not special handling of virtual elements
806# such as VkDeviceMemory is needed.
807def vulkanTypeNeedsTransform(structOrApi):
808    return structOrApi.deviceMemoryInfoParameterIndices != None
809
810def vulkanTypeGetNeededTransformTypes(structOrApi):
811    res = []
812    if structOrApi.deviceMemoryInfoParameterIndices != None:
813        res.append("devicememory")
814    return res
815
816def vulkanTypeforEachSubType(structOrApi, f):
817    toLoop = None
818    if type(structOrApi) == VulkanCompoundType:
819        toLoop = structOrApi.members
820    if type(structOrApi) == VulkanAPI:
821        toLoop = structOrApi.parameters
822
823    for (i, x) in enumerate(toLoop):
824        f(i, x)
825
826# Parses everything about Vulkan types into a Python readable format.
827class VulkanTypeInfo(object):
828
829    def __init__(self, generator):
830        self.generator = generator
831        self.categories: Set[str] = set([])
832
833        # Tracks what Vulkan type is part of what category.
834        self.typeCategories: Dict[str, str] = {}
835
836        # Tracks the primitive encoding size for each type, if applicable.
837        self.encodingSizes: Dict[str, Optional[int]] = {}
838
839        self.structs: Dict[str, VulkanCompoundType] = {}
840        self.apis: Dict[str, VulkanAPI] = {}
841
842        # Maps bitmask types to the enum type used for the flags
843        # E.g. "VkImageAspectFlags" -> "VkImageAspectFlagBits"
844        self.bitmasks: Dict[str, str] = {}
845
846        # Maps all enum names to their values.
847        # For aliases, the value is the name of the canonical enum
848        self.enumValues: Dict[str, Union[int, str]] = {}
849
850        self.feature = None
851
852    def initType(self, name: str, category: str):
853        self.categories.add(category)
854        self.typeCategories[name] = category
855        self.encodingSizes[name] = self.setPrimitiveEncodingSize(name)
856
857    def categoryOf(self, name):
858        return self.typeCategories[name]
859
860    def getPrimitiveEncodingSize(self, name):
861        return self.encodingSizes[name]
862
863    # Queries relating to categories of Vulkan types.
864    def isHandleType(self, name):
865        return self.typeCategories.get(name) == "handle"
866
867    def isCompoundType(self, name: str):
868        return self.typeCategories.get(name) in ["struct", "union"]
869
870    # Gets the best size in bytes
871    # for encoding/decoding a particular Vulkan type.
872    # If not applicable, returns None.
873    def setPrimitiveEncodingSize(self, name: str) -> Optional[int]:
874        baseEncodingSizes = {
875            "void": 8,
876            "char": 1,
877            "float": 4,
878            "uint8_t": 1,
879            "uint16_t": 2,
880            "uint32_t": 4,
881            "uint64_t": 8,
882            "int": 4,
883            "int8_t": 1,
884            "int16_t": 2,
885            "int32_t": 4,
886            "int64_t": 8,
887            "size_t": 8,
888            "ssize_t": 8,
889            "VkBool32": 4,
890        }
891
892        if name in baseEncodingSizes:
893            return baseEncodingSizes[name]
894
895        category = self.typeCategories[name]
896
897        if category in [None, "api", "include", "define", "struct", "union"]:
898            return None
899
900        # Handles are pointers so they must be 8 bytes. Basetype includes VkDeviceSize which is 8 bytes.
901        if category in ["handle", "basetype", "funcpointer"]:
902            return 8
903
904        if category in ["enum", "bitmask"]:
905            return 4
906
907    def isNonAbiPortableType(self, typeName):
908        if typeName in EXPLICITLY_ABI_PORTABLE_TYPES:
909            return False
910
911        if typeName in EXPLICITLY_ABI_NON_PORTABLE_TYPES:
912            return True
913
914        category = self.typeCategories[typeName]
915        return category in NON_ABI_PORTABLE_TYPE_CATEGORIES
916
917    def onBeginFeature(self, featureName, featureType):
918        self.feature = featureName
919
920    def onEndFeature(self):
921        self.feature = None
922
923    def onGenType(self, typeinfo, name, alias):
924        category = typeinfo.elem.get("category")
925        self.initType(name, category)
926
927        if category in ["struct", "union"]:
928            self.onGenStruct(typeinfo, name, alias)
929
930        if category == "bitmask":
931            self.bitmasks[name] = typeinfo.elem.get("requires")
932
933    def onGenStruct(self, typeinfo, typeName, alias):
934        if not alias:
935            members: List[VulkanType] = []
936
937            structExtendsExpr = typeinfo.elem.get("structextends")
938
939            structEnumExpr = None
940
941            initialEnv = {}
942            envStr = typeinfo.elem.get("exists")
943            if envStr != None:
944                comma_separated = envStr.split(",")
945                name_type_pairs = map(lambda cs: tuple(map(lambda t: t.strip(), cs.split(":"))), comma_separated)
946                for (name, typ) in name_type_pairs:
947                    initialEnv[name] = {
948                        "type" : typ,
949                        "binding" : None,
950                        "structmember" : False,
951                        "body" : None,
952                    }
953
954            letenvStr = typeinfo.elem.get("let")
955            if letenvStr != None:
956                comma_separated = letenvStr.split(",")
957                name_body_pairs = map(lambda cs: tuple(map(lambda t: t.strip(), cs.split(":"))), comma_separated)
958                for (name, body) in name_body_pairs:
959                    initialEnv[name] = {
960                        "type" : "uint32_t",
961                        "binding" : name,
962                        "structmember" : False,
963                        "body" : parseLetBodyExpr(body)
964                    }
965
966            for member in typeinfo.elem.findall(".//member"):
967                vulkanType = makeVulkanTypeFromXMLTag(self, member)
968                initialEnv[vulkanType.paramName] = {
969                    "type": vulkanType.typeName,
970                    "binding": vulkanType.paramName,
971                    "structmember": True,
972                    "body": None,
973                }
974                members.append(vulkanType)
975                if vulkanType.typeName == "VkStructureType" and \
976                   member.get("values"):
977                   structEnumExpr = member.get("values")
978
979            self.structs[typeName] = \
980                VulkanCompoundType(
981                    typeName,
982                    members,
983                    isUnion = self.categoryOf(typeName) == "union",
984                    structEnumExpr = structEnumExpr,
985                    structExtendsExpr = structExtendsExpr,
986                    feature = self.feature,
987                    initialEnv = initialEnv,
988                    optional = typeinfo.elem.get("optional", None))
989            self.structs[typeName].initCopies()
990
991    def onGenGroup(self, groupinfo, groupName, _alias=None):
992        self.initType(groupName, "enum")
993        enums = groupinfo.elem.findall("enum")
994        for enum in enums:
995            intVal, strVal = self.generator.enumToValue(enum, True)
996            self.enumValues[enum.get('name')] = intVal if intVal is not None else strVal
997
998
999    def onGenEnum(self, enuminfo, name: str, alias):
1000        self.initType(name, "enum")
1001        value: str = enuminfo.elem.get("value")
1002        if value and value.isdigit():
1003            self.enumValues[name] = int(value)
1004        elif value and value[0] == '"' and value[-1] == '"':
1005            self.enumValues[name] = value[1:-1]
1006        elif alias is not None:
1007            self.enumValues[name] = alias
1008        else:
1009            # There's about a dozen cases of using the bitwise NOT operator (e.g.: `(~0U)`, `(~0ULL)`)
1010            # to concisely represent large values. Just ignore them for now.
1011            # In the future, we can add a lookup table to convert these to int
1012            return
1013
1014    def onGenCmd(self, cmdinfo, name, _alias):
1015        self.initType(name, "api")
1016
1017        proto = cmdinfo.elem.find("proto")
1018        params = cmdinfo.elem.findall("param")
1019
1020        self.apis[name] = \
1021            VulkanAPI(
1022                name,
1023                makeVulkanTypeFromXMLTag(self, proto),
1024                list(map(lambda p: makeVulkanTypeFromXMLTag(self, p),
1025                         params)))
1026        self.apis[name].initCopies()
1027
1028    def onEnd(self):
1029        pass
1030
1031def hasNullOptionalStringFeature(forEachType):
1032    return (hasattr(forEachType, "onCheckWithNullOptionalStringFeature")) and \
1033           (hasattr(forEachType, "endCheckWithNullOptionalStringFeature")) and \
1034           (hasattr(forEachType, "finalCheckWithNullOptionalStringFeature"))
1035
1036
1037# General function to iterate over a vulkan type and call code that processes
1038# each of its sub-components, if any.
1039def iterateVulkanType(typeInfo: VulkanTypeInfo, vulkanType: VulkanType, forEachType):
1040    if not vulkanType.isArrayOfStrings():
1041        if vulkanType.isPointerToConstPointer:
1042            return False
1043
1044    forEachType.registerTypeInfo(typeInfo)
1045
1046    needCheck = vulkanType.isOptionalPointer()
1047
1048    if typeInfo.isCompoundType(vulkanType.typeName) and not vulkanType.isNextPointer():
1049
1050        if needCheck:
1051            forEachType.onCheck(vulkanType)
1052
1053        forEachType.onCompoundType(vulkanType)
1054
1055        if needCheck:
1056            forEachType.endCheck(vulkanType)
1057
1058    else:
1059        if vulkanType.isString():
1060            if needCheck and hasNullOptionalStringFeature(forEachType):
1061                forEachType.onCheckWithNullOptionalStringFeature(vulkanType)
1062                forEachType.onString(vulkanType)
1063                forEachType.endCheckWithNullOptionalStringFeature(vulkanType)
1064                forEachType.onString(vulkanType)
1065                forEachType.finalCheckWithNullOptionalStringFeature(vulkanType)
1066            elif needCheck:
1067                forEachType.onCheck(vulkanType)
1068                forEachType.onString(vulkanType)
1069                forEachType.endCheck(vulkanType)
1070            else:
1071                forEachType.onString(vulkanType)
1072
1073        elif vulkanType.isArrayOfStrings():
1074            forEachType.onStringArray(vulkanType)
1075
1076        elif vulkanType.staticArrExpr:
1077            forEachType.onStaticArr(vulkanType)
1078
1079        elif vulkanType.isNextPointer():
1080            if needCheck:
1081                forEachType.onCheck(vulkanType)
1082            forEachType.onStructExtension(vulkanType)
1083            if needCheck:
1084                forEachType.endCheck(vulkanType)
1085
1086        elif vulkanType.pointerIndirectionLevels > 0:
1087            if needCheck:
1088                forEachType.onCheck(vulkanType)
1089            forEachType.onPointer(vulkanType)
1090            if needCheck:
1091                forEachType.endCheck(vulkanType)
1092
1093        else:
1094            forEachType.onValue(vulkanType)
1095
1096    return True
1097
1098class VulkanTypeIterator(object):
1099    def __init__(self,):
1100        self.typeInfo = None
1101
1102    def registerTypeInfo(self, typeInfo):
1103        self.typeInfo = typeInfo
1104
1105def vulkanTypeGetStructFieldLengthInfo(structInfo, vulkanType):
1106    def getSpecialCaseVulkanStructFieldLength(structInfo, vulkanType):
1107        cases = [
1108            {
1109                "structName": "VkShaderModuleCreateInfo",
1110                "field": "pCode",
1111                "lenExpr": "codeSize",
1112                "postprocess": lambda expr: "(%s / 4)" % expr
1113            },
1114            {
1115                "structName": "VkPipelineMultisampleStateCreateInfo",
1116                "field": "pSampleMask",
1117                "lenExpr": "rasterizationSamples",
1118                "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
1119            },
1120        ]
1121
1122        for c in cases:
1123            if (structInfo.name, vulkanType.paramName) == (c["structName"], c["field"]):
1124                return c
1125
1126        return None
1127
1128    specialCaseAccess = getSpecialCaseVulkanStructFieldLength(structInfo, vulkanType)
1129
1130    if specialCaseAccess is not None:
1131        return specialCaseAccess
1132
1133    lenExpr = vulkanType.getLengthExpression()
1134
1135    if lenExpr is None:
1136        return None
1137
1138    return {
1139        "structName": structInfo.name,
1140        "field": vulkanType.typeName,
1141        "lenExpr": lenExpr,
1142        "postprocess": lambda expr: expr}
1143
1144
1145class VulkanTypeProtobufInfo(object):
1146    def __init__(self, typeInfo, structInfo, vulkanType):
1147        self.needsMessage = typeInfo.isCompoundType(vulkanType.typeName)
1148        self.isRepeatedString = vulkanType.isArrayOfStrings()
1149        self.isString = vulkanType.isString() or (
1150            vulkanType.typeName == "char" and (vulkanType.staticArrExpr != ""))
1151
1152        if structInfo is not None:
1153            self.lengthInfo = vulkanTypeGetStructFieldLengthInfo(
1154                structInfo, vulkanType)
1155        else:
1156            self.lengthInfo = vulkanType.getLengthExpression()
1157
1158        self.protobufType = None
1159        self.origTypeCategory = typeInfo.categoryOf(vulkanType.typeName)
1160
1161        self.isExtensionStruct = \
1162            vulkanType.typeName == "void" and \
1163            vulkanType.pointerIndirectionLevels > 0 and \
1164            vulkanType.paramName == "pNext"
1165
1166        if self.needsMessage:
1167            return
1168
1169        if typeInfo.categoryOf(vulkanType.typeName) in ["enum", "bitmask"]:
1170            self.protobufType = "uint32"
1171
1172        if typeInfo.categoryOf(vulkanType.typeName) in ["funcpointer", "handle", "define"]:
1173            self.protobufType = "uint64"
1174
1175        if typeInfo.categoryOf(vulkanType.typeName) in ["basetype"]:
1176            baseTypeMapping = {
1177                "VkFlags" : "uint32",
1178                "VkBool32" : "uint32",
1179                "VkDeviceSize" : "uint64",
1180                "VkSampleMask" : "uint32",
1181            }
1182            self.protobufType = baseTypeMapping[vulkanType.typeName]
1183
1184        if typeInfo.categoryOf(vulkanType.typeName) == None:
1185
1186            otherTypeMapping = {
1187                "void" : "uint64",
1188                "char" : "uint8",
1189                "size_t" : "uint64",
1190                "float" : "float",
1191                "uint8_t" : "uint32",
1192                "uint16_t" : "uint32",
1193                "int32_t" : "int32",
1194                "uint32_t" : "uint32",
1195                "uint64_t" : "uint64",
1196                "VkDeviceSize" : "uint64",
1197                "VkSampleMask" : "uint32",
1198            }
1199
1200            if vulkanType.typeName in otherTypeMapping:
1201                self.protobufType = otherTypeMapping[vulkanType.typeName]
1202            else:
1203                self.protobufType = "uint64"
1204
1205
1206        protobufCTypeMapping = {
1207            "uint8" : "uint8_t",
1208            "uint32" : "uint32_t",
1209            "int32" : "int32_t",
1210            "uint64" : "uint64_t",
1211            "float" : "float",
1212            "string" : "const char*",
1213        }
1214
1215        self.protobufCType = protobufCTypeMapping[self.protobufType]
1216
1217