1# Copyright (c) 2018 The Android Open Source Project 2# Copyright (c) 2018 Google Inc. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI 17from collections import OrderedDict 18from copy import copy 19from pathlib import Path, PurePosixPath 20 21import os 22import sys 23import shutil 24import subprocess 25 26# Class capturing a single file 27 28 29class SingleFileModule(object): 30 def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False): 31 self.directory = directory 32 self.basename = basename 33 self.customAbsDir = customAbsDir 34 self.suffix = suffix 35 self.file = None 36 37 self.preamble = "" 38 self.postamble = "" 39 40 self.suppress = suppress 41 42 def begin(self, globalDir): 43 if self.suppress: 44 return 45 46 # Create subdirectory, if needed 47 if self.customAbsDir: 48 absDir = self.customAbsDir 49 else: 50 absDir = os.path.join(globalDir, self.directory) 51 52 filename = os.path.join(absDir, self.basename) 53 54 self.file = open(filename + self.suffix, "w", encoding="utf-8") 55 self.file.write(self.preamble) 56 57 def append(self, toAppend): 58 if self.suppress: 59 return 60 61 self.file.write(toAppend) 62 63 def end(self): 64 if self.suppress: 65 return 66 67 self.file.write(self.postamble) 68 self.file.close() 69 70 def getMakefileSrcEntry(self): 71 return "" 72 73 def getCMakeSrcEntry(self): 74 return "" 75 76# Class capturing a .cpp file and a .h file (a "C++ module") 77 78 79class Module(object): 80 81 def __init__( 82 self, directory, basename, customAbsDir=None, suppress=False, implOnly=False, 83 headerOnly=False, suppressFeatureGuards=False): 84 self._headerFileModule = SingleFileModule( 85 ".h", directory, basename, customAbsDir, suppress or implOnly) 86 self._implFileModule = SingleFileModule( 87 ".cpp", directory, basename, customAbsDir, suppress or headerOnly) 88 89 self._headerOnly = headerOnly 90 self._implOnly = implOnly 91 92 self.directory = directory 93 self.basename = basename 94 self._customAbsDir = customAbsDir 95 96 self.suppressFeatureGuards = suppressFeatureGuards 97 98 @property 99 def suppress(self): 100 raise AttributeError("suppress is write only") 101 102 @suppress.setter 103 def suppress(self, value: bool): 104 self._headerFileModule.suppress = self._implOnly or value 105 self._implFileModule.suppress = self._headerOnly or value 106 107 @property 108 def headerPreamble(self) -> str: 109 return self._headerFileModule.preamble 110 111 @headerPreamble.setter 112 def headerPreamble(self, value: str): 113 self._headerFileModule.preamble = value 114 115 @property 116 def headerPostamble(self) -> str: 117 return self._headerFileModule.postamble 118 119 @headerPostamble.setter 120 def headerPostamble(self, value: str): 121 self._headerFileModule.postamble = value 122 123 @property 124 def implPreamble(self) -> str: 125 return self._implFileModule.preamble 126 127 @implPreamble.setter 128 def implPreamble(self, value: str): 129 self._implFileModule.preamble = value 130 131 @property 132 def implPostamble(self) -> str: 133 return self._implFileModule.postamble 134 135 @implPostamble.setter 136 def implPostamble(self, value: str): 137 self._implFileModule.postamble = value 138 139 def getMakefileSrcEntry(self): 140 if self._customAbsDir: 141 return self.basename + ".cpp \\\n" 142 dirName = self.directory 143 baseName = self.basename 144 joined = os.path.join(dirName, baseName) 145 return " " + joined + ".cpp \\\n" 146 147 def getCMakeSrcEntry(self): 148 if self._customAbsDir: 149 return "\n" + self.basename + ".cpp " 150 dirName = Path(self.directory) 151 baseName = Path(self.basename) 152 joined = PurePosixPath(dirName / baseName) 153 return "\n " + str(joined) + ".cpp " 154 155 def begin(self, globalDir): 156 self._headerFileModule.begin(globalDir) 157 self._implFileModule.begin(globalDir) 158 159 def appendHeader(self, toAppend): 160 self._headerFileModule.append(toAppend) 161 162 def appendImpl(self, toAppend): 163 self._implFileModule.append(toAppend) 164 165 def end(self): 166 self._headerFileModule.end() 167 self._implFileModule.end() 168 169 clang_format_command = shutil.which('clang-format') 170 assert (clang_format_command is not None) 171 172 def formatFile(filename: Path): 173 assert (subprocess.call([clang_format_command, "-i", 174 "--style=file", str(filename.resolve())]) == 0) 175 176 if not self._headerFileModule.suppress: 177 formatFile(Path(self._headerFileModule.file.name)) 178 179 if not self._implFileModule.suppress: 180 formatFile(Path(self._implFileModule.file.name)) 181 182 183class PyScript(SingleFileModule): 184 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 185 super().__init__(".py", directory, basename, customAbsDir, suppress) 186 187 188# Class capturing a .proto protobuf definition file 189class Proto(SingleFileModule): 190 191 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 192 super().__init__(".proto", directory, basename, customAbsDir, suppress) 193 194 def getMakefileSrcEntry(self): 195 super().getMakefileSrcEntry() 196 if self.customAbsDir: 197 return self.basename + ".proto \\\n" 198 dirName = self.directory 199 baseName = self.basename 200 joined = os.path.join(dirName, baseName) 201 return " " + joined + ".proto \\\n" 202 203 def getCMakeSrcEntry(self): 204 super().getCMakeSrcEntry() 205 if self.customAbsDir: 206 return "\n" + self.basename + ".proto " 207 208 dirName = self.directory 209 baseName = self.basename 210 joined = os.path.join(dirName, baseName) 211 return "\n " + joined + ".proto " 212 213class CodeGen(object): 214 215 def __init__(self,): 216 self.code = "" 217 self.indentLevel = 0 218 self.gensymCounter = [-1] 219 220 def var(self, prefix="cgen_var"): 221 self.gensymCounter[-1] += 1 222 res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0)) 223 return res 224 225 def swapCode(self,): 226 res = "%s" % self.code 227 self.code = "" 228 return res 229 230 def indent(self,extra=0): 231 return "".join(" " * (self.indentLevel + extra)) 232 233 def incrIndent(self,): 234 self.indentLevel += 1 235 236 def decrIndent(self,): 237 if self.indentLevel > 0: 238 self.indentLevel -= 1 239 240 def beginBlock(self, bracketPrint=True): 241 if bracketPrint: 242 self.code += self.indent() + "{\n" 243 self.indentLevel += 1 244 self.gensymCounter.append(-1) 245 246 def endBlock(self,bracketPrint=True): 247 self.indentLevel -= 1 248 if bracketPrint: 249 self.code += self.indent() + "}\n" 250 del self.gensymCounter[-1] 251 252 def beginIf(self, cond): 253 self.code += self.indent() + "if (" + cond + ")\n" 254 self.beginBlock() 255 256 def beginElse(self, cond = None): 257 if cond is not None: 258 self.code += \ 259 self.indent() + \ 260 "else if (" + cond + ")\n" 261 else: 262 self.code += self.indent() + "else\n" 263 self.beginBlock() 264 265 def endElse(self): 266 self.endBlock() 267 268 def endIf(self): 269 self.endBlock() 270 271 def beginSwitch(self, switchvar): 272 self.code += self.indent() + "switch (" + switchvar + ")\n" 273 self.beginBlock() 274 275 def switchCase(self, switchval, blocked = False): 276 self.code += self.indent() + "case %s:" % switchval 277 self.beginBlock(bracketPrint = blocked) 278 279 def switchCaseBreak(self, switchval, blocked = False): 280 self.code += self.indent() + "case %s:" % switchval 281 self.endBlock(bracketPrint = blocked) 282 283 def switchCaseDefault(self, blocked = False): 284 self.code += self.indent() + "default:" % switchval 285 self.beginBlock(bracketPrint = blocked) 286 287 def endSwitch(self): 288 self.endBlock() 289 290 def beginWhile(self, cond): 291 self.code += self.indent() + "while (" + cond + ")\n" 292 self.beginBlock() 293 294 def endWhile(self): 295 self.endBlock() 296 297 def beginFor(self, initial, condition, increment): 298 self.code += \ 299 self.indent() + "for (" + \ 300 "; ".join([initial, condition, increment]) + \ 301 ")\n" 302 self.beginBlock() 303 304 def endFor(self): 305 self.endBlock() 306 307 def beginLoop(self, loopVarType, loopVar, loopInit, loopBound): 308 self.beginFor( 309 "%s %s = %s" % (loopVarType, loopVar, loopInit), 310 "%s < %s" % (loopVar, loopBound), 311 "++%s" % (loopVar)) 312 313 def endLoop(self): 314 self.endBlock() 315 316 def stmt(self, code): 317 self.code += self.indent() + code + ";\n" 318 319 def line(self, code): 320 self.code += self.indent() + code + "\n" 321 322 def leftline(self, code): 323 self.code += code + "\n" 324 325 def makeCallExpr(self, funcName, parameters): 326 return funcName + "(%s)" % (", ".join(parameters)) 327 328 def funcCall(self, lhs, funcName, parameters): 329 res = self.indent() 330 331 if lhs is not None: 332 res += lhs + " = " 333 334 res += self.makeCallExpr(funcName, parameters) + ";\n" 335 self.code += res 336 337 def funcCallRet(self, _lhs, funcName, parameters): 338 res = self.indent() 339 res += "return " + self.makeCallExpr(funcName, parameters) + ";\n" 340 self.code += res 341 342 # Given a VulkanType object, generate a C type declaration 343 # with optional parameter name: 344 # [const] [typename][*][const*] [paramName] 345 def makeCTypeDecl(self, vulkanType, useParamName=True): 346 constness = "const " if vulkanType.isConst else "" 347 typeName = vulkanType.typeName 348 349 if vulkanType.pointerIndirectionLevels == 0: 350 ptrSpec = "" 351 elif vulkanType.isPointerToConstPointer: 352 ptrSpec = "* const*" if vulkanType.isConst else "**" 353 if vulkanType.pointerIndirectionLevels > 2: 354 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 355 else: 356 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 357 358 if useParamName and (vulkanType.paramName is not None): 359 paramStr = (" " + vulkanType.paramName) 360 else: 361 paramStr = "" 362 363 return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr) 364 365 def makeRichCTypeDecl(self, vulkanType, useParamName=True): 366 constness = "const " if vulkanType.isConst else "" 367 typeName = vulkanType.typeName 368 369 if vulkanType.pointerIndirectionLevels == 0: 370 ptrSpec = "" 371 elif vulkanType.isPointerToConstPointer: 372 ptrSpec = "* const*" if vulkanType.isConst else "**" 373 if vulkanType.pointerIndirectionLevels > 2: 374 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 375 else: 376 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 377 378 if useParamName and (vulkanType.paramName is not None): 379 paramStr = (" " + vulkanType.paramName) 380 else: 381 paramStr = "" 382 383 if vulkanType.staticArrExpr: 384 staticArrInfo = "[%s]" % vulkanType.staticArrExpr 385 else: 386 staticArrInfo = "" 387 388 return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo) 389 390 # Given a VulkanAPI object, generate the C function protype: 391 # <returntype> <funcname>(<parameters>) 392 def makeFuncProto(self, vulkanApi, useParamName=True): 393 394 protoBegin = "%s %s" % (self.makeCTypeDecl( 395 vulkanApi.retType, useParamName=False), vulkanApi.name) 396 397 def getFuncArgDecl(param): 398 if param.staticArrExpr: 399 return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr) 400 else: 401 return self.makeCTypeDecl(param, useParamName=useParamName) 402 403 protoParams = "(\n %s)" % ((",\n%s" % self.indent(1)).join( 404 list(map( 405 getFuncArgDecl, 406 vulkanApi.parameters)))) 407 408 return protoBegin + protoParams 409 410 def makeFuncAlias(self, nameDst, nameSrc): 411 return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst) 412 413 def makeFuncDecl(self, vulkanApi): 414 return self.makeFuncProto(vulkanApi) + ";\n\n" 415 416 def makeFuncImpl(self, vulkanApi, codegenFunc): 417 self.swapCode() 418 419 self.line(self.makeFuncProto(vulkanApi)) 420 self.beginBlock() 421 codegenFunc(self) 422 self.endBlock() 423 424 return self.swapCode() + "\n" 425 426 def emitFuncImpl(self, vulkanApi, codegenFunc): 427 self.line(self.makeFuncProto(vulkanApi)) 428 self.beginBlock() 429 codegenFunc(self) 430 self.endBlock() 431 432 def makeStructAccess(self, 433 vulkanType, 434 structVarName, 435 asPtr=True, 436 structAsPtr=True, 437 accessIndex=None): 438 439 deref = "->" if structAsPtr else "." 440 441 indexExpr = (" + %s" % accessIndex) if accessIndex else "" 442 443 addrOfExpr = "" if vulkanType.accessibleAsPointer() or ( 444 not asPtr) else "&" 445 446 return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref, 447 vulkanType.paramName, indexExpr) 448 449 def makeRawLengthAccess(self, vulkanType): 450 lenExpr = vulkanType.getLengthExpression() 451 452 if not lenExpr: 453 return None, None 454 455 if lenExpr == "null-terminated": 456 return "strlen(%s)" % vulkanType.paramName, None 457 458 return lenExpr, None 459 460 def makeLengthAccessFromStruct(self, 461 structInfo, 462 vulkanType, 463 structVarName, 464 asPtr=True): 465 # Handle special cases first 466 # Mostly when latexmath is involved 467 def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr): 468 cases = [ 469 { 470 "structName": "VkShaderModuleCreateInfo", 471 "field": "pCode", 472 "lenExprMember": "codeSize", 473 "postprocess": lambda expr: "(%s / 4)" % expr 474 }, 475 { 476 "structName": "VkPipelineMultisampleStateCreateInfo", 477 "field": "pSampleMask", 478 "lenExprMember": "rasterizationSamples", 479 "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr 480 }, 481 { 482 "structName": "VkAccelerationStructureVersionInfoKHR", 483 "field": "pVersionData", 484 "lenExprMember": "", 485 "postprocess": lambda _: "2*VK_UUID_SIZE" 486 }, 487 ] 488 489 for c in cases: 490 if (structInfo.name, vulkanType.paramName) == (c["structName"], 491 c["field"]): 492 deref = "->" if asPtr else "." 493 expr = "%s%s%s" % (structVarName, deref, 494 c["lenExprMember"]) 495 lenAccessGuardExpr = "%s" % structVarName 496 return c["postprocess"](expr), lenAccessGuardExpr 497 498 return None, None 499 500 specialCaseAccess = \ 501 handleSpecialCases( 502 structInfo, vulkanType, structVarName, asPtr) 503 504 if specialCaseAccess != (None, None): 505 return specialCaseAccess 506 507 lenExpr = vulkanType.getLengthExpression() 508 509 if not lenExpr: 510 return None, None 511 512 deref = "->" if asPtr else "." 513 lenAccessGuardExpr = "%s" % ( 514 515 structVarName) if deref else None 516 if lenExpr == "null-terminated": 517 return "strlen(%s%s%s)" % (structVarName, deref, 518 vulkanType.paramName), lenAccessGuardExpr 519 520 if not structInfo.getMember(lenExpr): 521 return self.makeRawLengthAccess(vulkanType) 522 523 return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr 524 525 def makeLengthAccessFromApi(self, api, vulkanType): 526 # Handle special cases first 527 # Mostly when :: is involved 528 def handleSpecialCases(vulkanType): 529 lenExpr = vulkanType.getLengthExpression() 530 531 if lenExpr is None: 532 return None, None 533 534 if "::" in lenExpr: 535 structVarName, memberVarName = lenExpr.split("::") 536 lenAccessGuardExpr = "%s" % (structVarName) 537 return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr 538 return None, None 539 540 specialCaseAccess = handleSpecialCases(vulkanType) 541 542 if specialCaseAccess != (None, None): 543 return specialCaseAccess 544 545 lenExpr = vulkanType.getLengthExpression() 546 547 if not lenExpr: 548 return None, None 549 550 lenExprInfo = api.getParameter(lenExpr) 551 552 if not lenExprInfo: 553 return self.makeRawLengthAccess(vulkanType) 554 555 if lenExpr == "null-terminated": 556 return "strlen(%s)" % vulkanType.paramName(), None 557 else: 558 deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else "" 559 lenAccessGuardExpr = "%s" % lenExpr if deref else None 560 return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr 561 562 def accessParameter(self, param, asPtr=True): 563 if asPtr: 564 if param.pointerIndirectionLevels > 0: 565 return param.paramName 566 else: 567 return "&%s" % param.paramName 568 else: 569 return param.paramName 570 571 def sizeofExpr(self, vulkanType): 572 return "sizeof(%s)" % ( 573 self.makeCTypeDecl(vulkanType, useParamName=False)) 574 575 def generalAccess(self, 576 vulkanType, 577 parentVarName=None, 578 asPtr=True, 579 structAsPtr=True): 580 if vulkanType.parent is None: 581 if parentVarName is None: 582 return self.accessParameter(vulkanType, asPtr=asPtr) 583 else: 584 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 585 586 if isinstance(vulkanType.parent, VulkanCompoundType): 587 return self.makeStructAccess( 588 vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr) 589 590 if isinstance(vulkanType.parent, VulkanAPI): 591 if parentVarName is None: 592 return self.accessParameter(vulkanType, asPtr=asPtr) 593 else: 594 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 595 596 os.abort("Could not find a way to access Vulkan type %s" % 597 vulkanType.name) 598 599 def makeLengthAccess(self, vulkanType, parentVarName="parent"): 600 if vulkanType.parent is None: 601 return self.makeRawLengthAccess(vulkanType) 602 603 if isinstance(vulkanType.parent, VulkanCompoundType): 604 return self.makeLengthAccessFromStruct( 605 vulkanType.parent, vulkanType, parentVarName, asPtr=True) 606 607 if isinstance(vulkanType.parent, VulkanAPI): 608 return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType) 609 610 os.abort("Could not find a way to access length of Vulkan type %s" % 611 vulkanType.name) 612 613 def generalLengthAccess(self, vulkanType, parentVarName="parent"): 614 return self.makeLengthAccess(vulkanType, parentVarName)[0] 615 616 def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"): 617 return self.makeLengthAccess(vulkanType, parentVarName)[1] 618 619 def vkApiCall(self, api, customPrefix="", globalStatePrefix="", customParameters=None, checkForDeviceLost=False, checkForOutOfMemory=False): 620 callLhs = None 621 622 retTypeName = api.getRetTypeExpr() 623 retVar = None 624 625 if retTypeName != "void": 626 retVar = api.getRetVarExpr() 627 self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName)) 628 callLhs = retVar 629 630 if customParameters is None: 631 self.funcCall( 632 callLhs, customPrefix + api.name, [p.paramName for p in api.parameters]) 633 else: 634 self.funcCall( 635 callLhs, customPrefix + api.name, customParameters) 636 637 if retTypeName == "VkResult" and checkForDeviceLost: 638 self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix)) 639 640 if retTypeName == "VkResult" and checkForOutOfMemory: 641 if api.name == "vkAllocateMemory": 642 self.stmt( 643 "%sCheckOutOfMemory(%s, opcode, context, std::make_optional<uint64_t>(pAllocateInfo->allocationSize))" 644 % (globalStatePrefix, callLhs)) 645 else: 646 self.stmt( 647 "%sCheckOutOfMemory(%s, opcode, context)" 648 % (globalStatePrefix, callLhs)) 649 650 return (retTypeName, retVar) 651 652 def makeCheckVkSuccess(self, expr): 653 return "((%s) == VK_SUCCESS)" % expr 654 655 def makeReinterpretCast(self, varName, typeName, const=True): 656 return "reinterpret_cast<%s%s*>(%s)" % \ 657 ("const " if const else "", typeName, varName) 658 659 def validPrimitive(self, typeInfo, typeName): 660 size = typeInfo.getPrimitiveEncodingSize(typeName) 661 return size != None 662 663 def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"): 664 if not self.validPrimitive(typeInfo, typeName): 665 return None 666 667 size = typeInfo.getPrimitiveEncodingSize(typeName) 668 prefix = "put" if direction == "write" else "get" 669 suffix = None 670 if size == 1: 671 suffix = "Byte" 672 elif size == 2: 673 suffix = "Be16" 674 elif size == 4: 675 suffix = "Be32" 676 elif size == 8: 677 suffix = "Be64" 678 679 if suffix: 680 return prefix + suffix 681 682 return None 683 684 def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"): 685 if not self.validPrimitive(typeInfo, typeName): 686 return None 687 688 size = typeInfo.getPrimitiveEncodingSize(typeName) 689 prefix = "to" if direction == "write" else "from" 690 suffix = None 691 if size == 1: 692 suffix = "Byte" 693 elif size == 2: 694 suffix = "Be16" 695 elif size == 4: 696 suffix = "Be32" 697 elif size == 8: 698 suffix = "Be64" 699 700 if suffix: 701 return prefix + suffix 702 703 return None 704 705 def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"): 706 accessTypeName = accessType.typeName 707 708 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 709 print("Tried to stream a non-primitive type: %s" % accessTypeName) 710 os.abort() 711 712 needPtrCast = False 713 714 if accessType.pointerIndirectionLevels > 0: 715 streamSize = 8 716 streamStorageVarType = "uint64_t" 717 needPtrCast = True 718 streamMethod = "putBe64" if direction == "write" else "getBe64" 719 else: 720 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 721 if streamSize == 1: 722 streamStorageVarType = "uint8_t" 723 elif streamSize == 2: 724 streamStorageVarType = "uint16_t" 725 elif streamSize == 4: 726 streamStorageVarType = "uint32_t" 727 elif streamSize == 8: 728 streamStorageVarType = "uint64_t" 729 streamMethod = self.makePrimitiveStreamMethod( 730 typeInfo, accessTypeName, direction=direction) 731 732 streamStorageVar = self.var() 733 734 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 735 736 ptrCast = "(uintptr_t)" if needPtrCast else "" 737 738 if direction == "read": 739 self.stmt("%s = (%s)%s%s->%s()" % 740 (accessExpr, 741 accessCast, 742 ptrCast, 743 streamVar, 744 streamMethod)) 745 else: 746 self.stmt("%s %s = (%s)%s%s" % 747 (streamStorageVarType, streamStorageVar, 748 streamStorageVarType, ptrCast, accessExpr)) 749 self.stmt("%s->%s(%s)" % 750 (streamVar, streamMethod, streamStorageVar)) 751 752 def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"): 753 accessTypeName = accessType.typeName 754 755 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 756 print("Tried to stream a non-primitive type: %s" % accessTypeName) 757 os.abort() 758 759 needPtrCast = False 760 761 streamSize = 8 762 763 if accessType.pointerIndirectionLevels > 0: 764 streamSize = 8 765 streamStorageVarType = "uint64_t" 766 needPtrCast = True 767 streamMethod = "toBe64" if direction == "write" else "fromBe64" 768 else: 769 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 770 if streamSize == 1: 771 streamStorageVarType = "uint8_t" 772 elif streamSize == 2: 773 streamStorageVarType = "uint16_t" 774 elif streamSize == 4: 775 streamStorageVarType = "uint32_t" 776 elif streamSize == 8: 777 streamStorageVarType = "uint64_t" 778 streamMethod = self.makePrimitiveStreamMethodInPlace( 779 typeInfo, accessTypeName, direction=direction) 780 781 streamStorageVar = self.var() 782 783 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 784 785 if direction == "read": 786 accessCast = self.makeRichCTypeDecl( 787 accessType.getForNonConstAccess(), useParamName=False) 788 789 ptrCast = "(uintptr_t)" if needPtrCast else "" 790 791 if direction == "read": 792 self.stmt("memcpy((%s*)&%s, %s, %s)" % 793 (accessCast, 794 accessExpr, 795 streamVar, 796 str(streamSize))) 797 self.stmt("android::base::Stream::%s((uint8_t*)&%s)" % ( 798 streamMethod, 799 accessExpr)) 800 else: 801 self.stmt("%s %s = (%s)%s%s" % 802 (streamStorageVarType, streamStorageVar, 803 streamStorageVarType, ptrCast, accessExpr)) 804 self.stmt("memcpy(%s, &%s, %s)" % 805 (streamVar, streamStorageVar, str(streamSize))) 806 self.stmt("android::base::Stream::%s((uint8_t*)%s)" % ( 807 streamMethod, 808 streamVar)) 809 810 def countPrimitive(self, typeInfo, accessType): 811 accessTypeName = accessType.typeName 812 813 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 814 print("Tried to count a non-primitive type: %s" % accessTypeName) 815 os.abort() 816 817 needPtrCast = False 818 819 if accessType.pointerIndirectionLevels > 0: 820 streamSize = 8 821 else: 822 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 823 824 return streamSize 825 826# Class to wrap a Vulkan API call. 827# 828# The user gives a generic callback, |codegenDef|, 829# that takes a CodeGen object and a VulkanAPI object as arguments. 830# codegenDef uses CodeGen along with the VulkanAPI object 831# to generate the function body. 832class VulkanAPIWrapper(object): 833 834 def __init__(self, 835 customApiPrefix, 836 extraParameters=None, 837 returnTypeOverride=None, 838 codegenDef=None): 839 self.customApiPrefix = customApiPrefix 840 self.extraParameters = extraParameters 841 self.returnTypeOverride = returnTypeOverride 842 843 self.codegen = CodeGen() 844 845 self.definitionFunc = codegenDef 846 847 # Private function 848 849 def makeApiFunc(self, typeInfo, apiName): 850 customApi = copy(typeInfo.apis[apiName]) 851 customApi.name = self.customApiPrefix + customApi.name 852 if self.extraParameters is not None: 853 if isinstance(self.extraParameters, list): 854 customApi.parameters = \ 855 self.extraParameters + customApi.parameters 856 else: 857 os.abort( 858 "Type of extra parameters to custom API not valid. Expected list, got %s" % type( 859 self.extraParameters)) 860 861 if self.returnTypeOverride is not None: 862 customApi.retType = self.returnTypeOverride 863 return customApi 864 865 self.makeApi = makeApiFunc 866 867 def setCodegenDef(self, codegenDefFunc): 868 self.definitionFunc = codegenDefFunc 869 870 def makeDecl(self, typeInfo, apiName): 871 return self.codegen.makeFuncProto( 872 self.makeApi(self, typeInfo, apiName)) + ";\n\n" 873 874 def makeDefinition(self, typeInfo, apiName, isStatic=False): 875 vulkanApi = self.makeApi(self, typeInfo, apiName) 876 877 self.codegen.swapCode() 878 self.codegen.beginBlock() 879 880 if self.definitionFunc is None: 881 print("ERROR: No definition found for (%s, %s)" % 882 (vulkanApi.name, self.customApiPrefix)) 883 sys.exit(1) 884 885 self.definitionFunc(self.codegen, vulkanApi) 886 887 self.codegen.endBlock() 888 889 return ("static " if isStatic else "") + self.codegen.makeFuncProto( 890 vulkanApi) + "\n" + self.codegen.swapCode() + "\n" 891 892# Base class for wrapping all Vulkan API objects. These work with Vulkan 893# Registry generators and have gen* triggers. They tend to contain 894# VulkanAPIWrapper objects to make it easier to generate the code. 895class VulkanWrapperGenerator(object): 896 897 def __init__(self, module: Module, typeInfo: VulkanTypeInfo): 898 self.module: Module = module 899 self.typeInfo: VulkanTypeInfo = typeInfo 900 self.extensionStructTypes = OrderedDict() 901 902 def onBegin(self): 903 pass 904 905 def onEnd(self): 906 pass 907 908 def onBeginFeature(self, featureName, featureType): 909 pass 910 911 def onFeatureNewCmd(self, cmdName): 912 pass 913 914 def onEndFeature(self): 915 pass 916 917 def onGenType(self, typeInfo, name, alias): 918 category = self.typeInfo.categoryOf(name) 919 if category in ["struct", "union"] and not alias: 920 structInfo = self.typeInfo.structs[name] 921 if structInfo.structExtendsExpr: 922 self.extensionStructTypes[name] = structInfo 923 pass 924 925 def onGenStruct(self, typeInfo, name, alias): 926 pass 927 928 def onGenGroup(self, groupinfo, groupName, alias=None): 929 pass 930 931 def onGenEnum(self, enuminfo, name, alias): 932 pass 933 934 def onGenCmd(self, cmdinfo, name, alias): 935 pass 936 937 # Below Vulkan structure types may correspond to multiple Vulkan structs 938 # due to a conflict between different Vulkan registries. In order to get 939 # the correct Vulkan struct type, we need to check the type of its "root" 940 # struct as well. 941 ROOT_TYPE_MAPPING = { 942 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": { 943 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 944 "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 945 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE", 946 "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 947 }, 948 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": { 949 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 950 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkCreateBlobGOOGLE", 951 "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 952 }, 953 "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": { 954 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT", 955 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT", 956 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE", 957 "default": "VkRenderPassFragmentDensityMapCreateInfoEXT", 958 }, 959 } 960 961 def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None): 962 def readStructType(structTypeName, structVarName, cgen): 963 cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \ 964 (structTypeName, "goldfish_vk_struct_type", structVarName)) 965 966 def castAsStruct(varName, typeName, const=True): 967 return "reinterpret_cast<%s%s*>(%s)" % \ 968 ("const " if const else "", typeName, varName) 969 970 def doDefaultReturn(cgen): 971 if retType.typeName == "void": 972 cgen.stmt("return") 973 else: 974 cgen.stmt("return (%s)0" % retType.typeName) 975 976 cgen.beginIf("!%s" % triggerVar.paramName) 977 if nullEmit is None: 978 doDefaultReturn(cgen) 979 else: 980 nullEmit(cgen) 981 cgen.endIf() 982 983 readStructType("structType", triggerVar.paramName, cgen) 984 985 cgen.line("switch(structType)") 986 cgen.beginBlock() 987 988 currFeature = None 989 990 for ext in self.extensionStructTypes.values(): 991 if not currFeature: 992 cgen.leftline("#ifdef %s" % ext.feature) 993 currFeature = ext.feature 994 995 if currFeature and ext.feature != currFeature: 996 cgen.leftline("#endif") 997 cgen.leftline("#ifdef %s" % ext.feature) 998 currFeature = ext.feature 999 1000 enum = ext.structEnumExpr 1001 cgen.line("case %s:" % enum) 1002 cgen.beginBlock() 1003 1004 if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING: 1005 cgen.line("switch(%s)" % rootTypeVar.paramName) 1006 cgen.beginBlock() 1007 kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum] 1008 for k in kv: 1009 v = self.extensionStructTypes[kv[k]] 1010 if k == "default": 1011 cgen.line("%s:" % k) 1012 else: 1013 cgen.line("case %s:" % k) 1014 cgen.beginBlock() 1015 castedAccess = castAsStruct( 1016 triggerVar.paramName, v.name, const=triggerVar.isConst) 1017 forEachFunc(v, castedAccess, cgen) 1018 cgen.line("break;") 1019 cgen.endBlock() 1020 cgen.endBlock() 1021 else: 1022 castedAccess = castAsStruct( 1023 triggerVar.paramName, ext.name, const=triggerVar.isConst) 1024 forEachFunc(ext, castedAccess, cgen) 1025 1026 if autoBreak: 1027 cgen.stmt("break") 1028 cgen.endBlock() 1029 1030 if currFeature: 1031 cgen.leftline("#endif") 1032 1033 cgen.line("default:") 1034 cgen.beginBlock() 1035 if defaultEmit is None: 1036 doDefaultReturn(cgen) 1037 else: 1038 defaultEmit(cgen) 1039 cgen.endBlock() 1040 1041 cgen.endBlock() 1042 1043 def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False): 1044 currFeature = None 1045 1046 for (i, ext) in enumerate(self.extensionStructTypes.values()): 1047 if doFeatureIfdefs: 1048 if not currFeature: 1049 cgen.leftline("#ifdef %s" % ext.feature) 1050 currFeature = ext.feature 1051 1052 if currFeature and ext.feature != currFeature: 1053 cgen.leftline("#endif") 1054 cgen.leftline("#ifdef %s" % ext.feature) 1055 currFeature = ext.feature 1056 1057 forEachFunc(i, ext, cgen) 1058 1059 if doFeatureIfdefs: 1060 if currFeature: 1061 cgen.leftline("#endif") 1062