1# Copyright (c) 2018 The Android Open Source Project 2# Copyright (c) 2018 Google Inc. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI 17from collections import OrderedDict 18from copy import copy 19from pathlib import Path, PurePosixPath 20 21import os 22import sys 23import shutil 24import subprocess 25 26# Class capturing a single file 27 28 29class SingleFileModule(object): 30 def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False): 31 self.directory = directory 32 self.basename = basename 33 self.customAbsDir = customAbsDir 34 self.suffix = suffix 35 self.file = None 36 37 self.preamble = "" 38 self.postamble = "" 39 40 self.suppress = suppress 41 42 def begin(self, globalDir): 43 if self.suppress: 44 return 45 46 # Create subdirectory, if needed 47 if self.customAbsDir: 48 absDir = self.customAbsDir 49 else: 50 absDir = os.path.join(globalDir, self.directory) 51 52 filename = os.path.join(absDir, self.basename) 53 54 self.file = open(filename + self.suffix, "w", encoding="utf-8") 55 self.file.write(self.preamble) 56 57 def append(self, toAppend): 58 if self.suppress: 59 return 60 61 self.file.write(toAppend) 62 63 def end(self): 64 if self.suppress: 65 return 66 67 self.file.write(self.postamble) 68 self.file.close() 69 70 def getMakefileSrcEntry(self): 71 return "" 72 73 def getCMakeSrcEntry(self): 74 return "" 75 76# Class capturing a .cpp file and a .h file (a "C++ module") 77 78 79class Module(object): 80 81 def __init__( 82 self, directory, basename, customAbsDir=None, suppress=False, implOnly=False, 83 headerOnly=False, suppressFeatureGuards=False): 84 self._headerFileModule = SingleFileModule( 85 ".h", directory, basename, customAbsDir, suppress or implOnly) 86 self._implFileModule = SingleFileModule( 87 ".cpp", directory, basename, customAbsDir, suppress or headerOnly) 88 89 self._headerOnly = headerOnly 90 self._implOnly = implOnly 91 92 self.directory = directory 93 self.basename = basename 94 self._customAbsDir = customAbsDir 95 96 self.suppressFeatureGuards = suppressFeatureGuards 97 98 @property 99 def suppress(self): 100 raise AttributeError("suppress is write only") 101 102 @suppress.setter 103 def suppress(self, value: bool): 104 self._headerFileModule.suppress = self._implOnly or value 105 self._implFileModule.suppress = self._headerOnly or value 106 107 @property 108 def headerPreamble(self) -> str: 109 return self._headerFileModule.preamble 110 111 @headerPreamble.setter 112 def headerPreamble(self, value: str): 113 self._headerFileModule.preamble = value 114 115 @property 116 def headerPostamble(self) -> str: 117 return self._headerFileModule.postamble 118 119 @headerPostamble.setter 120 def headerPostamble(self, value: str): 121 self._headerFileModule.postamble = value 122 123 @property 124 def implPreamble(self) -> str: 125 return self._implFileModule.preamble 126 127 @implPreamble.setter 128 def implPreamble(self, value: str): 129 self._implFileModule.preamble = value 130 131 @property 132 def implPostamble(self) -> str: 133 return self._implFileModule.postamble 134 135 @implPostamble.setter 136 def implPostamble(self, value: str): 137 self._implFileModule.postamble = value 138 139 def getMakefileSrcEntry(self): 140 if self._customAbsDir: 141 return self.basename + ".cpp \\\n" 142 dirName = self.directory 143 baseName = self.basename 144 joined = os.path.join(dirName, baseName) 145 return " " + joined + ".cpp \\\n" 146 147 def getCMakeSrcEntry(self): 148 if self._customAbsDir: 149 return "\n" + self.basename + ".cpp " 150 dirName = Path(self.directory) 151 baseName = Path(self.basename) 152 joined = PurePosixPath(dirName / baseName) 153 return "\n " + str(joined) + ".cpp " 154 155 def begin(self, globalDir): 156 self._headerFileModule.begin(globalDir) 157 self._implFileModule.begin(globalDir) 158 159 def appendHeader(self, toAppend): 160 self._headerFileModule.append(toAppend) 161 162 def appendImpl(self, toAppend): 163 self._implFileModule.append(toAppend) 164 165 def end(self): 166 self._headerFileModule.end() 167 self._implFileModule.end() 168 169 clang_format_command = shutil.which('clang-format') 170 assert (clang_format_command is not None) 171 172 def formatFile(filename: Path): 173 assert (subprocess.call([clang_format_command, "-i", 174 "--style=file", str(filename.resolve())]) == 0) 175 176 if not self._headerFileModule.suppress: 177 formatFile(Path(self._headerFileModule.file.name)) 178 179 if not self._implFileModule.suppress: 180 formatFile(Path(self._implFileModule.file.name)) 181 182 183class PyScript(SingleFileModule): 184 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 185 super().__init__(".py", directory, basename, customAbsDir, suppress) 186 187 188# Class capturing a .proto protobuf definition file 189class Proto(SingleFileModule): 190 191 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 192 super().__init__(".proto", directory, basename, customAbsDir, suppress) 193 194 def getMakefileSrcEntry(self): 195 super().getMakefileSrcEntry() 196 if self.customAbsDir: 197 return self.basename + ".proto \\\n" 198 dirName = self.directory 199 baseName = self.basename 200 joined = os.path.join(dirName, baseName) 201 return " " + joined + ".proto \\\n" 202 203 def getCMakeSrcEntry(self): 204 super().getCMakeSrcEntry() 205 if self.customAbsDir: 206 return "\n" + self.basename + ".proto " 207 208 dirName = self.directory 209 baseName = self.basename 210 joined = os.path.join(dirName, baseName) 211 return "\n " + joined + ".proto " 212 213class CodeGen(object): 214 215 def __init__(self,): 216 self.code = "" 217 self.indentLevel = 0 218 self.gensymCounter = [-1] 219 220 def var(self, prefix="cgen_var"): 221 self.gensymCounter[-1] += 1 222 res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0)) 223 return res 224 225 def swapCode(self,): 226 res = "%s" % self.code 227 self.code = "" 228 return res 229 230 def indent(self,extra=0): 231 return "".join(" " * (self.indentLevel + extra)) 232 233 def incrIndent(self,): 234 self.indentLevel += 1 235 236 def decrIndent(self,): 237 if self.indentLevel > 0: 238 self.indentLevel -= 1 239 240 def beginBlock(self, bracketPrint=True): 241 if bracketPrint: 242 self.code += self.indent() + "{\n" 243 self.indentLevel += 1 244 self.gensymCounter.append(-1) 245 246 def endBlock(self,bracketPrint=True): 247 self.indentLevel -= 1 248 if bracketPrint: 249 self.code += self.indent() + "}\n" 250 del self.gensymCounter[-1] 251 252 def beginIf(self, cond): 253 self.code += self.indent() + "if (" + cond + ")\n" 254 self.beginBlock() 255 256 def beginElse(self, cond = None): 257 if cond is not None: 258 self.code += \ 259 self.indent() + \ 260 "else if (" + cond + ")\n" 261 else: 262 self.code += self.indent() + "else\n" 263 self.beginBlock() 264 265 def endElse(self): 266 self.endBlock() 267 268 def endIf(self): 269 self.endBlock() 270 271 def beginSwitch(self, switchvar): 272 self.code += self.indent() + "switch (" + switchvar + ")\n" 273 self.beginBlock() 274 275 def switchCase(self, switchval, blocked = False): 276 self.code += self.indent() + "case %s:" % switchval 277 self.beginBlock(bracketPrint = blocked) 278 279 def switchCaseBreak(self, switchval, blocked = False): 280 self.code += self.indent() + "case %s:" % switchval 281 self.endBlock(bracketPrint = blocked) 282 283 def switchCaseDefault(self, blocked = False): 284 self.code += self.indent() + "default:" % switchval 285 self.beginBlock(bracketPrint = blocked) 286 287 def endSwitch(self): 288 self.endBlock() 289 290 def beginWhile(self, cond): 291 self.code += self.indent() + "while (" + cond + ")\n" 292 self.beginBlock() 293 294 def endWhile(self): 295 self.endBlock() 296 297 def beginFor(self, initial, condition, increment): 298 self.code += \ 299 self.indent() + "for (" + \ 300 "; ".join([initial, condition, increment]) + \ 301 ")\n" 302 self.beginBlock() 303 304 def endFor(self): 305 self.endBlock() 306 307 def beginLoop(self, loopVarType, loopVar, loopInit, loopBound): 308 self.beginFor( 309 "%s %s = %s" % (loopVarType, loopVar, loopInit), 310 "%s < %s" % (loopVar, loopBound), 311 "++%s" % (loopVar)) 312 313 def endLoop(self): 314 self.endBlock() 315 316 def stmt(self, code): 317 self.code += self.indent() + code + ";\n" 318 319 def line(self, code): 320 self.code += self.indent() + code + "\n" 321 322 def leftline(self, code): 323 self.code += code + "\n" 324 325 def makeCallExpr(self, funcName, parameters): 326 return funcName + "(%s)" % (", ".join(parameters)) 327 328 def funcCall(self, lhs, funcName, parameters): 329 res = self.indent() 330 331 if lhs is not None: 332 res += lhs + " = " 333 334 res += self.makeCallExpr(funcName, parameters) + ";\n" 335 self.code += res 336 337 def funcCallRet(self, _lhs, funcName, parameters): 338 res = self.indent() 339 res += "return " + self.makeCallExpr(funcName, parameters) + ";\n" 340 self.code += res 341 342 # Given a VulkanType object, generate a C type declaration 343 # with optional parameter name: 344 # [const] [typename][*][const*] [paramName] 345 def makeCTypeDecl(self, vulkanType, useParamName=True): 346 constness = "const " if vulkanType.isConst else "" 347 typeName = vulkanType.typeName 348 349 if vulkanType.pointerIndirectionLevels == 0: 350 ptrSpec = "" 351 elif vulkanType.isPointerToConstPointer: 352 ptrSpec = "* const*" if vulkanType.isConst else "**" 353 if vulkanType.pointerIndirectionLevels > 2: 354 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 355 else: 356 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 357 358 if useParamName and (vulkanType.paramName is not None): 359 paramStr = (" " + vulkanType.paramName) 360 else: 361 paramStr = "" 362 363 return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr) 364 365 def makeRichCTypeDecl(self, vulkanType, useParamName=True): 366 constness = "const " if vulkanType.isConst else "" 367 typeName = vulkanType.typeName 368 369 if vulkanType.pointerIndirectionLevels == 0: 370 ptrSpec = "" 371 elif vulkanType.isPointerToConstPointer: 372 ptrSpec = "* const*" if vulkanType.isConst else "**" 373 if vulkanType.pointerIndirectionLevels > 2: 374 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 375 else: 376 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 377 378 if useParamName and (vulkanType.paramName is not None): 379 paramStr = (" " + vulkanType.paramName) 380 else: 381 paramStr = "" 382 383 if vulkanType.staticArrExpr: 384 staticArrInfo = "[%s]" % vulkanType.staticArrExpr 385 else: 386 staticArrInfo = "" 387 388 return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo) 389 390 # Given a VulkanAPI object, generate the C function protype: 391 # <returntype> <funcname>(<parameters>) 392 def makeFuncProto(self, vulkanApi, useParamName=True): 393 394 protoBegin = "%s %s" % (self.makeCTypeDecl( 395 vulkanApi.retType, useParamName=False), vulkanApi.name) 396 397 def getFuncArgDecl(param): 398 if param.staticArrExpr: 399 return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr) 400 else: 401 return self.makeCTypeDecl(param, useParamName=useParamName) 402 403 protoParams = "(\n %s)" % ((",\n%s" % self.indent(1)).join( 404 list(map( 405 getFuncArgDecl, 406 vulkanApi.parameters)))) 407 408 return protoBegin + protoParams 409 410 def makeFuncAlias(self, nameDst, nameSrc): 411 return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst) 412 413 def makeFuncDecl(self, vulkanApi): 414 return self.makeFuncProto(vulkanApi) + ";\n\n" 415 416 def makeFuncImpl(self, vulkanApi, codegenFunc): 417 self.swapCode() 418 419 self.line(self.makeFuncProto(vulkanApi)) 420 self.beginBlock() 421 codegenFunc(self) 422 self.endBlock() 423 424 return self.swapCode() + "\n" 425 426 def emitFuncImpl(self, vulkanApi, codegenFunc): 427 self.line(self.makeFuncProto(vulkanApi)) 428 self.beginBlock() 429 codegenFunc(self) 430 self.endBlock() 431 432 def makeStructAccess(self, 433 vulkanType, 434 structVarName, 435 asPtr=True, 436 structAsPtr=True, 437 accessIndex=None): 438 439 deref = "->" if structAsPtr else "." 440 441 indexExpr = (" + %s" % accessIndex) if accessIndex else "" 442 443 addrOfExpr = "" if vulkanType.accessibleAsPointer() or ( 444 not asPtr) else "&" 445 446 return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref, 447 vulkanType.paramName, indexExpr) 448 449 def makeRawLengthAccess(self, vulkanType): 450 lenExpr = vulkanType.getLengthExpression() 451 452 if not lenExpr: 453 return None, None 454 455 if lenExpr == "null-terminated": 456 return "strlen(%s)" % vulkanType.paramName, None 457 458 return lenExpr, None 459 460 def makeLengthAccessFromStruct(self, 461 structInfo, 462 vulkanType, 463 structVarName, 464 asPtr=True): 465 # Handle special cases first 466 # Mostly when latexmath is involved 467 def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr): 468 cases = [ 469 { 470 "structName": "VkShaderModuleCreateInfo", 471 "field": "pCode", 472 "lenExprMember": "codeSize", 473 "postprocess": lambda expr: "(%s / 4)" % expr 474 }, 475 { 476 "structName": "VkPipelineMultisampleStateCreateInfo", 477 "field": "pSampleMask", 478 "lenExprMember": "rasterizationSamples", 479 "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr 480 }, 481 { 482 "structName": "VkAccelerationStructureVersionInfoKHR", 483 "field": "pVersionData", 484 "lenExprMember": "", 485 "postprocess": lambda _: "2*VK_UUID_SIZE" 486 }, 487 ] 488 489 for c in cases: 490 if (structInfo.name, vulkanType.paramName) == (c["structName"], 491 c["field"]): 492 deref = "->" if asPtr else "." 493 expr = "%s%s%s" % (structVarName, deref, 494 c["lenExprMember"]) 495 lenAccessGuardExpr = "%s" % structVarName 496 return c["postprocess"](expr), lenAccessGuardExpr 497 498 return None, None 499 500 specialCaseAccess = \ 501 handleSpecialCases( 502 structInfo, vulkanType, structVarName, asPtr) 503 504 if specialCaseAccess != (None, None): 505 return specialCaseAccess 506 507 lenExpr = vulkanType.getLengthExpression() 508 509 if not lenExpr: 510 return None, None 511 512 deref = "->" if asPtr else "." 513 lenAccessGuardExpr = "%s" % ( 514 515 structVarName) if deref else None 516 if lenExpr == "null-terminated": 517 return "strlen(%s%s%s)" % (structVarName, deref, 518 vulkanType.paramName), lenAccessGuardExpr 519 520 if not structInfo.getMember(lenExpr): 521 return self.makeRawLengthAccess(vulkanType) 522 523 return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr 524 525 def makeLengthAccessFromApi(self, api, vulkanType): 526 # Handle special cases first 527 # Mostly when :: is involved 528 def handleSpecialCases(vulkanType): 529 lenExpr = vulkanType.getLengthExpression() 530 531 if lenExpr is None: 532 return None, None 533 534 if "::" in lenExpr: 535 structVarName, memberVarName = lenExpr.split("::") 536 lenAccessGuardExpr = "%s" % (structVarName) 537 return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr 538 return None, None 539 540 specialCaseAccess = handleSpecialCases(vulkanType) 541 542 if specialCaseAccess != (None, None): 543 return specialCaseAccess 544 545 lenExpr = vulkanType.getLengthExpression() 546 547 if not lenExpr: 548 return None, None 549 550 lenExprInfo = api.getParameter(lenExpr) 551 552 if not lenExprInfo: 553 return self.makeRawLengthAccess(vulkanType) 554 555 if lenExpr == "null-terminated": 556 return "strlen(%s)" % vulkanType.paramName(), None 557 else: 558 deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else "" 559 lenAccessGuardExpr = "%s" % lenExpr if deref else None 560 return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr 561 562 def accessParameter(self, param, asPtr=True): 563 if asPtr: 564 if param.pointerIndirectionLevels > 0: 565 return param.paramName 566 else: 567 return "&%s" % param.paramName 568 else: 569 return param.paramName 570 571 def sizeofExpr(self, vulkanType): 572 return "sizeof(%s)" % ( 573 self.makeCTypeDecl(vulkanType, useParamName=False)) 574 575 def generalAccess(self, 576 vulkanType, 577 parentVarName=None, 578 asPtr=True, 579 structAsPtr=True): 580 if vulkanType.parent is None: 581 if parentVarName is None: 582 return self.accessParameter(vulkanType, asPtr=asPtr) 583 else: 584 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 585 586 if isinstance(vulkanType.parent, VulkanCompoundType): 587 return self.makeStructAccess( 588 vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr) 589 590 if isinstance(vulkanType.parent, VulkanAPI): 591 if parentVarName is None: 592 return self.accessParameter(vulkanType, asPtr=asPtr) 593 else: 594 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 595 596 os.abort("Could not find a way to access Vulkan type %s" % 597 vulkanType.name) 598 599 def makeLengthAccess(self, vulkanType, parentVarName="parent"): 600 if vulkanType.parent is None: 601 return self.makeRawLengthAccess(vulkanType) 602 603 if isinstance(vulkanType.parent, VulkanCompoundType): 604 return self.makeLengthAccessFromStruct( 605 vulkanType.parent, vulkanType, parentVarName, asPtr=True) 606 607 if isinstance(vulkanType.parent, VulkanAPI): 608 return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType) 609 610 os.abort("Could not find a way to access length of Vulkan type %s" % 611 vulkanType.name) 612 613 def generalLengthAccess(self, vulkanType, parentVarName="parent"): 614 return self.makeLengthAccess(vulkanType, parentVarName)[0] 615 616 def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"): 617 return self.makeLengthAccess(vulkanType, parentVarName)[1] 618 619 def vkApiCall(self, api, customPrefix="", globalStatePrefix="", customParameters=None, checkForDeviceLost=False, checkForOutOfMemory=False): 620 callLhs = None 621 622 retTypeName = api.getRetTypeExpr() 623 retVar = None 624 625 if retTypeName != "void": 626 retVar = api.getRetVarExpr() 627 self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName)) 628 callLhs = retVar 629 630 if customParameters is None: 631 self.funcCall( 632 callLhs, customPrefix + api.name, [p.paramName for p in api.parameters]) 633 else: 634 self.funcCall( 635 callLhs, customPrefix + api.name, customParameters) 636 637 if retTypeName == "VkResult" and checkForDeviceLost: 638 self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix)) 639 640 if retTypeName == "VkResult" and checkForOutOfMemory: 641 if api.name == "vkAllocateMemory": 642 self.stmt( 643 "%sCheckOutOfMemory(%s, opcode, context, std::make_optional<uint64_t>(pAllocateInfo->allocationSize))" 644 % (globalStatePrefix, callLhs)) 645 else: 646 self.stmt( 647 "%sCheckOutOfMemory(%s, opcode, context)" 648 % (globalStatePrefix, callLhs)) 649 650 return (retTypeName, retVar) 651 652 def makeCheckVkSuccess(self, expr): 653 return "((%s) == VK_SUCCESS)" % expr 654 655 def makeReinterpretCast(self, varName, typeName, const=True): 656 return "reinterpret_cast<%s%s*>(%s)" % \ 657 ("const " if const else "", typeName, varName) 658 659 def validPrimitive(self, typeInfo, typeName): 660 size = typeInfo.getPrimitiveEncodingSize(typeName) 661 return size != None 662 663 def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"): 664 if not self.validPrimitive(typeInfo, typeName): 665 return None 666 667 size = typeInfo.getPrimitiveEncodingSize(typeName) 668 prefix = "put" if direction == "write" else "get" 669 suffix = None 670 if size == 1: 671 suffix = "Byte" 672 elif size == 2: 673 suffix = "Be16" 674 elif size == 4: 675 suffix = "Be32" 676 elif size == 8: 677 suffix = "Be64" 678 679 if suffix: 680 return prefix + suffix 681 682 return None 683 684 def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"): 685 if not self.validPrimitive(typeInfo, typeName): 686 return None 687 688 size = typeInfo.getPrimitiveEncodingSize(typeName) 689 prefix = "to" if direction == "write" else "from" 690 suffix = None 691 if size == 1: 692 suffix = "Byte" 693 elif size == 2: 694 suffix = "Be16" 695 elif size == 4: 696 suffix = "Be32" 697 elif size == 8: 698 suffix = "Be64" 699 700 if suffix: 701 return prefix + suffix 702 703 return None 704 705 def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"): 706 accessTypeName = accessType.typeName 707 708 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 709 print("Tried to stream a non-primitive type: %s" % accessTypeName) 710 os.abort() 711 712 needPtrCast = False 713 714 if accessType.pointerIndirectionLevels > 0: 715 streamSize = 8 716 streamStorageVarType = "uint64_t" 717 needPtrCast = True 718 streamMethod = "putBe64" if direction == "write" else "getBe64" 719 else: 720 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 721 if streamSize == 1: 722 streamStorageVarType = "uint8_t" 723 elif streamSize == 2: 724 streamStorageVarType = "uint16_t" 725 elif streamSize == 4: 726 streamStorageVarType = "uint32_t" 727 elif streamSize == 8: 728 streamStorageVarType = "uint64_t" 729 streamMethod = self.makePrimitiveStreamMethod( 730 typeInfo, accessTypeName, direction=direction) 731 732 streamStorageVar = self.var() 733 734 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 735 736 ptrCast = "(uintptr_t)" if needPtrCast else "" 737 738 if direction == "read": 739 self.stmt("%s = (%s)%s%s->%s()" % 740 (accessExpr, 741 accessCast, 742 ptrCast, 743 streamVar, 744 streamMethod)) 745 else: 746 self.stmt("%s %s = (%s)%s%s" % 747 (streamStorageVarType, streamStorageVar, 748 streamStorageVarType, ptrCast, accessExpr)) 749 self.stmt("%s->%s(%s)" % 750 (streamVar, streamMethod, streamStorageVar)) 751 752 def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, variant, direction="write"): 753 accessTypeName = accessType.typeName 754 755 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 756 print("Tried to stream a non-primitive type: %s" % accessTypeName) 757 os.abort() 758 759 needPtrCast = False 760 761 streamSize = 8 762 763 if accessType.pointerIndirectionLevels > 0: 764 streamSize = 8 765 streamStorageVarType = "uint64_t" 766 needPtrCast = True 767 streamMethod = "toBe64" if direction == "write" else "fromBe64" 768 else: 769 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 770 if streamSize == 1: 771 streamStorageVarType = "uint8_t" 772 elif streamSize == 2: 773 streamStorageVarType = "uint16_t" 774 elif streamSize == 4: 775 streamStorageVarType = "uint32_t" 776 elif streamSize == 8: 777 streamStorageVarType = "uint64_t" 778 streamMethod = self.makePrimitiveStreamMethodInPlace( 779 typeInfo, accessTypeName, direction=direction) 780 781 streamStorageVar = self.var() 782 783 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 784 785 if direction == "read": 786 accessCast = self.makeRichCTypeDecl( 787 accessType.getForNonConstAccess(), useParamName=False) 788 789 ptrCast = "(uintptr_t)" if needPtrCast else "" 790 791 streamNamespace = "gfxstream::guest" if variant == "guest" else "android::base" 792 793 if direction == "read": 794 self.stmt("memcpy((%s*)&%s, %s, %s)" % 795 (accessCast, 796 accessExpr, 797 streamVar, 798 str(streamSize))) 799 self.stmt("%s::Stream::%s((uint8_t*)&%s)" % ( 800 streamNamespace, 801 streamMethod, 802 accessExpr)) 803 else: 804 self.stmt("%s %s = (%s)%s%s" % 805 (streamStorageVarType, streamStorageVar, 806 streamStorageVarType, ptrCast, accessExpr)) 807 self.stmt("memcpy(%s, &%s, %s)" % 808 (streamVar, streamStorageVar, str(streamSize))) 809 self.stmt("%s::Stream::%s((uint8_t*)%s)" % ( 810 streamNamespace, 811 streamMethod, 812 streamVar)) 813 814 def countPrimitive(self, typeInfo, accessType): 815 accessTypeName = accessType.typeName 816 817 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 818 print("Tried to count a non-primitive type: %s" % accessTypeName) 819 os.abort() 820 821 needPtrCast = False 822 823 if accessType.pointerIndirectionLevels > 0: 824 streamSize = 8 825 else: 826 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 827 828 return streamSize 829 830# Class to wrap a Vulkan API call. 831# 832# The user gives a generic callback, |codegenDef|, 833# that takes a CodeGen object and a VulkanAPI object as arguments. 834# codegenDef uses CodeGen along with the VulkanAPI object 835# to generate the function body. 836class VulkanAPIWrapper(object): 837 838 def __init__(self, 839 customApiPrefix, 840 extraParameters=None, 841 returnTypeOverride=None, 842 codegenDef=None): 843 self.customApiPrefix = customApiPrefix 844 self.extraParameters = extraParameters 845 self.returnTypeOverride = returnTypeOverride 846 847 self.codegen = CodeGen() 848 849 self.definitionFunc = codegenDef 850 851 # Private function 852 853 def makeApiFunc(self, typeInfo, apiName): 854 customApi = copy(typeInfo.apis[apiName]) 855 customApi.name = self.customApiPrefix + customApi.name 856 if self.extraParameters is not None: 857 if isinstance(self.extraParameters, list): 858 customApi.parameters = \ 859 self.extraParameters + customApi.parameters 860 else: 861 os.abort( 862 "Type of extra parameters to custom API not valid. Expected list, got %s" % type( 863 self.extraParameters)) 864 865 if self.returnTypeOverride is not None: 866 customApi.retType = self.returnTypeOverride 867 return customApi 868 869 self.makeApi = makeApiFunc 870 871 def setCodegenDef(self, codegenDefFunc): 872 self.definitionFunc = codegenDefFunc 873 874 def makeDecl(self, typeInfo, apiName): 875 return self.codegen.makeFuncProto( 876 self.makeApi(self, typeInfo, apiName)) + ";\n\n" 877 878 def makeDefinition(self, typeInfo, apiName, isStatic=False): 879 vulkanApi = self.makeApi(self, typeInfo, apiName) 880 881 self.codegen.swapCode() 882 self.codegen.beginBlock() 883 884 if self.definitionFunc is None: 885 print("ERROR: No definition found for (%s, %s)" % 886 (vulkanApi.name, self.customApiPrefix)) 887 sys.exit(1) 888 889 self.definitionFunc(self.codegen, vulkanApi) 890 891 self.codegen.endBlock() 892 893 return ("static " if isStatic else "") + self.codegen.makeFuncProto( 894 vulkanApi) + "\n" + self.codegen.swapCode() + "\n" 895 896# Base class for wrapping all Vulkan API objects. These work with Vulkan 897# Registry generators and have gen* triggers. They tend to contain 898# VulkanAPIWrapper objects to make it easier to generate the code. 899class VulkanWrapperGenerator(object): 900 901 def __init__(self, module: Module, typeInfo: VulkanTypeInfo): 902 self.module: Module = module 903 self.typeInfo: VulkanTypeInfo = typeInfo 904 self.extensionStructTypes = OrderedDict() 905 906 def onBegin(self): 907 pass 908 909 def onEnd(self): 910 pass 911 912 def onBeginFeature(self, featureName, featureType): 913 pass 914 915 def onFeatureNewCmd(self, cmdName): 916 pass 917 918 def onEndFeature(self): 919 pass 920 921 def onGenType(self, typeInfo, name, alias): 922 category = self.typeInfo.categoryOf(name) 923 if category in ["struct", "union"] and not alias: 924 structInfo = self.typeInfo.structs[name] 925 if structInfo.structExtendsExpr: 926 self.extensionStructTypes[name] = structInfo 927 pass 928 929 def onGenStruct(self, typeInfo, name, alias): 930 pass 931 932 def onGenGroup(self, groupinfo, groupName, alias=None): 933 pass 934 935 def onGenEnum(self, enuminfo, name, alias): 936 pass 937 938 def onGenCmd(self, cmdinfo, name, alias): 939 pass 940 941 # Below Vulkan structure types may correspond to multiple Vulkan structs 942 # due to a conflict between different Vulkan registries. In order to get 943 # the correct Vulkan struct type, we need to check the type of its "root" 944 # struct as well. 945 ROOT_TYPE_MAPPING = { 946 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": { 947 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 948 "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 949 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE", 950 "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 951 }, 952 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": { 953 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 954 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkCreateBlobGOOGLE", 955 "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 956 }, 957 "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": { 958 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT", 959 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT", 960 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE", 961 "default": "VkRenderPassFragmentDensityMapCreateInfoEXT", 962 }, 963 } 964 965 def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None): 966 def readStructType(structTypeName, structVarName, cgen): 967 cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \ 968 (structTypeName, "goldfish_vk_struct_type", structVarName)) 969 970 def castAsStruct(varName, typeName, const=True): 971 return "reinterpret_cast<%s%s*>(%s)" % \ 972 ("const " if const else "", typeName, varName) 973 974 def doDefaultReturn(cgen): 975 if retType.typeName == "void": 976 cgen.stmt("return") 977 else: 978 cgen.stmt("return (%s)0" % retType.typeName) 979 980 cgen.beginIf("!%s" % triggerVar.paramName) 981 if nullEmit is None: 982 doDefaultReturn(cgen) 983 else: 984 nullEmit(cgen) 985 cgen.endIf() 986 987 readStructType("structType", triggerVar.paramName, cgen) 988 989 cgen.line("switch(structType)") 990 cgen.beginBlock() 991 992 currFeature = None 993 994 for ext in self.extensionStructTypes.values(): 995 if not currFeature: 996 cgen.leftline("#ifdef %s" % ext.feature) 997 currFeature = ext.feature 998 999 if currFeature and ext.feature != currFeature: 1000 cgen.leftline("#endif") 1001 cgen.leftline("#ifdef %s" % ext.feature) 1002 currFeature = ext.feature 1003 1004 enum = ext.structEnumExpr 1005 protect = None 1006 if enum in self.typeInfo.enumElem: 1007 protect = self.typeInfo.enumElem[enum].get("protect", default=None) 1008 if protect is not None: 1009 cgen.leftline("#ifdef %s" % protect) 1010 1011 cgen.line("case %s:" % enum) 1012 cgen.beginBlock() 1013 1014 if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING: 1015 cgen.line("switch(%s)" % rootTypeVar.paramName) 1016 cgen.beginBlock() 1017 kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum] 1018 for k in kv: 1019 v = self.extensionStructTypes[kv[k]] 1020 if k == "default": 1021 cgen.line("%s:" % k) 1022 else: 1023 cgen.line("case %s:" % k) 1024 cgen.beginBlock() 1025 castedAccess = castAsStruct( 1026 triggerVar.paramName, v.name, const=triggerVar.isConst) 1027 forEachFunc(v, castedAccess, cgen) 1028 cgen.line("break;") 1029 cgen.endBlock() 1030 cgen.endBlock() 1031 else: 1032 castedAccess = castAsStruct( 1033 triggerVar.paramName, ext.name, const=triggerVar.isConst) 1034 forEachFunc(ext, castedAccess, cgen) 1035 1036 if autoBreak: 1037 cgen.stmt("break") 1038 cgen.endBlock() 1039 1040 if protect is not None: 1041 cgen.leftline("#endif // %s" % protect) 1042 1043 if currFeature: 1044 cgen.leftline("#endif") 1045 1046 cgen.line("default:") 1047 cgen.beginBlock() 1048 if defaultEmit is None: 1049 doDefaultReturn(cgen) 1050 else: 1051 defaultEmit(cgen) 1052 cgen.endBlock() 1053 1054 cgen.endBlock() 1055 1056 def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False): 1057 currFeature = None 1058 1059 for (i, ext) in enumerate(self.extensionStructTypes.values()): 1060 if doFeatureIfdefs: 1061 if not currFeature: 1062 cgen.leftline("#ifdef %s" % ext.feature) 1063 currFeature = ext.feature 1064 1065 if currFeature and ext.feature != currFeature: 1066 cgen.leftline("#endif") 1067 cgen.leftline("#ifdef %s" % ext.feature) 1068 currFeature = ext.feature 1069 1070 forEachFunc(i, ext, cgen) 1071 1072 if doFeatureIfdefs: 1073 if currFeature: 1074 cgen.leftline("#endif") 1075