• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from generator import noneStr
17
18from copy import copy
19from string import whitespace
20
21# Holds information about core Vulkan objects
22# and the API calls that are used to create/destroy each one.
23class HandleInfo(object):
24    def __init__(self, name, createApis, destroyApis):
25        self.name = name
26        self.createApis = createApis
27        self.destroyApis = destroyApis
28
29    def isCreateApi(self, apiName):
30        return apiName == self.createApis or (apiName in self.createApis)
31
32    def isDestroyApi(self, apiName):
33        if self.destroyApis is None:
34            return False
35        return apiName == self.destroyApis or (apiName in self.destroyApis)
36
37DISPATCHABLE_HANDLE_TYPES = [
38    "VkInstance",
39    "VkPhysicalDevice",
40    "VkDevice",
41    "VkQueue",
42    "VkCommandBuffer",
43]
44
45NON_DISPATCHABLE_HANDLE_TYPES = [
46    "VkDeviceMemory",
47    "VkBuffer",
48    "VkBufferView",
49    "VkImage",
50    "VkImageView",
51    "VkShaderModule",
52    "VkDescriptorPool",
53    "VkDescriptorSetLayout",
54    "VkDescriptorSet",
55    "VkSampler",
56    "VkPipeline",
57    "VkPipelineLayout",
58    "VkRenderPass",
59    "VkFramebuffer",
60    "VkPipelineCache",
61    "VkCommandPool",
62    "VkFence",
63    "VkSemaphore",
64    "VkEvent",
65    "VkQueryPool",
66    "VkSamplerYcbcrConversion",
67    "VkSamplerYcbcrConversionKHR",
68    "VkDescriptorUpdateTemplate",
69    "VkSurfaceKHR",
70    "VkSwapchainKHR",
71    "VkDisplayKHR",
72    "VkDisplayModeKHR",
73    "VkObjectTableNVX",
74    "VkIndirectCommandsLayoutNVX",
75    "VkValidationCacheEXT",
76    "VkDebugReportCallbackEXT",
77    "VkDebugUtilsMessengerEXT",
78    "VkAccelerationStructureNV",
79    "VkIndirectCommandsLayoutNV",
80    "VkAccelerationStructureKHR",
81]
82
83CUSTOM_HANDLE_CREATE_TYPES = [
84    "VkPhysicalDevice",
85    "VkQueue",
86    "VkPipeline",
87    "VkDeviceMemory",
88    "VkDescriptorSet",
89    "VkCommandBuffer",
90]
91
92HANDLE_TYPES = list(sorted(list(set(DISPATCHABLE_HANDLE_TYPES +
93                                    NON_DISPATCHABLE_HANDLE_TYPES + CUSTOM_HANDLE_CREATE_TYPES))))
94
95HANDLE_INFO = {}
96
97for h in HANDLE_TYPES:
98    if h in CUSTOM_HANDLE_CREATE_TYPES:
99        if h == "VkPhysicalDevice":
100            HANDLE_INFO[h] = \
101                HandleInfo(
102                    "VkPhysicalDevice",
103                    "vkEnumeratePhysicalDevices", None)
104        if h == "VkQueue":
105            HANDLE_INFO[h] = \
106                HandleInfo(
107                    "VkQueue",
108                    "vkGetDeviceQueue", None)
109        if h == "VkPipeline":
110            HANDLE_INFO[h] = \
111                HandleInfo(
112                    "VkPipeline",
113                    ["vkCreateGraphicsPipelines", "vkCreateComputePipelines"],
114                    "vkDestroyPipeline")
115        if h == "VkDeviceMemory":
116            HANDLE_INFO[h] = \
117                HandleInfo("VkDeviceMemory",
118                           "vkAllocateMemory", ["vkFreeMemory", "vkFreeMemorySyncGOOGLE"])
119        if h == "VkDescriptorSet":
120            HANDLE_INFO[h] = \
121                HandleInfo("VkDescriptorSet", "vkAllocateDescriptorSets",
122                           "vkFreeDescriptorSets")
123        if h == "VkCommandBuffer":
124            HANDLE_INFO[h] = \
125                HandleInfo("VkCommandBuffer", "vkAllocateCommandBuffers",
126                           "vkFreeCommandBuffers")
127    else:
128        HANDLE_INFO[h] = \
129            HandleInfo(h, "vkCreate" + h[2:], "vkDestroy" + h[2:])
130
131EXCLUDED_APIS = [
132    "vkEnumeratePhysicalDeviceGroups",
133]
134
135EXPLICITLY_ABI_PORTABLE_TYPES = [
136    "VkResult",
137    "VkBool32",
138    "VkSampleMask",
139    "VkFlags",
140    "VkDeviceSize",
141]
142
143EXPLICITLY_ABI_NON_PORTABLE_TYPES = [
144    "size_t"
145]
146
147NON_ABI_PORTABLE_TYPE_CATEGORIES = [
148    "handle",
149    "funcpointer",
150]
151
152DEVICE_MEMORY_INFO_KEYS = [
153    "devicememoryhandle",
154    "devicememoryoffset",
155    "devicememorysize",
156    "devicememorytypeindex",
157    "devicememorytypebits",
158]
159
160TRIVIAL_TRANSFORMED_TYPES = [
161    "VkPhysicalDeviceExternalImageFormatInfo",
162    "VkPhysicalDeviceExternalBufferInfo",
163    "VkExternalMemoryImageCreateInfo",
164    "VkExternalMemoryBufferCreateInfo",
165    "VkExportMemoryAllocateInfo",
166    "VkExternalImageFormatProperties",
167    "VkExternalBufferProperties",
168]
169
170NON_TRIVIAL_TRANSFORMED_TYPES = [
171    "VkExternalMemoryProperties",
172]
173
174TRANSFORMED_TYPES = TRIVIAL_TRANSFORMED_TYPES + NON_TRIVIAL_TRANSFORMED_TYPES
175
176# Holds information about a Vulkan type instance (i.e., not a type definition).
177# Type instances are used as struct field definitions or function parameters,
178# to be later fed to code generation.
179# VulkanType instances can be constructed in two ways:
180# 1. From an XML tag with <type> / <param> tags in vk.xml,
181#    using makeVulkanTypeFromXMLTag
182# 2. User-defined instances with makeVulkanTypeSimple.
183class VulkanType(object):
184
185    def __init__(self):
186        self.parent = None
187        self.typeName = ""
188
189        self.isTransformed = False
190
191        self.paramName = None
192
193        self.lenExpr = None
194        self.isOptional = False
195        self.optionalStr = None
196
197        self.isConst = False
198
199        # "" means it's not a static array, otherwise this is the total size of
200        # all elements. e.g. staticArrExpr of "x[3][2][8]" will be "((3)*(2)*(8))".
201        self.staticArrExpr = ""
202        # "" means it's not a static array, otherwise it's the raw expression
203        # of static array size, which can be one-dimensional or multi-dimensional.
204        self.rawStaticArrExpr = ""
205
206        self.pointerIndirectionLevels = 0  # 0 means not pointer
207        self.isPointerToConstPointer = False
208
209        self.primitiveEncodingSize = None
210
211        self.deviceMemoryInfoParameterIndices = None
212
213        # Annotations
214        # Environment annotation for binding current
215        # variables to sub-structures
216        self.binds = {}
217
218        # Device memory annotations
219
220        # self.deviceMemoryAttrib/Val stores
221        # device memory info attributes from the XML.
222        # devicememoryhandle
223        # devicememoryoffset
224        # devicememorysize
225        # devicememorytypeindex
226        # devicememorytypebits
227        self.deviceMemoryAttrib = None
228        self.deviceMemoryVal = None
229
230        # Filter annotations
231        self.filterVar = None
232        self.filterVals = None
233        self.filterFunc = None
234        self.filterOtherwise = None
235
236        # Stream feature
237        self.streamFeature = None
238
239        # All other annotations
240        self.attribs = {}
241
242    def __str__(self,):
243        return ("(vulkantype %s %s paramName %s len %s optional? %s "
244                "staticArrExpr %s %s)") % (
245            self.typeName + ("*" * self.pointerIndirectionLevels) +
246            ("ptr2constptr" if self.isPointerToConstPointer else ""), "const"
247            if self.isConst else "nonconst", self.paramName, self.lenExpr,
248            self.isOptional, self.staticArrExpr, self.staticArrCount)
249
250    def isString(self,):
251        return self.pointerIndirectionLevels == 1 and (self.typeName == "char")
252
253    def isArrayOfStrings(self,):
254        return self.isPointerToConstPointer and (self.typeName == "char")
255
256    def primEncodingSize(self,):
257        return self.primitiveEncodingSize
258
259    # Utility functions to make codegen life easier.
260    # This method derives the correct "count" expression if possible.
261    # Otherwise, returns None or "null-terminated" if a string.
262    def getLengthExpression(self,):
263        if self.staticArrExpr != "":
264            return self.staticArrExpr
265        if self.lenExpr:
266            return self.lenExpr
267        return None
268
269    # Can we just pass this to functions expecting T*
270    def accessibleAsPointer(self,):
271        if self.staticArrExpr != "":
272            return True
273        if self.pointerIndirectionLevels > 0:
274            return True
275        return False
276
277    # Rough attempt to infer where a type could be an output.
278    # Good for inferring which things need to be marshaled in
279    # versus marshaled out for Vulkan API calls
280    def possiblyOutput(self,):
281        return self.pointerIndirectionLevels > 0 and (not self.isConst)
282
283    def isVoidWithNoSize(self,):
284        return self.typeName == "void" and self.pointerIndirectionLevels == 0
285
286    def getCopy(self,):
287        return copy(self)
288
289    def getTransformed(self, isConstChoice=None, ptrIndirectionChoice=None):
290        res = self.getCopy()
291
292        if isConstChoice is not None:
293            res.isConst = isConstChoice
294        if ptrIndirectionChoice is not None:
295            res.pointerIndirectionLevels = ptrIndirectionChoice
296
297        return res
298
299    def getWithCustomName(self):
300        return self.getTransformed(
301            ptrIndirectionChoice=self.pointerIndirectionLevels + 1)
302
303    def getForAddressAccess(self):
304        return self.getTransformed(
305            ptrIndirectionChoice=self.pointerIndirectionLevels + 1)
306
307    def getForValueAccess(self):
308        if self.typeName == "void" and self.pointerIndirectionLevels == 1:
309            asUint8Type = self.getCopy()
310            asUint8Type.typeName = "uint8_t"
311            return asUint8Type.getForValueAccess()
312        return self.getTransformed(
313            ptrIndirectionChoice=self.pointerIndirectionLevels - 1)
314
315    def getForNonConstAccess(self):
316        return self.getTransformed(isConstChoice=False)
317
318    def withModifiedName(self, newName):
319        res = self.getCopy()
320        res.paramName = newName
321        return res
322
323    def isNextPointer(self):
324        if self.paramName == "pNext":
325            return True
326        return False
327
328    # Only deals with 'core' handle types here.
329    def isDispatchableHandleType(self):
330        return self.typeName in DISPATCHABLE_HANDLE_TYPES
331
332    def isNonDispatchableHandleType(self):
333        return self.typeName in NON_DISPATCHABLE_HANDLE_TYPES
334
335    def isHandleType(self):
336        return self.isDispatchableHandleType() or \
337               self.isNonDispatchableHandleType()
338
339    def isCreatedBy(self, api):
340        if self.typeName in HANDLE_INFO.keys():
341            nonKhrRes = HANDLE_INFO[self.typeName].isCreateApi(api.name)
342            if nonKhrRes:
343                return True
344            if len(api.name) > 3 and "KHR" == api.name[-3:]:
345                return HANDLE_INFO[self.typeName].isCreateApi(api.name[:-3])
346
347        if self.typeName == "VkImage" and api.name == "vkCreateImageWithRequirementsGOOGLE":
348            return True
349
350        if self.typeName == "VkBuffer" and api.name == "vkCreateBufferWithRequirementsGOOGLE":
351            return True
352
353        return False
354
355    def isDestroyedBy(self, api):
356        if self.typeName in HANDLE_INFO.keys():
357            nonKhrRes = HANDLE_INFO[self.typeName].isDestroyApi(api.name)
358            if nonKhrRes:
359                return True
360            if len(api.name) > 3 and "KHR" == api.name[-3:]:
361                return HANDLE_INFO[self.typeName].isDestroyApi(api.name[:-3])
362
363        return False
364
365    def isSimpleValueType(self, typeInfo):
366        if typeInfo.isCompoundType(self.typeName):
367            return False
368        if self.isString() or self.isArrayOfStrings():
369            return False
370        if self.staticArrExpr or self.pointerIndirectionLevels > 0:
371            return False
372        return True
373
374    def getStructEnumExpr(self,):
375        return None
376
377# Is an S-expression w/ the following spec:
378# From https://gist.github.com/pib/240957
379class Atom(object):
380    def __init__(self, name):
381        self.name = name
382    def __repr__(self,):
383        return self.name
384
385def parse_sexp(sexp):
386    atom_end = set('()"\'') | set(whitespace)
387    stack, i, length = [[]], 0, len(sexp)
388    while i < length:
389        c = sexp[i]
390
391        reading = type(stack[-1])
392        if reading == list:
393            if   c == '(': stack.append([])
394            elif c == ')':
395                stack[-2].append(stack.pop())
396                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
397            elif c == '"': stack.append('')
398            elif c == "'": stack.append([('quote',)])
399            elif c in whitespace: pass
400            else: stack.append(Atom(c))
401        elif reading == str:
402            if   c == '"':
403                stack[-2].append(stack.pop())
404                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
405            elif c == '\\':
406                i += 1
407                stack[-1] += sexp[i]
408            else: stack[-1] += c
409        elif reading == Atom:
410            if c in atom_end:
411                atom = stack.pop()
412                if atom.name[0].isdigit(): stack[-1].append(eval(atom.name))
413                else: stack[-1].append(atom)
414                if stack[-1][0] == ('quote',): stack[-2].append(stack.pop())
415                continue
416            else: stack[-1] = Atom(stack[-1].name + c)
417        i += 1
418
419    return stack.pop()
420
421class FuncExprVal(object):
422    def __init__(self, val):
423        self.val = val
424    def __repr__(self,):
425        return self.val.__repr__()
426
427class FuncExpr(object):
428    def __init__(self, name, args):
429        self.name = name
430        self.args = args
431    def __repr__(self,):
432        if len(self.args) == 0:
433            return "(%s)" % (self.name.__repr__())
434        else:
435            return "(%s %s)" % (self.name.__repr__(), " ".join(map(lambda x: x.__repr__(), self.args)))
436
437class FuncLambda(object):
438    def __init__(self, vs, body):
439        self.vs = vs
440        self.body = body
441    def __repr__(self,):
442        return "(L (%s) %s)" % (" ".join(map(lambda x: x.__repr__(), self.vs)), self.body.__repr__())
443
444class FuncLambdaParam(object):
445    def __init__(self, name, typ):
446        self.name = name
447        self.typ = typ
448    def __repr__(self,):
449        return "%s : %s" % (self.name, self.typ)
450
451def parse_func_expr(parsed_sexp):
452    if len(parsed_sexp) != 1:
453        print("Error: parsed # expressions != 1: %d" % (len(parsed_sexp)))
454        raise
455
456    e = parsed_sexp[0]
457
458    def parse_lambda_param(e):
459        return FuncLambdaParam(e[0].name, e[1].name)
460
461    def parse_one(exp):
462        if list == type(exp):
463            if "lambda" == exp[0].__repr__():
464                return FuncLambda(list(map(parse_lambda_param, exp[1])), parse_one(exp[2]))
465            else:
466                return FuncExpr(exp[0], list(map(parse_one, exp[1:])))
467        else:
468            return FuncExprVal(exp)
469
470    return parse_one(e)
471
472def parseFilterFuncExpr(expr):
473    res = parse_func_expr(parse_sexp(expr))
474    print("parseFilterFuncExpr: parsed %s" % res)
475    return res
476
477def parseLetBodyExpr(expr):
478    res = parse_func_expr(parse_sexp(expr))
479    print("parseLetBodyExpr: parsed %s" % res)
480    return res
481
482def makeVulkanTypeFromXMLTag(typeInfo, tag):
483    res = VulkanType()
484
485    # Process the length expression
486
487    if tag.attrib.get("len") is not None:
488        lengths = tag.attrib.get("len").split(",")
489        res.lenExpr = lengths[0]
490
491    # Calculate static array expression
492
493    nametag = tag.find("name")
494    enumtag = tag.find("enum")
495
496    if enumtag is not None:
497        res.staticArrExpr = enumtag.text
498    elif nametag is not None:
499        res.rawStaticArrExpr = noneStr(nametag.tail)
500
501        dimensions = res.rawStaticArrExpr.count('[')
502        if dimensions == 1:
503            res.staticArrExpr = res.rawStaticArrExpr[1:-1]
504        elif dimensions > 1:
505            arraySizes = res.rawStaticArrExpr[1:-1].split('][')
506            res.staticArrExpr = '(' + \
507                '*'.join(f'({size})' for size in arraySizes) + ')'
508
509    # Determine const
510
511    beforeTypePart = noneStr(tag.text)
512
513    if "const" in beforeTypePart:
514        res.isConst = True
515
516    # Calculate type and pointer info
517
518    for elem in tag:
519        if elem.tag == "name":
520            res.paramName = elem.text
521        if elem.tag == "type":
522            duringTypePart = noneStr(elem.text)
523            afterTypePart = noneStr(elem.tail)
524            # Now we know enough to fill some stuff in
525            res.typeName = duringTypePart
526
527            if res.typeName in TRANSFORMED_TYPES:
528                res.isTransformed = True
529
530            # This only handles pointerIndirectionLevels == 2
531            # along with optional constant pointer for the inner part.
532            for c in afterTypePart:
533                if c == "*":
534                    res.pointerIndirectionLevels += 1
535            if "const" in afterTypePart and res.pointerIndirectionLevels == 2:
536                res.isPointerToConstPointer = True
537
538            # If void*, treat like it's not a pointer
539            # if duringTypePart == "void":
540            # res.pointerIndirectionLevels -= 1
541
542    # Calculate optionality (based on validitygenerator.py)
543    if tag.attrib.get("optional") is not None:
544        res.isOptional = True
545        res.optionalStr = tag.attrib.get("optional")
546
547    # If no validity is being generated, it usually means that
548    # validity is complex and not absolute, so let's say yes.
549    if tag.attrib.get("noautovalidity") is not None:
550        res.isOptional = True
551
552    # If this is a structure extension, it is optional.
553    if tag.attrib.get("structextends") is not None:
554        res.isOptional = True
555
556    # If this is a pNext pointer, it is optional.
557    if res.paramName == "pNext":
558        res.isOptional = True
559
560    res.primitiveEncodingSize = \
561        typeInfo.getPrimitiveEncodingSize(res.typeName)
562
563    # Annotations: Environment binds
564    if tag.attrib.get("binds") is not None:
565        bindPairs = map(lambda x: x.strip(), tag.attrib.get("binds").split(","))
566        bindPairsSplit = map(lambda p: p.split(":"), bindPairs)
567        res.binds = dict(map(lambda sp: (sp[0].strip(), sp[1].strip()), bindPairsSplit))
568
569    # Annotations: Device memory
570    for k in DEVICE_MEMORY_INFO_KEYS:
571        if tag.attrib.get(k) is not None:
572            res.deviceMemoryAttrib = k
573            res.deviceMemoryVal = \
574                tag.attrib.get(k)
575            break;
576
577    # Annotations: Filters
578    if tag.attrib.get("filterVar") is not None:
579        res.filterVar = tag.attrib.get("filterVar").strip()
580
581    if tag.attrib.get("filterVals") is not None:
582        res.filterVals = \
583            list(map(lambda v: v.strip(),
584                    tag.attrib.get("filterVals").strip().split(",")))
585        print("Filtervals: %s" % res.filterVals)
586
587    if tag.attrib.get("filterFunc") is not None:
588        res.filterFunc = parseFilterFuncExpr(tag.attrib.get("filterFunc"))
589
590    if tag.attrib.get("filterOtherwise") is not None:
591        res.Otherwise = tag.attrib.get("filterOtherwise")
592
593    # store all other attribs here
594    res.attribs = dict(tag.attrib)
595
596    return res
597
598
599def makeVulkanTypeSimple(isConst,
600                         typeName,
601                         ptrIndirectionLevels,
602                         paramName=None):
603    res = VulkanType()
604
605    res.typeName = typeName
606    res.isConst = isConst
607    res.pointerIndirectionLevels = ptrIndirectionLevels
608    res.isPointerToConstPointer = False
609    res.paramName = paramName
610    res.primitiveEncodingSize = None
611
612    return res
613
614# A class for holding the parameter indices corresponding to various
615# attributes about a VkDeviceMemory, such as the handle, size, offset, etc.
616class DeviceMemoryInfoParameterIndices(object):
617    def __init__(self, handle, offset, size, typeIndex, typeBits):
618        self.handle = handle
619        self.offset = offset
620        self.size = size
621        self.typeIndex = typeIndex
622        self.typeBits = typeBits
623
624# initializes DeviceMemoryInfoParameterIndices for each
625# abstract VkDeviceMemory encountered over |parameters|
626def initDeviceMemoryInfoParameterIndices(parameters):
627
628    use = False
629    deviceMemoryInfoById = {}
630
631    for (i, p) in enumerate(parameters):
632        a = p.deviceMemoryAttrib
633        if not a:
634            continue
635
636        if a in DEVICE_MEMORY_INFO_KEYS:
637            use = True
638            deviceMemoryInfoById[
639                p.deviceMemoryVal] = \
640                    DeviceMemoryInfoParameterIndices(
641                        None, None, None, None, None)
642
643    for (i, p) in enumerate(parameters):
644        a = p.deviceMemoryAttrib
645        if not a:
646            continue
647
648        info = \
649            deviceMemoryInfoById[p.deviceMemoryVal]
650
651        if a == "devicememoryhandle":
652            info.handle = i
653        if a == "devicememoryoffset":
654            info.offset = i
655        if a == "devicememorysize":
656            info.size = i
657        if a == "devicememorytypeindex":
658            info.typeIndex = i
659        if a == "devicememorytypebits":
660            info.typeBits = i
661
662    if not use:
663        return None
664
665    return deviceMemoryInfoById
666
667# Classes for describing aggregate types (unions, structs) and API calls.
668class VulkanCompoundType(object):
669
670    def __init__(self, name, members, isUnion=False, structEnumExpr=None, structExtendsExpr=None, feature=None, initialEnv={}, optional=None):
671        self.name = name
672        self.typeName = name
673        self.members = members
674        self.environment = initialEnv
675
676        self.isUnion = isUnion
677        self.structEnumExpr = structEnumExpr
678        self.structExtendsExpr = structExtendsExpr
679        self.feature = feature
680
681        self.deviceMemoryInfoParameterIndices = \
682            initDeviceMemoryInfoParameterIndices(self.members)
683
684        self.isTransformed = name in TRANSFORMED_TYPES
685
686        self.copy = None
687
688        self.optionalStr = optional
689
690    def initCopies(self):
691        self.copy = self
692
693        for m in self.members:
694            m.parent = self.copy
695
696    def getMember(self, memberName):
697        for m in self.members:
698            if m.paramName == memberName:
699                return m
700        return None
701
702    def getStructEnumExpr(self,):
703        return self.structEnumExpr
704
705class VulkanAPI(object):
706
707    def __init__(self, name, retType, parameters, origName=None):
708        self.name = name
709        self.origName = name
710        self.retType = retType
711        self.parameters = parameters
712
713        self.deviceMemoryInfoParameterIndices = \
714            initDeviceMemoryInfoParameterIndices(self.parameters)
715
716        self.copy = None
717
718        self.isTransformed = name in TRANSFORMED_TYPES
719
720        if origName:
721            self.origName = origName
722
723    def initCopies(self):
724        self.copy = self
725
726        for m in self.parameters:
727            m.parent = self.copy
728
729    def getCopy(self,):
730        return copy(self)
731
732    def getParameter(self, parameterName):
733        for p in self.parameters:
734            if p.paramName == parameterName:
735                return p
736        return None
737
738    def withModifiedName(self, newName):
739        res = VulkanAPI(newName, self.retType, self.parameters)
740        return res
741
742    def getRetVarExpr(self):
743        if self.retType.typeName == "void":
744            return None
745        return "%s_%s_return" % (self.name, self.retType.typeName)
746
747    def getRetTypeExpr(self):
748        return self.retType.typeName
749
750    def withCustomParameters(self, customParams):
751        res = self.getCopy()
752        res.parameters = customParams
753        return res
754
755    def withCustomReturnType(self, retType):
756        res = self.getCopy()
757        res.retType = retType
758        return res
759
760# Whether or not special handling of virtual elements
761# such as VkDeviceMemory is needed.
762def vulkanTypeNeedsTransform(structOrApi):
763    return structOrApi.deviceMemoryInfoParameterIndices != None
764
765def vulkanTypeGetNeededTransformTypes(structOrApi):
766    res = []
767    if structOrApi.deviceMemoryInfoParameterIndices != None:
768        res.append("devicememory")
769    return res
770
771def vulkanTypeforEachSubType(structOrApi, f):
772    toLoop = None
773    if type(structOrApi) == VulkanCompoundType:
774        toLoop = structOrApi.members
775    if type(structOrApi) == VulkanAPI:
776        toLoop = structOrApi.parameters
777
778    for (i, x) in enumerate(toLoop):
779        f(i, x)
780
781# Parses everything about Vulkan types into a Python readable format.
782class VulkanTypeInfo(object):
783
784    def __init__(self,):
785        self.categories = set([])
786
787        # Tracks what Vulkan type is part of what category.
788        self.typeCategories = {}
789
790        # Tracks the primitve encoding size for each type,
791        # if applicable.
792        self.encodingSizes = {}
793
794        self.structs = {}
795        self.apis = {}
796
797        self.feature = None
798
799    def initType(self, name, category):
800        self.categories.add(category)
801        self.typeCategories[name] = category
802        self.encodingSizes[name] = self.setPrimitiveEncodingSize(name)
803
804    def categoryOf(self, name):
805        return self.typeCategories[name]
806
807    def getPrimitiveEncodingSize(self, name):
808        return self.encodingSizes[name]
809
810    # Queries relating to categories of Vulkan types.
811    def isHandleType(self, name):
812        if name in self.typeCategories:
813            return self.typeCategories[name] == "handle"
814        return False
815
816    def isCompoundType(self, name):
817        if name in self.typeCategories:
818            return self.typeCategories[name] in ["struct", "union"]
819        else:
820            return False
821
822    # Gets the best size in bytes
823    # for encoding/decoding a particular Vulkan type.
824    # If not applicable, returns None.
825    def setPrimitiveEncodingSize(self, name):
826        baseEncodingSizes = {
827            "void" : 8,
828            "char" : 1,
829            "float" : 4,
830            "uint8_t" : 1,
831            "uint16_t" : 2,
832            "uint32_t" : 4,
833            "uint64_t" : 8,
834            "size_t" : 8,
835            "ssize_t" : 8,
836        }
837
838        if name in baseEncodingSizes:
839            return baseEncodingSizes[name]
840
841        category = self.typeCategories[name]
842
843        if category in [None, "api", "bitmask", "include", "define", "struct", "union"]:
844            return None
845
846        # Must be 8---handles are pointers and basetype includes VkDeviceSize
847        # which is 8 bytes
848        if category in ["handle", "basetype", "funcpointer"]:
849            return 8
850
851        # Most of the time, enums are only 4 bytes, but this is
852        # vague enough to be the source of a future headache, and
853        # it's easy to just stream 8 bytes there anyway.
854        if category in ["enum"]:
855            return 8
856
857    def isNonAbiPortableType(self, typeName):
858        if typeName in EXPLICITLY_ABI_PORTABLE_TYPES:
859            return False
860
861        if typeName in EXPLICITLY_ABI_NON_PORTABLE_TYPES:
862            return True
863
864        category = self.typeCategories[typeName]
865        return category in NON_ABI_PORTABLE_TYPE_CATEGORIES
866
867    def onBeginFeature(self, featureName):
868        self.feature = featureName
869
870    def onEndFeature(self):
871        self.feature = None
872
873    def onGenType(self, typeinfo, name, alias):
874        category = typeinfo.elem.get("category")
875        self.initType(name, category)
876
877        if category in ["struct", "union"]:
878            self.onGenStruct(typeinfo, name, alias)
879
880    def onGenStruct(self, typeinfo, typeName, alias):
881        if not alias:
882            members = []
883
884            structExtendsExpr = typeinfo.elem.get("structextends")
885
886            structEnumExpr = None
887
888            initialEnv = {}
889            envStr = typeinfo.elem.get("exists")
890            if envStr != None:
891                comma_separated = envStr.split(",")
892                name_type_pairs = map(lambda cs: tuple(map(lambda t: t.strip(), cs.split(":"))), comma_separated)
893                for (name, typ) in name_type_pairs:
894                    initialEnv[name] = {
895                        "type" : typ,
896                        "binding" : None,
897                        "structmember" : False,
898                        "body" : None,
899                    }
900
901            letenvStr = typeinfo.elem.get("let")
902            if letenvStr != None:
903                comma_separated = letenvStr.split(",")
904                name_body_pairs = map(lambda cs: tuple(map(lambda t: t.strip(), cs.split(":"))), comma_separated)
905                for (name, body) in name_body_pairs:
906                    initialEnv[name] = {
907                        "type" : "uint32_t",
908                        "binding" : name,
909                        "structmember" : False,
910                        "body" : parseLetBodyExpr(body)
911                    }
912
913            for member in typeinfo.elem.findall(".//member"):
914                vulkanType = makeVulkanTypeFromXMLTag(self, member)
915                initialEnv[vulkanType.paramName] = {
916                    "type" : vulkanType.typeName,
917                    "binding" : vulkanType.paramName,
918                    "structmember" : True,
919                    "body" : None,
920                }
921                vulkanType.paramName
922                members.append(vulkanType)
923                if vulkanType.typeName == "VkStructureType" and \
924                   member.get("values"):
925                   structEnumExpr = member.get("values")
926
927            self.structs[typeName] = \
928                VulkanCompoundType( \
929                    typeName,
930                    members,
931                    isUnion = self.categoryOf(typeName) == "union",
932                    structEnumExpr = structEnumExpr,
933                    structExtendsExpr = structExtendsExpr,
934                    feature = self.feature,
935                    initialEnv = initialEnv,
936                    optional = typeinfo.elem.get("optional", None))
937            self.structs[typeName].initCopies()
938
939    def onGenGroup(self, _groupinfo, groupName, _alias=None):
940        self.initType(groupName, "enum")
941
942    def onGenEnum(self, _enuminfo, name, _alias):
943        self.initType(name, "enum")
944
945    def onGenCmd(self, cmdinfo, name, _alias):
946        self.initType(name, "api")
947
948        proto = cmdinfo.elem.find("proto")
949        params = cmdinfo.elem.findall("param")
950
951        self.apis[name] = \
952            VulkanAPI(
953                name,
954                makeVulkanTypeFromXMLTag(self, proto),
955                list(map(lambda p: makeVulkanTypeFromXMLTag(self, p),
956                         params)))
957        self.apis[name].initCopies()
958
959    def onEnd(self,):
960        pass
961
962def hasNullOptionalStringFeature(forEachType):
963    return (hasattr(forEachType, "onCheckWithNullOptionalStringFeature")) and \
964           (hasattr(forEachType, "endCheckWithNullOptionalStringFeature")) and \
965           (hasattr(forEachType, "finalCheckWithNullOptionalStringFeature"))
966
967# General function to iterate over a vulkan type and call code that processes
968# each of its sub-components, if any.
969def iterateVulkanType(typeInfo, vulkanType, forEachType):
970    if not vulkanType.isArrayOfStrings():
971        if vulkanType.isPointerToConstPointer:
972            return False
973
974    forEachType.registerTypeInfo(typeInfo)
975
976    needCheck = \
977        vulkanType.isOptional and \
978        vulkanType.pointerIndirectionLevels > 0 and \
979        (not vulkanType.isNextPointer())
980
981    if typeInfo.isCompoundType(vulkanType.typeName) and not vulkanType.isNextPointer():
982
983        if needCheck:
984            forEachType.onCheck(vulkanType)
985
986        forEachType.onCompoundType(vulkanType)
987
988        if needCheck:
989            forEachType.endCheck(vulkanType)
990
991    else:
992
993        if vulkanType.isString():
994
995            if needCheck and hasNullOptionalStringFeature(forEachType):
996                forEachType.onCheckWithNullOptionalStringFeature(vulkanType)
997                forEachType.onString(vulkanType)
998                forEachType.endCheckWithNullOptionalStringFeature(vulkanType)
999                forEachType.onString(vulkanType)
1000                forEachType.finalCheckWithNullOptionalStringFeature(vulkanType)
1001            elif needCheck:
1002                forEachType.onCheck(vulkanType)
1003                forEachType.onString(vulkanType)
1004                forEachType.endCheck(vulkanType)
1005            else:
1006                forEachType.onString(vulkanType)
1007
1008        elif vulkanType.isArrayOfStrings():
1009
1010            forEachType.onStringArray(vulkanType)
1011
1012        elif vulkanType.staticArrExpr:
1013
1014            forEachType.onStaticArr(vulkanType)
1015
1016        elif vulkanType.isNextPointer():
1017
1018            if needCheck:
1019                forEachType.onCheck(vulkanType)
1020            forEachType.onStructExtension(vulkanType)
1021            if needCheck:
1022                forEachType.endCheck(vulkanType)
1023
1024        elif vulkanType.pointerIndirectionLevels > 0:
1025            if needCheck:
1026                forEachType.onCheck(vulkanType)
1027            forEachType.onPointer(vulkanType)
1028            if needCheck:
1029                forEachType.endCheck(vulkanType)
1030        else:
1031
1032            forEachType.onValue(vulkanType)
1033
1034    return True
1035
1036class VulkanTypeIterator(object):
1037    def __init__(self,):
1038        self.typeInfo = None
1039
1040    def registerTypeInfo(self, typeInfo):
1041        self.typeInfo = typeInfo
1042
1043def vulkanTypeGetStructFieldLengthInfo(structInfo, vulkanType):
1044    def getSpecialCaseVulkanStructFieldLength(structInfo, vulkanType):
1045        cases = [
1046            {
1047                "structName": "VkShaderModuleCreateInfo",
1048                "field": "pCode",
1049                "lenExpr": "codeSize",
1050                "postprocess": lambda expr: "(%s / 4)" % expr
1051            },
1052            {
1053                "structName": "VkPipelineMultisampleStateCreateInfo",
1054                "field": "pSampleMask",
1055                "lenExpr": "rasterizationSamples",
1056                "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
1057            },
1058        ]
1059
1060        for c in cases:
1061            if (structInfo.name, vulkanType.paramName) == (c["structName"],
1062                                                           c["field"]):
1063                return c
1064
1065        return None
1066
1067    specialCaseAccess = \
1068        getSpecialCaseVulkanStructFieldLength( \
1069            structInfo, vulkanType)
1070
1071    if specialCaseAccess is not None:
1072        return specialCaseAccess
1073
1074    lenExpr = vulkanType.getLengthExpression()
1075
1076    if lenExpr is None:
1077        return lenExpr
1078
1079    return {
1080        "structName" : structInfo.name,
1081        "field": vulkanType.typeName,
1082        "lenExpr": lenExpr,
1083        "postprocess": lambda expr: expr }
1084
1085class VulkanTypeProtobufInfo(object):
1086    def __init__(self, typeInfo, structInfo, vulkanType):
1087
1088        self.needsMessage = typeInfo.isCompoundType(vulkanType.typeName)
1089        self.isRepeatedString = vulkanType.isArrayOfStrings()
1090        self.isString = vulkanType.isString() or (
1091            vulkanType.typeName == "char" and (vulkanType.staticArrExpr != ""))
1092
1093        if structInfo is not None:
1094            self.lengthInfo = vulkanTypeGetStructFieldLengthInfo(
1095                structInfo, vulkanType)
1096        else:
1097            self.lengthInfo = vulkanType.getLengthExpression()
1098
1099        self.protobufType = None
1100        self.origTypeCategory = typeInfo.categoryOf(vulkanType.typeName)
1101
1102        self.isExtensionStruct = \
1103            vulkanType.typeName == "void" and \
1104            vulkanType.pointerIndirectionLevels > 0 and \
1105            vulkanType.paramName == "pNext"
1106
1107        if self.needsMessage:
1108            return
1109
1110        if typeInfo.categoryOf(vulkanType.typeName) in ["enum", "bitmask"]:
1111            self.protobufType = "uint32"
1112
1113        if typeInfo.categoryOf(vulkanType.typeName) in ["funcpointer", "handle", "define"]:
1114            self.protobufType = "uint64"
1115
1116        if typeInfo.categoryOf(vulkanType.typeName) in ["basetype"]:
1117            baseTypeMapping = {
1118                "VkFlags" : "uint32",
1119                "VkBool32" : "uint32",
1120                "VkDeviceSize" : "uint64",
1121                "VkSampleMask" : "uint32",
1122            }
1123            self.protobufType = baseTypeMapping[vulkanType.typeName]
1124
1125        if typeInfo.categoryOf(vulkanType.typeName) == None:
1126
1127            otherTypeMapping = {
1128                "void" : "uint64",
1129                "char" : "uint8",
1130                "size_t" : "uint64",
1131                "float" : "float",
1132                "uint8_t" : "uint32",
1133                "uint16_t" : "uint32",
1134                "int32_t" : "int32",
1135                "uint32_t" : "uint32",
1136                "uint64_t" : "uint64",
1137                "VkDeviceSize" : "uint64",
1138                "VkSampleMask" : "uint32",
1139            }
1140
1141            if vulkanType.typeName in otherTypeMapping:
1142                self.protobufType = otherTypeMapping[vulkanType.typeName]
1143            else:
1144                self.protobufType = "uint64"
1145
1146
1147        protobufCTypeMapping = {
1148            "uint8" : "uint8_t",
1149            "uint32" : "uint32_t",
1150            "int32" : "int32_t",
1151            "uint64" : "uint64_t",
1152            "float" : "float",
1153            "string" : "const char*",
1154        }
1155
1156        self.protobufCType = protobufCTypeMapping[self.protobufType]
1157
1158