• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI
17from collections import OrderedDict
18from copy import copy
19
20import os
21import sys
22
23# Class capturing a .cpp file and a .h file (a "C++ module")
24class Module(object):
25
26    def __init__(self, directory, basename, customAbsDir = None, suppress = False, implOnly = False):
27        self.directory = directory
28        self.basename = basename
29
30        self.headerPreamble = ""
31        self.implPreamble = ""
32
33        self.headerPostamble = ""
34        self.implPostamble = ""
35
36        self.headerFileHandle = ""
37        self.implFileHandle = ""
38
39        self.customAbsDir = customAbsDir
40
41        self.suppress = suppress
42
43        self.implOnly = implOnly
44
45    def getMakefileSrcEntry(self):
46        if self.customAbsDir:
47            return self.basename + ".cpp \\\n"
48        dirName = self.directory
49        baseName = self.basename
50        joined = os.path.join(dirName, baseName)
51        return "    " + joined + ".cpp \\\n"
52
53    def getCMakeSrcEntry(self):
54        if self.customAbsDir:
55            return self.basename + ".cpp "
56        dirName = self.directory
57        baseName = self.basename
58        joined = os.path.join(dirName, baseName)
59        return "    " + joined + ".cpp "
60
61    def begin(self, globalDir):
62        if self.suppress:
63            return
64
65        # Create subdirectory, if needed
66        if self.customAbsDir:
67            absDir = self.customAbsDir
68        else:
69            absDir = os.path.join(globalDir, self.directory)
70
71        filename = os.path.join(absDir, self.basename)
72
73        fpHeader = None
74
75        if not self.implOnly:
76            fpHeader = open(filename + ".h", "w", encoding="utf-8")
77
78        fpImpl = open(filename + ".cpp", "w", encoding="utf-8")
79
80        self.headerFileHandle = fpHeader
81        self.implFileHandle = fpImpl
82
83        if not self.implOnly:
84            self.headerFileHandle.write(self.headerPreamble)
85
86        self.implFileHandle.write(self.implPreamble)
87
88    def appendHeader(self, toAppend):
89        if self.suppress:
90            return
91
92        if not self.implOnly:
93            self.headerFileHandle.write(toAppend)
94
95    def appendImpl(self, toAppend):
96        if self.suppress:
97            return
98
99        self.implFileHandle.write(toAppend)
100
101    def end(self):
102        if self.suppress:
103            return
104
105        if not self.implOnly:
106            self.headerFileHandle.write(self.headerPostamble)
107
108        self.implFileHandle.write(self.implPostamble)
109
110        if not self.implOnly:
111            self.headerFileHandle.close()
112
113        self.implFileHandle.close()
114
115# Class capturing a .proto protobuf definition file
116class Proto(object):
117
118    def __init__(self, directory, basename, customAbsDir = None, suppress = False):
119        self.directory = directory
120        self.basename = basename
121        self.customAbsDir = customAbsDir
122
123        self.preamble = ""
124        self.postamble = ""
125
126        self.suppress = suppress
127
128    def getMakefileSrcEntry(self):
129        if self.customAbsDir:
130            return self.basename + ".proto \\\n"
131        dirName = self.directory
132        baseName = self.basename
133        joined = os.path.join(dirName, baseName)
134        return "    " + joined + ".proto \\\n"
135
136    def getCMakeSrcEntry(self):
137        if self.customAbsDir:
138            return self.basename + ".proto "
139
140        dirName = self.directory
141        baseName = self.basename
142        joined = os.path.join(dirName, baseName)
143        return "    " + joined + ".proto "
144
145    def begin(self, globalDir):
146        if self.suppress:
147            return
148
149        # Create subdirectory, if needed
150        if self.customAbsDir:
151            absDir = self.customAbsDir
152        else:
153            absDir = os.path.join(globalDir, self.directory)
154
155        filename = os.path.join(absDir, self.basename)
156
157        fpProto = open(filename + ".proto", "w", encoding="utf-8")
158        self.protoFileHandle = fpProto
159        self.protoFileHandle.write(self.preamble)
160
161    def append(self, toAppend):
162        if self.suppress:
163            return
164
165        self.protoFileHandle.write(toAppend)
166
167    def end(self):
168        if self.suppress:
169            return
170
171        self.protoFileHandle.write(self.postamble)
172        self.protoFileHandle.close()
173
174class CodeGen(object):
175
176    def __init__(self,):
177        self.code = ""
178        self.indentLevel = 0
179        self.gensymCounter = [-1]
180
181    def var(self, prefix="cgen_var"):
182        self.gensymCounter[-1] += 1
183        res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0))
184        return res
185
186    def swapCode(self,):
187        res = "%s" % self.code
188        self.code = ""
189        return res
190
191    def indent(self,extra=0):
192        return "".join("    " * (self.indentLevel + extra))
193
194    def incrIndent(self,):
195        self.indentLevel += 1
196
197    def decrIndent(self,):
198        if self.indentLevel > 0:
199            self.indentLevel -= 1
200
201    def beginBlock(self, bracketPrint=True):
202        if bracketPrint:
203            self.code += self.indent() + "{\n"
204        self.indentLevel += 1
205        self.gensymCounter.append(-1)
206
207    def endBlock(self,bracketPrint=True):
208        self.indentLevel -= 1
209        if bracketPrint:
210            self.code += self.indent() + "}\n"
211        del self.gensymCounter[-1]
212
213    def beginIf(self, cond):
214        self.code += self.indent() + "if (" + cond + ")\n"
215        self.beginBlock()
216
217    def beginElse(self, cond = None):
218        if cond is not None:
219            self.code += \
220                self.indent() + \
221                "else if (" + cond + ")\n"
222        else:
223            self.code += self.indent() + "else\n"
224        self.beginBlock()
225
226    def endElse(self):
227        self.endBlock()
228
229    def endIf(self):
230        self.endBlock()
231
232    def beginSwitch(self, switchvar):
233        self.code += self.indent() + "switch (" + switchvar + ")\n"
234        self.beginBlock()
235
236    def switchCase(self, switchval, blocked = False):
237        self.code += self.indent() + "case %s:" % switchval
238        self.beginBlock(bracketPrint = blocked)
239
240    def switchCaseBreak(self, switchval, blocked = False):
241        self.code += self.indent() + "case %s:" % switchval
242        self.endBlock(bracketPrint = blocked)
243
244    def switchCaseDefault(self, blocked = False):
245        self.code += self.indent() + "default:" % switchval
246        self.beginBlock(bracketPrint = blocked)
247
248    def endSwitch(self):
249        self.endBlock()
250
251    def beginWhile(self, cond):
252        self.code += self.indent() + "while (" + cond + ")\n"
253        self.beginBlock()
254
255    def endWhile(self):
256        self.endBlock()
257
258    def beginFor(self, initial, condition, increment):
259        self.code += \
260            self.indent() + "for (" + \
261            "; ".join([initial, condition, increment]) + \
262            ")\n"
263        self.beginBlock()
264
265    def endFor(self):
266        self.endBlock()
267
268    def beginLoop(self, loopVarType, loopVar, loopInit, loopBound):
269        self.beginFor(
270            "%s %s = %s" % (loopVarType, loopVar, loopInit),
271            "%s < %s" % (loopVar, loopBound),
272            "++%s" % (loopVar))
273
274    def endLoop(self):
275        self.endBlock()
276
277    def stmt(self, code):
278        self.code += self.indent() + code + ";\n"
279
280    def line(self, code):
281        self.code += self.indent() + code + "\n"
282
283    def leftline(self, code):
284        self.code += code + "\n"
285
286    def makeCallExpr(self, funcName, parameters):
287        return funcName + "(%s)" % (", ".join(parameters))
288
289    def funcCall(self, lhs, funcName, parameters):
290        res = self.indent()
291
292        if lhs is not None:
293            res += lhs + " = "
294
295        res += self.makeCallExpr(funcName, parameters) + ";\n"
296        self.code += res
297
298    def funcCallRet(self, _lhs, funcName, parameters):
299        res = self.indent()
300        res += "return " + self.makeCallExpr(funcName, parameters) + ";\n"
301        self.code += res
302
303    # Given a VulkanType object, generate a C type declaration
304    # with optional parameter name:
305    # [const] [typename][*][const*] [paramName]
306    def makeCTypeDecl(self, vulkanType, useParamName=True):
307        constness = "const " if vulkanType.isConst else ""
308        typeName = vulkanType.typeName
309
310        if vulkanType.pointerIndirectionLevels == 0:
311            ptrSpec = ""
312        elif vulkanType.isPointerToConstPointer:
313            ptrSpec = "* const*" if vulkanType.isConst else "**"
314            if vulkanType.pointerIndirectionLevels > 2:
315                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
316        else:
317            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
318
319        if useParamName and (vulkanType.paramName is not None):
320            paramStr = (" " + vulkanType.paramName)
321        else:
322            paramStr = ""
323
324        return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr)
325
326    def makeRichCTypeDecl(self, vulkanType, useParamName=True):
327        constness = "const " if vulkanType.isConst else ""
328        typeName = vulkanType.typeName
329
330        if vulkanType.pointerIndirectionLevels == 0:
331            ptrSpec = ""
332        elif vulkanType.isPointerToConstPointer:
333            ptrSpec = "* const*" if vulkanType.isConst else "**"
334            if vulkanType.pointerIndirectionLevels > 2:
335                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
336        else:
337            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
338
339        if useParamName and (vulkanType.paramName is not None):
340            paramStr = (" " + vulkanType.paramName)
341        else:
342            paramStr = ""
343
344        if vulkanType.staticArrExpr:
345            staticArrInfo = "[%s]" % vulkanType.staticArrExpr
346        else:
347            staticArrInfo = ""
348
349        return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo)
350
351    # Given a VulkanAPI object, generate the C function protype:
352    # <returntype> <funcname>(<parameters>)
353    def makeFuncProto(self, vulkanApi, useParamName=True):
354
355        protoBegin = "%s %s" % (self.makeCTypeDecl(
356            vulkanApi.retType, useParamName=False), vulkanApi.name)
357
358        def getFuncArgDecl(param):
359            if param.staticArrExpr:
360                return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr)
361            else:
362                return self.makeCTypeDecl(param, useParamName=useParamName)
363
364        protoParams = "(\n    %s)" % ((",\n%s" % self.indent(1)).join(
365            list(map(
366                getFuncArgDecl,
367                vulkanApi.parameters))))
368
369        return protoBegin + protoParams
370
371    def makeFuncAlias(self, nameDst, nameSrc):
372        return "DEFINE_ALIAS_FUNCTION({}, {});\n\n".format(nameSrc, nameDst)
373
374    def makeFuncDecl(self, vulkanApi):
375        return self.makeFuncProto(vulkanApi) + ";\n\n"
376
377    def makeFuncImpl(self, vulkanApi, codegenFunc):
378        self.swapCode()
379
380        self.line(self.makeFuncProto(vulkanApi))
381        self.beginBlock()
382        codegenFunc(self)
383        self.endBlock()
384
385        return self.swapCode() + "\n"
386
387    def emitFuncImpl(self, vulkanApi, codegenFunc):
388        self.line(self.makeFuncProto(vulkanApi))
389        self.beginBlock()
390        codegenFunc(self)
391        self.endBlock()
392
393    def makeStructAccess(self,
394                         vulkanType,
395                         structVarName,
396                         asPtr=True,
397                         structAsPtr=True,
398                         accessIndex=None):
399
400        deref = "->" if structAsPtr else "."
401
402        indexExpr = (" + %s" % accessIndex) if accessIndex else ""
403
404        addrOfExpr = "" if vulkanType.accessibleAsPointer() or (
405            not asPtr) else "&"
406
407        return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref,
408                               vulkanType.paramName, indexExpr)
409
410    def makeRawLengthAccess(self, vulkanType):
411        lenExpr = vulkanType.getLengthExpression()
412
413        if not lenExpr:
414            return None, None
415
416        if lenExpr == "null-terminated":
417            return "strlen(%s)" % vulkanType.paramName, None
418
419        return lenExpr, None
420
421    def makeLengthAccessFromStruct(self,
422                                   structInfo,
423                                   vulkanType,
424                                   structVarName,
425                                   asPtr=True):
426        # Handle special cases first
427        # Mostly when latexmath is involved
428        def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr):
429            cases = [
430                {
431                    "structName": "VkShaderModuleCreateInfo",
432                    "field": "pCode",
433                    "lenExprMember": "codeSize",
434                    "postprocess": lambda expr: "(%s / 4)" % expr
435                },
436                {
437                    "structName": "VkPipelineMultisampleStateCreateInfo",
438                    "field": "pSampleMask",
439                    "lenExprMember": "rasterizationSamples",
440                    "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
441                },
442                {
443                    "structName": "VkAccelerationStructureVersionInfoKHR",
444                    "field": "pVersionData",
445                    "lenExprMember": "",
446                    "postprocess": lambda _: "2*VK_UUID_SIZE"
447                },
448            ]
449
450            for c in cases:
451                if (structInfo.name, vulkanType.paramName) == (c["structName"],
452                                                               c["field"]):
453                    deref = "->" if asPtr else "."
454                    expr = "%s%s%s" % (structVarName, deref,
455                                       c["lenExprMember"])
456                    lenAccessGuardExpr = "%s" % structVarName
457                    return c["postprocess"](expr), lenAccessGuardExpr
458
459            return None, None
460
461        specialCaseAccess = \
462            handleSpecialCases(
463                structInfo, vulkanType, structVarName, asPtr)
464
465        if specialCaseAccess != (None, None):
466            return specialCaseAccess
467
468        lenExpr = vulkanType.getLengthExpression()
469
470        if not lenExpr:
471            return None, None
472
473        deref = "->" if asPtr else "."
474        lenAccessGuardExpr = "%s" % (
475
476            structVarName) if deref else None
477        if lenExpr == "null-terminated":
478            return "strlen(%s%s%s)" % (structVarName, deref,
479                                       vulkanType.paramName), lenAccessGuardExpr
480
481        if not structInfo.getMember(lenExpr):
482            return self.makeRawLengthAccess(vulkanType)
483
484        return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr
485
486    def makeLengthAccessFromApi(self, api, vulkanType):
487        # Handle special cases first
488        # Mostly when :: is involved
489        def handleSpecialCases(vulkanType):
490            lenExpr = vulkanType.getLengthExpression()
491
492            if lenExpr is None:
493                return None, None
494
495            if "::" in lenExpr:
496                structVarName, memberVarName = lenExpr.split("::")
497                lenAccessGuardExpr = "%s" % (structVarName)
498                return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr
499            return None, None
500
501        specialCaseAccess = handleSpecialCases(vulkanType)
502
503        if specialCaseAccess != (None, None):
504            return specialCaseAccess
505
506        lenExpr = vulkanType.getLengthExpression()
507
508        if not lenExpr:
509            return None, None
510
511        lenExprInfo = api.getParameter(lenExpr)
512
513        if not lenExprInfo:
514            return self.makeRawLengthAccess(vulkanType)
515
516        if lenExpr == "null-terminated":
517            return "strlen(%s)" % vulkanType.paramName(), None
518        else:
519            deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else ""
520            lenAccessGuardExpr = "%s" % lenExpr if deref else None
521            return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr
522
523    def accessParameter(self, param, asPtr=True):
524        if asPtr:
525            if param.pointerIndirectionLevels > 0:
526                return param.paramName
527            else:
528                return "&%s" % param.paramName
529        else:
530            return param.paramName
531
532    def sizeofExpr(self, vulkanType):
533        return "sizeof(%s)" % (
534            self.makeCTypeDecl(vulkanType, useParamName=False))
535
536    def generalAccess(self,
537                      vulkanType,
538                      parentVarName=None,
539                      asPtr=True,
540                      structAsPtr=True):
541        if vulkanType.parent is None:
542            if parentVarName is None:
543                return self.accessParameter(vulkanType, asPtr=asPtr)
544            else:
545                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
546
547        if isinstance(vulkanType.parent, VulkanCompoundType):
548            return self.makeStructAccess(
549                vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr)
550
551        if isinstance(vulkanType.parent, VulkanAPI):
552            if parentVarName is None:
553                return self.accessParameter(vulkanType, asPtr=asPtr)
554            else:
555                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
556
557        os.abort("Could not find a way to access Vulkan type %s" %
558                 vulkanType.name)
559
560    def makeLengthAccess(self, vulkanType, parentVarName="parent"):
561        if vulkanType.parent is None:
562            return self.makeRawLengthAccess(vulkanType)
563
564        if isinstance(vulkanType.parent, VulkanCompoundType):
565            return self.makeLengthAccessFromStruct(
566                vulkanType.parent, vulkanType, parentVarName, asPtr=True)
567
568        if isinstance(vulkanType.parent, VulkanAPI):
569            return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType)
570
571        os.abort("Could not find a way to access length of Vulkan type %s" %
572                 vulkanType.name)
573
574    def generalLengthAccess(self, vulkanType, parentVarName="parent"):
575        return self.makeLengthAccess(vulkanType, parentVarName)[0]
576
577    def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"):
578        return self.makeLengthAccess(vulkanType, parentVarName)[1]
579
580    def vkApiCall(self, api, customPrefix="", customParameters=None, retVarDecl=True):
581        callLhs = None
582
583        retTypeName = api.getRetTypeExpr()
584        retVar = None
585
586        if retTypeName != "void":
587            retVar = api.getRetVarExpr()
588            if retVarDecl:
589                self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName))
590            callLhs = retVar
591
592        if customParameters is None:
593            self.funcCall(
594            callLhs, customPrefix + api.name, [p.paramName for p in api.parameters])
595        else:
596            self.funcCall(
597                callLhs, customPrefix + api.name, customParameters)
598
599        return (retTypeName, retVar)
600
601    def makeCheckVkSuccess(self, expr):
602        return "((%s) == VK_SUCCESS)" % expr
603
604    def makeReinterpretCast(self, varName, typeName, const=True):
605        return "reinterpret_cast<%s%s*>(%s)" % \
606               ("const " if const else "", typeName, varName)
607
608    def validPrimitive(self, typeInfo, typeName):
609        size = typeInfo.getPrimitiveEncodingSize(typeName)
610        return size != None
611
612    def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"):
613        if not self.validPrimitive(typeInfo, typeName):
614            return None
615
616        size = typeInfo.getPrimitiveEncodingSize(typeName)
617        prefix = "put" if direction == "write" else "get"
618        suffix = None
619        if size == 1:
620            suffix = "Byte"
621        elif size == 2:
622            suffix = "Be16"
623        elif size == 4:
624            suffix = "Be32"
625        elif size == 8:
626            suffix = "Be64"
627
628        if suffix:
629            return prefix + suffix
630
631        return None
632
633    def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"):
634        if not self.validPrimitive(typeInfo, typeName):
635            return None
636
637        size = typeInfo.getPrimitiveEncodingSize(typeName)
638        prefix = "to" if direction == "write" else "from"
639        suffix = None
640        if size == 1:
641            suffix = "Byte"
642        elif size == 2:
643            suffix = "Be16"
644        elif size == 4:
645            suffix = "Be32"
646        elif size == 8:
647            suffix = "Be64"
648
649        if suffix:
650            return prefix + suffix
651
652        return None
653
654    def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
655        accessTypeName = accessType.typeName
656
657        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
658            print("Tried to stream a non-primitive type: %s" % accessTypeName)
659            os.abort()
660
661        needPtrCast = False
662
663        if accessType.pointerIndirectionLevels > 0:
664            streamSize = 8
665            streamStorageVarType = "uint64_t"
666            needPtrCast = True
667            streamMethod = "putBe64" if direction == "write" else "getBe64"
668        else:
669            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
670            if streamSize == 1:
671                streamStorageVarType = "uint8_t"
672            elif streamSize == 2:
673                streamStorageVarType = "uint16_t"
674            elif streamSize == 4:
675                streamStorageVarType = "uint32_t"
676            elif streamSize == 8:
677                streamStorageVarType = "uint64_t"
678            streamMethod = self.makePrimitiveStreamMethod(
679                typeInfo, accessTypeName, direction=direction)
680
681        streamStorageVar = self.var()
682
683        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
684
685        ptrCast = "(uintptr_t)" if needPtrCast else ""
686
687        if direction == "read":
688            self.stmt("%s = (%s)%s%s->%s()" %
689                      (accessExpr,
690                       accessCast,
691                       ptrCast,
692                       streamVar,
693                       streamMethod))
694        else:
695            self.stmt("%s %s = (%s)%s%s" %
696                      (streamStorageVarType, streamStorageVar,
697                       streamStorageVarType, ptrCast, accessExpr))
698            self.stmt("%s->%s(%s)" %
699                      (streamVar, streamMethod, streamStorageVar))
700
701    def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
702        accessTypeName = accessType.typeName
703
704        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
705            print("Tried to stream a non-primitive type: %s" % accessTypeName)
706            os.abort()
707
708        needPtrCast = False
709
710        streamSize = 8
711
712        if accessType.pointerIndirectionLevels > 0:
713            streamSize = 8
714            streamStorageVarType = "uint64_t"
715            needPtrCast = True
716            streamMethod = "toBe64" if direction == "write" else "fromBe64"
717        else:
718            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
719            if streamSize == 1:
720                streamStorageVarType = "uint8_t"
721            elif streamSize == 2:
722                streamStorageVarType = "uint16_t"
723            elif streamSize == 4:
724                streamStorageVarType = "uint32_t"
725            elif streamSize == 8:
726                streamStorageVarType = "uint64_t"
727            streamMethod = self.makePrimitiveStreamMethodInPlace(
728                typeInfo, accessTypeName, direction=direction)
729
730        streamStorageVar = self.var()
731
732        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
733
734        if direction == "read":
735            accessCast = self.makeRichCTypeDecl(
736                accessType.getForNonConstAccess(), useParamName=False)
737
738        ptrCast = "(uintptr_t)" if needPtrCast else ""
739
740        if direction == "read":
741            self.stmt("memcpy((%s*)&%s, %s, %s)" %
742                      (accessCast,
743                       accessExpr,
744                       streamVar,
745                       str(streamSize)))
746            self.stmt("android::base::Stream::%s((uint8_t*)&%s)" % (
747                streamMethod,
748                accessExpr))
749        else:
750            self.stmt("%s %s = (%s)%s%s" %
751                      (streamStorageVarType, streamStorageVar,
752                       streamStorageVarType, ptrCast, accessExpr))
753            self.stmt("memcpy(%s, &%s, %s)" %
754                      (streamVar, streamStorageVar, str(streamSize)))
755            self.stmt("android::base::Stream::%s((uint8_t*)%s)" % (
756                streamMethod,
757                streamVar))
758
759    def countPrimitive(self, typeInfo, accessType):
760        accessTypeName = accessType.typeName
761
762        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
763            print("Tried to count a non-primitive type: %s" % accessTypeName)
764            os.abort()
765
766        needPtrCast = False
767
768        if accessType.pointerIndirectionLevels > 0:
769            streamSize = 8
770        else:
771            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
772
773        return streamSize
774
775# Class to wrap a Vulkan API call.
776#
777# The user gives a generic callback, |codegenDef|,
778# that takes a CodeGen object and a VulkanAPI object as arguments.
779# codegenDef uses CodeGen along with the VulkanAPI object
780# to generate the function body.
781class VulkanAPIWrapper(object):
782
783    def __init__(self,
784                 customApiPrefix,
785                 extraParameters=None,
786                 returnTypeOverride=None,
787                 codegenDef=None):
788        self.customApiPrefix = customApiPrefix
789        self.extraParameters = extraParameters
790        self.returnTypeOverride = returnTypeOverride
791
792        self.codegen = CodeGen()
793
794        self.definitionFunc = codegenDef
795
796        # Private function
797
798        def makeApiFunc(self, typeInfo, apiName):
799            customApi = copy(typeInfo.apis[apiName])
800            customApi.name = self.customApiPrefix + customApi.name
801            if self.extraParameters is not None:
802                if isinstance(self.extraParameters, list):
803                    customApi.parameters = \
804                        self.extraParameters + customApi.parameters
805                else:
806                    os.abort(
807                        "Type of extra parameters to custom API not valid. Expected list, got %s" % type(
808                            self.extraParameters))
809
810            if self.returnTypeOverride is not None:
811                customApi.retType = self.returnTypeOverride
812            return customApi
813
814        self.makeApi = makeApiFunc
815
816    def setCodegenDef(self, codegenDefFunc):
817        self.definitionFunc = codegenDefFunc
818
819    def makeDecl(self, typeInfo, apiName):
820        return self.codegen.makeFuncProto(
821            self.makeApi(self, typeInfo, apiName)) + ";\n\n"
822
823    def makeDefinition(self, typeInfo, apiName, isStatic=False):
824        vulkanApi = self.makeApi(self, typeInfo, apiName)
825
826        self.codegen.swapCode()
827        self.codegen.beginBlock()
828
829        if self.definitionFunc is None:
830            print("ERROR: No definition found for (%s, %s)" %
831                  (vulkanApi.name, self.customApiPrefix))
832            sys.exit(1)
833
834        self.definitionFunc(self.codegen, vulkanApi)
835
836        self.codegen.endBlock()
837
838        return ("static " if isStatic else "") + self.codegen.makeFuncProto(
839            vulkanApi) + "\n" + self.codegen.swapCode() + "\n"
840
841# Base class for wrapping all Vulkan API objects.  These work with Vulkan
842# Registry generators and have gen* triggers.  They tend to contain
843# VulkanAPIWrapper objects to make it easier to generate the code.
844class VulkanWrapperGenerator(object):
845
846    def __init__(self, module, typeInfo):
847        self.module = module
848        self.typeInfo = typeInfo
849        self.extensionStructTypes = OrderedDict()
850
851    def onBegin(self):
852        pass
853
854    def onEnd(self):
855        pass
856
857    def onBeginFeature(self, featureName):
858        pass
859
860    def onEndFeature(self):
861        pass
862
863    def onGenType(self, typeInfo, name, alias):
864        category = self.typeInfo.categoryOf(name)
865        if category in ["struct", "union"] and not alias:
866            structInfo = self.typeInfo.structs[name]
867            if structInfo.structExtendsExpr:
868                self.extensionStructTypes[name] = structInfo
869        pass
870
871    def onGenStruct(self, typeInfo, name, alias):
872        pass
873
874    def onGenGroup(self, groupinfo, groupName, alias=None):
875        pass
876
877    def onGenEnum(self, enuminfo, name, alias):
878        pass
879
880    def onGenCmd(self, cmdinfo, name, alias):
881        pass
882
883    # Below Vulkan structure types may correspond to multiple Vulkan structs
884    # due to a conflict between different Vulkan registries. In order to get
885    # the correct Vulkan struct type, we need to check the type of its "root"
886    # struct as well.
887    ROOT_TYPE_MAPPING = {
888        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": {
889            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
890            "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
891            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE",
892            "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
893        },
894        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": {
895            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
896            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportPhysicalAddressGOOGLE",
897            "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
898        },
899        "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": {
900            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT",
901            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT",
902            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE",
903            "default": "VkRenderPassFragmentDensityMapCreateInfoEXT",
904        },
905    }
906
907    def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None):
908        def readStructType(structTypeName, structVarName, cgen):
909            cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \
910                (structTypeName, "goldfish_vk_struct_type", structVarName))
911
912        def castAsStruct(varName, typeName, const=True):
913            return "reinterpret_cast<%s%s*>(%s)" % \
914                   ("const " if const else "", typeName, varName)
915
916        def doDefaultReturn(cgen):
917            if retType.typeName == "void":
918                cgen.stmt("return")
919            else:
920                cgen.stmt("return (%s)0" % retType.typeName)
921
922        cgen.beginIf("!%s" % triggerVar.paramName)
923        if nullEmit is None:
924            doDefaultReturn(cgen)
925        else:
926            nullEmit(cgen)
927        cgen.endIf()
928
929        readStructType("structType", triggerVar.paramName, cgen)
930
931        cgen.line("switch(structType)")
932        cgen.beginBlock()
933
934        currFeature = None
935
936        for ext in self.extensionStructTypes.values():
937            if not currFeature:
938                cgen.leftline("#ifdef %s" % ext.feature)
939                currFeature = ext.feature
940
941            if currFeature and ext.feature != currFeature:
942                cgen.leftline("#endif")
943                cgen.leftline("#ifdef %s" % ext.feature)
944                currFeature = ext.feature
945
946            enum = ext.structEnumExpr
947            cgen.line("case %s:" % enum)
948            cgen.beginBlock()
949
950            if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING:
951                cgen.line("switch(%s)" % rootTypeVar.paramName)
952                cgen.beginBlock()
953                kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum]
954                for k in kv:
955                    v = self.extensionStructTypes[kv[k]]
956                    if k == "default":
957                        cgen.line("%s:" % k)
958                    else:
959                        cgen.line("case %s:" % k)
960                    cgen.beginBlock()
961                    castedAccess = castAsStruct(
962                        triggerVar.paramName, v.name, const=triggerVar.isConst)
963                    forEachFunc(v, castedAccess, cgen)
964                    cgen.line("break;")
965                    cgen.endBlock()
966                cgen.endBlock()
967            else:
968                castedAccess = castAsStruct(
969                    triggerVar.paramName, ext.name, const=triggerVar.isConst)
970                forEachFunc(ext, castedAccess, cgen)
971
972            if autoBreak:
973                cgen.stmt("break")
974            cgen.endBlock()
975
976        if currFeature:
977            cgen.leftline("#endif")
978
979        cgen.line("default:")
980        cgen.beginBlock()
981        if defaultEmit is None:
982            doDefaultReturn(cgen)
983        else:
984            defaultEmit(cgen)
985        cgen.endBlock()
986
987        cgen.endBlock()
988
989    def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False):
990        currFeature = None
991
992        for (i, ext) in enumerate(self.extensionStructTypes.values()):
993            if doFeatureIfdefs:
994                if not currFeature:
995                    cgen.leftline("#ifdef %s" % ext.feature)
996                    currFeature = ext.feature
997
998                if currFeature and ext.feature != currFeature:
999                    cgen.leftline("#endif")
1000                    cgen.leftline("#ifdef %s" % ext.feature)
1001                    currFeature = ext.feature
1002
1003            forEachFunc(i, ext, cgen)
1004
1005        if doFeatureIfdefs:
1006            if currFeature:
1007                cgen.leftline("#endif")
1008