• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI
17from collections import OrderedDict
18from copy import copy
19from pathlib import Path, PurePosixPath
20
21import os
22import sys
23import shutil
24import subprocess
25
26# Class capturing a single file
27
28
29class SingleFileModule(object):
30    def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False):
31        self.directory = directory
32        self.basename = basename
33        self.customAbsDir = customAbsDir
34        self.suffix = suffix
35        self.file = None
36
37        self.preamble = ""
38        self.postamble = ""
39
40        self.suppress = suppress
41
42    def begin(self, globalDir):
43        if self.suppress:
44            return
45
46        # Create subdirectory, if needed
47        if self.customAbsDir:
48            absDir = self.customAbsDir
49        else:
50            absDir = os.path.join(globalDir, self.directory)
51
52        filename = os.path.join(absDir, self.basename)
53
54        self.file = open(filename + self.suffix, "w", encoding="utf-8")
55        self.file.write(self.preamble)
56
57    def append(self, toAppend):
58        if self.suppress:
59            return
60
61        self.file.write(toAppend)
62
63    def end(self):
64        if self.suppress:
65            return
66
67        self.file.write(self.postamble)
68        self.file.close()
69
70    def getMakefileSrcEntry(self):
71        return ""
72
73    def getCMakeSrcEntry(self):
74        return ""
75
76# Class capturing a .cpp file and a .h file (a "C++ module")
77
78
79class Module(object):
80
81    def __init__(
82            self, directory, basename, customAbsDir=None, suppress=False, implOnly=False,
83            headerOnly=False, suppressFeatureGuards=False):
84        self._headerFileModule = SingleFileModule(
85            ".h", directory, basename, customAbsDir, suppress or implOnly)
86        self._implFileModule = SingleFileModule(
87            ".cpp", directory, basename, customAbsDir, suppress or headerOnly)
88
89        self._headerOnly = headerOnly
90        self._implOnly = implOnly
91
92        self.directory = directory
93        self.basename = basename
94        self._customAbsDir = customAbsDir
95
96        self.suppressFeatureGuards = suppressFeatureGuards
97
98    @property
99    def suppress(self):
100        raise AttributeError("suppress is write only")
101
102    @suppress.setter
103    def suppress(self, value: bool):
104        self._headerFileModule.suppress = self._implOnly or value
105        self._implFileModule.suppress = self._headerOnly or value
106
107    @property
108    def headerPreamble(self) -> str:
109        return self._headerFileModule.preamble
110
111    @headerPreamble.setter
112    def headerPreamble(self, value: str):
113        self._headerFileModule.preamble = value
114
115    @property
116    def headerPostamble(self) -> str:
117        return self._headerFileModule.postamble
118
119    @headerPostamble.setter
120    def headerPostamble(self, value: str):
121        self._headerFileModule.postamble = value
122
123    @property
124    def implPreamble(self) -> str:
125        return self._implFileModule.preamble
126
127    @implPreamble.setter
128    def implPreamble(self, value: str):
129        self._implFileModule.preamble = value
130
131    @property
132    def implPostamble(self) -> str:
133        return self._implFileModule.postamble
134
135    @implPostamble.setter
136    def implPostamble(self, value: str):
137        self._implFileModule.postamble = value
138
139    def getMakefileSrcEntry(self):
140        if self._customAbsDir:
141            return self.basename + ".cpp \\\n"
142        dirName = self.directory
143        baseName = self.basename
144        joined = os.path.join(dirName, baseName)
145        return "    " + joined + ".cpp \\\n"
146
147    def getCMakeSrcEntry(self):
148        if self._customAbsDir:
149            return "\n" + self.basename + ".cpp "
150        dirName = Path(self.directory)
151        baseName = Path(self.basename)
152        joined = PurePosixPath(dirName / baseName)
153        return "\n    " + str(joined) + ".cpp "
154
155    def begin(self, globalDir):
156        self._headerFileModule.begin(globalDir)
157        self._implFileModule.begin(globalDir)
158
159    def appendHeader(self, toAppend):
160        self._headerFileModule.append(toAppend)
161
162    def appendImpl(self, toAppend):
163        self._implFileModule.append(toAppend)
164
165    def end(self):
166        self._headerFileModule.end()
167        self._implFileModule.end()
168
169        clang_format_command = shutil.which('clang-format')
170        assert (clang_format_command is not None)
171
172        def formatFile(filename: Path):
173            assert (subprocess.call([clang_format_command, "-i",
174                    "--style=file", str(filename.resolve())]) == 0)
175
176        if not self._headerFileModule.suppress:
177            formatFile(Path(self._headerFileModule.file.name))
178
179        if not self._implFileModule.suppress:
180            formatFile(Path(self._implFileModule.file.name))
181
182
183class PyScript(SingleFileModule):
184    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
185        super().__init__(".py", directory, basename, customAbsDir, suppress)
186
187
188# Class capturing a .proto protobuf definition file
189class Proto(SingleFileModule):
190
191    def __init__(self, directory, basename, customAbsDir=None, suppress=False):
192        super().__init__(".proto", directory, basename, customAbsDir, suppress)
193
194    def getMakefileSrcEntry(self):
195        super().getMakefileSrcEntry()
196        if self.customAbsDir:
197            return self.basename + ".proto \\\n"
198        dirName = self.directory
199        baseName = self.basename
200        joined = os.path.join(dirName, baseName)
201        return "    " + joined + ".proto \\\n"
202
203    def getCMakeSrcEntry(self):
204        super().getCMakeSrcEntry()
205        if self.customAbsDir:
206            return "\n" + self.basename + ".proto "
207
208        dirName = self.directory
209        baseName = self.basename
210        joined = os.path.join(dirName, baseName)
211        return "\n    " + joined + ".proto "
212
213class CodeGen(object):
214
215    def __init__(self,):
216        self.code = ""
217        self.indentLevel = 0
218        self.gensymCounter = [-1]
219
220    def var(self, prefix="cgen_var"):
221        self.gensymCounter[-1] += 1
222        res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0))
223        return res
224
225    def swapCode(self,):
226        res = "%s" % self.code
227        self.code = ""
228        return res
229
230    def indent(self,extra=0):
231        return "".join("    " * (self.indentLevel + extra))
232
233    def incrIndent(self,):
234        self.indentLevel += 1
235
236    def decrIndent(self,):
237        if self.indentLevel > 0:
238            self.indentLevel -= 1
239
240    def beginBlock(self, bracketPrint=True):
241        if bracketPrint:
242            self.code += self.indent() + "{\n"
243        self.indentLevel += 1
244        self.gensymCounter.append(-1)
245
246    def endBlock(self,bracketPrint=True):
247        self.indentLevel -= 1
248        if bracketPrint:
249            self.code += self.indent() + "}\n"
250        del self.gensymCounter[-1]
251
252    def beginIf(self, cond):
253        self.code += self.indent() + "if (" + cond + ")\n"
254        self.beginBlock()
255
256    def beginElse(self, cond = None):
257        if cond is not None:
258            self.code += \
259                self.indent() + \
260                "else if (" + cond + ")\n"
261        else:
262            self.code += self.indent() + "else\n"
263        self.beginBlock()
264
265    def endElse(self):
266        self.endBlock()
267
268    def endIf(self):
269        self.endBlock()
270
271    def beginSwitch(self, switchvar):
272        self.code += self.indent() + "switch (" + switchvar + ")\n"
273        self.beginBlock()
274
275    def switchCase(self, switchval, blocked = False):
276        self.code += self.indent() + "case %s:" % switchval
277        self.beginBlock(bracketPrint = blocked)
278
279    def switchCaseBreak(self, switchval, blocked = False):
280        self.code += self.indent() + "case %s:" % switchval
281        self.endBlock(bracketPrint = blocked)
282
283    def switchCaseDefault(self, blocked = False):
284        self.code += self.indent() + "default:" % switchval
285        self.beginBlock(bracketPrint = blocked)
286
287    def endSwitch(self):
288        self.endBlock()
289
290    def beginWhile(self, cond):
291        self.code += self.indent() + "while (" + cond + ")\n"
292        self.beginBlock()
293
294    def endWhile(self):
295        self.endBlock()
296
297    def beginFor(self, initial, condition, increment):
298        self.code += \
299            self.indent() + "for (" + \
300            "; ".join([initial, condition, increment]) + \
301            ")\n"
302        self.beginBlock()
303
304    def endFor(self):
305        self.endBlock()
306
307    def beginLoop(self, loopVarType, loopVar, loopInit, loopBound):
308        self.beginFor(
309            "%s %s = %s" % (loopVarType, loopVar, loopInit),
310            "%s < %s" % (loopVar, loopBound),
311            "++%s" % (loopVar))
312
313    def endLoop(self):
314        self.endBlock()
315
316    def stmt(self, code):
317        self.code += self.indent() + code + ";\n"
318
319    def line(self, code):
320        self.code += self.indent() + code + "\n"
321
322    def leftline(self, code):
323        self.code += code + "\n"
324
325    def makeCallExpr(self, funcName, parameters):
326        return funcName + "(%s)" % (", ".join(parameters))
327
328    def funcCall(self, lhs, funcName, parameters):
329        res = self.indent()
330
331        if lhs is not None:
332            res += lhs + " = "
333
334        res += self.makeCallExpr(funcName, parameters) + ";\n"
335        self.code += res
336
337    def funcCallRet(self, _lhs, funcName, parameters):
338        res = self.indent()
339        res += "return " + self.makeCallExpr(funcName, parameters) + ";\n"
340        self.code += res
341
342    # Given a VulkanType object, generate a C type declaration
343    # with optional parameter name:
344    # [const] [typename][*][const*] [paramName]
345    def makeCTypeDecl(self, vulkanType, useParamName=True):
346        constness = "const " if vulkanType.isConst else ""
347        typeName = vulkanType.typeName
348
349        if vulkanType.pointerIndirectionLevels == 0:
350            ptrSpec = ""
351        elif vulkanType.isPointerToConstPointer:
352            ptrSpec = "* const*" if vulkanType.isConst else "**"
353            if vulkanType.pointerIndirectionLevels > 2:
354                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
355        else:
356            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
357
358        if useParamName and (vulkanType.paramName is not None):
359            paramStr = (" " + vulkanType.paramName)
360        else:
361            paramStr = ""
362
363        return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr)
364
365    def makeRichCTypeDecl(self, vulkanType, useParamName=True):
366        constness = "const " if vulkanType.isConst else ""
367        typeName = vulkanType.typeName
368
369        if vulkanType.pointerIndirectionLevels == 0:
370            ptrSpec = ""
371        elif vulkanType.isPointerToConstPointer:
372            ptrSpec = "* const*" if vulkanType.isConst else "**"
373            if vulkanType.pointerIndirectionLevels > 2:
374                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
375        else:
376            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
377
378        if useParamName and (vulkanType.paramName is not None):
379            paramStr = (" " + vulkanType.paramName)
380        else:
381            paramStr = ""
382
383        if vulkanType.staticArrExpr:
384            staticArrInfo = "[%s]" % vulkanType.staticArrExpr
385        else:
386            staticArrInfo = ""
387
388        return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo)
389
390    # Given a VulkanAPI object, generate the C function protype:
391    # <returntype> <funcname>(<parameters>)
392    def makeFuncProto(self, vulkanApi, useParamName=True):
393
394        protoBegin = "%s %s" % (self.makeCTypeDecl(
395            vulkanApi.retType, useParamName=False), vulkanApi.name)
396
397        def getFuncArgDecl(param):
398            if param.staticArrExpr:
399                return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr)
400            else:
401                return self.makeCTypeDecl(param, useParamName=useParamName)
402
403        protoParams = "(\n    %s)" % ((",\n%s" % self.indent(1)).join(
404            list(map(
405                getFuncArgDecl,
406                vulkanApi.parameters))))
407
408        return protoBegin + protoParams
409
410    def makeFuncAlias(self, nameDst, nameSrc):
411        return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst)
412
413    def makeFuncDecl(self, vulkanApi):
414        return self.makeFuncProto(vulkanApi) + ";\n\n"
415
416    def makeFuncImpl(self, vulkanApi, codegenFunc):
417        self.swapCode()
418
419        self.line(self.makeFuncProto(vulkanApi))
420        self.beginBlock()
421        codegenFunc(self)
422        self.endBlock()
423
424        return self.swapCode() + "\n"
425
426    def emitFuncImpl(self, vulkanApi, codegenFunc):
427        self.line(self.makeFuncProto(vulkanApi))
428        self.beginBlock()
429        codegenFunc(self)
430        self.endBlock()
431
432    def makeStructAccess(self,
433                         vulkanType,
434                         structVarName,
435                         asPtr=True,
436                         structAsPtr=True,
437                         accessIndex=None):
438
439        deref = "->" if structAsPtr else "."
440
441        indexExpr = (" + %s" % accessIndex) if accessIndex else ""
442
443        addrOfExpr = "" if vulkanType.accessibleAsPointer() or (
444            not asPtr) else "&"
445
446        return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref,
447                               vulkanType.paramName, indexExpr)
448
449    def makeRawLengthAccess(self, vulkanType):
450        lenExpr = vulkanType.getLengthExpression()
451
452        if not lenExpr:
453            return None, None
454
455        if lenExpr == "null-terminated":
456            return "strlen(%s)" % vulkanType.paramName, None
457
458        return lenExpr, None
459
460    def makeLengthAccessFromStruct(self,
461                                   structInfo,
462                                   vulkanType,
463                                   structVarName,
464                                   asPtr=True):
465        # Handle special cases first
466        # Mostly when latexmath is involved
467        def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr):
468            cases = [
469                {
470                    "structName": "VkShaderModuleCreateInfo",
471                    "field": "pCode",
472                    "lenExprMember": "codeSize",
473                    "postprocess": lambda expr: "(%s / 4)" % expr
474                },
475                {
476                    "structName": "VkPipelineMultisampleStateCreateInfo",
477                    "field": "pSampleMask",
478                    "lenExprMember": "rasterizationSamples",
479                    "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
480                },
481                {
482                    "structName": "VkAccelerationStructureVersionInfoKHR",
483                    "field": "pVersionData",
484                    "lenExprMember": "",
485                    "postprocess": lambda _: "2*VK_UUID_SIZE"
486                },
487            ]
488
489            for c in cases:
490                if (structInfo.name, vulkanType.paramName) == (c["structName"],
491                                                               c["field"]):
492                    deref = "->" if asPtr else "."
493                    expr = "%s%s%s" % (structVarName, deref,
494                                       c["lenExprMember"])
495                    lenAccessGuardExpr = "%s" % structVarName
496                    return c["postprocess"](expr), lenAccessGuardExpr
497
498            return None, None
499
500        specialCaseAccess = \
501            handleSpecialCases(
502                structInfo, vulkanType, structVarName, asPtr)
503
504        if specialCaseAccess != (None, None):
505            return specialCaseAccess
506
507        lenExpr = vulkanType.getLengthExpression()
508
509        if not lenExpr:
510            return None, None
511
512        deref = "->" if asPtr else "."
513        lenAccessGuardExpr = "%s" % (
514
515            structVarName) if deref else None
516        if lenExpr == "null-terminated":
517            return "strlen(%s%s%s)" % (structVarName, deref,
518                                       vulkanType.paramName), lenAccessGuardExpr
519
520        if not structInfo.getMember(lenExpr):
521            return self.makeRawLengthAccess(vulkanType)
522
523        return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr
524
525    def makeLengthAccessFromApi(self, api, vulkanType):
526        # Handle special cases first
527        # Mostly when :: is involved
528        def handleSpecialCases(vulkanType):
529            lenExpr = vulkanType.getLengthExpression()
530
531            if lenExpr is None:
532                return None, None
533
534            if "::" in lenExpr:
535                structVarName, memberVarName = lenExpr.split("::")
536                lenAccessGuardExpr = "%s" % (structVarName)
537                return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr
538            return None, None
539
540        specialCaseAccess = handleSpecialCases(vulkanType)
541
542        if specialCaseAccess != (None, None):
543            return specialCaseAccess
544
545        lenExpr = vulkanType.getLengthExpression()
546
547        if not lenExpr:
548            return None, None
549
550        lenExprInfo = api.getParameter(lenExpr)
551
552        if not lenExprInfo:
553            return self.makeRawLengthAccess(vulkanType)
554
555        if lenExpr == "null-terminated":
556            return "strlen(%s)" % vulkanType.paramName(), None
557        else:
558            deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else ""
559            lenAccessGuardExpr = "%s" % lenExpr if deref else None
560            return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr
561
562    def accessParameter(self, param, asPtr=True):
563        if asPtr:
564            if param.pointerIndirectionLevels > 0:
565                return param.paramName
566            else:
567                return "&%s" % param.paramName
568        else:
569            return param.paramName
570
571    def sizeofExpr(self, vulkanType):
572        return "sizeof(%s)" % (
573            self.makeCTypeDecl(vulkanType, useParamName=False))
574
575    def generalAccess(self,
576                      vulkanType,
577                      parentVarName=None,
578                      asPtr=True,
579                      structAsPtr=True):
580        if vulkanType.parent is None:
581            if parentVarName is None:
582                return self.accessParameter(vulkanType, asPtr=asPtr)
583            else:
584                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
585
586        if isinstance(vulkanType.parent, VulkanCompoundType):
587            return self.makeStructAccess(
588                vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr)
589
590        if isinstance(vulkanType.parent, VulkanAPI):
591            if parentVarName is None:
592                return self.accessParameter(vulkanType, asPtr=asPtr)
593            else:
594                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
595
596        os.abort("Could not find a way to access Vulkan type %s" %
597                 vulkanType.name)
598
599    def makeLengthAccess(self, vulkanType, parentVarName="parent"):
600        if vulkanType.parent is None:
601            return self.makeRawLengthAccess(vulkanType)
602
603        if isinstance(vulkanType.parent, VulkanCompoundType):
604            return self.makeLengthAccessFromStruct(
605                vulkanType.parent, vulkanType, parentVarName, asPtr=True)
606
607        if isinstance(vulkanType.parent, VulkanAPI):
608            return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType)
609
610        os.abort("Could not find a way to access length of Vulkan type %s" %
611                 vulkanType.name)
612
613    def generalLengthAccess(self, vulkanType, parentVarName="parent"):
614        return self.makeLengthAccess(vulkanType, parentVarName)[0]
615
616    def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"):
617        return self.makeLengthAccess(vulkanType, parentVarName)[1]
618
619    def vkApiCall(self, api, customPrefix="", globalStatePrefix="", customParameters=None, checkForDeviceLost=False):
620        callLhs = None
621
622        retTypeName = api.getRetTypeExpr()
623        retVar = None
624
625        if retTypeName != "void":
626            retVar = api.getRetVarExpr()
627            self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName))
628            callLhs = retVar
629
630        if customParameters is None:
631            self.funcCall(
632            callLhs, customPrefix + api.name, [p.paramName for p in api.parameters])
633        else:
634            self.funcCall(
635                callLhs, customPrefix + api.name, customParameters)
636
637        if retTypeName == "VkResult" and checkForDeviceLost:
638            self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix))
639
640        return (retTypeName, retVar)
641
642    def makeCheckVkSuccess(self, expr):
643        return "((%s) == VK_SUCCESS)" % expr
644
645    def makeReinterpretCast(self, varName, typeName, const=True):
646        return "reinterpret_cast<%s%s*>(%s)" % \
647               ("const " if const else "", typeName, varName)
648
649    def validPrimitive(self, typeInfo, typeName):
650        size = typeInfo.getPrimitiveEncodingSize(typeName)
651        return size != None
652
653    def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"):
654        if not self.validPrimitive(typeInfo, typeName):
655            return None
656
657        size = typeInfo.getPrimitiveEncodingSize(typeName)
658        prefix = "put" if direction == "write" else "get"
659        suffix = None
660        if size == 1:
661            suffix = "Byte"
662        elif size == 2:
663            suffix = "Be16"
664        elif size == 4:
665            suffix = "Be32"
666        elif size == 8:
667            suffix = "Be64"
668
669        if suffix:
670            return prefix + suffix
671
672        return None
673
674    def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"):
675        if not self.validPrimitive(typeInfo, typeName):
676            return None
677
678        size = typeInfo.getPrimitiveEncodingSize(typeName)
679        prefix = "to" if direction == "write" else "from"
680        suffix = None
681        if size == 1:
682            suffix = "Byte"
683        elif size == 2:
684            suffix = "Be16"
685        elif size == 4:
686            suffix = "Be32"
687        elif size == 8:
688            suffix = "Be64"
689
690        if suffix:
691            return prefix + suffix
692
693        return None
694
695    def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
696        accessTypeName = accessType.typeName
697
698        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
699            print("Tried to stream a non-primitive type: %s" % accessTypeName)
700            os.abort()
701
702        needPtrCast = False
703
704        if accessType.pointerIndirectionLevels > 0:
705            streamSize = 8
706            streamStorageVarType = "uint64_t"
707            needPtrCast = True
708            streamMethod = "putBe64" if direction == "write" else "getBe64"
709        else:
710            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
711            if streamSize == 1:
712                streamStorageVarType = "uint8_t"
713            elif streamSize == 2:
714                streamStorageVarType = "uint16_t"
715            elif streamSize == 4:
716                streamStorageVarType = "uint32_t"
717            elif streamSize == 8:
718                streamStorageVarType = "uint64_t"
719            streamMethod = self.makePrimitiveStreamMethod(
720                typeInfo, accessTypeName, direction=direction)
721
722        streamStorageVar = self.var()
723
724        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
725
726        ptrCast = "(uintptr_t)" if needPtrCast else ""
727
728        if direction == "read":
729            self.stmt("%s = (%s)%s%s->%s()" %
730                      (accessExpr,
731                       accessCast,
732                       ptrCast,
733                       streamVar,
734                       streamMethod))
735        else:
736            self.stmt("%s %s = (%s)%s%s" %
737                      (streamStorageVarType, streamStorageVar,
738                       streamStorageVarType, ptrCast, accessExpr))
739            self.stmt("%s->%s(%s)" %
740                      (streamVar, streamMethod, streamStorageVar))
741
742    def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
743        accessTypeName = accessType.typeName
744
745        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
746            print("Tried to stream a non-primitive type: %s" % accessTypeName)
747            os.abort()
748
749        needPtrCast = False
750
751        streamSize = 8
752
753        if accessType.pointerIndirectionLevels > 0:
754            streamSize = 8
755            streamStorageVarType = "uint64_t"
756            needPtrCast = True
757            streamMethod = "toBe64" if direction == "write" else "fromBe64"
758        else:
759            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
760            if streamSize == 1:
761                streamStorageVarType = "uint8_t"
762            elif streamSize == 2:
763                streamStorageVarType = "uint16_t"
764            elif streamSize == 4:
765                streamStorageVarType = "uint32_t"
766            elif streamSize == 8:
767                streamStorageVarType = "uint64_t"
768            streamMethod = self.makePrimitiveStreamMethodInPlace(
769                typeInfo, accessTypeName, direction=direction)
770
771        streamStorageVar = self.var()
772
773        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
774
775        if direction == "read":
776            accessCast = self.makeRichCTypeDecl(
777                accessType.getForNonConstAccess(), useParamName=False)
778
779        ptrCast = "(uintptr_t)" if needPtrCast else ""
780
781        if direction == "read":
782            self.stmt("memcpy((%s*)&%s, %s, %s)" %
783                      (accessCast,
784                       accessExpr,
785                       streamVar,
786                       str(streamSize)))
787            self.stmt("android::base::Stream::%s((uint8_t*)&%s)" % (
788                streamMethod,
789                accessExpr))
790        else:
791            self.stmt("%s %s = (%s)%s%s" %
792                      (streamStorageVarType, streamStorageVar,
793                       streamStorageVarType, ptrCast, accessExpr))
794            self.stmt("memcpy(%s, &%s, %s)" %
795                      (streamVar, streamStorageVar, str(streamSize)))
796            self.stmt("android::base::Stream::%s((uint8_t*)%s)" % (
797                streamMethod,
798                streamVar))
799
800    def countPrimitive(self, typeInfo, accessType):
801        accessTypeName = accessType.typeName
802
803        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
804            print("Tried to count a non-primitive type: %s" % accessTypeName)
805            os.abort()
806
807        needPtrCast = False
808
809        if accessType.pointerIndirectionLevels > 0:
810            streamSize = 8
811        else:
812            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
813
814        return streamSize
815
816# Class to wrap a Vulkan API call.
817#
818# The user gives a generic callback, |codegenDef|,
819# that takes a CodeGen object and a VulkanAPI object as arguments.
820# codegenDef uses CodeGen along with the VulkanAPI object
821# to generate the function body.
822class VulkanAPIWrapper(object):
823
824    def __init__(self,
825                 customApiPrefix,
826                 extraParameters=None,
827                 returnTypeOverride=None,
828                 codegenDef=None):
829        self.customApiPrefix = customApiPrefix
830        self.extraParameters = extraParameters
831        self.returnTypeOverride = returnTypeOverride
832
833        self.codegen = CodeGen()
834
835        self.definitionFunc = codegenDef
836
837        # Private function
838
839        def makeApiFunc(self, typeInfo, apiName):
840            customApi = copy(typeInfo.apis[apiName])
841            customApi.name = self.customApiPrefix + customApi.name
842            if self.extraParameters is not None:
843                if isinstance(self.extraParameters, list):
844                    customApi.parameters = \
845                        self.extraParameters + customApi.parameters
846                else:
847                    os.abort(
848                        "Type of extra parameters to custom API not valid. Expected list, got %s" % type(
849                            self.extraParameters))
850
851            if self.returnTypeOverride is not None:
852                customApi.retType = self.returnTypeOverride
853            return customApi
854
855        self.makeApi = makeApiFunc
856
857    def setCodegenDef(self, codegenDefFunc):
858        self.definitionFunc = codegenDefFunc
859
860    def makeDecl(self, typeInfo, apiName):
861        return self.codegen.makeFuncProto(
862            self.makeApi(self, typeInfo, apiName)) + ";\n\n"
863
864    def makeDefinition(self, typeInfo, apiName, isStatic=False):
865        vulkanApi = self.makeApi(self, typeInfo, apiName)
866
867        self.codegen.swapCode()
868        self.codegen.beginBlock()
869
870        if self.definitionFunc is None:
871            print("ERROR: No definition found for (%s, %s)" %
872                  (vulkanApi.name, self.customApiPrefix))
873            sys.exit(1)
874
875        self.definitionFunc(self.codegen, vulkanApi)
876
877        self.codegen.endBlock()
878
879        return ("static " if isStatic else "") + self.codegen.makeFuncProto(
880            vulkanApi) + "\n" + self.codegen.swapCode() + "\n"
881
882# Base class for wrapping all Vulkan API objects.  These work with Vulkan
883# Registry generators and have gen* triggers.  They tend to contain
884# VulkanAPIWrapper objects to make it easier to generate the code.
885class VulkanWrapperGenerator(object):
886
887    def __init__(self, module: Module, typeInfo: VulkanTypeInfo):
888        self.module: Module = module
889        self.typeInfo: VulkanTypeInfo = typeInfo
890        self.extensionStructTypes = OrderedDict()
891
892    def onBegin(self):
893        pass
894
895    def onEnd(self):
896        pass
897
898    def onBeginFeature(self, featureName, featureType):
899        pass
900
901    def onFeatureNewCmd(self, cmdName):
902        pass
903
904    def onEndFeature(self):
905        pass
906
907    def onGenType(self, typeInfo, name, alias):
908        category = self.typeInfo.categoryOf(name)
909        if category in ["struct", "union"] and not alias:
910            structInfo = self.typeInfo.structs[name]
911            if structInfo.structExtendsExpr:
912                self.extensionStructTypes[name] = structInfo
913        pass
914
915    def onGenStruct(self, typeInfo, name, alias):
916        pass
917
918    def onGenGroup(self, groupinfo, groupName, alias=None):
919        pass
920
921    def onGenEnum(self, enuminfo, name, alias):
922        pass
923
924    def onGenCmd(self, cmdinfo, name, alias):
925        pass
926
927    # Below Vulkan structure types may correspond to multiple Vulkan structs
928    # due to a conflict between different Vulkan registries. In order to get
929    # the correct Vulkan struct type, we need to check the type of its "root"
930    # struct as well.
931    ROOT_TYPE_MAPPING = {
932        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": {
933            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
934            "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
935            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE",
936            "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
937        },
938        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": {
939            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
940            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportPhysicalAddressGOOGLE",
941            "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
942        },
943        "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": {
944            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT",
945            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT",
946            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE",
947            "default": "VkRenderPassFragmentDensityMapCreateInfoEXT",
948        },
949    }
950
951    def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None):
952        def readStructType(structTypeName, structVarName, cgen):
953            cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \
954                (structTypeName, "goldfish_vk_struct_type", structVarName))
955
956        def castAsStruct(varName, typeName, const=True):
957            return "reinterpret_cast<%s%s*>(%s)" % \
958                   ("const " if const else "", typeName, varName)
959
960        def doDefaultReturn(cgen):
961            if retType.typeName == "void":
962                cgen.stmt("return")
963            else:
964                cgen.stmt("return (%s)0" % retType.typeName)
965
966        cgen.beginIf("!%s" % triggerVar.paramName)
967        if nullEmit is None:
968            doDefaultReturn(cgen)
969        else:
970            nullEmit(cgen)
971        cgen.endIf()
972
973        readStructType("structType", triggerVar.paramName, cgen)
974
975        cgen.line("switch(structType)")
976        cgen.beginBlock()
977
978        currFeature = None
979
980        for ext in self.extensionStructTypes.values():
981            if not currFeature:
982                cgen.leftline("#ifdef %s" % ext.feature)
983                currFeature = ext.feature
984
985            if currFeature and ext.feature != currFeature:
986                cgen.leftline("#endif")
987                cgen.leftline("#ifdef %s" % ext.feature)
988                currFeature = ext.feature
989
990            enum = ext.structEnumExpr
991            cgen.line("case %s:" % enum)
992            cgen.beginBlock()
993
994            if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING:
995                cgen.line("switch(%s)" % rootTypeVar.paramName)
996                cgen.beginBlock()
997                kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum]
998                for k in kv:
999                    v = self.extensionStructTypes[kv[k]]
1000                    if k == "default":
1001                        cgen.line("%s:" % k)
1002                    else:
1003                        cgen.line("case %s:" % k)
1004                    cgen.beginBlock()
1005                    castedAccess = castAsStruct(
1006                        triggerVar.paramName, v.name, const=triggerVar.isConst)
1007                    forEachFunc(v, castedAccess, cgen)
1008                    cgen.line("break;")
1009                    cgen.endBlock()
1010                cgen.endBlock()
1011            else:
1012                castedAccess = castAsStruct(
1013                    triggerVar.paramName, ext.name, const=triggerVar.isConst)
1014                forEachFunc(ext, castedAccess, cgen)
1015
1016            if autoBreak:
1017                cgen.stmt("break")
1018            cgen.endBlock()
1019
1020        if currFeature:
1021            cgen.leftline("#endif")
1022
1023        cgen.line("default:")
1024        cgen.beginBlock()
1025        if defaultEmit is None:
1026            doDefaultReturn(cgen)
1027        else:
1028            defaultEmit(cgen)
1029        cgen.endBlock()
1030
1031        cgen.endBlock()
1032
1033    def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False):
1034        currFeature = None
1035
1036        for (i, ext) in enumerate(self.extensionStructTypes.values()):
1037            if doFeatureIfdefs:
1038                if not currFeature:
1039                    cgen.leftline("#ifdef %s" % ext.feature)
1040                    currFeature = ext.feature
1041
1042                if currFeature and ext.feature != currFeature:
1043                    cgen.leftline("#endif")
1044                    cgen.leftline("#ifdef %s" % ext.feature)
1045                    currFeature = ext.feature
1046
1047            forEachFunc(i, ext, cgen)
1048
1049        if doFeatureIfdefs:
1050            if currFeature:
1051                cgen.leftline("#endif")
1052