• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI
17from collections import OrderedDict
18from copy import copy
19
20import os
21import sys
22
23# Class capturing a .cpp file and a .h file (a "C++ module")
24class Module(object):
25
26    def __init__(self, directory, basename, customAbsDir = None, suppress = False, implOnly = False):
27        self.directory = directory
28        self.basename = basename
29
30        self.headerPreamble = ""
31        self.implPreamble = ""
32
33        self.headerPostamble = ""
34        self.implPostamble = ""
35
36        self.headerFileHandle = ""
37        self.implFileHandle = ""
38
39        self.customAbsDir = customAbsDir
40
41        self.suppress = suppress
42
43        self.implOnly = implOnly
44
45    def getMakefileSrcEntry(self):
46        if self.customAbsDir:
47            return self.basename + ".cpp \\\n"
48        dirName = self.directory
49        baseName = self.basename
50        joined = os.path.join(dirName, baseName).replace("\\", "/")
51        return "    " + joined + ".cpp \\\n"
52
53    def getCMakeSrcEntry(self):
54        if self.customAbsDir:
55            return self.basename + ".cpp "
56        dirName = self.directory
57        baseName = self.basename
58        joined = os.path.join(dirName, baseName).replace("\\", "/")
59        return "    " + joined + ".cpp "
60
61    def begin(self, globalDir):
62        if self.suppress:
63            return
64
65        # Create subdirectory, if needed
66        if self.customAbsDir:
67            absDir = self.customAbsDir
68        else:
69            absDir = os.path.join(globalDir, self.directory)
70
71        filename = os.path.join(absDir, self.basename)
72
73        fpHeader = None
74
75        if not self.implOnly:
76            fpHeader = open(filename + ".h", "w", encoding="utf-8")
77
78        fpImpl = open(filename + ".cpp", "w", encoding="utf-8")
79
80        self.headerFileHandle = fpHeader
81        self.implFileHandle = fpImpl
82
83        if not self.implOnly:
84            self.headerFileHandle.write(self.headerPreamble)
85
86        self.implFileHandle.write(self.implPreamble)
87
88    def appendHeader(self, toAppend):
89        if self.suppress:
90            return
91
92        if not self.implOnly:
93            self.headerFileHandle.write(toAppend)
94
95    def appendImpl(self, toAppend):
96        if self.suppress:
97            return
98
99        self.implFileHandle.write(toAppend)
100
101    def end(self):
102        if self.suppress:
103            return
104
105        if not self.implOnly:
106            self.headerFileHandle.write(self.headerPostamble)
107
108        self.implFileHandle.write(self.implPostamble)
109
110        if not self.implOnly:
111            self.headerFileHandle.close()
112
113        self.implFileHandle.close()
114
115# Class capturing a .proto protobuf definition file
116class Proto(object):
117
118    def __init__(self, directory, basename, customAbsDir = None, suppress = False):
119        self.directory = directory
120        self.basename = basename
121        self.customAbsDir = customAbsDir
122
123        self.preamble = ""
124        self.postamble = ""
125
126        self.suppress = suppress
127
128    def getMakefileSrcEntry(self):
129        if self.customAbsDir:
130            return self.basename + ".proto \\\n"
131        dirName = self.directory
132        baseName = self.basename
133        joined = os.path.join(dirName, baseName)
134        return "    " + joined + ".proto \\\n"
135
136    def getCMakeSrcEntry(self):
137        if self.customAbsDir:
138            return self.basename + ".proto "
139
140        dirName = self.directory
141        baseName = self.basename
142        joined = os.path.join(dirName, baseName)
143        return "    " + joined + ".proto "
144
145    def begin(self, globalDir):
146        if self.suppress:
147            return
148
149        # Create subdirectory, if needed
150        if self.customAbsDir:
151            absDir = self.customAbsDir
152        else:
153            absDir = os.path.join(globalDir, self.directory)
154
155        filename = os.path.join(absDir, self.basename)
156
157        fpProto = open(filename + ".proto", "w", encoding="utf-8")
158        self.protoFileHandle = fpProto
159        self.protoFileHandle.write(self.preamble)
160
161    def append(self, toAppend):
162        if self.suppress:
163            return
164
165        self.protoFileHandle.write(toAppend)
166
167    def end(self):
168        if self.suppress:
169            return
170
171        self.protoFileHandle.write(self.postamble)
172        self.protoFileHandle.close()
173
174class CodeGen(object):
175
176    def __init__(self,):
177        self.code = ""
178        self.indentLevel = 0
179        self.gensymCounter = [-1]
180
181    def var(self, prefix="cgen_var"):
182        self.gensymCounter[-1] += 1
183        res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0))
184        return res
185
186    def swapCode(self,):
187        res = "%s" % self.code
188        self.code = ""
189        return res
190
191    def indent(self,extra=0):
192        return "".join("    " * (self.indentLevel + extra))
193
194    def incrIndent(self,):
195        self.indentLevel += 1
196
197    def decrIndent(self,):
198        if self.indentLevel > 0:
199            self.indentLevel -= 1
200
201    def beginBlock(self, bracketPrint=True):
202        if bracketPrint:
203            self.code += self.indent() + "{\n"
204        self.indentLevel += 1
205        self.gensymCounter.append(-1)
206
207    def endBlock(self,bracketPrint=True):
208        self.indentLevel -= 1
209        if bracketPrint:
210            self.code += self.indent() + "}\n"
211        del self.gensymCounter[-1]
212
213    def beginIf(self, cond):
214        self.code += self.indent() + "if (" + cond + ")\n"
215        self.beginBlock()
216
217    def beginElse(self, cond = None):
218        if cond is not None:
219            self.code += \
220                self.indent() + \
221                "else if (" + cond + ")\n"
222        else:
223            self.code += self.indent() + "else\n"
224        self.beginBlock()
225
226    def endElse(self):
227        self.endBlock()
228
229    def endIf(self):
230        self.endBlock()
231
232    def beginSwitch(self, switchvar):
233        self.code += self.indent() + "switch (" + switchvar + ")\n"
234        self.beginBlock()
235
236    def switchCase(self, switchval, blocked = False):
237        self.code += self.indent() + "case %s:" % switchval
238        self.beginBlock(bracketPrint = blocked)
239
240    def switchCaseBreak(self, switchval, blocked = False):
241        self.code += self.indent() + "case %s:" % switchval
242        self.endBlock(bracketPrint = blocked)
243
244    def switchCaseDefault(self, blocked = False):
245        self.code += self.indent() + "default:" % switchval
246        self.beginBlock(bracketPrint = blocked)
247
248    def endSwitch(self):
249        self.endBlock()
250
251    def beginWhile(self, cond):
252        self.code += self.indent() + "while (" + cond + ")\n"
253        self.beginBlock()
254
255    def endWhile(self):
256        self.endBlock()
257
258    def beginFor(self, initial, condition, increment):
259        self.code += \
260            self.indent() + "for (" + \
261            "; ".join([initial, condition, increment]) + \
262            ")\n"
263        self.beginBlock()
264
265    def endFor(self):
266        self.endBlock()
267
268    def beginLoop(self, loopVarType, loopVar, loopInit, loopBound):
269        self.beginFor(
270            "%s %s = %s" % (loopVarType, loopVar, loopInit),
271            "%s < %s" % (loopVar, loopBound),
272            "++%s" % (loopVar))
273
274    def endLoop(self):
275        self.endBlock()
276
277    def stmt(self, code):
278        self.code += self.indent() + code + ";\n"
279
280    def line(self, code):
281        self.code += self.indent() + code + "\n"
282
283    def leftline(self, code):
284        self.code += code + "\n"
285
286    def makeCallExpr(self, funcName, parameters):
287        return funcName + "(%s)" % (", ".join(parameters))
288
289    def funcCall(self, lhs, funcName, parameters):
290        res = self.indent()
291
292        if lhs is not None:
293            res += lhs + " = "
294
295        res += self.makeCallExpr(funcName, parameters) + ";\n"
296        self.code += res
297
298    def funcCallRet(self, _lhs, funcName, parameters):
299        res = self.indent()
300        res += "return " + self.makeCallExpr(funcName, parameters) + ";\n"
301        self.code += res
302
303    # Given a VulkanType object, generate a C type declaration
304    # with optional parameter name:
305    # [const] [typename][*][const*] [paramName]
306    def makeCTypeDecl(self, vulkanType, useParamName=True):
307        constness = "const " if vulkanType.isConst else ""
308        typeName = vulkanType.typeName
309
310        if vulkanType.pointerIndirectionLevels == 0:
311            ptrSpec = ""
312        elif vulkanType.isPointerToConstPointer:
313            ptrSpec = "* const*" if vulkanType.isConst else "**"
314            if vulkanType.pointerIndirectionLevels > 2:
315                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
316        else:
317            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
318
319        if useParamName and (vulkanType.paramName is not None):
320            paramStr = (" " + vulkanType.paramName)
321        else:
322            paramStr = ""
323
324        return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr)
325
326    def makeRichCTypeDecl(self, vulkanType, useParamName=True):
327        constness = "const " if vulkanType.isConst else ""
328        typeName = vulkanType.typeName
329
330        if vulkanType.pointerIndirectionLevels == 0:
331            ptrSpec = ""
332        elif vulkanType.isPointerToConstPointer:
333            ptrSpec = "* const*" if vulkanType.isConst else "**"
334            if vulkanType.pointerIndirectionLevels > 2:
335                ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2)
336        else:
337            ptrSpec = "*" * vulkanType.pointerIndirectionLevels
338
339        if useParamName and (vulkanType.paramName is not None):
340            paramStr = (" " + vulkanType.paramName)
341        else:
342            paramStr = ""
343
344        if vulkanType.staticArrExpr:
345            staticArrInfo = "[%s]" % vulkanType.staticArrExpr
346        else:
347            staticArrInfo = ""
348
349        return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo)
350
351    # Given a VulkanAPI object, generate the C function protype:
352    # <returntype> <funcname>(<parameters>)
353    def makeFuncProto(self, vulkanApi, useParamName=True):
354
355        protoBegin = "%s %s" % (self.makeCTypeDecl(
356            vulkanApi.retType, useParamName=False), vulkanApi.name)
357
358        def getFuncArgDecl(param):
359            if param.staticArrExpr:
360                return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr)
361            else:
362                return self.makeCTypeDecl(param, useParamName=useParamName)
363
364        protoParams = "(\n    %s)" % ((",\n%s" % self.indent(1)).join(
365            list(map(
366                getFuncArgDecl,
367                vulkanApi.parameters))))
368
369        return protoBegin + protoParams
370
371    def makeFuncAlias(self, nameDst, nameSrc):
372        return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst)
373
374    def makeFuncDecl(self, vulkanApi):
375        return self.makeFuncProto(vulkanApi) + ";\n\n"
376
377    def makeFuncImpl(self, vulkanApi, codegenFunc):
378        self.swapCode()
379
380        self.line(self.makeFuncProto(vulkanApi))
381        self.beginBlock()
382        codegenFunc(self)
383        self.endBlock()
384
385        return self.swapCode() + "\n"
386
387    def emitFuncImpl(self, vulkanApi, codegenFunc):
388        self.line(self.makeFuncProto(vulkanApi))
389        self.beginBlock()
390        codegenFunc(self)
391        self.endBlock()
392
393    def makeStructAccess(self,
394                         vulkanType,
395                         structVarName,
396                         asPtr=True,
397                         structAsPtr=True,
398                         accessIndex=None):
399
400        deref = "->" if structAsPtr else "."
401
402        indexExpr = (" + %s" % accessIndex) if accessIndex else ""
403
404        addrOfExpr = "" if vulkanType.accessibleAsPointer() or (
405            not asPtr) else "&"
406
407        return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref,
408                               vulkanType.paramName, indexExpr)
409
410    def makeRawLengthAccess(self, vulkanType):
411        lenExpr = vulkanType.getLengthExpression()
412
413        if not lenExpr:
414            return None, None
415
416        if lenExpr == "null-terminated":
417            return "strlen(%s)" % vulkanType.paramName, None
418
419        return lenExpr, None
420
421    def makeLengthAccessFromStruct(self,
422                                   structInfo,
423                                   vulkanType,
424                                   structVarName,
425                                   asPtr=True):
426        # Handle special cases first
427        # Mostly when latexmath is involved
428        def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr):
429            cases = [
430                {
431                    "structName": "VkShaderModuleCreateInfo",
432                    "field": "pCode",
433                    "lenExprMember": "codeSize",
434                    "postprocess": lambda expr: "(%s / 4)" % expr
435                },
436                {
437                    "structName": "VkPipelineMultisampleStateCreateInfo",
438                    "field": "pSampleMask",
439                    "lenExprMember": "rasterizationSamples",
440                    "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr
441                },
442                {
443                    "structName": "VkAccelerationStructureVersionInfoKHR",
444                    "field": "pVersionData",
445                    "lenExprMember": "",
446                    "postprocess": lambda _: "2*VK_UUID_SIZE"
447                },
448            ]
449
450            for c in cases:
451                if (structInfo.name, vulkanType.paramName) == (c["structName"],
452                                                               c["field"]):
453                    deref = "->" if asPtr else "."
454                    expr = "%s%s%s" % (structVarName, deref,
455                                       c["lenExprMember"])
456                    lenAccessGuardExpr = "%s" % structVarName
457                    return c["postprocess"](expr), lenAccessGuardExpr
458
459            return None, None
460
461        specialCaseAccess = \
462            handleSpecialCases(
463                structInfo, vulkanType, structVarName, asPtr)
464
465        if specialCaseAccess != (None, None):
466            return specialCaseAccess
467
468        lenExpr = vulkanType.getLengthExpression()
469
470        if not lenExpr:
471            return None, None
472
473        deref = "->" if asPtr else "."
474        lenAccessGuardExpr = "%s" % (
475
476            structVarName) if deref else None
477        if lenExpr == "null-terminated":
478            return "strlen(%s%s%s)" % (structVarName, deref,
479                                       vulkanType.paramName), lenAccessGuardExpr
480
481        if not structInfo.getMember(lenExpr):
482            return self.makeRawLengthAccess(vulkanType)
483
484        return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr
485
486    def makeLengthAccessFromApi(self, api, vulkanType):
487        # Handle special cases first
488        # Mostly when :: is involved
489        def handleSpecialCases(vulkanType):
490            lenExpr = vulkanType.getLengthExpression()
491
492            if lenExpr is None:
493                return None, None
494
495            if "::" in lenExpr:
496                structVarName, memberVarName = lenExpr.split("::")
497                lenAccessGuardExpr = "%s" % (structVarName)
498                return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr
499            return None, None
500
501        specialCaseAccess = handleSpecialCases(vulkanType)
502
503        if specialCaseAccess != (None, None):
504            return specialCaseAccess
505
506        lenExpr = vulkanType.getLengthExpression()
507
508        if not lenExpr:
509            return None, None
510
511        lenExprInfo = api.getParameter(lenExpr)
512
513        if not lenExprInfo:
514            return self.makeRawLengthAccess(vulkanType)
515
516        if lenExpr == "null-terminated":
517            return "strlen(%s)" % vulkanType.paramName(), None
518        else:
519            deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else ""
520            lenAccessGuardExpr = "%s" % lenExpr if deref else None
521            return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr
522
523    def accessParameter(self, param, asPtr=True):
524        if asPtr:
525            if param.pointerIndirectionLevels > 0:
526                return param.paramName
527            else:
528                return "&%s" % param.paramName
529        else:
530            return param.paramName
531
532    def sizeofExpr(self, vulkanType):
533        return "sizeof(%s)" % (
534            self.makeCTypeDecl(vulkanType, useParamName=False))
535
536    def generalAccess(self,
537                      vulkanType,
538                      parentVarName=None,
539                      asPtr=True,
540                      structAsPtr=True):
541        if vulkanType.parent is None:
542            if parentVarName is None:
543                return self.accessParameter(vulkanType, asPtr=asPtr)
544            else:
545                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
546
547        if isinstance(vulkanType.parent, VulkanCompoundType):
548            return self.makeStructAccess(
549                vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr)
550
551        if isinstance(vulkanType.parent, VulkanAPI):
552            if parentVarName is None:
553                return self.accessParameter(vulkanType, asPtr=asPtr)
554            else:
555                return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr)
556
557        os.abort("Could not find a way to access Vulkan type %s" %
558                 vulkanType.name)
559
560    def makeLengthAccess(self, vulkanType, parentVarName="parent"):
561        if vulkanType.parent is None:
562            return self.makeRawLengthAccess(vulkanType)
563
564        if isinstance(vulkanType.parent, VulkanCompoundType):
565            return self.makeLengthAccessFromStruct(
566                vulkanType.parent, vulkanType, parentVarName, asPtr=True)
567
568        if isinstance(vulkanType.parent, VulkanAPI):
569            return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType)
570
571        os.abort("Could not find a way to access length of Vulkan type %s" %
572                 vulkanType.name)
573
574    def generalLengthAccess(self, vulkanType, parentVarName="parent"):
575        return self.makeLengthAccess(vulkanType, parentVarName)[0]
576
577    def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"):
578        return self.makeLengthAccess(vulkanType, parentVarName)[1]
579
580    def vkApiCall(self, api, customPrefix="", customParameters=None, retVarDecl=True, retVarAssign=True):
581        callLhs = None
582
583        retTypeName = api.getRetTypeExpr()
584        retVar = None
585
586        if retTypeName != "void":
587            retVar = api.getRetVarExpr()
588            if retVarDecl:
589                self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName))
590            if retVarAssign:
591                callLhs = retVar
592
593        if customParameters is None:
594            self.funcCall(
595            callLhs, customPrefix + api.name, [p.paramName for p in api.parameters])
596        else:
597            self.funcCall(
598                callLhs, customPrefix + api.name, customParameters)
599
600        return (retTypeName, retVar)
601
602    def makeCheckVkSuccess(self, expr):
603        return "((%s) == VK_SUCCESS)" % expr
604
605    def makeReinterpretCast(self, varName, typeName, const=True):
606        return "reinterpret_cast<%s%s*>(%s)" % \
607               ("const " if const else "", typeName, varName)
608
609    def validPrimitive(self, typeInfo, typeName):
610        size = typeInfo.getPrimitiveEncodingSize(typeName)
611        return size != None
612
613    def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"):
614        if not self.validPrimitive(typeInfo, typeName):
615            return None
616
617        size = typeInfo.getPrimitiveEncodingSize(typeName)
618        prefix = "put" if direction == "write" else "get"
619        suffix = None
620        if size == 1:
621            suffix = "Byte"
622        elif size == 2:
623            suffix = "Be16"
624        elif size == 4:
625            suffix = "Be32"
626        elif size == 8:
627            suffix = "Be64"
628
629        if suffix:
630            return prefix + suffix
631
632        return None
633
634    def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"):
635        if not self.validPrimitive(typeInfo, typeName):
636            return None
637
638        size = typeInfo.getPrimitiveEncodingSize(typeName)
639        prefix = "to" if direction == "write" else "from"
640        suffix = None
641        if size == 1:
642            suffix = "Byte"
643        elif size == 2:
644            suffix = "Be16"
645        elif size == 4:
646            suffix = "Be32"
647        elif size == 8:
648            suffix = "Be64"
649
650        if suffix:
651            return prefix + suffix
652
653        return None
654
655    def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
656        accessTypeName = accessType.typeName
657
658        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
659            print("Tried to stream a non-primitive type: %s" % accessTypeName)
660            os.abort()
661
662        needPtrCast = False
663
664        if accessType.pointerIndirectionLevels > 0:
665            streamSize = 8
666            streamStorageVarType = "uint64_t"
667            needPtrCast = True
668            streamMethod = "putBe64" if direction == "write" else "getBe64"
669        else:
670            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
671            if streamSize == 1:
672                streamStorageVarType = "uint8_t"
673            elif streamSize == 2:
674                streamStorageVarType = "uint16_t"
675            elif streamSize == 4:
676                streamStorageVarType = "uint32_t"
677            elif streamSize == 8:
678                streamStorageVarType = "uint64_t"
679            streamMethod = self.makePrimitiveStreamMethod(
680                typeInfo, accessTypeName, direction=direction)
681
682        streamStorageVar = self.var()
683
684        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
685
686        ptrCast = "(uintptr_t)" if needPtrCast else ""
687
688        if direction == "read":
689            self.stmt("%s = (%s)%s%s->%s()" %
690                      (accessExpr,
691                       accessCast,
692                       ptrCast,
693                       streamVar,
694                       streamMethod))
695        else:
696            self.stmt("%s %s = (%s)%s%s" %
697                      (streamStorageVarType, streamStorageVar,
698                       streamStorageVarType, ptrCast, accessExpr))
699            self.stmt("%s->%s(%s)" %
700                      (streamVar, streamMethod, streamStorageVar))
701
702    def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"):
703        accessTypeName = accessType.typeName
704
705        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
706            print("Tried to stream a non-primitive type: %s" % accessTypeName)
707            os.abort()
708
709        needPtrCast = False
710
711        streamSize = 8
712
713        if accessType.pointerIndirectionLevels > 0:
714            streamSize = 8
715            streamStorageVarType = "uint64_t"
716            needPtrCast = True
717            streamMethod = "toBe64" if direction == "write" else "fromBe64"
718        else:
719            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
720            if streamSize == 1:
721                streamStorageVarType = "uint8_t"
722            elif streamSize == 2:
723                streamStorageVarType = "uint16_t"
724            elif streamSize == 4:
725                streamStorageVarType = "uint32_t"
726            elif streamSize == 8:
727                streamStorageVarType = "uint64_t"
728            streamMethod = self.makePrimitiveStreamMethodInPlace(
729                typeInfo, accessTypeName, direction=direction)
730
731        streamStorageVar = self.var()
732
733        accessCast = self.makeRichCTypeDecl(accessType, useParamName=False)
734
735        if direction == "read":
736            accessCast = self.makeRichCTypeDecl(
737                accessType.getForNonConstAccess(), useParamName=False)
738
739        ptrCast = "(uintptr_t)" if needPtrCast else ""
740
741        if direction == "read":
742            self.stmt("memcpy((%s*)&%s, %s, %s)" %
743                      (accessCast,
744                       accessExpr,
745                       streamVar,
746                       str(streamSize)))
747            self.stmt("android::base::Stream::%s((uint8_t*)&%s)" % (
748                streamMethod,
749                accessExpr))
750        else:
751            self.stmt("%s %s = (%s)%s%s" %
752                      (streamStorageVarType, streamStorageVar,
753                       streamStorageVarType, ptrCast, accessExpr))
754            self.stmt("memcpy(%s, &%s, %s)" %
755                      (streamVar, streamStorageVar, str(streamSize)))
756            self.stmt("android::base::Stream::%s((uint8_t*)%s)" % (
757                streamMethod,
758                streamVar))
759
760    def countPrimitive(self, typeInfo, accessType):
761        accessTypeName = accessType.typeName
762
763        if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName):
764            print("Tried to count a non-primitive type: %s" % accessTypeName)
765            os.abort()
766
767        needPtrCast = False
768
769        if accessType.pointerIndirectionLevels > 0:
770            streamSize = 8
771        else:
772            streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName)
773
774        return streamSize
775
776# Class to wrap a Vulkan API call.
777#
778# The user gives a generic callback, |codegenDef|,
779# that takes a CodeGen object and a VulkanAPI object as arguments.
780# codegenDef uses CodeGen along with the VulkanAPI object
781# to generate the function body.
782class VulkanAPIWrapper(object):
783
784    def __init__(self,
785                 customApiPrefix,
786                 extraParameters=None,
787                 returnTypeOverride=None,
788                 codegenDef=None):
789        self.customApiPrefix = customApiPrefix
790        self.extraParameters = extraParameters
791        self.returnTypeOverride = returnTypeOverride
792
793        self.codegen = CodeGen()
794
795        self.definitionFunc = codegenDef
796
797        # Private function
798
799        def makeApiFunc(self, typeInfo, apiName):
800            customApi = copy(typeInfo.apis[apiName])
801            customApi.name = self.customApiPrefix + customApi.name
802            if self.extraParameters is not None:
803                if isinstance(self.extraParameters, list):
804                    customApi.parameters = \
805                        self.extraParameters + customApi.parameters
806                else:
807                    os.abort(
808                        "Type of extra parameters to custom API not valid. Expected list, got %s" % type(
809                            self.extraParameters))
810
811            if self.returnTypeOverride is not None:
812                customApi.retType = self.returnTypeOverride
813            return customApi
814
815        self.makeApi = makeApiFunc
816
817    def setCodegenDef(self, codegenDefFunc):
818        self.definitionFunc = codegenDefFunc
819
820    def makeDecl(self, typeInfo, apiName):
821        return self.codegen.makeFuncProto(
822            self.makeApi(self, typeInfo, apiName)) + ";\n\n"
823
824    def makeDefinition(self, typeInfo, apiName, isStatic=False):
825        vulkanApi = self.makeApi(self, typeInfo, apiName)
826
827        self.codegen.swapCode()
828        self.codegen.beginBlock()
829
830        if self.definitionFunc is None:
831            print("ERROR: No definition found for (%s, %s)" %
832                  (vulkanApi.name, self.customApiPrefix))
833            sys.exit(1)
834
835        self.definitionFunc(self.codegen, vulkanApi)
836
837        self.codegen.endBlock()
838
839        return ("static " if isStatic else "") + self.codegen.makeFuncProto(
840            vulkanApi) + "\n" + self.codegen.swapCode() + "\n"
841
842# Base class for wrapping all Vulkan API objects.  These work with Vulkan
843# Registry generators and have gen* triggers.  They tend to contain
844# VulkanAPIWrapper objects to make it easier to generate the code.
845class VulkanWrapperGenerator(object):
846
847    def __init__(self, module, typeInfo):
848        self.module = module
849        self.typeInfo = typeInfo
850        self.extensionStructTypes = OrderedDict()
851
852    def onBegin(self):
853        pass
854
855    def onEnd(self):
856        pass
857
858    def onBeginFeature(self, featureName):
859        pass
860
861    def onEndFeature(self):
862        pass
863
864    def onGenType(self, typeInfo, name, alias):
865        category = self.typeInfo.categoryOf(name)
866        if category in ["struct", "union"] and not alias:
867            structInfo = self.typeInfo.structs[name]
868            if structInfo.structExtendsExpr:
869                self.extensionStructTypes[name] = structInfo
870        pass
871
872    def onGenStruct(self, typeInfo, name, alias):
873        pass
874
875    def onGenGroup(self, groupinfo, groupName, alias=None):
876        pass
877
878    def onGenEnum(self, enuminfo, name, alias):
879        pass
880
881    def onGenCmd(self, cmdinfo, name, alias):
882        pass
883
884    # Below Vulkan structure types may correspond to multiple Vulkan structs
885    # due to a conflict between different Vulkan registries. In order to get
886    # the correct Vulkan struct type, we need to check the type of its "root"
887    # struct as well.
888    ROOT_TYPE_MAPPING = {
889        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": {
890            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
891            "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
892            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE",
893            "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT",
894        },
895        "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": {
896            "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
897            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportPhysicalAddressGOOGLE",
898            "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT",
899        },
900        "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": {
901            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT",
902            "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT",
903            "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE",
904            "default": "VkRenderPassFragmentDensityMapCreateInfoEXT",
905        },
906    }
907
908    def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None):
909        def readStructType(structTypeName, structVarName, cgen):
910            cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \
911                (structTypeName, "goldfish_vk_struct_type", structVarName))
912
913        def castAsStruct(varName, typeName, const=True):
914            return "reinterpret_cast<%s%s*>(%s)" % \
915                   ("const " if const else "", typeName, varName)
916
917        def doDefaultReturn(cgen):
918            if retType.typeName == "void":
919                cgen.stmt("return")
920            else:
921                cgen.stmt("return (%s)0" % retType.typeName)
922
923        cgen.beginIf("!%s" % triggerVar.paramName)
924        if nullEmit is None:
925            doDefaultReturn(cgen)
926        else:
927            nullEmit(cgen)
928        cgen.endIf()
929
930        readStructType("structType", triggerVar.paramName, cgen)
931
932        cgen.line("switch(structType)")
933        cgen.beginBlock()
934
935        currFeature = None
936
937        for ext in self.extensionStructTypes.values():
938            if not currFeature:
939                cgen.leftline("#ifdef %s" % ext.feature)
940                currFeature = ext.feature
941
942            if currFeature and ext.feature != currFeature:
943                cgen.leftline("#endif")
944                cgen.leftline("#ifdef %s" % ext.feature)
945                currFeature = ext.feature
946
947            enum = ext.structEnumExpr
948            cgen.line("case %s:" % enum)
949            cgen.beginBlock()
950
951            if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING:
952                cgen.line("switch(%s)" % rootTypeVar.paramName)
953                cgen.beginBlock()
954                kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum]
955                for k in kv:
956                    v = self.extensionStructTypes[kv[k]]
957                    if k == "default":
958                        cgen.line("%s:" % k)
959                    else:
960                        cgen.line("case %s:" % k)
961                    cgen.beginBlock()
962                    castedAccess = castAsStruct(
963                        triggerVar.paramName, v.name, const=triggerVar.isConst)
964                    forEachFunc(v, castedAccess, cgen)
965                    cgen.line("break;")
966                    cgen.endBlock()
967                cgen.endBlock()
968            else:
969                castedAccess = castAsStruct(
970                    triggerVar.paramName, ext.name, const=triggerVar.isConst)
971                forEachFunc(ext, castedAccess, cgen)
972
973            if autoBreak:
974                cgen.stmt("break")
975            cgen.endBlock()
976
977        if currFeature:
978            cgen.leftline("#endif")
979
980        cgen.line("default:")
981        cgen.beginBlock()
982        if defaultEmit is None:
983            doDefaultReturn(cgen)
984        else:
985            defaultEmit(cgen)
986        cgen.endBlock()
987
988        cgen.endBlock()
989
990    def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False):
991        currFeature = None
992
993        for (i, ext) in enumerate(self.extensionStructTypes.values()):
994            if doFeatureIfdefs:
995                if not currFeature:
996                    cgen.leftline("#ifdef %s" % ext.feature)
997                    currFeature = ext.feature
998
999                if currFeature and ext.feature != currFeature:
1000                    cgen.leftline("#endif")
1001                    cgen.leftline("#ifdef %s" % ext.feature)
1002                    currFeature = ext.feature
1003
1004            forEachFunc(i, ext, cgen)
1005
1006        if doFeatureIfdefs:
1007            if currFeature:
1008                cgen.leftline("#endif")
1009