• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI
17from collections import OrderedDict
18from copy import copy
19from pathlib import Path, PurePosixPath
20
21import os
22import sys
23import shutil
24import subprocess
25
26# Class capturing a single file
27
28
29class SingleFileModule(object):
30    def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False):
31        self.directory = directory
32        self.basename = basename
33        self.customAbsDir = customAbsDir
34        self.suffix = suffix
35        self.file = None
36
37        self.preamble = ""
38        self.postamble = ""
39
40        self.suppress = suppress
41
42    def begin(self, globalDir):
43        if self.suppress:
44            return
45
46        # Create subdirectory, if needed
47        if self.customAbsDir:
48            absDir = self.customAbsDir
49        else:
50            absDir = os.path.join(globalDir, self.directory)
51
52        filename = os.path.join(absDir, self.basename)
53
54        self.file = open(filename + self.suffix, "w", encoding="utf-8")
55        self.file.write(self.preamble)
56
57    def append(self, toAppend):
58        if self.suppress:
59            return
60
61        self.file.write(toAppend)
62
63    def end(self):
64        if self.suppress:
65            return
66
67        self.file.write(self.postamble)
68        self.file.close()
69
70    def getMakefileSrcEntry(self):
71        return ""
72
73    def getCMakeSrcEntry(self):
74        return ""
75
76# Class capturing a .cpp file and a .h file (a "C++ module")
77
78
79class Module(object):
80
81    def __init__(
82            self, directory, basename, customAbsDir=None, suppress=False, implOnly=False,
83            headerOnly=False, suppressFeatureGuards=False):
84        self._headerFileModule = SingleFileModule(
85            ".h", directory, basename, customAbsDir, suppress or implOnly)
86        self._implFileModule = SingleFileModule(
87            ".cpp", directory, basename, customAbsDir, suppress or headerOnly)
88
89        self._headerOnly = headerOnly
90        self._implOnly = implOnly
91
92        self.directory = directory
93        self.basename = basename
94        self._customAbsDir = customAbsDir
95
96        self.suppressFeatureGuards = suppressFeatureGuards
97
98    @property
99    def suppress(self):
100        raise AttributeError("suppress is write only")
101
102    @suppress.setter
103    def suppress(self, value: bool):
104        self._headerFileModule.suppress = self._implOnly or value
105        self._implFileModule.suppress = self._headerOnly or value
106
107    @property
108    def headerPreamble(self) -> str:
109        return self._headerFileModule.preamble
110
111    @headerPreamble.setter
112    def headerPreamble(self, value: str):
113        self._headerFileModule.preamble = value
114
115    @property
116    def headerPostamble(self) -> str:
117        return self._headerFileModule.postamble
118
119    @headerPostamble.setter
120    def headerPostamble(self, value: str):
121        self._headerFileModule.postamble = value
122
123    @property
124    def implPreamble(self) -> str:
125        return self._implFileModule.preamble
126
127    @implPreamble.setter
128    def implPreamble(self, value: str):
129        self._implFileModule.preamble = value
130
131    @property
132    def implPostamble(self) -> str:
133        return self._implFileModule.postamble
134
135    @implPostamble.setter
136    def implPostamble(self, value: str):
137        self._implFileModule.postamble = value
138
139    def getMakefileSrcEntry(self):
140        if self._customAbsDir:
141            return self.basename + ".cpp \\\n"
142        dirName = self.directory
143        baseName = self.basename
144        joined = os.path.join(dirName, baseName)
145        return "    " + joined + ".cpp \\\n"
146
147    def getCMakeSrcEntry(self):
148        if self._customAbsDir:
149            return "\n" + self.basename + ".cpp "
150        dirName = Path(self.directory)
151        baseName = Path(self.basename)
152        joined = PurePosixPath(dirName / baseName)
153        return "\n    " + str(joined) + ".cpp "
154
155    def begin(self, globalDir):
156        self._headerFileModule.begin(globalDir)
157        self._implFileModule.begin(globalDir)
158
159    def appendHeader(self, toAppend):
160        self._headerFileModule.append(toAppend)
161
162    def appendImpl(self, toAppend):
163        self._implFileModule.append(toAppend)
164
165    def end(self):
166        self._headerFileModule.end()
167        self._implFileModule.end()
168
169        clang_format_command = shutil.which('clang-format')
170        assert (clang_format_command is not None)
171
172        def formatFile(filename: Path):
173            assert (subprocess.call([clang_format_command, "-i",
174                    "--style=file", str(filename.resolve())]) == 0)
175
176        if not self._headerFileModule.suppress:
177            formatFile(Path(self._headerFileModule.file.name))
178
179        if not self._implFileModule.suppress:
180            formatFile(Path(self._implFileModule.file.name))
181
182
183class PyScript(SingleFileModule):
184    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
185        super().__init__(".py", directory, basename, customAbsDir, suppress)
186
187
188# Class capturing a .proto protobuf definition file
189class Proto(SingleFileModule):
190
191    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
192        super().__init__(".proto", directory, basename, customAbsDir, suppress)
193
194    def getMakefileSrcEntry(self):
195        super().getMakefileSrcEntry()
196        if self.customAbsDir:
197            return self.basename + ".proto \\\n"
198        dirName = self.directory
199        baseName = self.basename
200        joined = os.path.join(dirName, baseName)
201        return "    " + joined + ".proto \\\n"
202
203    def getCMakeSrcEntry(self):
204        super().getCMakeSrcEntry()
205        if self.customAbsDir:
206            return "\n" + self.basename + ".proto "
207
208        dirName = self.directory
209        baseName = self.basename
210        joined = os.path.join(dirName, baseName)
211        return "\n    " + joined + ".proto "
212
213class CodeGen(object):
214
215    def __init__(self,):
216        self.code = ""
217        self.indentLevel = 0
218        self.gensymCounter = [-1]
219
220    def var(self, prefix="cgen_var"):
221        self.gensymCounter[-1] += 1
222        res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0))
223        return res
224
225    def swapCode(self,):
226        res = "%s" % self.code
227        self.code = ""
228        return res
229
230    def indent(self,extra=0):
231        return "".join("    " * (self.indentLevel + extra))
232
233    def incrIndent(self,):
234        self.indentLevel += 1
235
236    def decrIndent(self,):
237        if self.indentLevel > 0:
238            self.indentLevel -= 1
239
240    def beginBlock(self, bracketPrint=True):
241        if bracketPrint:
242            self.code += self.indent() + "{\n"
243        self.indentLevel += 1
244        self.gensymCounter.append(-1)
245
246    def endBlock(self,bracketPrint=True):
247        self.indentLevel -= 1
248        if bracketPrint:
249            self.code += self.indent() + "}\n"
250        del self.gensymCounter[-1]
251
252    def beginIf(self, cond):
253        self.code += self.indent() + "if (" + cond + ")\n"
254        self.beginBlock()
255
256    def beginElse(self, cond = None):
257        if cond is not None:
258            self.code += \
259                self.indent() + \
260                "else if (" + cond + ")\n"
261        else:
262            self.code += self.indent() + "else\n"
263        self.beginBlock()
264
265    def endElse(self):
266        self.endBlock()
267
268    def endIf(self):
269        self.endBlock()
270
271    def beginSwitch(self, switchvar):
272        self.code += self.indent() + "switch (" + switchvar + ")\n"
273        self.beginBlock()
274
275    def switchCase(self, switchval, blocked = False):
276        self.code += self.indent() + "case %s:" % switchval
277        self.beginBlock(bracketPrint = blocked)
278
279    def switchCaseBreak(self, switchval, blocked = False):
280        self.code += self.indent() + "case %s:" % switchval
281        self.endBlock(bracketPrint = blocked)
282
283    def switchCaseDefault(self, blocked = False):
284        self.code += self.indent() + "default:" % switchval
285        self.beginBlock(bracketPrint = blocked)
286
287    def endSwitch(self):
288        self.endBlock()
289
290    def beginWhile(self, cond):
291        self.code += self.indent() + "while (" + cond + ")\n"
292        self.beginBlock()
293
294    def endWhile(self):
295        self.endBlock()
296
297    def beginFor(self, initial, condition, increment):
298        self.code += \
299            self.indent() + "for (" + \
300            "; ".join([initial, condition, increment]) + \
301            ")\n"
302        self.beginBlock()
303
304    def endFor(self):
305        self.endBlock()
306
307    def beginLoop(self, loopVarType, loopVar, loopInit, loopBound):
308        self.beginFor(
309            "%s %s = %s" % (loopVarType, loopVar, loopInit),
310            "%s < %s" % (loopVar, loopBound),
311            "++%s" % (loopVar))
312
313    def endLoop(self):
314        self.endBlock()
315
316    def stmt(self, code):
317        self.code += self.indent() + code + ";\n"
318
319    def line(self, code):
320        self.code += self.indent() + code + "\n"
321
322    def leftline(self, code):
323        self.code += code + "\n"
324
325    def makeCallExpr(self, funcName, parameters):
326        return funcName + "(%s)" % (", ".join(parameters))
327
328    def funcCall(self, lhs, funcName, parameters):
329        res = self.indent()
330
331        if lhs is not None:
332            res += lhs + " = "
333
334        res += self.makeCallExpr(funcName, parameters) + ";\n"
335        self.code += res
336
337    def funcCallRet(self, _lhs, funcName, parameters):
338        res = self.indent()
339        res += "return " + self.makeCallExpr(funcName, parameters) + ";\n"
340        self.code += res
341
342    # Given a VulkanType object, generate a C type declaration
343    # with optional parameter name:
344    # [const] [typename][*][const*] [paramName]
345    def makeCTypeDecl(self, vulkanType, useParamName=True):
346        constness = "const " if vulkanType.isConst else ""
347        typeName = vulkanType.typeName
348
349        if vulkanType.pointerIndirectionLevels == 0:
350            ptrSpec = ""
351        elif vulkanType.isPointerToConstPointer:
352            ptrSpec = "* const*" if vulkanType.isConst else "**"
353            if vulkanType.pointerIndirectionLevels > 2:
354                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
355        else:
356            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
357
358        if useParamName and (vulkanType.paramName is not None):
359            paramStr = (" " + vulkanType.paramName)
360        else:
361            paramStr = ""
362
363        return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr)
364
365    def makeRichCTypeDecl(self, vulkanType, useParamName=True):
366        constness = "const " if vulkanType.isConst else ""
367        typeName = vulkanType.typeName
368
369        if vulkanType.pointerIndirectionLevels == 0:
370            ptrSpec = ""
371        elif vulkanType.isPointerToConstPointer:
372            ptrSpec = "* const*" if vulkanType.isConst else "**"
373            if vulkanType.pointerIndirectionLevels > 2:
374                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
375        else:
376            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
377
378        if useParamName and (vulkanType.paramName is not None):
379            paramStr = (" " + vulkanType.paramName)
380        else:
381            paramStr = ""
382
383        if vulkanType.staticArrExpr:
384            staticArrInfo = "[%s]" % vulkanType.staticArrExpr
385        else:
386            staticArrInfo = ""
387
388        return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo)
389
390    # Given a VulkanAPI object, generate the C function protype:
391    # <returntype> <funcname>(<parameters>)
392    def makeFuncProto(self, vulkanApi, useParamName=True):
393
394        protoBegin = "%s %s" % (self.makeCTypeDecl(
395            vulkanApi.retType, useParamName=False), vulkanApi.name)
396
397        def getFuncArgDecl(param):
398            if param.staticArrExpr:
399                return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr)
400            else:
401                return self.makeCTypeDecl(param, useParamName=useParamName)
402
403        protoParams = "(\n    %s)" % ((",\n%s" % self.indent(1)).join(
404            list(map(
405                getFuncArgDecl,
406                vulkanApi.parameters))))
407
408        return protoBegin + protoParams
409
410    def makeFuncAlias(self, nameDst, nameSrc):
411        return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst)
412
413    def makeFuncDecl(self, vulkanApi):
414        return self.makeFuncProto(vulkanApi) + ";\n\n"
415
416    def makeFuncImpl(self, vulkanApi, codegenFunc):
417        self.swapCode()
418
419        self.line(self.makeFuncProto(vulkanApi))
420        self.beginBlock()
421        codegenFunc(self)
422        self.endBlock()
423
424        return self.swapCode() + "\n"
425
426    def emitFuncImpl(self, vulkanApi, codegenFunc):
427        self.line(self.makeFuncProto(vulkanApi))
428        self.beginBlock()
429        codegenFunc(self)
430        self.endBlock()
431
432    def makeStructAccess(self,
433                         vulkanType,
434                         structVarName,
435                         asPtr=True,
436                         structAsPtr=True,
437                         accessIndex=None):
438
439        deref = "->" if structAsPtr else "."
440
441        indexExpr = (" + %s" % accessIndex) if accessIndex else ""
442
443        addrOfExpr = "" if vulkanType.accessibleAsPointer() or (
444            not asPtr) else "&"
445
446        return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref,
447                               vulkanType.paramName, indexExpr)
448
449    def makeRawLengthAccess(self, vulkanType):
450        lenExpr = vulkanType.getLengthExpression()
451
452        if not lenExpr:
453            return None, None
454
455        if lenExpr == "null-terminated":
456            return "strlen(%s)" % vulkanType.paramName, None
457
458        return lenExpr, None
459
460    def makeLengthAccessFromStruct(self,
461                                   structInfo,
462                                   vulkanType,
463                                   structVarName,
464                                   asPtr=True):
465        # Handle special cases first
466        # Mostly when latexmath is involved
467        def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr):
468            cases = [
469                {
470                    "structName": "VkShaderModuleCreateInfo",
471                    "field": "pCode",
472                    "lenExprMember": "codeSize",
473                    "postprocess": lambda expr: "(%s / 4)" % expr
474                },
475                {
476                    "structName": "VkPipelineMultisampleStateCreateInfo",
477                    "field": "pSampleMask",
478                    "lenExprMember": "rasterizationSamples",
479                    "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
480                },
481                {
482                    "structName": "VkAccelerationStructureVersionInfoKHR",
483                    "field": "pVersionData",
484                    "lenExprMember": "",
485                    "postprocess": lambda _: "2*VK_UUID_SIZE"
486                },
487            ]
488
489            for c in cases:
490                if (structInfo.name, vulkanType.paramName) == (c["structName"],
491                                                               c["field"]):
492                    deref = "->" if asPtr else "."
493                    expr = "%s%s%s" % (structVarName, deref,
494                                       c["lenExprMember"])
495                    lenAccessGuardExpr = "%s" % structVarName
496                    return c["postprocess"](expr), lenAccessGuardExpr
497
498            return None, None
499
500        specialCaseAccess = \
501            handleSpecialCases(
502                structInfo, vulkanType, structVarName, asPtr)
503
504        if specialCaseAccess != (None, None):
505            return specialCaseAccess
506
507        lenExpr = vulkanType.getLengthExpression()
508
509        if not lenExpr:
510            return None, None
511
512        deref = "->" if asPtr else "."
513        lenAccessGuardExpr = "%s" % (
514
515            structVarName) if deref else None
516        if lenExpr == "null-terminated":
517            return "strlen(%s%s%s)" % (structVarName, deref,
518                                       vulkanType.paramName), lenAccessGuardExpr
519
520        if not structInfo.getMember(lenExpr):
521            return self.makeRawLengthAccess(vulkanType)
522
523        return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr
524
525    def makeLengthAccessFromApi(self, api, vulkanType):
526        # Handle special cases first
527        # Mostly when :: is involved
528        def handleSpecialCases(vulkanType):
529            lenExpr = vulkanType.getLengthExpression()
530
531            if lenExpr is None:
532                return None, None
533
534            if "::" in lenExpr:
535                structVarName, memberVarName = lenExpr.split("::")
536                lenAccessGuardExpr = "%s" % (structVarName)
537                return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr
538            return None, None
539
540        specialCaseAccess = handleSpecialCases(vulkanType)
541
542        if specialCaseAccess != (None, None):
543            return specialCaseAccess
544
545        lenExpr = vulkanType.getLengthExpression()
546
547        if not lenExpr:
548            return None, None
549
550        lenExprInfo = api.getParameter(lenExpr)
551
552        if not lenExprInfo:
553            return self.makeRawLengthAccess(vulkanType)
554
555        if lenExpr == "null-terminated":
556            return "strlen(%s)" % vulkanType.paramName(), None
557        else:
558            deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else ""
559            lenAccessGuardExpr = "%s" % lenExpr if deref else None
560            return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr
561
562    def accessParameter(self, param, asPtr=True):
563        if asPtr:
564            if param.pointerIndirectionLevels > 0:
565                return param.paramName
566            else:
567                return "&%s" % param.paramName
568        else:
569            return param.paramName
570
571    def sizeofExpr(self, vulkanType):
572        return "sizeof(%s)" % (
573            self.makeCTypeDecl(vulkanType, useParamName=False))
574
575    def generalAccess(self,
576                      vulkanType,
577                      parentVarName=None,
578                      asPtr=True,
579                      structAsPtr=True):
580        if vulkanType.parent is None:
581            if parentVarName is None:
582                return self.accessParameter(vulkanType, asPtr=asPtr)
583            else:
584                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
585
586        if isinstance(vulkanType.parent, VulkanCompoundType):
587            return self.makeStructAccess(
588                vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr)
589
590        if isinstance(vulkanType.parent, VulkanAPI):
591            if parentVarName is None:
592                return self.accessParameter(vulkanType, asPtr=asPtr)
593            else:
594                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
595
596        os.abort("Could not find a way to access Vulkan type %s" %
597                 vulkanType.name)
598
599    def makeLengthAccess(self, vulkanType, parentVarName="parent"):
600        if vulkanType.parent is None:
601            return self.makeRawLengthAccess(vulkanType)
602
603        if isinstance(vulkanType.parent, VulkanCompoundType):
604            return self.makeLengthAccessFromStruct(
605                vulkanType.parent, vulkanType, parentVarName, asPtr=True)
606
607        if isinstance(vulkanType.parent, VulkanAPI):
608            return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType)
609
610        os.abort("Could not find a way to access length of Vulkan type %s" %
611                 vulkanType.name)
612
613    def generalLengthAccess(self, vulkanType, parentVarName="parent"):
614        return self.makeLengthAccess(vulkanType, parentVarName)[0]
615
616    def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"):
617        return self.makeLengthAccess(vulkanType, parentVarName)[1]
618
619    def vkApiCall(self, api, customPrefix="", globalStatePrefix="", customParameters=None, checkForDeviceLost=False, checkForOutOfMemory=False):
620        callLhs = None
621
622        retTypeName = api.getRetTypeExpr()
623        retVar = None
624
625        if retTypeName != "void":
626            retVar = api.getRetVarExpr()
627            self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName))
628            callLhs = retVar
629
630        if customParameters is None:
631            self.funcCall(
632            callLhs, customPrefix + api.name, [p.paramName for p in api.parameters])
633        else:
634            self.funcCall(
635                callLhs, customPrefix + api.name, customParameters)
636
637        if retTypeName == "VkResult" and checkForDeviceLost:
638            self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix))
639
640        if retTypeName == "VkResult" and checkForOutOfMemory:
641            if api.name == "vkAllocateMemory":
642                self.stmt(
643                    "%sCheckOutOfMemory(%s, opcode, context, std::make_optional<uint64_t>(pAllocateInfo->allocationSize))"
644                    % (globalStatePrefix, callLhs))
645            else:
646                self.stmt(
647                    "%sCheckOutOfMemory(%s, opcode, context)"
648                    % (globalStatePrefix, callLhs))
649
650        return (retTypeName, retVar)
651
652    def makeCheckVkSuccess(self, expr):
653        return "((%s) == VK_SUCCESS)" % expr
654
655    def makeReinterpretCast(self, varName, typeName, const=True):
656        return "reinterpret_cast<%s%s*>(%s)" % \
657               ("const " if const else "", typeName, varName)
658
659    def validPrimitive(self, typeInfo, typeName):
660        size = typeInfo.getPrimitiveEncodingSize(typeName)
661        return size != None
662
663    def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"):
664        if not self.validPrimitive(typeInfo, typeName):
665            return None
666
667        size = typeInfo.getPrimitiveEncodingSize(typeName)
668        prefix = "put" if direction == "write" else "get"
669        suffix = None
670        if size == 1:
671            suffix = "Byte"
672        elif size == 2:
673            suffix = "Be16"
674        elif size == 4:
675            suffix = "Be32"
676        elif size == 8:
677            suffix = "Be64"
678
679        if suffix:
680            return prefix + suffix
681
682        return None
683
684    def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"):
685        if not self.validPrimitive(typeInfo, typeName):
686            return None
687
688        size = typeInfo.getPrimitiveEncodingSize(typeName)
689        prefix = "to" if direction == "write" else "from"
690        suffix = None
691        if size == 1:
692            suffix = "Byte"
693        elif size == 2:
694            suffix = "Be16"
695        elif size == 4:
696            suffix = "Be32"
697        elif size == 8:
698            suffix = "Be64"
699
700        if suffix:
701            return prefix + suffix
702
703        return None
704
705    def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
706        accessTypeName = accessType.typeName
707
708        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
709            print("Tried to stream a non-primitive type: %s" % accessTypeName)
710            os.abort()
711
712        needPtrCast = False
713
714        if accessType.pointerIndirectionLevels > 0:
715            streamSize = 8
716            streamStorageVarType = "uint64_t"
717            needPtrCast = True
718            streamMethod = "putBe64" if direction == "write" else "getBe64"
719        else:
720            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
721            if streamSize == 1:
722                streamStorageVarType = "uint8_t"
723            elif streamSize == 2:
724                streamStorageVarType = "uint16_t"
725            elif streamSize == 4:
726                streamStorageVarType = "uint32_t"
727            elif streamSize == 8:
728                streamStorageVarType = "uint64_t"
729            streamMethod = self.makePrimitiveStreamMethod(
730                typeInfo, accessTypeName, direction=direction)
731
732        streamStorageVar = self.var()
733
734        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
735
736        ptrCast = "(uintptr_t)" if needPtrCast else ""
737
738        if direction == "read":
739            self.stmt("%s = (%s)%s%s->%s()" %
740                      (accessExpr,
741                       accessCast,
742                       ptrCast,
743                       streamVar,
744                       streamMethod))
745        else:
746            self.stmt("%s %s = (%s)%s%s" %
747                      (streamStorageVarType, streamStorageVar,
748                       streamStorageVarType, ptrCast, accessExpr))
749            self.stmt("%s->%s(%s)" %
750                      (streamVar, streamMethod, streamStorageVar))
751
752    def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
753        accessTypeName = accessType.typeName
754
755        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
756            print("Tried to stream a non-primitive type: %s" % accessTypeName)
757            os.abort()
758
759        needPtrCast = False
760
761        streamSize = 8
762
763        if accessType.pointerIndirectionLevels > 0:
764            streamSize = 8
765            streamStorageVarType = "uint64_t"
766            needPtrCast = True
767            streamMethod = "toBe64" if direction == "write" else "fromBe64"
768        else:
769            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
770            if streamSize == 1:
771                streamStorageVarType = "uint8_t"
772            elif streamSize == 2:
773                streamStorageVarType = "uint16_t"
774            elif streamSize == 4:
775                streamStorageVarType = "uint32_t"
776            elif streamSize == 8:
777                streamStorageVarType = "uint64_t"
778            streamMethod = self.makePrimitiveStreamMethodInPlace(
779                typeInfo, accessTypeName, direction=direction)
780
781        streamStorageVar = self.var()
782
783        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
784
785        if direction == "read":
786            accessCast = self.makeRichCTypeDecl(
787                accessType.getForNonConstAccess(), useParamName=False)
788
789        ptrCast = "(uintptr_t)" if needPtrCast else ""
790
791        if direction == "read":
792            self.stmt("memcpy((%s*)&%s, %s, %s)" %
793                      (accessCast,
794                       accessExpr,
795                       streamVar,
796                       str(streamSize)))
797            self.stmt("android::base::Stream::%s((uint8_t*)&%s)" % (
798                streamMethod,
799                accessExpr))
800        else:
801            self.stmt("%s %s = (%s)%s%s" %
802                      (streamStorageVarType, streamStorageVar,
803                       streamStorageVarType, ptrCast, accessExpr))
804            self.stmt("memcpy(%s, &%s, %s)" %
805                      (streamVar, streamStorageVar, str(streamSize)))
806            self.stmt("android::base::Stream::%s((uint8_t*)%s)" % (
807                streamMethod,
808                streamVar))
809
810    def countPrimitive(self, typeInfo, accessType):
811        accessTypeName = accessType.typeName
812
813        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
814            print("Tried to count a non-primitive type: %s" % accessTypeName)
815            os.abort()
816
817        needPtrCast = False
818
819        if accessType.pointerIndirectionLevels > 0:
820            streamSize = 8
821        else:
822            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
823
824        return streamSize
825
826# Class to wrap a Vulkan API call.
827#
828# The user gives a generic callback, |codegenDef|,
829# that takes a CodeGen object and a VulkanAPI object as arguments.
830# codegenDef uses CodeGen along with the VulkanAPI object
831# to generate the function body.
832class VulkanAPIWrapper(object):
833
834    def __init__(self,
835                 customApiPrefix,
836                 extraParameters=None,
837                 returnTypeOverride=None,
838                 codegenDef=None):
839        self.customApiPrefix = customApiPrefix
840        self.extraParameters = extraParameters
841        self.returnTypeOverride = returnTypeOverride
842
843        self.codegen = CodeGen()
844
845        self.definitionFunc = codegenDef
846
847        # Private function
848
849        def makeApiFunc(self, typeInfo, apiName):
850            customApi = copy(typeInfo.apis[apiName])
851            customApi.name = self.customApiPrefix + customApi.name
852            if self.extraParameters is not None:
853                if isinstance(self.extraParameters, list):
854                    customApi.parameters = \
855                        self.extraParameters + customApi.parameters
856                else:
857                    os.abort(
858                        "Type of extra parameters to custom API not valid. Expected list, got %s" % type(
859                            self.extraParameters))
860
861            if self.returnTypeOverride is not None:
862                customApi.retType = self.returnTypeOverride
863            return customApi
864
865        self.makeApi = makeApiFunc
866
867    def setCodegenDef(self, codegenDefFunc):
868        self.definitionFunc = codegenDefFunc
869
870    def makeDecl(self, typeInfo, apiName):
871        return self.codegen.makeFuncProto(
872            self.makeApi(self, typeInfo, apiName)) + ";\n\n"
873
874    def makeDefinition(self, typeInfo, apiName, isStatic=False):
875        vulkanApi = self.makeApi(self, typeInfo, apiName)
876
877        self.codegen.swapCode()
878        self.codegen.beginBlock()
879
880        if self.definitionFunc is None:
881            print("ERROR: No definition found for (%s, %s)" %
882                  (vulkanApi.name, self.customApiPrefix))
883            sys.exit(1)
884
885        self.definitionFunc(self.codegen, vulkanApi)
886
887        self.codegen.endBlock()
888
889        return ("static " if isStatic else "") + self.codegen.makeFuncProto(
890            vulkanApi) + "\n" + self.codegen.swapCode() + "\n"
891
892# Base class for wrapping all Vulkan API objects.  These work with Vulkan
893# Registry generators and have gen* triggers.  They tend to contain
894# VulkanAPIWrapper objects to make it easier to generate the code.
895class VulkanWrapperGenerator(object):
896
897    def __init__(self, module: Module, typeInfo: VulkanTypeInfo):
898        self.module: Module = module
899        self.typeInfo: VulkanTypeInfo = typeInfo
900        self.extensionStructTypes = OrderedDict()
901
902    def onBegin(self):
903        pass
904
905    def onEnd(self):
906        pass
907
908    def onBeginFeature(self, featureName, featureType):
909        pass
910
911    def onFeatureNewCmd(self, cmdName):
912        pass
913
914    def onEndFeature(self):
915        pass
916
917    def onGenType(self, typeInfo, name, alias):
918        category = self.typeInfo.categoryOf(name)
919        if category in ["struct", "union"] and not alias:
920            structInfo = self.typeInfo.structs[name]
921            if structInfo.structExtendsExpr:
922                self.extensionStructTypes[name] = structInfo
923        pass
924
925    def onGenStruct(self, typeInfo, name, alias):
926        pass
927
928    def onGenGroup(self, groupinfo, groupName, alias=None):
929        pass
930
931    def onGenEnum(self, enuminfo, name, alias):
932        pass
933
934    def onGenCmd(self, cmdinfo, name, alias):
935        pass
936
937    # Below Vulkan structure types may correspond to multiple Vulkan structs
938    # due to a conflict between different Vulkan registries. In order to get
939    # the correct Vulkan struct type, we need to check the type of its "root"
940    # struct as well.
941    ROOT_TYPE_MAPPING = {
942        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": {
943            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
944            "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
945            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE",
946            "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
947        },
948        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": {
949            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
950            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkCreateBlobGOOGLE",
951            "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
952        },
953        "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": {
954            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT",
955            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT",
956            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE",
957            "default": "VkRenderPassFragmentDensityMapCreateInfoEXT",
958        },
959    }
960
961    def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None):
962        def readStructType(structTypeName, structVarName, cgen):
963            cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \
964                (structTypeName, "goldfish_vk_struct_type", structVarName))
965
966        def castAsStruct(varName, typeName, const=True):
967            return "reinterpret_cast<%s%s*>(%s)" % \
968                   ("const " if const else "", typeName, varName)
969
970        def doDefaultReturn(cgen):
971            if retType.typeName == "void":
972                cgen.stmt("return")
973            else:
974                cgen.stmt("return (%s)0" % retType.typeName)
975
976        cgen.beginIf("!%s" % triggerVar.paramName)
977        if nullEmit is None:
978            doDefaultReturn(cgen)
979        else:
980            nullEmit(cgen)
981        cgen.endIf()
982
983        readStructType("structType", triggerVar.paramName, cgen)
984
985        cgen.line("switch(structType)")
986        cgen.beginBlock()
987
988        currFeature = None
989
990        for ext in self.extensionStructTypes.values():
991            if not currFeature:
992                cgen.leftline("#ifdef %s" % ext.feature)
993                currFeature = ext.feature
994
995            if currFeature and ext.feature != currFeature:
996                cgen.leftline("#endif")
997                cgen.leftline("#ifdef %s" % ext.feature)
998                currFeature = ext.feature
999
1000            enum = ext.structEnumExpr
1001            cgen.line("case %s:" % enum)
1002            cgen.beginBlock()
1003
1004            if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING:
1005                cgen.line("switch(%s)" % rootTypeVar.paramName)
1006                cgen.beginBlock()
1007                kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum]
1008                for k in kv:
1009                    v = self.extensionStructTypes[kv[k]]
1010                    if k == "default":
1011                        cgen.line("%s:" % k)
1012                    else:
1013                        cgen.line("case %s:" % k)
1014                    cgen.beginBlock()
1015                    castedAccess = castAsStruct(
1016                        triggerVar.paramName, v.name, const=triggerVar.isConst)
1017                    forEachFunc(v, castedAccess, cgen)
1018                    cgen.line("break;")
1019                    cgen.endBlock()
1020                cgen.endBlock()
1021            else:
1022                castedAccess = castAsStruct(
1023                    triggerVar.paramName, ext.name, const=triggerVar.isConst)
1024                forEachFunc(ext, castedAccess, cgen)
1025
1026            if autoBreak:
1027                cgen.stmt("break")
1028            cgen.endBlock()
1029
1030        if currFeature:
1031            cgen.leftline("#endif")
1032
1033        cgen.line("default:")
1034        cgen.beginBlock()
1035        if defaultEmit is None:
1036            doDefaultReturn(cgen)
1037        else:
1038            defaultEmit(cgen)
1039        cgen.endBlock()
1040
1041        cgen.endBlock()
1042
1043    def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False):
1044        currFeature = None
1045
1046        for (i, ext) in enumerate(self.extensionStructTypes.values()):
1047            if doFeatureIfdefs:
1048                if not currFeature:
1049                    cgen.leftline("#ifdef %s" % ext.feature)
1050                    currFeature = ext.feature
1051
1052                if currFeature and ext.feature != currFeature:
1053                    cgen.leftline("#endif")
1054                    cgen.leftline("#ifdef %s" % ext.feature)
1055                    currFeature = ext.feature
1056
1057            forEachFunc(i, ext, cgen)
1058
1059        if doFeatureIfdefs:
1060            if currFeature:
1061                cgen.leftline("#endif")
1062