• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from generator import noneStr
17
18from copy import copy
19from string import whitespace
20
21# Holds information about core Vulkan objects
22# and the API calls that are used to create/destroy each one.
23class HandleInfo(object):
24    def __init__(self, name, createApis, destroyApis):
25        self.name = name
26        self.createApis = createApis
27        self.destroyApis = destroyApis
28
29    def isCreateApi(self, apiName):
30        return apiName == self.createApis or (apiName in self.createApis)
31
32    def isDestroyApi(self, apiName):
33        if self.destroyApis is None:
34            return False
35        return apiName == self.destroyApis or (apiName in self.destroyApis)
36
37DISPATCHABLE_HANDLE_TYPES = [
38    "VkInstance",
39    "VkPhysicalDevice",
40    "VkDevice",
41    "VkQueue",
42    "VkCommandBuffer",
43]
44
45NON_DISPATCHABLE_HANDLE_TYPES = [
46    "VkDeviceMemory",
47    "VkBuffer",
48    "VkBufferView",
49    "VkImage",
50    "VkImageView",
51    "VkShaderModule",
52    "VkDescriptorPool",
53    "VkDescriptorSetLayout",
54    "VkDescriptorSet",
55    "VkSampler",
56    "VkPipeline",
57    "VkPipelineLayout",
58    "VkRenderPass",
59    "VkFramebuffer",
60    "VkPipelineCache",
61    "VkCommandPool",
62    "VkFence",
63    "VkSemaphore",
64    "VkEvent",
65    "VkQueryPool",
66    "VkSamplerYcbcrConversion",
67    "VkSamplerYcbcrConversionKHR",
68    "VkDescriptorUpdateTemplate",
69    "VkSurfaceKHR",
70    "VkSwapchainKHR",
71    "VkDisplayKHR",
72    "VkDisplayModeKHR",
73    "VkObjectTableNVX",
74    "VkIndirectCommandsLayoutNVX",
75    "VkValidationCacheEXT",
76    "VkDebugReportCallbackEXT",
77    "VkDebugUtilsMessengerEXT",
78    "VkAccelerationStructureNV",
79    "VkIndirectCommandsLayoutNV",
80    "VkAccelerationStructureKHR",
81]
82
83CUSTOM_HANDLE_CREATE_TYPES = [
84    "VkPhysicalDevice",
85    "VkQueue",
86    "VkPipeline",
87    "VkDeviceMemory",
88    "VkDescriptorSet",
89    "VkCommandBuffer",
90]
91
92HANDLE_TYPES = list(sorted(list(set(DISPATCHABLE_HANDLE_TYPES +
93                                    NON_DISPATCHABLE_HANDLE_TYPES + CUSTOM_HANDLE_CREATE_TYPES))))
94
95HANDLE_INFO = {}
96
97for h in HANDLE_TYPES:
98    if h in CUSTOM_HANDLE_CREATE_TYPES:
99        if h == "VkPhysicalDevice":
100            HANDLE_INFO[h] = \
101                HandleInfo(
102                    "VkPhysicalDevice",
103                    "vkEnumeratePhysicalDevices", None)
104        if h == "VkQueue":
105            HANDLE_INFO[h] = \
106                HandleInfo(
107                    "VkQueue",
108                    "vkGetDeviceQueue", None)
109        if h == "VkPipeline":
110            HANDLE_INFO[h] = \
111                HandleInfo(
112                    "VkPipeline",
113                    ["vkCreateGraphicsPipelines", "vkCreateComputePipelines"],
114                    "vkDestroyPipeline")
115        if h == "VkDeviceMemory":
116            HANDLE_INFO[h] = \
117                HandleInfo("VkDeviceMemory",
118                           "vkAllocateMemory", ["vkFreeMemory", "vkFreeMemorySyncGOOGLE"])
119        if h == "VkDescriptorSet":
120            HANDLE_INFO[h] = \
121                HandleInfo("VkDescriptorSet", "vkAllocateDescriptorSets",
122                           "vkFreeDescriptorSets")
123        if h == "VkCommandBuffer":
124            HANDLE_INFO[h] = \
125                HandleInfo("VkCommandBuffer", "vkAllocateCommandBuffers",
126                           "vkFreeCommandBuffers")
127    else:
128        HANDLE_INFO[h] = \
129            HandleInfo(h, "vkCreate" + h[2:], "vkDestroy" + h[2:])
130
131EXCLUDED_APIS = [
132    "vkEnumeratePhysicalDeviceGroups",
133]
134
135EXPLICITLY_ABI_PORTABLE_TYPES = [
136    "VkResult",
137    "VkBool32",
138    "VkSampleMask",
139    "VkFlags",
140    "VkDeviceSize",
141]
142
143EXPLICITLY_ABI_NON_PORTABLE_TYPES = [
144    "size_t"
145]
146
147NON_ABI_PORTABLE_TYPE_CATEGORIES = [
148    "handle",
149    "funcpointer",
150]
151
152DEVICE_MEMORY_INFO_KEYS = [
153    "devicememoryhandle",
154    "devicememoryoffset",
155    "devicememorysize",
156    "devicememorytypeindex",
157    "devicememorytypebits",
158]
159
160TRIVIAL_TRANSFORMED_TYPES = [
161    "VkPhysicalDeviceExternalImageFormatInfo",
162    "VkPhysicalDeviceExternalBufferInfo",
163    "VkExternalMemoryImageCreateInfo",
164    "VkExternalMemoryBufferCreateInfo",
165    "VkExportMemoryAllocateInfo",
166    "VkExternalImageFormatProperties",
167    "VkExternalBufferProperties",
168]
169
170NON_TRIVIAL_TRANSFORMED_TYPES = [
171    "VkExternalMemoryProperties",
172    "VkImageCreateInfo",
173]
174
175TRANSFORMED_TYPES = TRIVIAL_TRANSFORMED_TYPES + NON_TRIVIAL_TRANSFORMED_TYPES
176
177# Holds information about a Vulkan type instance (i.e., not a type definition).
178# Type instances are used as struct field definitions or function parameters,
179# to be later fed to code generation.
180# VulkanType instances can be constructed in two ways:
181# 1. From an XML tag with <type> / <param> tags in vk.xml,
182#    using makeVulkanTypeFromXMLTag
183# 2. User-defined instances with makeVulkanTypeSimple.
184class VulkanType(object):
185
186    def __init__(self):
187        self.parent = None
188        self.typeName = ""
189
190        self.isTransformed = False
191
192        self.paramName = None
193
194        self.lenExpr = None
195        self.isOptional = False
196        self.optionalStr = None
197
198        self.isConst = False
199
200        # "" means it's not a static array, otherwise this is the total size of
201        # all elements. e.g. staticArrExpr of "x[3][2][8]" will be "((3)*(2)*(8))".
202        self.staticArrExpr = ""
203        # "" means it's not a static array, otherwise it's the raw expression
204        # of static array size, which can be one-dimensional or multi-dimensional.
205        self.rawStaticArrExpr = ""
206
207        self.pointerIndirectionLevels = 0  # 0 means not pointer
208        self.isPointerToConstPointer = False
209
210        self.primitiveEncodingSize = None
211
212        self.deviceMemoryInfoParameterIndices = None
213
214        # Annotations
215        # Environment annotation for binding current
216        # variables to sub-structures
217        self.binds = {}
218
219        # Device memory annotations
220
221        # self.deviceMemoryAttrib/Val stores
222        # device memory info attributes from the XML.
223        # devicememoryhandle
224        # devicememoryoffset
225        # devicememorysize
226        # devicememorytypeindex
227        # devicememorytypebits
228        self.deviceMemoryAttrib = None
229        self.deviceMemoryVal = None
230
231        # Filter annotations
232        self.filterVar = None
233        self.filterVals = None
234        self.filterFunc = None
235        self.filterOtherwise = None
236
237        # Stream feature
238        self.streamFeature = None
239
240        # All other annotations
241        self.attribs = {}
242
243    def __str__(self,):
244        return ("(vulkantype %s %s paramName %s len %s optional? %s "
245                "staticArrExpr %s %s)") % (
246            self.typeName + ("*" * self.pointerIndirectionLevels) +
247            ("ptr2constptr" if self.isPointerToConstPointer else ""), "const"
248            if self.isConst else "nonconst", self.paramName, self.lenExpr,
249            self.isOptional, self.staticArrExpr, self.staticArrCount)
250
251    def isString(self,):
252        return self.pointerIndirectionLevels == 1 and (self.typeName == "char")
253
254    def isArrayOfStrings(self,):
255        return self.isPointerToConstPointer and (self.typeName == "char")
256
257    def primEncodingSize(self,):
258        return self.primitiveEncodingSize
259
260    # Utility functions to make codegen life easier.
261    # This method derives the correct "count" expression if possible.
262    # Otherwise, returns None or "null-terminated" if a string.
263    def getLengthExpression(self,):
264        if self.staticArrExpr != "":
265            return self.staticArrExpr
266        if self.lenExpr:
267            return self.lenExpr
268        return None
269
270    # Can we just pass this to functions expecting T*
271    def accessibleAsPointer(self,):
272        if self.staticArrExpr != "":
273            return True
274        if self.pointerIndirectionLevels > 0:
275            return True
276        return False
277
278    # Rough attempt to infer where a type could be an output.
279    # Good for inferring which things need to be marshaled in
280    # versus marshaled out for Vulkan API calls
281    def possiblyOutput(self,):
282        return self.pointerIndirectionLevels > 0 and (not self.isConst)
283
284    def isVoidWithNoSize(self,):
285        return self.typeName == "void" and self.pointerIndirectionLevels == 0
286
287    def getCopy(self,):
288        return copy(self)
289
290    def getTransformed(self, isConstChoice=None, ptrIndirectionChoice=None):
291        res = self.getCopy()
292
293        if isConstChoice is not None:
294            res.isConst = isConstChoice
295        if ptrIndirectionChoice is not None:
296            res.pointerIndirectionLevels = ptrIndirectionChoice
297
298        return res
299
300    def getWithCustomName(self):
301        return self.getTransformed(
302            ptrIndirectionChoice=self.pointerIndirectionLevels + 1)
303
304    def getForAddressAccess(self):
305        return self.getTransformed(
306            ptrIndirectionChoice=self.pointerIndirectionLevels + 1)
307
308    def getForValueAccess(self):
309        if self.typeName == "void" and self.pointerIndirectionLevels == 1:
310            asUint8Type = self.getCopy()
311            asUint8Type.typeName = "uint8_t"
312            return asUint8Type.getForValueAccess()
313        return self.getTransformed(
314            ptrIndirectionChoice=self.pointerIndirectionLevels - 1)
315
316    def getForNonConstAccess(self):
317        return self.getTransformed(isConstChoice=False)
318
319    def withModifiedName(self, newName):
320        res = self.getCopy()
321        res.paramName = newName
322        return res
323
324    def isNextPointer(self):
325        if self.paramName == "pNext":
326            return True
327        return False
328
329    # Only deals with 'core' handle types here.
330    def isDispatchableHandleType(self):
331        return self.typeName in DISPATCHABLE_HANDLE_TYPES
332
333    def isNonDispatchableHandleType(self):
334        return self.typeName in NON_DISPATCHABLE_HANDLE_TYPES
335
336    def isHandleType(self):
337        return self.isDispatchableHandleType() or \
338               self.isNonDispatchableHandleType()
339
340    def isCreatedBy(self, api):
341        if self.typeName in HANDLE_INFO.keys():
342            nonKhrRes = HANDLE_INFO[self.typeName].isCreateApi(api.name)
343            if nonKhrRes:
344                return True
345            if len(api.name) > 3 and "KHR" == api.name[-3:]:
346                return HANDLE_INFO[self.typeName].isCreateApi(api.name[:-3])
347
348        if self.typeName == "VkImage" and api.name == "vkCreateImageWithRequirementsGOOGLE":
349            return True
350
351        if self.typeName == "VkBuffer" and api.name == "vkCreateBufferWithRequirementsGOOGLE":
352            return True
353
354        return False
355
356    def isDestroyedBy(self, api):
357        if self.typeName in HANDLE_INFO.keys():
358            nonKhrRes = HANDLE_INFO[self.typeName].isDestroyApi(api.name)
359            if nonKhrRes:
360                return True
361            if len(api.name) > 3 and "KHR" == api.name[-3:]:
362                return HANDLE_INFO[self.typeName].isDestroyApi(api.name[:-3])
363
364        return False
365
366    def isSimpleValueType(self, typeInfo):
367        if typeInfo.isCompoundType(self.typeName):
368            return False
369        if self.isString() or self.isArrayOfStrings():
370            return False
371        if self.staticArrExpr or self.pointerIndirectionLevels > 0:
372            return False
373        return True
374
375    def getStructEnumExpr(self,):
376        return None
377
378# Is an S-expression w/ the following spec:
379# From https://gist.github.com/pib/240957
380class Atom(object):
381    def __init__(self, name):
382        self.name = name
383    def __repr__(self,):
384        return self.name
385
386def parse_sexp(sexp):
387    atom_end = set('()"\'') | set(whitespace)
388    stack, i, length = [[]], 0, len(sexp)
389    while i < length:
390        c = sexp[i]
391
392        reading = type(stack[-1])
393        if reading == list:
394            if   c == '(': stack.append([])
395            elif c == ')':
396                stack[-2].append(stack.pop())
397                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
398            elif c == '"': stack.append('')
399            elif c == "'": stack.append([('quote',)])
400            elif c in whitespace: pass
401            else: stack.append(Atom(c))
402        elif reading == str:
403            if   c == '"':
404                stack[-2].append(stack.pop())
405                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
406            elif c == '\\':
407                i += 1
408                stack[-1] += sexp[i]
409            else: stack[-1] += c
410        elif reading == Atom:
411            if c in atom_end:
412                atom = stack.pop()
413                if atom.name[0].isdigit(): stack[-1].append(eval(atom.name))
414                else: stack[-1].append(atom)
415                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
416                continue
417            else: stack[-1] = Atom(stack[-1].name + c)
418        i += 1
419
420    return stack.pop()
421
422class FuncExprVal(object):
423    def __init__(self, val):
424        self.val = val
425    def __repr__(self,):
426        return self.val.__repr__()
427
428class FuncExpr(object):
429    def __init__(self, name, args):
430        self.name = name
431        self.args = args
432    def __repr__(self,):
433        if len(self.args) == 0:
434            return "(%s)" % (self.name.__repr__())
435        else:
436            return "(%s %s)" % (self.name.__repr__(), " ".join(map(lambda x: x.__repr__(), self.args)))
437
438class FuncLambda(object):
439    def __init__(self, vs, body):
440        self.vs = vs
441        self.body = body
442    def __repr__(self,):
443        return "(L (%s) %s)" % (" ".join(map(lambda x: x.__repr__(), self.vs)), self.body.__repr__())
444
445class FuncLambdaParam(object):
446    def __init__(self, name, typ):
447        self.name = name
448        self.typ = typ
449    def __repr__(self,):
450        return "%s : %s" % (self.name, self.typ)
451
452def parse_func_expr(parsed_sexp):
453    if len(parsed_sexp) != 1:
454        print("Error: parsed # expressions != 1: %d" % (len(parsed_sexp)))
455        raise
456
457    e = parsed_sexp[0]
458
459    def parse_lambda_param(e):
460        return FuncLambdaParam(e[0].name, e[1].name)
461
462    def parse_one(exp):
463        if list == type(exp):
464            if "lambda" == exp[0].__repr__():
465                return FuncLambda(list(map(parse_lambda_param, exp[1])), parse_one(exp[2]))
466            else:
467                return FuncExpr(exp[0], list(map(parse_one, exp[1:])))
468        else:
469            return FuncExprVal(exp)
470
471    return parse_one(e)
472
473def parseFilterFuncExpr(expr):
474    res = parse_func_expr(parse_sexp(expr))
475    print("parseFilterFuncExpr: parsed %s" % res)
476    return res
477
478def parseLetBodyExpr(expr):
479    res = parse_func_expr(parse_sexp(expr))
480    print("parseLetBodyExpr: parsed %s" % res)
481    return res
482
483def makeVulkanTypeFromXMLTag(typeInfo, tag):
484    res = VulkanType()
485
486    # Process the length expression
487
488    if tag.attrib.get("len") is not None:
489        lengths = tag.attrib.get("len").split(",")
490        res.lenExpr = lengths[0]
491
492    # Calculate static array expression
493
494    nametag = tag.find("name")
495    enumtag = tag.find("enum")
496
497    if enumtag is not None:
498        res.staticArrExpr = enumtag.text
499    elif nametag is not None:
500        res.rawStaticArrExpr = noneStr(nametag.tail)
501
502        dimensions = res.rawStaticArrExpr.count('[')
503        if dimensions == 1:
504            res.staticArrExpr = res.rawStaticArrExpr[1:-1]
505        elif dimensions > 1:
506            arraySizes = res.rawStaticArrExpr[1:-1].split('][')
507            res.staticArrExpr = '(' + \
508                '*'.join(f'({size})' for size in arraySizes) + ')'
509
510    # Determine const
511
512    beforeTypePart = noneStr(tag.text)
513
514    if "const" in beforeTypePart:
515        res.isConst = True
516
517    # Calculate type and pointer info
518
519    for elem in tag:
520        if elem.tag == "name":
521            res.paramName = elem.text
522        if elem.tag == "type":
523            duringTypePart = noneStr(elem.text)
524            afterTypePart = noneStr(elem.tail)
525            # Now we know enough to fill some stuff in
526            res.typeName = duringTypePart
527
528            if res.typeName in TRANSFORMED_TYPES:
529                res.isTransformed = True
530
531            # This only handles pointerIndirectionLevels == 2
532            # along with optional constant pointer for the inner part.
533            for c in afterTypePart:
534                if c == "*":
535                    res.pointerIndirectionLevels += 1
536            if "const" in afterTypePart and res.pointerIndirectionLevels == 2:
537                res.isPointerToConstPointer = True
538
539            # If void*, treat like it's not a pointer
540            # if duringTypePart == "void":
541            # res.pointerIndirectionLevels -= 1
542
543    # Calculate optionality (based on validitygenerator.py)
544    if tag.attrib.get("optional") is not None:
545        res.isOptional = True
546        res.optionalStr = tag.attrib.get("optional")
547
548    # If no validity is being generated, it usually means that
549    # validity is complex and not absolute, so let's say yes.
550    if tag.attrib.get("noautovalidity") is not None:
551        res.isOptional = True
552
553    # If this is a structure extension, it is optional.
554    if tag.attrib.get("structextends") is not None:
555        res.isOptional = True
556
557    # If this is a pNext pointer, it is optional.
558    if res.paramName == "pNext":
559        res.isOptional = True
560
561    res.primitiveEncodingSize = \
562        typeInfo.getPrimitiveEncodingSize(res.typeName)
563
564    # Annotations: Environment binds
565    if tag.attrib.get("binds") is not None:
566        bindPairs = map(lambda x: x.strip(), tag.attrib.get("binds").split(","))
567        bindPairsSplit = map(lambda p: p.split(":"), bindPairs)
568        res.binds = dict(map(lambda sp: (sp[0].strip(), sp[1].strip()), bindPairsSplit))
569
570    # Annotations: Device memory
571    for k in DEVICE_MEMORY_INFO_KEYS:
572        if tag.attrib.get(k) is not None:
573            res.deviceMemoryAttrib = k
574            res.deviceMemoryVal = \
575                tag.attrib.get(k)
576            break;
577
578    # Annotations: Filters
579    if tag.attrib.get("filterVar") is not None:
580        res.filterVar = tag.attrib.get("filterVar").strip()
581
582    if tag.attrib.get("filterVals") is not None:
583        res.filterVals = \
584            list(map(lambda v: v.strip(),
585                    tag.attrib.get("filterVals").strip().split(",")))
586        print("Filtervals: %s" % res.filterVals)
587
588    if tag.attrib.get("filterFunc") is not None:
589        res.filterFunc = parseFilterFuncExpr(tag.attrib.get("filterFunc"))
590
591    if tag.attrib.get("filterOtherwise") is not None:
592        res.Otherwise = tag.attrib.get("filterOtherwise")
593
594    # store all other attribs here
595    res.attribs = dict(tag.attrib)
596
597    return res
598
599
600def makeVulkanTypeSimple(isConst,
601                         typeName,
602                         ptrIndirectionLevels,
603                         paramName=None):
604    res = VulkanType()
605
606    res.typeName = typeName
607    res.isConst = isConst
608    res.pointerIndirectionLevels = ptrIndirectionLevels
609    res.isPointerToConstPointer = False
610    res.paramName = paramName
611    res.primitiveEncodingSize = None
612
613    return res
614
615# A class for holding the parameter indices corresponding to various
616# attributes about a VkDeviceMemory, such as the handle, size, offset, etc.
617class DeviceMemoryInfoParameterIndices(object):
618    def __init__(self, handle, offset, size, typeIndex, typeBits):
619        self.handle = handle
620        self.offset = offset
621        self.size = size
622        self.typeIndex = typeIndex
623        self.typeBits = typeBits
624
625# initializes DeviceMemoryInfoParameterIndices for each
626# abstract VkDeviceMemory encountered over |parameters|
627def initDeviceMemoryInfoParameterIndices(parameters):
628
629    use = False
630    deviceMemoryInfoById = {}
631
632    for (i, p) in enumerate(parameters):
633        a = p.deviceMemoryAttrib
634        if not a:
635            continue
636
637        if a in DEVICE_MEMORY_INFO_KEYS:
638            use = True
639            deviceMemoryInfoById[
640                p.deviceMemoryVal] = \
641                    DeviceMemoryInfoParameterIndices(
642                        None, None, None, None, None)
643
644    for (i, p) in enumerate(parameters):
645        a = p.deviceMemoryAttrib
646        if not a:
647            continue
648
649        info = \
650            deviceMemoryInfoById[p.deviceMemoryVal]
651
652        if a == "devicememoryhandle":
653            info.handle = i
654        if a == "devicememoryoffset":
655            info.offset = i
656        if a == "devicememorysize":
657            info.size = i
658        if a == "devicememorytypeindex":
659            info.typeIndex = i
660        if a == "devicememorytypebits":
661            info.typeBits = i
662
663    if not use:
664        return None
665
666    return deviceMemoryInfoById
667
668# Classes for describing aggregate types (unions, structs) and API calls.
669class VulkanCompoundType(object):
670
671    def __init__(self, name, members, isUnion=False, structEnumExpr=None, structExtendsExpr=None, feature=None, initialEnv={}, optional=None):
672        self.name = name
673        self.typeName = name
674        self.members = members
675        self.environment = initialEnv
676
677        self.isUnion = isUnion
678        self.structEnumExpr = structEnumExpr
679        self.structExtendsExpr = structExtendsExpr
680        self.feature = feature
681
682        self.deviceMemoryInfoParameterIndices = \
683            initDeviceMemoryInfoParameterIndices(self.members)
684
685        self.isTransformed = name in TRANSFORMED_TYPES
686
687        self.copy = None
688
689        self.optionalStr = optional
690
691    def initCopies(self):
692        self.copy = self
693
694        for m in self.members:
695            m.parent = self.copy
696
697    def getMember(self, memberName):
698        for m in self.members:
699            if m.paramName == memberName:
700                return m
701        return None
702
703    def getStructEnumExpr(self,):
704        return self.structEnumExpr
705
706class VulkanAPI(object):
707
708    def __init__(self, name, retType, parameters, origName=None):
709        self.name = name
710        self.origName = name
711        self.retType = retType
712        self.parameters = parameters
713
714        self.deviceMemoryInfoParameterIndices = \
715            initDeviceMemoryInfoParameterIndices(self.parameters)
716
717        self.copy = None
718
719        self.isTransformed = name in TRANSFORMED_TYPES
720
721        if origName:
722            self.origName = origName
723
724    def initCopies(self):
725        self.copy = self
726
727        for m in self.parameters:
728            m.parent = self.copy
729
730    def getCopy(self,):
731        return copy(self)
732
733    def getParameter(self, parameterName):
734        for p in self.parameters:
735            if p.paramName == parameterName:
736                return p
737        return None
738
739    def withModifiedName(self, newName):
740        res = VulkanAPI(newName, self.retType, self.parameters)
741        return res
742
743    def getRetVarExpr(self):
744        if self.retType.typeName == "void":
745            return None
746        return "%s_%s_return" % (self.name, self.retType.typeName)
747
748    def getRetTypeExpr(self):
749        return self.retType.typeName
750
751    def withCustomParameters(self, customParams):
752        res = self.getCopy()
753        res.parameters = customParams
754        return res
755
756    def withCustomReturnType(self, retType):
757        res = self.getCopy()
758        res.retType = retType
759        return res
760
761# Whether or not special handling of virtual elements
762# such as VkDeviceMemory is needed.
763def vulkanTypeNeedsTransform(structOrApi):
764    return structOrApi.deviceMemoryInfoParameterIndices != None
765
766def vulkanTypeGetNeededTransformTypes(structOrApi):
767    res = []
768    if structOrApi.deviceMemoryInfoParameterIndices != None:
769        res.append("devicememory")
770    return res
771
772def vulkanTypeforEachSubType(structOrApi, f):
773    toLoop = None
774    if type(structOrApi) == VulkanCompoundType:
775        toLoop = structOrApi.members
776    if type(structOrApi) == VulkanAPI:
777        toLoop = structOrApi.parameters
778
779    for (i, x) in enumerate(toLoop):
780        f(i, x)
781
782# Parses everything about Vulkan types into a Python readable format.
783class VulkanTypeInfo(object):
784
785    def __init__(self,):
786        self.categories = set([])
787
788        # Tracks what Vulkan type is part of what category.
789        self.typeCategories = {}
790
791        # Tracks the primitve encoding size for each type,
792        # if applicable.
793        self.encodingSizes = {}
794
795        self.structs = {}
796        self.apis = {}
797
798        self.feature = None
799
800    def initType(self, name, category):
801        self.categories.add(category)
802        self.typeCategories[name] = category
803        self.encodingSizes[name] = self.setPrimitiveEncodingSize(name)
804
805    def categoryOf(self, name):
806        return self.typeCategories[name]
807
808    def getPrimitiveEncodingSize(self, name):
809        return self.encodingSizes[name]
810
811    # Queries relating to categories of Vulkan types.
812    def isHandleType(self, name):
813        if name in self.typeCategories:
814            return self.typeCategories[name] == "handle"
815        return False
816
817    def isCompoundType(self, name):
818        if name in self.typeCategories:
819            return self.typeCategories[name] in ["struct", "union"]
820        else:
821            return False
822
823    # Gets the best size in bytes
824    # for encoding/decoding a particular Vulkan type.
825    # If not applicable, returns None.
826    def setPrimitiveEncodingSize(self, name):
827        baseEncodingSizes = {
828            "void" : 8,
829            "char" : 1,
830            "float" : 4,
831            "uint8_t" : 1,
832            "uint16_t" : 2,
833            "uint32_t" : 4,
834            "uint64_t" : 8,
835            "size_t" : 8,
836            "ssize_t" : 8,
837        }
838
839        if name in baseEncodingSizes:
840            return baseEncodingSizes[name]
841
842        category = self.typeCategories[name]
843
844        if category in [None, "api", "bitmask", "include", "define", "struct", "union"]:
845            return None
846
847        # Must be 8---handles are pointers and basetype includes VkDeviceSize
848        # which is 8 bytes
849        if category in ["handle", "basetype", "funcpointer"]:
850            return 8
851
852        # Most of the time, enums are only 4 bytes, but this is
853        # vague enough to be the source of a future headache, and
854        # it's easy to just stream 8 bytes there anyway.
855        if category in ["enum"]:
856            return 8
857
858    def isNonAbiPortableType(self, typeName):
859        if typeName in EXPLICITLY_ABI_PORTABLE_TYPES:
860            return False
861
862        if typeName in EXPLICITLY_ABI_NON_PORTABLE_TYPES:
863            return True
864
865        category = self.typeCategories[typeName]
866        return category in NON_ABI_PORTABLE_TYPE_CATEGORIES
867
868    def onBeginFeature(self, featureName):
869        self.feature = featureName
870
871    def onEndFeature(self):
872        self.feature = None
873
874    def onGenType(self, typeinfo, name, alias):
875        category = typeinfo.elem.get("category")
876        self.initType(name, category)
877
878        if category in ["struct", "union"]:
879            self.onGenStruct(typeinfo, name, alias)
880
881    def onGenStruct(self, typeinfo, typeName, alias):
882        if not alias:
883            members = []
884
885            structExtendsExpr = typeinfo.elem.get("structextends")
886
887            structEnumExpr = None
888
889            initialEnv = {}
890            envStr = typeinfo.elem.get("exists")
891            if envStr != None:
892                comma_separated = envStr.split(",")
893                name_type_pairs = map(lambda cs: tuple(map(lambda t: t.strip(), cs.split(":"))), comma_separated)
894                for (name, typ) in name_type_pairs:
895                    initialEnv[name] = {
896                        "type" : typ,
897                        "binding" : None,
898                        "structmember" : False,
899                        "body" : None,
900                    }
901
902            letenvStr = typeinfo.elem.get("let")
903            if letenvStr != None:
904                comma_separated = letenvStr.split(",")
905                name_body_pairs = map(lambda cs: tuple(map(lambda t: t.strip(), cs.split(":"))), comma_separated)
906                for (name, body) in name_body_pairs:
907                    initialEnv[name] = {
908                        "type" : "uint32_t",
909                        "binding" : name,
910                        "structmember" : False,
911                        "body" : parseLetBodyExpr(body)
912                    }
913
914            for member in typeinfo.elem.findall(".//member"):
915                vulkanType = makeVulkanTypeFromXMLTag(self, member)
916                initialEnv[vulkanType.paramName] = {
917                    "type" : vulkanType.typeName,
918                    "binding" : vulkanType.paramName,
919                    "structmember" : True,
920                    "body" : None,
921                }
922                vulkanType.paramName
923                members.append(vulkanType)
924                if vulkanType.typeName == "VkStructureType" and \
925                   member.get("values"):
926                   structEnumExpr = member.get("values")
927
928            self.structs[typeName] = \
929                VulkanCompoundType( \
930                    typeName,
931                    members,
932                    isUnion = self.categoryOf(typeName) == "union",
933                    structEnumExpr = structEnumExpr,
934                    structExtendsExpr = structExtendsExpr,
935                    feature = self.feature,
936                    initialEnv = initialEnv,
937                    optional = typeinfo.elem.get("optional", None))
938            self.structs[typeName].initCopies()
939
940    def onGenGroup(self, _groupinfo, groupName, _alias=None):
941        self.initType(groupName, "enum")
942
943    def onGenEnum(self, _enuminfo, name, _alias):
944        self.initType(name, "enum")
945
946    def onGenCmd(self, cmdinfo, name, _alias):
947        self.initType(name, "api")
948
949        proto = cmdinfo.elem.find("proto")
950        params = cmdinfo.elem.findall("param")
951
952        self.apis[name] = \
953            VulkanAPI(
954                name,
955                makeVulkanTypeFromXMLTag(self, proto),
956                list(map(lambda p: makeVulkanTypeFromXMLTag(self, p),
957                         params)))
958        self.apis[name].initCopies()
959
960    def onEnd(self,):
961        pass
962
963def hasNullOptionalStringFeature(forEachType):
964    return (hasattr(forEachType, "onCheckWithNullOptionalStringFeature")) and \
965           (hasattr(forEachType, "endCheckWithNullOptionalStringFeature")) and \
966           (hasattr(forEachType, "finalCheckWithNullOptionalStringFeature"))
967
968# General function to iterate over a vulkan type and call code that processes
969# each of its sub-components, if any.
970def iterateVulkanType(typeInfo, vulkanType, forEachType):
971    if not vulkanType.isArrayOfStrings():
972        if vulkanType.isPointerToConstPointer:
973            return False
974
975    forEachType.registerTypeInfo(typeInfo)
976
977    needCheck = \
978        vulkanType.isOptional and \
979        vulkanType.pointerIndirectionLevels > 0 and \
980        (not vulkanType.isNextPointer())
981
982    if typeInfo.isCompoundType(vulkanType.typeName) and not vulkanType.isNextPointer():
983
984        if needCheck:
985            forEachType.onCheck(vulkanType)
986
987        forEachType.onCompoundType(vulkanType)
988
989        if needCheck:
990            forEachType.endCheck(vulkanType)
991
992    else:
993
994        if vulkanType.isString():
995
996            if needCheck and hasNullOptionalStringFeature(forEachType):
997                forEachType.onCheckWithNullOptionalStringFeature(vulkanType)
998                forEachType.onString(vulkanType)
999                forEachType.endCheckWithNullOptionalStringFeature(vulkanType)
1000                forEachType.onString(vulkanType)
1001                forEachType.finalCheckWithNullOptionalStringFeature(vulkanType)
1002            elif needCheck:
1003                forEachType.onCheck(vulkanType)
1004                forEachType.onString(vulkanType)
1005                forEachType.endCheck(vulkanType)
1006            else:
1007                forEachType.onString(vulkanType)
1008
1009        elif vulkanType.isArrayOfStrings():
1010
1011            forEachType.onStringArray(vulkanType)
1012
1013        elif vulkanType.staticArrExpr:
1014
1015            forEachType.onStaticArr(vulkanType)
1016
1017        elif vulkanType.isNextPointer():
1018
1019            if needCheck:
1020                forEachType.onCheck(vulkanType)
1021            forEachType.onStructExtension(vulkanType)
1022            if needCheck:
1023                forEachType.endCheck(vulkanType)
1024
1025        elif vulkanType.pointerIndirectionLevels > 0:
1026            if needCheck:
1027                forEachType.onCheck(vulkanType)
1028            forEachType.onPointer(vulkanType)
1029            if needCheck:
1030                forEachType.endCheck(vulkanType)
1031        else:
1032
1033            forEachType.onValue(vulkanType)
1034
1035    return True
1036
1037class VulkanTypeIterator(object):
1038    def __init__(self,):
1039        self.typeInfo = None
1040
1041    def registerTypeInfo(self, typeInfo):
1042        self.typeInfo = typeInfo
1043
1044def vulkanTypeGetStructFieldLengthInfo(structInfo, vulkanType):
1045    def getSpecialCaseVulkanStructFieldLength(structInfo, vulkanType):
1046        cases = [
1047            {
1048                "structName": "VkShaderModuleCreateInfo",
1049                "field": "pCode",
1050                "lenExpr": "codeSize",
1051                "postprocess": lambda expr: "(%s / 4)" % expr
1052            },
1053            {
1054                "structName": "VkPipelineMultisampleStateCreateInfo",
1055                "field": "pSampleMask",
1056                "lenExpr": "rasterizationSamples",
1057                "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
1058            },
1059        ]
1060
1061        for c in cases:
1062            if (structInfo.name, vulkanType.paramName) == (c["structName"],
1063                                                           c["field"]):
1064                return c
1065
1066        return None
1067
1068    specialCaseAccess = \
1069        getSpecialCaseVulkanStructFieldLength( \
1070            structInfo, vulkanType)
1071
1072    if specialCaseAccess is not None:
1073        return specialCaseAccess
1074
1075    lenExpr = vulkanType.getLengthExpression()
1076
1077    if lenExpr is None:
1078        return lenExpr
1079
1080    return {
1081        "structName" : structInfo.name,
1082        "field": vulkanType.typeName,
1083        "lenExpr": lenExpr,
1084        "postprocess": lambda expr: expr }
1085
1086class VulkanTypeProtobufInfo(object):
1087    def __init__(self, typeInfo, structInfo, vulkanType):
1088
1089        self.needsMessage = typeInfo.isCompoundType(vulkanType.typeName)
1090        self.isRepeatedString = vulkanType.isArrayOfStrings()
1091        self.isString = vulkanType.isString() or (
1092            vulkanType.typeName == "char" and (vulkanType.staticArrExpr != ""))
1093
1094        if structInfo is not None:
1095            self.lengthInfo = vulkanTypeGetStructFieldLengthInfo(
1096                structInfo, vulkanType)
1097        else:
1098            self.lengthInfo = vulkanType.getLengthExpression()
1099
1100        self.protobufType = None
1101        self.origTypeCategory = typeInfo.categoryOf(vulkanType.typeName)
1102
1103        self.isExtensionStruct = \
1104            vulkanType.typeName == "void" and \
1105            vulkanType.pointerIndirectionLevels > 0 and \
1106            vulkanType.paramName == "pNext"
1107
1108        if self.needsMessage:
1109            return
1110
1111        if typeInfo.categoryOf(vulkanType.typeName) in ["enum", "bitmask"]:
1112            self.protobufType = "uint32"
1113
1114        if typeInfo.categoryOf(vulkanType.typeName) in ["funcpointer", "handle", "define"]:
1115            self.protobufType = "uint64"
1116
1117        if typeInfo.categoryOf(vulkanType.typeName) in ["basetype"]:
1118            baseTypeMapping = {
1119                "VkFlags" : "uint32",
1120                "VkBool32" : "uint32",
1121                "VkDeviceSize" : "uint64",
1122                "VkSampleMask" : "uint32",
1123            }
1124            self.protobufType = baseTypeMapping[vulkanType.typeName]
1125
1126        if typeInfo.categoryOf(vulkanType.typeName) == None:
1127
1128            otherTypeMapping = {
1129                "void" : "uint64",
1130                "char" : "uint8",
1131                "size_t" : "uint64",
1132                "float" : "float",
1133                "uint8_t" : "uint32",
1134                "uint16_t" : "uint32",
1135                "int32_t" : "int32",
1136                "uint32_t" : "uint32",
1137                "uint64_t" : "uint64",
1138                "VkDeviceSize" : "uint64",
1139                "VkSampleMask" : "uint32",
1140            }
1141
1142            if vulkanType.typeName in otherTypeMapping:
1143                self.protobufType = otherTypeMapping[vulkanType.typeName]
1144            else:
1145                self.protobufType = "uint64"
1146
1147
1148        protobufCTypeMapping = {
1149            "uint8" : "uint8_t",
1150            "uint32" : "uint32_t",
1151            "int32" : "int32_t",
1152            "uint64" : "uint64_t",
1153            "float" : "float",
1154            "string" : "const char*",
1155        }
1156
1157        self.protobufCType = protobufCTypeMapping[self.protobufType]
1158
1159