1# Copyright 2023 Google LLC 2# SPDX-License-Identifier: MIT 3from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI 4from collections import OrderedDict 5from copy import copy 6from pathlib import Path, PurePosixPath 7 8import os 9import sys 10import shutil 11import subprocess 12import re 13 14# Class capturing a single file 15 16 17class SingleFileModule(object): 18 def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False): 19 self.directory = directory 20 self.basename = basename 21 self.customAbsDir = customAbsDir 22 self.suffix = suffix 23 self.file = None 24 25 self.preamble = "" 26 self.postamble = "" 27 28 self.suppress = suppress 29 30 def begin(self, globalDir): 31 if self.suppress: 32 return 33 34 # Create subdirectory, if needed 35 if self.customAbsDir: 36 absDir = self.customAbsDir 37 else: 38 absDir = os.path.join(globalDir, self.directory) 39 40 filename = os.path.join(absDir, self.basename) 41 42 self.file = open(filename + self.suffix, "w", encoding="utf-8") 43 self.file.write(self.preamble) 44 45 def append(self, toAppend): 46 if self.suppress: 47 return 48 49 self.file.write(toAppend) 50 51 def end(self): 52 if self.suppress: 53 return 54 55 self.file.write(self.postamble) 56 self.file.close() 57 58 def getMakefileSrcEntry(self): 59 return "" 60 61 def getCMakeSrcEntry(self): 62 return "" 63 64# Class capturing a .cpp file and a .h file (a "C++ module") 65 66 67class Module(object): 68 69 def __init__( 70 self, directory, basename, customAbsDir=None, suppress=False, implOnly=False, 71 headerOnly=False, suppressFeatureGuards=False): 72 self._headerFileModule = SingleFileModule( 73 ".h", directory, basename, customAbsDir, suppress or implOnly) 74 self._implFileModule = SingleFileModule( 75 ".cpp", directory, basename, customAbsDir, suppress or headerOnly) 76 77 self._headerOnly = headerOnly 78 self._implOnly = implOnly 79 80 self.directory = directory 81 self.basename = basename 82 self._customAbsDir = customAbsDir 83 84 self.suppressFeatureGuards = suppressFeatureGuards 85 86 @property 87 def suppress(self): 88 raise AttributeError("suppress is write only") 89 90 @suppress.setter 91 def suppress(self, value: bool): 92 self._headerFileModule.suppress = self._implOnly or value 93 self._implFileModule.suppress = self._headerOnly or value 94 95 @property 96 def headerPreamble(self) -> str: 97 return self._headerFileModule.preamble 98 99 @headerPreamble.setter 100 def headerPreamble(self, value: str): 101 self._headerFileModule.preamble = value 102 103 @property 104 def headerPostamble(self) -> str: 105 return self._headerFileModule.postamble 106 107 @headerPostamble.setter 108 def headerPostamble(self, value: str): 109 self._headerFileModule.postamble = value 110 111 @property 112 def implPreamble(self) -> str: 113 return self._implFileModule.preamble 114 115 @implPreamble.setter 116 def implPreamble(self, value: str): 117 self._implFileModule.preamble = value 118 119 @property 120 def implPostamble(self) -> str: 121 return self._implFileModule.postamble 122 123 @implPostamble.setter 124 def implPostamble(self, value: str): 125 self._implFileModule.postamble = value 126 127 def getMakefileSrcEntry(self): 128 if self._customAbsDir: 129 return self.basename + ".cpp \\\n" 130 dirName = self.directory 131 baseName = self.basename 132 joined = os.path.join(dirName, baseName) 133 return " " + joined + ".cpp \\\n" 134 135 def getCMakeSrcEntry(self): 136 if self._customAbsDir: 137 return "\n" + self.basename + ".cpp " 138 dirName = Path(self.directory) 139 baseName = Path(self.basename) 140 joined = PurePosixPath(dirName / baseName) 141 return "\n " + str(joined) + ".cpp " 142 143 def begin(self, globalDir): 144 self._headerFileModule.begin(globalDir) 145 self._implFileModule.begin(globalDir) 146 147 def appendHeader(self, toAppend): 148 self._headerFileModule.append(toAppend) 149 150 def appendImpl(self, toAppend): 151 self._implFileModule.append(toAppend) 152 153 def end(self): 154 self._headerFileModule.end() 155 self._implFileModule.end() 156 157 # Removes empty ifdef blocks with a regex query over the file 158 # which are mainly introduced by extensions with no functions or variables 159 def remove_empty_ifdefs(filename: Path): 160 """Removes empty #ifdef blocks from a C++ file.""" 161 162 # Load file contents 163 with open(filename, 'r') as file: 164 content = file.read() 165 166 # Regular Expression Pattern 167 pattern = r"#ifdef\s+(\w+)\s*(?://.*)?\s*\n\s*#endif\s*(?://.*)?\s*" 168 169 # Replace Empty Blocks 170 modified_content = re.sub(pattern, "", content) 171 172 # Save file back 173 with open(filename, 'w') as file: 174 file.write(modified_content) 175 176 clang_format_command = shutil.which('clang-format') 177 178 def formatFile(filename: Path): 179 if "GFXSTREAM_NO_CLANG_FMT" in os.environ: 180 return 181 assert (clang_format_command is not None) 182 assert (subprocess.call([clang_format_command, "-i", 183 "--style=file", str(filename.resolve())]) == 0) 184 185 if not self._headerFileModule.suppress: 186 filename = Path(self._headerFileModule.file.name) 187 remove_empty_ifdefs(filename) 188 formatFile(filename) 189 190 if not self._implFileModule.suppress: 191 filename = Path(self._implFileModule.file.name) 192 remove_empty_ifdefs(filename) 193 formatFile(filename) 194 195 196class PyScript(SingleFileModule): 197 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 198 super().__init__(".py", directory, basename, customAbsDir, suppress) 199 200 201# Class capturing a .proto protobuf definition file 202class Proto(SingleFileModule): 203 204 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 205 super().__init__(".proto", directory, basename, customAbsDir, suppress) 206 207 def getMakefileSrcEntry(self): 208 super().getMakefileSrcEntry() 209 if self.customAbsDir: 210 return self.basename + ".proto \\\n" 211 dirName = self.directory 212 baseName = self.basename 213 joined = os.path.join(dirName, baseName) 214 return " " + joined + ".proto \\\n" 215 216 def getCMakeSrcEntry(self): 217 super().getCMakeSrcEntry() 218 if self.customAbsDir: 219 return "\n" + self.basename + ".proto " 220 221 dirName = self.directory 222 baseName = self.basename 223 joined = os.path.join(dirName, baseName) 224 return "\n " + joined + ".proto " 225 226class CodeGen(object): 227 228 def __init__(self,): 229 self.code = "" 230 self.indentLevel = 0 231 self.gensymCounter = [-1] 232 233 def var(self, prefix="cgen_var"): 234 self.gensymCounter[-1] += 1 235 res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0)) 236 return res 237 238 def swapCode(self,): 239 res = "%s" % self.code 240 self.code = "" 241 return res 242 243 def indent(self,extra=0): 244 return "".join(" " * (self.indentLevel + extra)) 245 246 def incrIndent(self,): 247 self.indentLevel += 1 248 249 def decrIndent(self,): 250 if self.indentLevel > 0: 251 self.indentLevel -= 1 252 253 def beginBlock(self, bracketPrint=True): 254 if bracketPrint: 255 self.code += self.indent() + "{\n" 256 self.indentLevel += 1 257 self.gensymCounter.append(-1) 258 259 def endBlock(self,bracketPrint=True): 260 self.indentLevel -= 1 261 if bracketPrint: 262 self.code += self.indent() + "}\n" 263 del self.gensymCounter[-1] 264 265 def beginIf(self, cond): 266 self.code += self.indent() + "if (" + cond + ")\n" 267 self.beginBlock() 268 269 def beginElse(self, cond = None): 270 if cond is not None: 271 self.code += \ 272 self.indent() + \ 273 "else if (" + cond + ")\n" 274 else: 275 self.code += self.indent() + "else\n" 276 self.beginBlock() 277 278 def endElse(self): 279 self.endBlock() 280 281 def endIf(self): 282 self.endBlock() 283 284 def beginSwitch(self, switchvar): 285 self.code += self.indent() + "switch (" + switchvar + ")\n" 286 self.beginBlock() 287 288 def switchCase(self, switchval, blocked = False): 289 self.code += self.indent() + "case %s:" % switchval 290 self.beginBlock(bracketPrint = blocked) 291 292 def switchCaseBreak(self, switchval, blocked = False): 293 self.code += self.indent() + "case %s:" % switchval 294 self.endBlock(bracketPrint = blocked) 295 296 def switchCaseDefault(self, blocked = False): 297 self.code += self.indent() + "default:" % switchval 298 self.beginBlock(bracketPrint = blocked) 299 300 def endSwitch(self): 301 self.endBlock() 302 303 def beginWhile(self, cond): 304 self.code += self.indent() + "while (" + cond + ")\n" 305 self.beginBlock() 306 307 def endWhile(self): 308 self.endBlock() 309 310 def beginFor(self, initial, condition, increment): 311 self.code += \ 312 self.indent() + "for (" + \ 313 "; ".join([initial, condition, increment]) + \ 314 ")\n" 315 self.beginBlock() 316 317 def endFor(self): 318 self.endBlock() 319 320 def beginLoop(self, loopVarType, loopVar, loopInit, loopBound): 321 self.beginFor( 322 "%s %s = %s" % (loopVarType, loopVar, loopInit), 323 "%s < %s" % (loopVar, loopBound), 324 "++%s" % (loopVar)) 325 326 def endLoop(self): 327 self.endBlock() 328 329 def stmt(self, code): 330 self.code += self.indent() + code + ";\n" 331 332 def line(self, code): 333 self.code += self.indent() + code + "\n" 334 335 def leftline(self, code): 336 self.code += code + "\n" 337 338 def makeCallExpr(self, funcName, parameters): 339 return funcName + "(%s)" % (", ".join(parameters)) 340 341 def funcCall(self, lhs, funcName, parameters): 342 res = self.indent() 343 344 if lhs is not None: 345 res += lhs + " = " 346 347 res += self.makeCallExpr(funcName, parameters) + ";\n" 348 self.code += res 349 350 def funcCallRet(self, _lhs, funcName, parameters): 351 res = self.indent() 352 res += "return " + self.makeCallExpr(funcName, parameters) + ";\n" 353 self.code += res 354 355 # Given a VulkanType object, generate a C type declaration 356 # with optional parameter name: 357 # [const] [typename][*][const*] [paramName] 358 def makeCTypeDecl(self, vulkanType, useParamName=True): 359 constness = "const " if vulkanType.isConst else "" 360 typeName = vulkanType.typeName 361 362 if vulkanType.pointerIndirectionLevels == 0: 363 ptrSpec = "" 364 elif vulkanType.isPointerToConstPointer: 365 ptrSpec = "* const*" if vulkanType.isConst else "**" 366 if vulkanType.pointerIndirectionLevels > 2: 367 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 368 else: 369 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 370 371 if useParamName and (vulkanType.paramName is not None): 372 paramStr = (" " + vulkanType.paramName) 373 else: 374 paramStr = "" 375 376 return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr) 377 378 def makeRichCTypeDecl(self, vulkanType, useParamName=True): 379 constness = "const " if vulkanType.isConst else "" 380 typeName = vulkanType.typeName 381 382 if vulkanType.pointerIndirectionLevels == 0: 383 ptrSpec = "" 384 elif vulkanType.isPointerToConstPointer: 385 ptrSpec = "* const*" if vulkanType.isConst else "**" 386 if vulkanType.pointerIndirectionLevels > 2: 387 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 388 else: 389 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 390 391 if useParamName and (vulkanType.paramName is not None): 392 paramStr = (" " + vulkanType.paramName) 393 else: 394 paramStr = "" 395 396 if vulkanType.staticArrExpr: 397 staticArrInfo = "[%s]" % vulkanType.staticArrExpr 398 else: 399 staticArrInfo = "" 400 401 return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo) 402 403 # Given a VulkanAPI object, generate the C function protype: 404 # <returntype> <funcname>(<parameters>) 405 def makeFuncProto(self, vulkanApi, useParamName=True): 406 407 protoBegin = "%s %s" % (self.makeCTypeDecl( 408 vulkanApi.retType, useParamName=False), vulkanApi.name) 409 410 def getFuncArgDecl(param): 411 if param.staticArrExpr: 412 return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr) 413 else: 414 return self.makeCTypeDecl(param, useParamName=useParamName) 415 416 protoParams = "(\n %s)" % ((",\n%s" % self.indent(1)).join( 417 list(map( 418 getFuncArgDecl, 419 vulkanApi.parameters)))) 420 421 return protoBegin + protoParams 422 423 def makeFuncAlias(self, nameDst, nameSrc): 424 return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst) 425 426 def makeFuncDecl(self, vulkanApi): 427 return self.makeFuncProto(vulkanApi) + ";\n\n" 428 429 def makeFuncImpl(self, vulkanApi, codegenFunc): 430 self.swapCode() 431 432 self.line(self.makeFuncProto(vulkanApi)) 433 self.beginBlock() 434 codegenFunc(self) 435 self.endBlock() 436 437 return self.swapCode() + "\n" 438 439 def emitFuncImpl(self, vulkanApi, codegenFunc): 440 self.line(self.makeFuncProto(vulkanApi)) 441 self.beginBlock() 442 codegenFunc(self) 443 self.endBlock() 444 445 def makeStructAccess(self, 446 vulkanType, 447 structVarName, 448 asPtr=True, 449 structAsPtr=True, 450 accessIndex=None): 451 452 deref = "->" if structAsPtr else "." 453 454 indexExpr = (" + %s" % accessIndex) if accessIndex else "" 455 456 addrOfExpr = "" if vulkanType.accessibleAsPointer() or ( 457 not asPtr) else "&" 458 459 return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref, 460 vulkanType.paramName, indexExpr) 461 462 def makeRawLengthAccess(self, vulkanType): 463 lenExpr = vulkanType.getLengthExpression() 464 465 if not lenExpr: 466 return None, None 467 468 if lenExpr == "null-terminated": 469 return "strlen(%s)" % vulkanType.paramName, None 470 471 return lenExpr, None 472 473 def makeLengthAccessFromStruct(self, 474 structInfo, 475 vulkanType, 476 structVarName, 477 asPtr=True): 478 # Handle special cases first 479 # Mostly when latexmath is involved 480 def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr): 481 cases = [ 482 { 483 "structName": "VkShaderModuleCreateInfo", 484 "field": "pCode", 485 "lenExprMember": "codeSize", 486 "postprocess": lambda expr: "(%s / 4)" % expr 487 }, 488 { 489 "structName": "VkPipelineMultisampleStateCreateInfo", 490 "field": "pSampleMask", 491 "lenExprMember": "rasterizationSamples", 492 "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr 493 }, 494 { 495 "structName": "VkAccelerationStructureVersionInfoKHR", 496 "field": "pVersionData", 497 "lenExprMember": "", 498 "postprocess": lambda _: "2*VK_UUID_SIZE" 499 }, 500 ] 501 502 for c in cases: 503 if (structInfo.name, vulkanType.paramName) == (c["structName"], 504 c["field"]): 505 deref = "->" if asPtr else "." 506 expr = "%s%s%s" % (structVarName, deref, 507 c["lenExprMember"]) 508 lenAccessGuardExpr = "%s" % structVarName 509 return c["postprocess"](expr), lenAccessGuardExpr 510 511 return None, None 512 513 specialCaseAccess = \ 514 handleSpecialCases( 515 structInfo, vulkanType, structVarName, asPtr) 516 517 if specialCaseAccess != (None, None): 518 return specialCaseAccess 519 520 lenExpr = vulkanType.getLengthExpression() 521 522 if not lenExpr: 523 return None, None 524 525 deref = "->" if asPtr else "." 526 lenAccessGuardExpr = "%s" % ( 527 528 structVarName) if deref else None 529 if lenExpr == "null-terminated": 530 return "strlen(%s%s%s)" % (structVarName, deref, 531 vulkanType.paramName), lenAccessGuardExpr 532 533 if not structInfo.getMember(lenExpr): 534 return self.makeRawLengthAccess(vulkanType) 535 536 return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr 537 538 def makeLengthAccessFromApi(self, api, vulkanType): 539 # Handle special cases first 540 # Mostly when :: is involved 541 def handleSpecialCases(vulkanType): 542 lenExpr = vulkanType.getLengthExpression() 543 544 if lenExpr is None: 545 return None, None 546 547 if "::" in lenExpr: 548 structVarName, memberVarName = lenExpr.split("::") 549 lenAccessGuardExpr = "%s" % (structVarName) 550 return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr 551 return None, None 552 553 specialCaseAccess = handleSpecialCases(vulkanType) 554 555 if specialCaseAccess != (None, None): 556 return specialCaseAccess 557 558 lenExpr = vulkanType.getLengthExpression() 559 560 if not lenExpr: 561 return None, None 562 563 lenExprInfo = api.getParameter(lenExpr) 564 565 if not lenExprInfo: 566 return self.makeRawLengthAccess(vulkanType) 567 568 if lenExpr == "null-terminated": 569 return "strlen(%s)" % vulkanType.paramName(), None 570 else: 571 deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else "" 572 lenAccessGuardExpr = "%s" % lenExpr if deref else None 573 return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr 574 575 def accessParameter(self, param, asPtr=True): 576 if asPtr: 577 if param.pointerIndirectionLevels > 0: 578 return param.paramName 579 else: 580 return "&%s" % param.paramName 581 else: 582 return param.paramName 583 584 def sizeofExpr(self, vulkanType): 585 return "sizeof(%s)" % ( 586 self.makeCTypeDecl(vulkanType, useParamName=False)) 587 588 def generalAccess(self, 589 vulkanType, 590 parentVarName=None, 591 asPtr=True, 592 structAsPtr=True): 593 if vulkanType.parent is None: 594 if parentVarName is None: 595 return self.accessParameter(vulkanType, asPtr=asPtr) 596 else: 597 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 598 599 if isinstance(vulkanType.parent, VulkanCompoundType): 600 return self.makeStructAccess( 601 vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr) 602 603 if isinstance(vulkanType.parent, VulkanAPI): 604 if parentVarName is None: 605 return self.accessParameter(vulkanType, asPtr=asPtr) 606 else: 607 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 608 609 os.abort("Could not find a way to access Vulkan type %s" % 610 vulkanType.name) 611 612 def makeLengthAccess(self, vulkanType, parentVarName="parent"): 613 if vulkanType.parent is None: 614 return self.makeRawLengthAccess(vulkanType) 615 616 if isinstance(vulkanType.parent, VulkanCompoundType): 617 return self.makeLengthAccessFromStruct( 618 vulkanType.parent, vulkanType, parentVarName, asPtr=True) 619 620 if isinstance(vulkanType.parent, VulkanAPI): 621 return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType) 622 623 os.abort("Could not find a way to access length of Vulkan type %s" % 624 vulkanType.name) 625 626 def generalLengthAccess(self, vulkanType, parentVarName="parent"): 627 return self.makeLengthAccess(vulkanType, parentVarName)[0] 628 629 def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"): 630 return self.makeLengthAccess(vulkanType, parentVarName)[1] 631 632 def vkApiCall(self, api, customPrefix="", globalStatePrefix="", 633 customParameters=None, checkForDeviceLost=False, 634 checkForOutOfMemory=False, checkDispatcher=None): 635 callLhs = None 636 637 retTypeName = api.getRetTypeExpr() 638 retVar = None 639 640 if retTypeName != "void": 641 retVar = api.getRetVarExpr() 642 defaultReturn = "(%s)0" % retTypeName 643 if retTypeName == "VkResult": 644 # TODO: return a valid error code based on the call 645 # This is used to handle invalid dispatcher and snapshot states 646 deviceLostFunctions = ["vkQueueSubmit", 647 "vkQueueWaitIdle", 648 "vkWaitForFences"] 649 defaultReturn = "VK_ERROR_OUT_OF_HOST_MEMORY" 650 if api in deviceLostFunctions: 651 defaultReturn = "VK_ERROR_DEVICE_LOST" 652 self.stmt("%s %s = %s" % (retTypeName, retVar, defaultReturn)) 653 callLhs = retVar 654 655 if (checkDispatcher): 656 self.beginIf(checkDispatcher) 657 658 if customParameters is None: 659 self.funcCall( 660 callLhs, customPrefix + api.name, [p.paramName for p in api.parameters]) 661 else: 662 self.funcCall( 663 callLhs, customPrefix + api.name, customParameters) 664 665 if (checkDispatcher): 666 self.endIf() 667 668 if retTypeName == "VkResult" and checkForDeviceLost: 669 self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix)) 670 671 if retTypeName == "VkResult" and checkForOutOfMemory: 672 if api.name == "vkAllocateMemory": 673 self.stmt( 674 "%sCheckOutOfMemory(%s, opcode, context, std::make_optional<uint64_t>(pAllocateInfo->allocationSize))" 675 % (globalStatePrefix, callLhs)) 676 else: 677 self.stmt( 678 "%sCheckOutOfMemory(%s, opcode, context)" 679 % (globalStatePrefix, callLhs)) 680 681 return (retTypeName, retVar) 682 683 def makeCheckVkSuccess(self, expr): 684 return "((%s) == VK_SUCCESS)" % expr 685 686 def makeReinterpretCast(self, varName, typeName, const=True): 687 return "reinterpret_cast<%s%s*>(%s)" % \ 688 ("const " if const else "", typeName, varName) 689 690 def validPrimitive(self, typeInfo, typeName): 691 size = typeInfo.getPrimitiveEncodingSize(typeName) 692 return size != None 693 694 def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"): 695 if not self.validPrimitive(typeInfo, typeName): 696 return None 697 698 size = typeInfo.getPrimitiveEncodingSize(typeName) 699 prefix = "put" if direction == "write" else "get" 700 suffix = None 701 if size == 1: 702 suffix = "Byte" 703 elif size == 2: 704 suffix = "Be16" 705 elif size == 4: 706 suffix = "Be32" 707 elif size == 8: 708 suffix = "Be64" 709 710 if suffix: 711 return prefix + suffix 712 713 return None 714 715 def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"): 716 if not self.validPrimitive(typeInfo, typeName): 717 return None 718 719 size = typeInfo.getPrimitiveEncodingSize(typeName) 720 prefix = "to" if direction == "write" else "from" 721 suffix = None 722 if size == 1: 723 suffix = "Byte" 724 elif size == 2: 725 suffix = "Be16" 726 elif size == 4: 727 suffix = "Be32" 728 elif size == 8: 729 suffix = "Be64" 730 731 if suffix: 732 return prefix + suffix 733 734 return None 735 736 def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"): 737 accessTypeName = accessType.typeName 738 739 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 740 print("Tried to stream a non-primitive type: %s" % accessTypeName) 741 os.abort() 742 743 needPtrCast = False 744 745 if accessType.pointerIndirectionLevels > 0: 746 streamSize = 8 747 streamStorageVarType = "uint64_t" 748 needPtrCast = True 749 streamMethod = "putBe64" if direction == "write" else "getBe64" 750 else: 751 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 752 if streamSize == 1: 753 streamStorageVarType = "uint8_t" 754 elif streamSize == 2: 755 streamStorageVarType = "uint16_t" 756 elif streamSize == 4: 757 streamStorageVarType = "uint32_t" 758 elif streamSize == 8: 759 streamStorageVarType = "uint64_t" 760 streamMethod = self.makePrimitiveStreamMethod( 761 typeInfo, accessTypeName, direction=direction) 762 763 streamStorageVar = self.var() 764 765 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 766 767 ptrCast = "(uintptr_t)" if needPtrCast else "" 768 769 if direction == "read": 770 self.stmt("%s = (%s)%s%s->%s()" % 771 (accessExpr, 772 accessCast, 773 ptrCast, 774 streamVar, 775 streamMethod)) 776 else: 777 self.stmt("%s %s = (%s)%s%s" % 778 (streamStorageVarType, streamStorageVar, 779 streamStorageVarType, ptrCast, accessExpr)) 780 self.stmt("%s->%s(%s)" % 781 (streamVar, streamMethod, streamStorageVar)) 782 783 def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, variant, direction="write"): 784 accessTypeName = accessType.typeName 785 786 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 787 print("Tried to stream a non-primitive type: %s" % accessTypeName) 788 os.abort() 789 790 needPtrCast = False 791 792 streamSize = 8 793 794 if accessType.pointerIndirectionLevels > 0: 795 streamSize = 8 796 streamStorageVarType = "uint64_t" 797 needPtrCast = True 798 streamMethod = "toBe64" if direction == "write" else "fromBe64" 799 else: 800 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 801 if streamSize == 1: 802 streamStorageVarType = "uint8_t" 803 elif streamSize == 2: 804 streamStorageVarType = "uint16_t" 805 elif streamSize == 4: 806 streamStorageVarType = "uint32_t" 807 elif streamSize == 8: 808 streamStorageVarType = "uint64_t" 809 streamMethod = self.makePrimitiveStreamMethodInPlace( 810 typeInfo, accessTypeName, direction=direction) 811 812 streamStorageVar = self.var() 813 814 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 815 816 if direction == "read": 817 accessCast = self.makeRichCTypeDecl( 818 accessType.getForNonConstAccess(), useParamName=False) 819 820 ptrCast = "(uintptr_t)" if needPtrCast else "" 821 if variant == "guest": 822 streamNamespace = "gfxstream::aemu" 823 else: 824 streamNamespace = "android::base" 825 826 if direction == "read": 827 self.stmt("memcpy((%s*)&%s, %s, %s)" % 828 (accessCast, 829 accessExpr, 830 streamVar, 831 str(streamSize))) 832 self.stmt("%s::Stream::%s((uint8_t*)&%s)" % ( 833 streamNamespace, 834 streamMethod, 835 accessExpr)) 836 else: 837 self.stmt("%s %s = (%s)%s%s" % 838 (streamStorageVarType, streamStorageVar, 839 streamStorageVarType, ptrCast, accessExpr)) 840 self.stmt("memcpy(%s, &%s, %s)" % 841 (streamVar, streamStorageVar, str(streamSize))) 842 self.stmt("%s::Stream::%s((uint8_t*)%s)" % ( 843 streamNamespace, 844 streamMethod, 845 streamVar)) 846 847 def countPrimitive(self, typeInfo, accessType): 848 accessTypeName = accessType.typeName 849 850 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 851 print("Tried to count a non-primitive type: %s" % accessTypeName) 852 os.abort() 853 854 needPtrCast = False 855 856 if accessType.pointerIndirectionLevels > 0: 857 streamSize = 8 858 else: 859 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 860 861 return streamSize 862 863# Class to wrap a Vulkan API call. 864# 865# The user gives a generic callback, |codegenDef|, 866# that takes a CodeGen object and a VulkanAPI object as arguments. 867# codegenDef uses CodeGen along with the VulkanAPI object 868# to generate the function body. 869class VulkanAPIWrapper(object): 870 871 def __init__(self, 872 customApiPrefix, 873 extraParameters=None, 874 returnTypeOverride=None, 875 codegenDef=None): 876 self.customApiPrefix = customApiPrefix 877 self.extraParameters = extraParameters 878 self.returnTypeOverride = returnTypeOverride 879 880 self.codegen = CodeGen() 881 882 self.definitionFunc = codegenDef 883 884 # Private function 885 886 def makeApiFunc(self, typeInfo, apiName): 887 customApi = copy(typeInfo.apis[apiName]) 888 customApi.name = self.customApiPrefix + customApi.name 889 if self.extraParameters is not None: 890 if isinstance(self.extraParameters, list): 891 customApi.parameters = \ 892 self.extraParameters + customApi.parameters 893 else: 894 os.abort( 895 "Type of extra parameters to custom API not valid. Expected list, got %s" % type( 896 self.extraParameters)) 897 898 if self.returnTypeOverride is not None: 899 customApi.retType = self.returnTypeOverride 900 return customApi 901 902 self.makeApi = makeApiFunc 903 904 def setCodegenDef(self, codegenDefFunc): 905 self.definitionFunc = codegenDefFunc 906 907 def makeDecl(self, typeInfo, apiName): 908 return self.codegen.makeFuncProto( 909 self.makeApi(self, typeInfo, apiName)) + ";\n\n" 910 911 def makeDefinition(self, typeInfo, apiName, isStatic=False): 912 vulkanApi = self.makeApi(self, typeInfo, apiName) 913 914 self.codegen.swapCode() 915 self.codegen.beginBlock() 916 917 if self.definitionFunc is None: 918 print("ERROR: No definition found for (%s, %s)" % 919 (vulkanApi.name, self.customApiPrefix)) 920 sys.exit(1) 921 922 self.definitionFunc(self.codegen, vulkanApi) 923 924 self.codegen.endBlock() 925 926 return ("static " if isStatic else "") + self.codegen.makeFuncProto( 927 vulkanApi) + "\n" + self.codegen.swapCode() + "\n" 928 929# Base class for wrapping all Vulkan API objects. These work with Vulkan 930# Registry generators and have gen* triggers. They tend to contain 931# VulkanAPIWrapper objects to make it easier to generate the code. 932class VulkanWrapperGenerator(object): 933 934 def __init__(self, module: Module, typeInfo: VulkanTypeInfo): 935 self.module: Module = module 936 self.typeInfo: VulkanTypeInfo = typeInfo 937 self.extensionStructTypes = OrderedDict() 938 939 def onBegin(self): 940 pass 941 942 def onEnd(self): 943 pass 944 945 def onBeginFeature(self, featureName, featureType): 946 pass 947 948 def onFeatureNewCmd(self, cmdName): 949 pass 950 951 def onEndFeature(self): 952 pass 953 954 def onGenType(self, typeInfo, name, alias): 955 category = self.typeInfo.categoryOf(name) 956 if category in ["struct", "union"] and not alias: 957 structInfo = self.typeInfo.structs[name] 958 if structInfo.structExtendsExpr: 959 self.extensionStructTypes[name] = structInfo 960 pass 961 962 def onGenStruct(self, typeInfo, name, alias): 963 pass 964 965 def onGenGroup(self, groupinfo, groupName, alias=None): 966 pass 967 968 def onGenEnum(self, enuminfo, name, alias): 969 pass 970 971 def onGenCmd(self, cmdinfo, name, alias): 972 pass 973 974 # Below Vulkan structure types may correspond to multiple Vulkan structs 975 # due to a conflict between different Vulkan registries. In order to get 976 # the correct Vulkan struct type, we need to check the type of its "root" 977 # struct as well. 978 ROOT_TYPE_MAPPING = { 979 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": { 980 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 981 "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 982 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE", 983 "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 984 }, 985 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": { 986 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 987 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkCreateBlobGOOGLE", 988 "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 989 }, 990 "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": { 991 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT", 992 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT", 993 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE", 994 "default": "VkRenderPassFragmentDensityMapCreateInfoEXT", 995 }, 996 } 997 998 def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None): 999 def readStructType(structTypeName, structVarName, cgen): 1000 cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \ 1001 (structTypeName, "goldfish_vk_struct_type", structVarName)) 1002 1003 def castAsStruct(varName, typeName, const=True): 1004 return "reinterpret_cast<%s%s*>(%s)" % \ 1005 ("const " if const else "", typeName, varName) 1006 1007 def doDefaultReturn(cgen): 1008 if retType.typeName == "void": 1009 cgen.stmt("return") 1010 else: 1011 cgen.stmt("return (%s)0" % retType.typeName) 1012 1013 cgen.beginIf("!%s" % triggerVar.paramName) 1014 if nullEmit is None: 1015 doDefaultReturn(cgen) 1016 else: 1017 nullEmit(cgen) 1018 cgen.endIf() 1019 1020 readStructType("structType", triggerVar.paramName, cgen) 1021 1022 cgen.line("switch(structType)") 1023 cgen.beginBlock() 1024 1025 currFeature = None 1026 1027 for ext in self.extensionStructTypes.values(): 1028 if not currFeature: 1029 cgen.leftline("#ifdef %s" % ext.feature) 1030 currFeature = ext.feature 1031 1032 if currFeature and ext.feature != currFeature: 1033 cgen.leftline("#endif") 1034 cgen.leftline("#ifdef %s" % ext.feature) 1035 currFeature = ext.feature 1036 1037 enum = ext.structEnumExpr 1038 protect = None 1039 if enum in self.typeInfo.enumElem: 1040 protect = self.typeInfo.enumElem[enum].get("protect", default=None) 1041 if protect is not None: 1042 cgen.leftline("#ifdef %s" % protect) 1043 1044 cgen.line("case %s:" % enum) 1045 cgen.beginBlock() 1046 1047 if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING: 1048 cgen.line("switch(%s)" % rootTypeVar.paramName) 1049 cgen.beginBlock() 1050 kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum] 1051 for k in kv: 1052 v = self.extensionStructTypes[kv[k]] 1053 if k == "default": 1054 cgen.line("%s:" % k) 1055 else: 1056 cgen.line("case %s:" % k) 1057 cgen.beginBlock() 1058 castedAccess = castAsStruct( 1059 triggerVar.paramName, v.name, const=triggerVar.isConst) 1060 forEachFunc(v, castedAccess, cgen) 1061 cgen.line("break;") 1062 cgen.endBlock() 1063 cgen.endBlock() 1064 else: 1065 castedAccess = castAsStruct( 1066 triggerVar.paramName, ext.name, const=triggerVar.isConst) 1067 forEachFunc(ext, castedAccess, cgen) 1068 1069 if autoBreak: 1070 cgen.stmt("break") 1071 cgen.endBlock() 1072 1073 if protect is not None: 1074 cgen.leftline("#endif // %s" % protect) 1075 1076 if currFeature: 1077 cgen.leftline("#endif") 1078 1079 cgen.line("default:") 1080 cgen.beginBlock() 1081 if defaultEmit is None: 1082 doDefaultReturn(cgen) 1083 else: 1084 defaultEmit(cgen) 1085 cgen.endBlock() 1086 1087 cgen.endBlock() 1088 1089 def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False): 1090 currFeature = None 1091 1092 for (i, ext) in enumerate(self.extensionStructTypes.values()): 1093 if doFeatureIfdefs: 1094 if not currFeature: 1095 cgen.leftline("#ifdef %s" % ext.feature) 1096 currFeature = ext.feature 1097 1098 if currFeature and ext.feature != currFeature: 1099 cgen.leftline("#endif") 1100 cgen.leftline("#ifdef %s" % ext.feature) 1101 currFeature = ext.feature 1102 1103 forEachFunc(i, ext, cgen) 1104 1105 if doFeatureIfdefs: 1106 if currFeature: 1107 cgen.leftline("#endif") 1108