• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Google LLC
2# SPDX-License-Identifier: MIT
3from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI
4from collections import OrderedDict
5from copy import copy
6from pathlib import Path, PurePosixPath
7
8import os
9import sys
10import shutil
11import subprocess
12import re
13
14# Class capturing a single file
15
16
17class SingleFileModule(object):
18    def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False):
19        self.directory = directory
20        self.basename = basename
21        self.customAbsDir = customAbsDir
22        self.suffix = suffix
23        self.file = None
24
25        self.preamble = ""
26        self.postamble = ""
27
28        self.suppress = suppress
29
30    def begin(self, globalDir):
31        if self.suppress:
32            return
33
34        # Create subdirectory, if needed
35        if self.customAbsDir:
36            absDir = self.customAbsDir
37        else:
38            absDir = os.path.join(globalDir, self.directory)
39
40        filename = os.path.join(absDir, self.basename)
41
42        self.file = open(filename + self.suffix, "w", encoding="utf-8")
43        self.file.write(self.preamble)
44
45    def append(self, toAppend):
46        if self.suppress:
47            return
48
49        self.file.write(toAppend)
50
51    def end(self):
52        if self.suppress:
53            return
54
55        self.file.write(self.postamble)
56        self.file.close()
57
58    def getMakefileSrcEntry(self):
59        return ""
60
61    def getCMakeSrcEntry(self):
62        return ""
63
64# Class capturing a .cpp file and a .h file (a "C++ module")
65
66
67class Module(object):
68
69    def __init__(
70            self, directory, basename, customAbsDir=None, suppress=False, implOnly=False,
71            headerOnly=False, suppressFeatureGuards=False):
72        self._headerFileModule = SingleFileModule(
73            ".h", directory, basename, customAbsDir, suppress or implOnly)
74        self._implFileModule = SingleFileModule(
75            ".cpp", directory, basename, customAbsDir, suppress or headerOnly)
76
77        self._headerOnly = headerOnly
78        self._implOnly = implOnly
79
80        self.directory = directory
81        self.basename = basename
82        self._customAbsDir = customAbsDir
83
84        self.suppressFeatureGuards = suppressFeatureGuards
85
86    @property
87    def suppress(self):
88        raise AttributeError("suppress is write only")
89
90    @suppress.setter
91    def suppress(self, value: bool):
92        self._headerFileModule.suppress = self._implOnly or value
93        self._implFileModule.suppress = self._headerOnly or value
94
95    @property
96    def headerPreamble(self) -> str:
97        return self._headerFileModule.preamble
98
99    @headerPreamble.setter
100    def headerPreamble(self, value: str):
101        self._headerFileModule.preamble = value
102
103    @property
104    def headerPostamble(self) -> str:
105        return self._headerFileModule.postamble
106
107    @headerPostamble.setter
108    def headerPostamble(self, value: str):
109        self._headerFileModule.postamble = value
110
111    @property
112    def implPreamble(self) -> str:
113        return self._implFileModule.preamble
114
115    @implPreamble.setter
116    def implPreamble(self, value: str):
117        self._implFileModule.preamble = value
118
119    @property
120    def implPostamble(self) -> str:
121        return self._implFileModule.postamble
122
123    @implPostamble.setter
124    def implPostamble(self, value: str):
125        self._implFileModule.postamble = value
126
127    def getMakefileSrcEntry(self):
128        if self._customAbsDir:
129            return self.basename + ".cpp \\\n"
130        dirName = self.directory
131        baseName = self.basename
132        joined = os.path.join(dirName, baseName)
133        return "    " + joined + ".cpp \\\n"
134
135    def getCMakeSrcEntry(self):
136        if self._customAbsDir:
137            return "\n" + self.basename + ".cpp "
138        dirName = Path(self.directory)
139        baseName = Path(self.basename)
140        joined = PurePosixPath(dirName / baseName)
141        return "\n    " + str(joined) + ".cpp "
142
143    def begin(self, globalDir):
144        self._headerFileModule.begin(globalDir)
145        self._implFileModule.begin(globalDir)
146
147    def appendHeader(self, toAppend):
148        self._headerFileModule.append(toAppend)
149
150    def appendImpl(self, toAppend):
151        self._implFileModule.append(toAppend)
152
153    def end(self):
154        self._headerFileModule.end()
155        self._implFileModule.end()
156
157        # Removes empty ifdef blocks with a regex query over the file
158        # which are mainly introduced by extensions with no functions or variables
159        def remove_empty_ifdefs(filename: Path):
160            """Removes empty #ifdef blocks from a C++ file."""
161
162            # Load file contents
163            with open(filename, 'r') as file:
164                content = file.read()
165
166            # Regular Expression Pattern
167            pattern = r"#ifdef\s+(\w+)\s*(?://.*)?\s*\n\s*#endif\s*(?://.*)?\s*"
168
169            # Replace Empty Blocks
170            modified_content = re.sub(pattern, "", content)
171
172            # Save file back
173            with open(filename, 'w') as file:
174                file.write(modified_content)
175
176        clang_format_command = shutil.which('clang-format')
177
178        def formatFile(filename: Path):
179            if "GFXSTREAM_NO_CLANG_FMT" in os.environ:
180                return
181            assert (clang_format_command is not None)
182            assert (subprocess.call([clang_format_command, "-i",
183                    "--style=file", str(filename.resolve())]) == 0)
184
185        if not self._headerFileModule.suppress:
186            filename = Path(self._headerFileModule.file.name)
187            remove_empty_ifdefs(filename)
188            formatFile(filename)
189
190        if not self._implFileModule.suppress:
191            filename = Path(self._implFileModule.file.name)
192            remove_empty_ifdefs(filename)
193            formatFile(filename)
194
195
196class PyScript(SingleFileModule):
197    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
198        super().__init__(".py", directory, basename, customAbsDir, suppress)
199
200
201# Class capturing a .proto protobuf definition file
202class Proto(SingleFileModule):
203
204    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
205        super().__init__(".proto", directory, basename, customAbsDir, suppress)
206
207    def getMakefileSrcEntry(self):
208        super().getMakefileSrcEntry()
209        if self.customAbsDir:
210            return self.basename + ".proto \\\n"
211        dirName = self.directory
212        baseName = self.basename
213        joined = os.path.join(dirName, baseName)
214        return "    " + joined + ".proto \\\n"
215
216    def getCMakeSrcEntry(self):
217        super().getCMakeSrcEntry()
218        if self.customAbsDir:
219            return "\n" + self.basename + ".proto "
220
221        dirName = self.directory
222        baseName = self.basename
223        joined = os.path.join(dirName, baseName)
224        return "\n    " + joined + ".proto "
225
226class CodeGen(object):
227
228    def __init__(self,):
229        self.code = ""
230        self.indentLevel = 0
231        self.gensymCounter = [-1]
232
233    def var(self, prefix="cgen_var"):
234        self.gensymCounter[-1] += 1
235        res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0))
236        return res
237
238    def swapCode(self,):
239        res = "%s" % self.code
240        self.code = ""
241        return res
242
243    def indent(self,extra=0):
244        return "".join("    " * (self.indentLevel + extra))
245
246    def incrIndent(self,):
247        self.indentLevel += 1
248
249    def decrIndent(self,):
250        if self.indentLevel > 0:
251            self.indentLevel -= 1
252
253    def beginBlock(self, bracketPrint=True):
254        if bracketPrint:
255            self.code += self.indent() + "{\n"
256        self.indentLevel += 1
257        self.gensymCounter.append(-1)
258
259    def endBlock(self,bracketPrint=True):
260        self.indentLevel -= 1
261        if bracketPrint:
262            self.code += self.indent() + "}\n"
263        del self.gensymCounter[-1]
264
265    def beginIf(self, cond):
266        self.code += self.indent() + "if (" + cond + ")\n"
267        self.beginBlock()
268
269    def beginElse(self, cond = None):
270        if cond is not None:
271            self.code += \
272                self.indent() + \
273                "else if (" + cond + ")\n"
274        else:
275            self.code += self.indent() + "else\n"
276        self.beginBlock()
277
278    def endElse(self):
279        self.endBlock()
280
281    def endIf(self):
282        self.endBlock()
283
284    def beginSwitch(self, switchvar):
285        self.code += self.indent() + "switch (" + switchvar + ")\n"
286        self.beginBlock()
287
288    def switchCase(self, switchval, blocked = False):
289        self.code += self.indent() + "case %s:" % switchval
290        self.beginBlock(bracketPrint = blocked)
291
292    def switchCaseBreak(self, switchval, blocked = False):
293        self.code += self.indent() + "case %s:" % switchval
294        self.endBlock(bracketPrint = blocked)
295
296    def switchCaseDefault(self, blocked = False):
297        self.code += self.indent() + "default:" % switchval
298        self.beginBlock(bracketPrint = blocked)
299
300    def endSwitch(self):
301        self.endBlock()
302
303    def beginWhile(self, cond):
304        self.code += self.indent() + "while (" + cond + ")\n"
305        self.beginBlock()
306
307    def endWhile(self):
308        self.endBlock()
309
310    def beginFor(self, initial, condition, increment):
311        self.code += \
312            self.indent() + "for (" + \
313            "; ".join([initial, condition, increment]) + \
314            ")\n"
315        self.beginBlock()
316
317    def endFor(self):
318        self.endBlock()
319
320    def beginLoop(self, loopVarType, loopVar, loopInit, loopBound):
321        self.beginFor(
322            "%s %s = %s" % (loopVarType, loopVar, loopInit),
323            "%s < %s" % (loopVar, loopBound),
324            "++%s" % (loopVar))
325
326    def endLoop(self):
327        self.endBlock()
328
329    def stmt(self, code):
330        self.code += self.indent() + code + ";\n"
331
332    def line(self, code):
333        self.code += self.indent() + code + "\n"
334
335    def leftline(self, code):
336        self.code += code + "\n"
337
338    def makeCallExpr(self, funcName, parameters):
339        return funcName + "(%s)" % (", ".join(parameters))
340
341    def funcCall(self, lhs, funcName, parameters):
342        res = self.indent()
343
344        if lhs is not None:
345            res += lhs + " = "
346
347        res += self.makeCallExpr(funcName, parameters) + ";\n"
348        self.code += res
349
350    def funcCallRet(self, _lhs, funcName, parameters):
351        res = self.indent()
352        res += "return " + self.makeCallExpr(funcName, parameters) + ";\n"
353        self.code += res
354
355    # Given a VulkanType object, generate a C type declaration
356    # with optional parameter name:
357    # [const] [typename][*][const*] [paramName]
358    def makeCTypeDecl(self, vulkanType, useParamName=True):
359        constness = "const " if vulkanType.isConst else ""
360        typeName = vulkanType.typeName
361
362        if vulkanType.pointerIndirectionLevels == 0:
363            ptrSpec = ""
364        elif vulkanType.isPointerToConstPointer:
365            ptrSpec = "* const*" if vulkanType.isConst else "**"
366            if vulkanType.pointerIndirectionLevels > 2:
367                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
368        else:
369            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
370
371        if useParamName and (vulkanType.paramName is not None):
372            paramStr = (" " + vulkanType.paramName)
373        else:
374            paramStr = ""
375
376        return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr)
377
378    def makeRichCTypeDecl(self, vulkanType, useParamName=True):
379        constness = "const " if vulkanType.isConst else ""
380        typeName = vulkanType.typeName
381
382        if vulkanType.pointerIndirectionLevels == 0:
383            ptrSpec = ""
384        elif vulkanType.isPointerToConstPointer:
385            ptrSpec = "* const*" if vulkanType.isConst else "**"
386            if vulkanType.pointerIndirectionLevels > 2:
387                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
388        else:
389            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
390
391        if useParamName and (vulkanType.paramName is not None):
392            paramStr = (" " + vulkanType.paramName)
393        else:
394            paramStr = ""
395
396        if vulkanType.staticArrExpr:
397            staticArrInfo = "[%s]" % vulkanType.staticArrExpr
398        else:
399            staticArrInfo = ""
400
401        return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo)
402
403    # Given a VulkanAPI object, generate the C function protype:
404    # <returntype> <funcname>(<parameters>)
405    def makeFuncProto(self, vulkanApi, useParamName=True):
406
407        protoBegin = "%s %s" % (self.makeCTypeDecl(
408            vulkanApi.retType, useParamName=False), vulkanApi.name)
409
410        def getFuncArgDecl(param):
411            if param.staticArrExpr:
412                return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr)
413            else:
414                return self.makeCTypeDecl(param, useParamName=useParamName)
415
416        protoParams = "(\n    %s)" % ((",\n%s" % self.indent(1)).join(
417            list(map(
418                getFuncArgDecl,
419                vulkanApi.parameters))))
420
421        return protoBegin + protoParams
422
423    def makeFuncAlias(self, nameDst, nameSrc):
424        return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst)
425
426    def makeFuncDecl(self, vulkanApi):
427        return self.makeFuncProto(vulkanApi) + ";\n\n"
428
429    def makeFuncImpl(self, vulkanApi, codegenFunc):
430        self.swapCode()
431
432        self.line(self.makeFuncProto(vulkanApi))
433        self.beginBlock()
434        codegenFunc(self)
435        self.endBlock()
436
437        return self.swapCode() + "\n"
438
439    def emitFuncImpl(self, vulkanApi, codegenFunc):
440        self.line(self.makeFuncProto(vulkanApi))
441        self.beginBlock()
442        codegenFunc(self)
443        self.endBlock()
444
445    def makeStructAccess(self,
446                         vulkanType,
447                         structVarName,
448                         asPtr=True,
449                         structAsPtr=True,
450                         accessIndex=None):
451
452        deref = "->" if structAsPtr else "."
453
454        indexExpr = (" + %s" % accessIndex) if accessIndex else ""
455
456        addrOfExpr = "" if vulkanType.accessibleAsPointer() or (
457            not asPtr) else "&"
458
459        return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref,
460                               vulkanType.paramName, indexExpr)
461
462    def makeRawLengthAccess(self, vulkanType):
463        lenExpr = vulkanType.getLengthExpression()
464
465        if not lenExpr:
466            return None, None
467
468        if lenExpr == "null-terminated":
469            return "strlen(%s)" % vulkanType.paramName, None
470
471        return lenExpr, None
472
473    def makeLengthAccessFromStruct(self,
474                                   structInfo,
475                                   vulkanType,
476                                   structVarName,
477                                   asPtr=True):
478        # Handle special cases first
479        # Mostly when latexmath is involved
480        def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr):
481            cases = [
482                {
483                    "structName": "VkShaderModuleCreateInfo",
484                    "field": "pCode",
485                    "lenExprMember": "codeSize",
486                    "postprocess": lambda expr: "(%s / 4)" % expr
487                },
488                {
489                    "structName": "VkPipelineMultisampleStateCreateInfo",
490                    "field": "pSampleMask",
491                    "lenExprMember": "rasterizationSamples",
492                    "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
493                },
494                {
495                    "structName": "VkAccelerationStructureVersionInfoKHR",
496                    "field": "pVersionData",
497                    "lenExprMember": "",
498                    "postprocess": lambda _: "2*VK_UUID_SIZE"
499                },
500            ]
501
502            for c in cases:
503                if (structInfo.name, vulkanType.paramName) == (c["structName"],
504                                                               c["field"]):
505                    deref = "->" if asPtr else "."
506                    expr = "%s%s%s" % (structVarName, deref,
507                                       c["lenExprMember"])
508                    lenAccessGuardExpr = "%s" % structVarName
509                    return c["postprocess"](expr), lenAccessGuardExpr
510
511            return None, None
512
513        specialCaseAccess = \
514            handleSpecialCases(
515                structInfo, vulkanType, structVarName, asPtr)
516
517        if specialCaseAccess != (None, None):
518            return specialCaseAccess
519
520        lenExpr = vulkanType.getLengthExpression()
521
522        if not lenExpr:
523            return None, None
524
525        deref = "->" if asPtr else "."
526        lenAccessGuardExpr = "%s" % (
527
528            structVarName) if deref else None
529        if lenExpr == "null-terminated":
530            return "strlen(%s%s%s)" % (structVarName, deref,
531                                       vulkanType.paramName), lenAccessGuardExpr
532
533        if not structInfo.getMember(lenExpr):
534            return self.makeRawLengthAccess(vulkanType)
535
536        return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr
537
538    def makeLengthAccessFromApi(self, api, vulkanType):
539        # Handle special cases first
540        # Mostly when :: is involved
541        def handleSpecialCases(vulkanType):
542            lenExpr = vulkanType.getLengthExpression()
543
544            if lenExpr is None:
545                return None, None
546
547            if "::" in lenExpr:
548                structVarName, memberVarName = lenExpr.split("::")
549                lenAccessGuardExpr = "%s" % (structVarName)
550                return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr
551            return None, None
552
553        specialCaseAccess = handleSpecialCases(vulkanType)
554
555        if specialCaseAccess != (None, None):
556            return specialCaseAccess
557
558        lenExpr = vulkanType.getLengthExpression()
559
560        if not lenExpr:
561            return None, None
562
563        lenExprInfo = api.getParameter(lenExpr)
564
565        if not lenExprInfo:
566            return self.makeRawLengthAccess(vulkanType)
567
568        if lenExpr == "null-terminated":
569            return "strlen(%s)" % vulkanType.paramName(), None
570        else:
571            deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else ""
572            lenAccessGuardExpr = "%s" % lenExpr if deref else None
573            return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr
574
575    def accessParameter(self, param, asPtr=True):
576        if asPtr:
577            if param.pointerIndirectionLevels > 0:
578                return param.paramName
579            else:
580                return "&%s" % param.paramName
581        else:
582            return param.paramName
583
584    def sizeofExpr(self, vulkanType):
585        return "sizeof(%s)" % (
586            self.makeCTypeDecl(vulkanType, useParamName=False))
587
588    def generalAccess(self,
589                      vulkanType,
590                      parentVarName=None,
591                      asPtr=True,
592                      structAsPtr=True):
593        if vulkanType.parent is None:
594            if parentVarName is None:
595                return self.accessParameter(vulkanType, asPtr=asPtr)
596            else:
597                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
598
599        if isinstance(vulkanType.parent, VulkanCompoundType):
600            return self.makeStructAccess(
601                vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr)
602
603        if isinstance(vulkanType.parent, VulkanAPI):
604            if parentVarName is None:
605                return self.accessParameter(vulkanType, asPtr=asPtr)
606            else:
607                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
608
609        os.abort("Could not find a way to access Vulkan type %s" %
610                 vulkanType.name)
611
612    def makeLengthAccess(self, vulkanType, parentVarName="parent"):
613        if vulkanType.parent is None:
614            return self.makeRawLengthAccess(vulkanType)
615
616        if isinstance(vulkanType.parent, VulkanCompoundType):
617            return self.makeLengthAccessFromStruct(
618                vulkanType.parent, vulkanType, parentVarName, asPtr=True)
619
620        if isinstance(vulkanType.parent, VulkanAPI):
621            return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType)
622
623        os.abort("Could not find a way to access length of Vulkan type %s" %
624                 vulkanType.name)
625
626    def generalLengthAccess(self, vulkanType, parentVarName="parent"):
627        return self.makeLengthAccess(vulkanType, parentVarName)[0]
628
629    def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"):
630        return self.makeLengthAccess(vulkanType, parentVarName)[1]
631
632    def vkApiCall(self, api, customPrefix="", globalStatePrefix="",
633                  customParameters=None, checkForDeviceLost=False,
634                  checkForOutOfMemory=False, checkDispatcher=None):
635        callLhs = None
636
637        retTypeName = api.getRetTypeExpr()
638        retVar = None
639
640        if retTypeName != "void":
641            retVar = api.getRetVarExpr()
642            defaultReturn = "(%s)0" % retTypeName
643            if retTypeName == "VkResult":
644                # TODO: return a valid error code based on the call
645                # This is used to handle invalid dispatcher and snapshot states
646                deviceLostFunctions = ["vkQueueSubmit",
647                                       "vkQueueWaitIdle",
648                                       "vkWaitForFences"]
649                defaultReturn = "VK_ERROR_OUT_OF_HOST_MEMORY"
650                if api in deviceLostFunctions:
651                    defaultReturn = "VK_ERROR_DEVICE_LOST"
652            self.stmt("%s %s = %s" % (retTypeName, retVar, defaultReturn))
653            callLhs = retVar
654
655        if (checkDispatcher):
656            self.beginIf(checkDispatcher)
657
658        if customParameters is None:
659            self.funcCall(
660            callLhs, customPrefix + api.name, [p.paramName for p in api.parameters])
661        else:
662            self.funcCall(
663                callLhs, customPrefix + api.name, customParameters)
664
665        if (checkDispatcher):
666            self.endIf()
667
668        if retTypeName == "VkResult" and checkForDeviceLost:
669            self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix))
670
671        if retTypeName == "VkResult" and checkForOutOfMemory:
672            if api.name == "vkAllocateMemory":
673                self.stmt(
674                    "%sCheckOutOfMemory(%s, opcode, context, std::make_optional<uint64_t>(pAllocateInfo->allocationSize))"
675                    % (globalStatePrefix, callLhs))
676            else:
677                self.stmt(
678                    "%sCheckOutOfMemory(%s, opcode, context)"
679                    % (globalStatePrefix, callLhs))
680
681        return (retTypeName, retVar)
682
683    def makeCheckVkSuccess(self, expr):
684        return "((%s) == VK_SUCCESS)" % expr
685
686    def makeReinterpretCast(self, varName, typeName, const=True):
687        return "reinterpret_cast<%s%s*>(%s)" % \
688               ("const " if const else "", typeName, varName)
689
690    def validPrimitive(self, typeInfo, typeName):
691        size = typeInfo.getPrimitiveEncodingSize(typeName)
692        return size != None
693
694    def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"):
695        if not self.validPrimitive(typeInfo, typeName):
696            return None
697
698        size = typeInfo.getPrimitiveEncodingSize(typeName)
699        prefix = "put" if direction == "write" else "get"
700        suffix = None
701        if size == 1:
702            suffix = "Byte"
703        elif size == 2:
704            suffix = "Be16"
705        elif size == 4:
706            suffix = "Be32"
707        elif size == 8:
708            suffix = "Be64"
709
710        if suffix:
711            return prefix + suffix
712
713        return None
714
715    def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"):
716        if not self.validPrimitive(typeInfo, typeName):
717            return None
718
719        size = typeInfo.getPrimitiveEncodingSize(typeName)
720        prefix = "to" if direction == "write" else "from"
721        suffix = None
722        if size == 1:
723            suffix = "Byte"
724        elif size == 2:
725            suffix = "Be16"
726        elif size == 4:
727            suffix = "Be32"
728        elif size == 8:
729            suffix = "Be64"
730
731        if suffix:
732            return prefix + suffix
733
734        return None
735
736    def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
737        accessTypeName = accessType.typeName
738
739        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
740            print("Tried to stream a non-primitive type: %s" % accessTypeName)
741            os.abort()
742
743        needPtrCast = False
744
745        if accessType.pointerIndirectionLevels > 0:
746            streamSize = 8
747            streamStorageVarType = "uint64_t"
748            needPtrCast = True
749            streamMethod = "putBe64" if direction == "write" else "getBe64"
750        else:
751            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
752            if streamSize == 1:
753                streamStorageVarType = "uint8_t"
754            elif streamSize == 2:
755                streamStorageVarType = "uint16_t"
756            elif streamSize == 4:
757                streamStorageVarType = "uint32_t"
758            elif streamSize == 8:
759                streamStorageVarType = "uint64_t"
760            streamMethod = self.makePrimitiveStreamMethod(
761                typeInfo, accessTypeName, direction=direction)
762
763        streamStorageVar = self.var()
764
765        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
766
767        ptrCast = "(uintptr_t)" if needPtrCast else ""
768
769        if direction == "read":
770            self.stmt("%s = (%s)%s%s->%s()" %
771                      (accessExpr,
772                       accessCast,
773                       ptrCast,
774                       streamVar,
775                       streamMethod))
776        else:
777            self.stmt("%s %s = (%s)%s%s" %
778                      (streamStorageVarType, streamStorageVar,
779                       streamStorageVarType, ptrCast, accessExpr))
780            self.stmt("%s->%s(%s)" %
781                      (streamVar, streamMethod, streamStorageVar))
782
783    def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, variant, direction="write"):
784        accessTypeName = accessType.typeName
785
786        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
787            print("Tried to stream a non-primitive type: %s" % accessTypeName)
788            os.abort()
789
790        needPtrCast = False
791
792        streamSize = 8
793
794        if accessType.pointerIndirectionLevels > 0:
795            streamSize = 8
796            streamStorageVarType = "uint64_t"
797            needPtrCast = True
798            streamMethod = "toBe64" if direction == "write" else "fromBe64"
799        else:
800            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
801            if streamSize == 1:
802                streamStorageVarType = "uint8_t"
803            elif streamSize == 2:
804                streamStorageVarType = "uint16_t"
805            elif streamSize == 4:
806                streamStorageVarType = "uint32_t"
807            elif streamSize == 8:
808                streamStorageVarType = "uint64_t"
809            streamMethod = self.makePrimitiveStreamMethodInPlace(
810                typeInfo, accessTypeName, direction=direction)
811
812        streamStorageVar = self.var()
813
814        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
815
816        if direction == "read":
817            accessCast = self.makeRichCTypeDecl(
818                accessType.getForNonConstAccess(), useParamName=False)
819
820        ptrCast = "(uintptr_t)" if needPtrCast else ""
821        if variant == "guest":
822            streamNamespace = "gfxstream::aemu"
823        else:
824            streamNamespace = "android::base"
825
826        if direction == "read":
827            self.stmt("memcpy((%s*)&%s, %s, %s)" %
828                      (accessCast,
829                       accessExpr,
830                       streamVar,
831                       str(streamSize)))
832            self.stmt("%s::Stream::%s((uint8_t*)&%s)" % (
833                streamNamespace,
834                streamMethod,
835                accessExpr))
836        else:
837            self.stmt("%s %s = (%s)%s%s" %
838                      (streamStorageVarType, streamStorageVar,
839                       streamStorageVarType, ptrCast, accessExpr))
840            self.stmt("memcpy(%s, &%s, %s)" %
841                      (streamVar, streamStorageVar, str(streamSize)))
842            self.stmt("%s::Stream::%s((uint8_t*)%s)" % (
843                streamNamespace,
844                streamMethod,
845                streamVar))
846
847    def countPrimitive(self, typeInfo, accessType):
848        accessTypeName = accessType.typeName
849
850        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
851            print("Tried to count a non-primitive type: %s" % accessTypeName)
852            os.abort()
853
854        needPtrCast = False
855
856        if accessType.pointerIndirectionLevels > 0:
857            streamSize = 8
858        else:
859            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
860
861        return streamSize
862
863# Class to wrap a Vulkan API call.
864#
865# The user gives a generic callback, |codegenDef|,
866# that takes a CodeGen object and a VulkanAPI object as arguments.
867# codegenDef uses CodeGen along with the VulkanAPI object
868# to generate the function body.
869class VulkanAPIWrapper(object):
870
871    def __init__(self,
872                 customApiPrefix,
873                 extraParameters=None,
874                 returnTypeOverride=None,
875                 codegenDef=None):
876        self.customApiPrefix = customApiPrefix
877        self.extraParameters = extraParameters
878        self.returnTypeOverride = returnTypeOverride
879
880        self.codegen = CodeGen()
881
882        self.definitionFunc = codegenDef
883
884        # Private function
885
886        def makeApiFunc(self, typeInfo, apiName):
887            customApi = copy(typeInfo.apis[apiName])
888            customApi.name = self.customApiPrefix + customApi.name
889            if self.extraParameters is not None:
890                if isinstance(self.extraParameters, list):
891                    customApi.parameters = \
892                        self.extraParameters + customApi.parameters
893                else:
894                    os.abort(
895                        "Type of extra parameters to custom API not valid. Expected list, got %s" % type(
896                            self.extraParameters))
897
898            if self.returnTypeOverride is not None:
899                customApi.retType = self.returnTypeOverride
900            return customApi
901
902        self.makeApi = makeApiFunc
903
904    def setCodegenDef(self, codegenDefFunc):
905        self.definitionFunc = codegenDefFunc
906
907    def makeDecl(self, typeInfo, apiName):
908        return self.codegen.makeFuncProto(
909            self.makeApi(self, typeInfo, apiName)) + ";\n\n"
910
911    def makeDefinition(self, typeInfo, apiName, isStatic=False):
912        vulkanApi = self.makeApi(self, typeInfo, apiName)
913
914        self.codegen.swapCode()
915        self.codegen.beginBlock()
916
917        if self.definitionFunc is None:
918            print("ERROR: No definition found for (%s, %s)" %
919                  (vulkanApi.name, self.customApiPrefix))
920            sys.exit(1)
921
922        self.definitionFunc(self.codegen, vulkanApi)
923
924        self.codegen.endBlock()
925
926        return ("static " if isStatic else "") + self.codegen.makeFuncProto(
927            vulkanApi) + "\n" + self.codegen.swapCode() + "\n"
928
929# Base class for wrapping all Vulkan API objects.  These work with Vulkan
930# Registry generators and have gen* triggers.  They tend to contain
931# VulkanAPIWrapper objects to make it easier to generate the code.
932class VulkanWrapperGenerator(object):
933
934    def __init__(self, module: Module, typeInfo: VulkanTypeInfo):
935        self.module: Module = module
936        self.typeInfo: VulkanTypeInfo = typeInfo
937        self.extensionStructTypes = OrderedDict()
938
939    def onBegin(self):
940        pass
941
942    def onEnd(self):
943        pass
944
945    def onBeginFeature(self, featureName, featureType):
946        pass
947
948    def onFeatureNewCmd(self, cmdName):
949        pass
950
951    def onEndFeature(self):
952        pass
953
954    def onGenType(self, typeInfo, name, alias):
955        category = self.typeInfo.categoryOf(name)
956        if category in ["struct", "union"] and not alias:
957            structInfo = self.typeInfo.structs[name]
958            if structInfo.structExtendsExpr:
959                self.extensionStructTypes[name] = structInfo
960        pass
961
962    def onGenStruct(self, typeInfo, name, alias):
963        pass
964
965    def onGenGroup(self, groupinfo, groupName, alias=None):
966        pass
967
968    def onGenEnum(self, enuminfo, name, alias):
969        pass
970
971    def onGenCmd(self, cmdinfo, name, alias):
972        pass
973
974    # Below Vulkan structure types may correspond to multiple Vulkan structs
975    # due to a conflict between different Vulkan registries. In order to get
976    # the correct Vulkan struct type, we need to check the type of its "root"
977    # struct as well.
978    ROOT_TYPE_MAPPING = {
979        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": {
980            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
981            "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
982            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE",
983            "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
984        },
985        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": {
986            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
987            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkCreateBlobGOOGLE",
988            "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
989        },
990        "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": {
991            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT",
992            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT",
993            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE",
994            "default": "VkRenderPassFragmentDensityMapCreateInfoEXT",
995        },
996    }
997
998    def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None):
999        def readStructType(structTypeName, structVarName, cgen):
1000            cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \
1001                (structTypeName, "goldfish_vk_struct_type", structVarName))
1002
1003        def castAsStruct(varName, typeName, const=True):
1004            return "reinterpret_cast<%s%s*>(%s)" % \
1005                   ("const " if const else "", typeName, varName)
1006
1007        def doDefaultReturn(cgen):
1008            if retType.typeName == "void":
1009                cgen.stmt("return")
1010            else:
1011                cgen.stmt("return (%s)0" % retType.typeName)
1012
1013        cgen.beginIf("!%s" % triggerVar.paramName)
1014        if nullEmit is None:
1015            doDefaultReturn(cgen)
1016        else:
1017            nullEmit(cgen)
1018        cgen.endIf()
1019
1020        readStructType("structType", triggerVar.paramName, cgen)
1021
1022        cgen.line("switch(structType)")
1023        cgen.beginBlock()
1024
1025        currFeature = None
1026
1027        for ext in self.extensionStructTypes.values():
1028            if not currFeature:
1029                cgen.leftline("#ifdef %s" % ext.feature)
1030                currFeature = ext.feature
1031
1032            if currFeature and ext.feature != currFeature:
1033                cgen.leftline("#endif")
1034                cgen.leftline("#ifdef %s" % ext.feature)
1035                currFeature = ext.feature
1036
1037            enum = ext.structEnumExpr
1038            protect = None
1039            if enum in self.typeInfo.enumElem:
1040                protect = self.typeInfo.enumElem[enum].get("protect", default=None)
1041                if protect is not None:
1042                    cgen.leftline("#ifdef %s" % protect)
1043
1044            cgen.line("case %s:" % enum)
1045            cgen.beginBlock()
1046
1047            if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING:
1048                cgen.line("switch(%s)" % rootTypeVar.paramName)
1049                cgen.beginBlock()
1050                kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum]
1051                for k in kv:
1052                    v = self.extensionStructTypes[kv[k]]
1053                    if k == "default":
1054                        cgen.line("%s:" % k)
1055                    else:
1056                        cgen.line("case %s:" % k)
1057                    cgen.beginBlock()
1058                    castedAccess = castAsStruct(
1059                        triggerVar.paramName, v.name, const=triggerVar.isConst)
1060                    forEachFunc(v, castedAccess, cgen)
1061                    cgen.line("break;")
1062                    cgen.endBlock()
1063                cgen.endBlock()
1064            else:
1065                castedAccess = castAsStruct(
1066                    triggerVar.paramName, ext.name, const=triggerVar.isConst)
1067                forEachFunc(ext, castedAccess, cgen)
1068
1069            if autoBreak:
1070                cgen.stmt("break")
1071            cgen.endBlock()
1072
1073            if protect is not None:
1074                cgen.leftline("#endif // %s" % protect)
1075
1076        if currFeature:
1077            cgen.leftline("#endif")
1078
1079        cgen.line("default:")
1080        cgen.beginBlock()
1081        if defaultEmit is None:
1082            doDefaultReturn(cgen)
1083        else:
1084            defaultEmit(cgen)
1085        cgen.endBlock()
1086
1087        cgen.endBlock()
1088
1089    def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False):
1090        currFeature = None
1091
1092        for (i, ext) in enumerate(self.extensionStructTypes.values()):
1093            if doFeatureIfdefs:
1094                if not currFeature:
1095                    cgen.leftline("#ifdef %s" % ext.feature)
1096                    currFeature = ext.feature
1097
1098                if currFeature and ext.feature != currFeature:
1099                    cgen.leftline("#endif")
1100                    cgen.leftline("#ifdef %s" % ext.feature)
1101                    currFeature = ext.feature
1102
1103            forEachFunc(i, ext, cgen)
1104
1105        if doFeatureIfdefs:
1106            if currFeature:
1107                cgen.leftline("#endif")
1108