1# Copyright (c) 2018 The Android Open Source Project 2# Copyright (c) 2018 Google Inc. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16from .vulkantypes import VulkanType, VulkanTypeInfo, VulkanCompoundType, VulkanAPI 17from collections import OrderedDict 18from copy import copy 19from pathlib import Path, PurePosixPath 20 21import os 22import sys 23import shutil 24import subprocess 25 26# Class capturing a single file 27 28 29class SingleFileModule(object): 30 def __init__(self, suffix, directory, basename, customAbsDir=None, suppress=False): 31 self.directory = directory 32 self.basename = basename 33 self.customAbsDir = customAbsDir 34 self.suffix = suffix 35 self.file = None 36 37 self.preamble = "" 38 self.postamble = "" 39 40 self.suppress = suppress 41 42 def begin(self, globalDir): 43 if self.suppress: 44 return 45 46 # Create subdirectory, if needed 47 if self.customAbsDir: 48 absDir = self.customAbsDir 49 else: 50 absDir = os.path.join(globalDir, self.directory) 51 52 filename = os.path.join(absDir, self.basename) 53 54 self.file = open(filename + self.suffix, "w", encoding="utf-8") 55 self.file.write(self.preamble) 56 57 def append(self, toAppend): 58 if self.suppress: 59 return 60 61 self.file.write(toAppend) 62 63 def end(self): 64 if self.suppress: 65 return 66 67 self.file.write(self.postamble) 68 self.file.close() 69 70 def getMakefileSrcEntry(self): 71 return "" 72 73 def getCMakeSrcEntry(self): 74 return "" 75 76# Class capturing a .cpp file and a .h file (a "C++ module") 77 78 79class Module(object): 80 81 def __init__( 82 self, directory, basename, customAbsDir=None, suppress=False, implOnly=False, 83 headerOnly=False, suppressFeatureGuards=False): 84 self._headerFileModule = SingleFileModule( 85 ".h", directory, basename, customAbsDir, suppress or implOnly) 86 self._implFileModule = SingleFileModule( 87 ".cpp", directory, basename, customAbsDir, suppress or headerOnly) 88 89 self._headerOnly = headerOnly 90 self._implOnly = implOnly 91 92 self.directory = directory 93 self.basename = basename 94 self._customAbsDir = customAbsDir 95 96 self.suppressFeatureGuards = suppressFeatureGuards 97 98 @property 99 def suppress(self): 100 raise AttributeError("suppress is write only") 101 102 @suppress.setter 103 def suppress(self, value: bool): 104 self._headerFileModule.suppress = self._implOnly or value 105 self._implFileModule.suppress = self._headerOnly or value 106 107 @property 108 def headerPreamble(self) -> str: 109 return self._headerFileModule.preamble 110 111 @headerPreamble.setter 112 def headerPreamble(self, value: str): 113 self._headerFileModule.preamble = value 114 115 @property 116 def headerPostamble(self) -> str: 117 return self._headerFileModule.postamble 118 119 @headerPostamble.setter 120 def headerPostamble(self, value: str): 121 self._headerFileModule.postamble = value 122 123 @property 124 def implPreamble(self) -> str: 125 return self._implFileModule.preamble 126 127 @implPreamble.setter 128 def implPreamble(self, value: str): 129 self._implFileModule.preamble = value 130 131 @property 132 def implPostamble(self) -> str: 133 return self._implFileModule.postamble 134 135 @implPostamble.setter 136 def implPostamble(self, value: str): 137 self._implFileModule.postamble = value 138 139 def getMakefileSrcEntry(self): 140 if self._customAbsDir: 141 return self.basename + ".cpp \\\n" 142 dirName = self.directory 143 baseName = self.basename 144 joined = os.path.join(dirName, baseName) 145 return " " + joined + ".cpp \\\n" 146 147 def getCMakeSrcEntry(self): 148 if self._customAbsDir: 149 return "\n" + self.basename + ".cpp " 150 dirName = Path(self.directory) 151 baseName = Path(self.basename) 152 joined = PurePosixPath(dirName / baseName) 153 return "\n " + str(joined) + ".cpp " 154 155 def begin(self, globalDir): 156 self._headerFileModule.begin(globalDir) 157 self._implFileModule.begin(globalDir) 158 159 def appendHeader(self, toAppend): 160 self._headerFileModule.append(toAppend) 161 162 def appendImpl(self, toAppend): 163 self._implFileModule.append(toAppend) 164 165 def end(self): 166 self._headerFileModule.end() 167 self._implFileModule.end() 168 169 clang_format_command = shutil.which('clang-format') 170 assert (clang_format_command is not None) 171 172 def formatFile(filename: Path): 173 assert (subprocess.call([clang_format_command, "-i", 174 "--style=file", str(filename.resolve())]) == 0) 175 176 if not self._headerFileModule.suppress: 177 formatFile(Path(self._headerFileModule.file.name)) 178 179 if not self._implFileModule.suppress: 180 formatFile(Path(self._implFileModule.file.name)) 181 182 183class PyScript(SingleFileModule): 184 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 185 super().__init__(".py", directory, basename, customAbsDir, suppress) 186 187 188# Class capturing a .proto protobuf definition file 189class Proto(SingleFileModule): 190 191 def __init__(self, directory, basename, customAbsDir=None, suppress=False): 192 super().__init__(".proto", directory, basename, customAbsDir, suppress) 193 194 def getMakefileSrcEntry(self): 195 super().getMakefileSrcEntry() 196 if self.customAbsDir: 197 return self.basename + ".proto \\\n" 198 dirName = self.directory 199 baseName = self.basename 200 joined = os.path.join(dirName, baseName) 201 return " " + joined + ".proto \\\n" 202 203 def getCMakeSrcEntry(self): 204 super().getCMakeSrcEntry() 205 if self.customAbsDir: 206 return "\n" + self.basename + ".proto " 207 208 dirName = self.directory 209 baseName = self.basename 210 joined = os.path.join(dirName, baseName) 211 return "\n " + joined + ".proto " 212 213class CodeGen(object): 214 215 def __init__(self,): 216 self.code = "" 217 self.indentLevel = 0 218 self.gensymCounter = [-1] 219 220 def var(self, prefix="cgen_var"): 221 self.gensymCounter[-1] += 1 222 res = "%s_%s" % (prefix, '_'.join(str(i) for i in self.gensymCounter if i >= 0)) 223 return res 224 225 def swapCode(self,): 226 res = "%s" % self.code 227 self.code = "" 228 return res 229 230 def indent(self,extra=0): 231 return "".join(" " * (self.indentLevel + extra)) 232 233 def incrIndent(self,): 234 self.indentLevel += 1 235 236 def decrIndent(self,): 237 if self.indentLevel > 0: 238 self.indentLevel -= 1 239 240 def beginBlock(self, bracketPrint=True): 241 if bracketPrint: 242 self.code += self.indent() + "{\n" 243 self.indentLevel += 1 244 self.gensymCounter.append(-1) 245 246 def endBlock(self,bracketPrint=True): 247 self.indentLevel -= 1 248 if bracketPrint: 249 self.code += self.indent() + "}\n" 250 del self.gensymCounter[-1] 251 252 def beginIf(self, cond): 253 self.code += self.indent() + "if (" + cond + ")\n" 254 self.beginBlock() 255 256 def beginElse(self, cond = None): 257 if cond is not None: 258 self.code += \ 259 self.indent() + \ 260 "else if (" + cond + ")\n" 261 else: 262 self.code += self.indent() + "else\n" 263 self.beginBlock() 264 265 def endElse(self): 266 self.endBlock() 267 268 def endIf(self): 269 self.endBlock() 270 271 def beginSwitch(self, switchvar): 272 self.code += self.indent() + "switch (" + switchvar + ")\n" 273 self.beginBlock() 274 275 def switchCase(self, switchval, blocked = False): 276 self.code += self.indent() + "case %s:" % switchval 277 self.beginBlock(bracketPrint = blocked) 278 279 def switchCaseBreak(self, switchval, blocked = False): 280 self.code += self.indent() + "case %s:" % switchval 281 self.endBlock(bracketPrint = blocked) 282 283 def switchCaseDefault(self, blocked = False): 284 self.code += self.indent() + "default:" % switchval 285 self.beginBlock(bracketPrint = blocked) 286 287 def endSwitch(self): 288 self.endBlock() 289 290 def beginWhile(self, cond): 291 self.code += self.indent() + "while (" + cond + ")\n" 292 self.beginBlock() 293 294 def endWhile(self): 295 self.endBlock() 296 297 def beginFor(self, initial, condition, increment): 298 self.code += \ 299 self.indent() + "for (" + \ 300 "; ".join([initial, condition, increment]) + \ 301 ")\n" 302 self.beginBlock() 303 304 def endFor(self): 305 self.endBlock() 306 307 def beginLoop(self, loopVarType, loopVar, loopInit, loopBound): 308 self.beginFor( 309 "%s %s = %s" % (loopVarType, loopVar, loopInit), 310 "%s < %s" % (loopVar, loopBound), 311 "++%s" % (loopVar)) 312 313 def endLoop(self): 314 self.endBlock() 315 316 def stmt(self, code): 317 self.code += self.indent() + code + ";\n" 318 319 def line(self, code): 320 self.code += self.indent() + code + "\n" 321 322 def leftline(self, code): 323 self.code += code + "\n" 324 325 def makeCallExpr(self, funcName, parameters): 326 return funcName + "(%s)" % (", ".join(parameters)) 327 328 def funcCall(self, lhs, funcName, parameters): 329 res = self.indent() 330 331 if lhs is not None: 332 res += lhs + " = " 333 334 res += self.makeCallExpr(funcName, parameters) + ";\n" 335 self.code += res 336 337 def funcCallRet(self, _lhs, funcName, parameters): 338 res = self.indent() 339 res += "return " + self.makeCallExpr(funcName, parameters) + ";\n" 340 self.code += res 341 342 # Given a VulkanType object, generate a C type declaration 343 # with optional parameter name: 344 # [const] [typename][*][const*] [paramName] 345 def makeCTypeDecl(self, vulkanType, useParamName=True): 346 constness = "const " if vulkanType.isConst else "" 347 typeName = vulkanType.typeName 348 349 if vulkanType.pointerIndirectionLevels == 0: 350 ptrSpec = "" 351 elif vulkanType.isPointerToConstPointer: 352 ptrSpec = "* const*" if vulkanType.isConst else "**" 353 if vulkanType.pointerIndirectionLevels > 2: 354 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 355 else: 356 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 357 358 if useParamName and (vulkanType.paramName is not None): 359 paramStr = (" " + vulkanType.paramName) 360 else: 361 paramStr = "" 362 363 return "%s%s%s%s" % (constness, typeName, ptrSpec, paramStr) 364 365 def makeRichCTypeDecl(self, vulkanType, useParamName=True): 366 constness = "const " if vulkanType.isConst else "" 367 typeName = vulkanType.typeName 368 369 if vulkanType.pointerIndirectionLevels == 0: 370 ptrSpec = "" 371 elif vulkanType.isPointerToConstPointer: 372 ptrSpec = "* const*" if vulkanType.isConst else "**" 373 if vulkanType.pointerIndirectionLevels > 2: 374 ptrSpec += "*" * (vulkanType.pointerIndirectionLevels - 2) 375 else: 376 ptrSpec = "*" * vulkanType.pointerIndirectionLevels 377 378 if useParamName and (vulkanType.paramName is not None): 379 paramStr = (" " + vulkanType.paramName) 380 else: 381 paramStr = "" 382 383 if vulkanType.staticArrExpr: 384 staticArrInfo = "[%s]" % vulkanType.staticArrExpr 385 else: 386 staticArrInfo = "" 387 388 return "%s%s%s%s%s" % (constness, typeName, ptrSpec, paramStr, staticArrInfo) 389 390 # Given a VulkanAPI object, generate the C function protype: 391 # <returntype> <funcname>(<parameters>) 392 def makeFuncProto(self, vulkanApi, useParamName=True): 393 394 protoBegin = "%s %s" % (self.makeCTypeDecl( 395 vulkanApi.retType, useParamName=False), vulkanApi.name) 396 397 def getFuncArgDecl(param): 398 if param.staticArrExpr: 399 return self.makeCTypeDecl(param, useParamName=useParamName) + ("[%s]" % param.staticArrExpr) 400 else: 401 return self.makeCTypeDecl(param, useParamName=useParamName) 402 403 protoParams = "(\n %s)" % ((",\n%s" % self.indent(1)).join( 404 list(map( 405 getFuncArgDecl, 406 vulkanApi.parameters)))) 407 408 return protoBegin + protoParams 409 410 def makeFuncAlias(self, nameDst, nameSrc): 411 return "DEFINE_ALIAS_FUNCTION({}, {})\n\n".format(nameSrc, nameDst) 412 413 def makeFuncDecl(self, vulkanApi): 414 return self.makeFuncProto(vulkanApi) + ";\n\n" 415 416 def makeFuncImpl(self, vulkanApi, codegenFunc): 417 self.swapCode() 418 419 self.line(self.makeFuncProto(vulkanApi)) 420 self.beginBlock() 421 codegenFunc(self) 422 self.endBlock() 423 424 return self.swapCode() + "\n" 425 426 def emitFuncImpl(self, vulkanApi, codegenFunc): 427 self.line(self.makeFuncProto(vulkanApi)) 428 self.beginBlock() 429 codegenFunc(self) 430 self.endBlock() 431 432 def makeStructAccess(self, 433 vulkanType, 434 structVarName, 435 asPtr=True, 436 structAsPtr=True, 437 accessIndex=None): 438 439 deref = "->" if structAsPtr else "." 440 441 indexExpr = (" + %s" % accessIndex) if accessIndex else "" 442 443 addrOfExpr = "" if vulkanType.accessibleAsPointer() or ( 444 not asPtr) else "&" 445 446 return "%s%s%s%s%s" % (addrOfExpr, structVarName, deref, 447 vulkanType.paramName, indexExpr) 448 449 def makeRawLengthAccess(self, vulkanType): 450 lenExpr = vulkanType.getLengthExpression() 451 452 if not lenExpr: 453 return None, None 454 455 if lenExpr == "null-terminated": 456 return "strlen(%s)" % vulkanType.paramName, None 457 458 return lenExpr, None 459 460 def makeLengthAccessFromStruct(self, 461 structInfo, 462 vulkanType, 463 structVarName, 464 asPtr=True): 465 # Handle special cases first 466 # Mostly when latexmath is involved 467 def handleSpecialCases(structInfo, vulkanType, structVarName, asPtr): 468 cases = [ 469 { 470 "structName": "VkShaderModuleCreateInfo", 471 "field": "pCode", 472 "lenExprMember": "codeSize", 473 "postprocess": lambda expr: "(%s / 4)" % expr 474 }, 475 { 476 "structName": "VkPipelineMultisampleStateCreateInfo", 477 "field": "pSampleMask", 478 "lenExprMember": "rasterizationSamples", 479 "postprocess": lambda expr: "(((%s) + 31) / 32)" % expr 480 }, 481 { 482 "structName": "VkAccelerationStructureVersionInfoKHR", 483 "field": "pVersionData", 484 "lenExprMember": "", 485 "postprocess": lambda _: "2*VK_UUID_SIZE" 486 }, 487 ] 488 489 for c in cases: 490 if (structInfo.name, vulkanType.paramName) == (c["structName"], 491 c["field"]): 492 deref = "->" if asPtr else "." 493 expr = "%s%s%s" % (structVarName, deref, 494 c["lenExprMember"]) 495 lenAccessGuardExpr = "%s" % structVarName 496 return c["postprocess"](expr), lenAccessGuardExpr 497 498 return None, None 499 500 specialCaseAccess = \ 501 handleSpecialCases( 502 structInfo, vulkanType, structVarName, asPtr) 503 504 if specialCaseAccess != (None, None): 505 return specialCaseAccess 506 507 lenExpr = vulkanType.getLengthExpression() 508 509 if not lenExpr: 510 return None, None 511 512 deref = "->" if asPtr else "." 513 lenAccessGuardExpr = "%s" % ( 514 515 structVarName) if deref else None 516 if lenExpr == "null-terminated": 517 return "strlen(%s%s%s)" % (structVarName, deref, 518 vulkanType.paramName), lenAccessGuardExpr 519 520 if not structInfo.getMember(lenExpr): 521 return self.makeRawLengthAccess(vulkanType) 522 523 return "%s%s%s" % (structVarName, deref, lenExpr), lenAccessGuardExpr 524 525 def makeLengthAccessFromApi(self, api, vulkanType): 526 # Handle special cases first 527 # Mostly when :: is involved 528 def handleSpecialCases(vulkanType): 529 lenExpr = vulkanType.getLengthExpression() 530 531 if lenExpr is None: 532 return None, None 533 534 if "::" in lenExpr: 535 structVarName, memberVarName = lenExpr.split("::") 536 lenAccessGuardExpr = "%s" % (structVarName) 537 return "%s->%s" % (structVarName, memberVarName), lenAccessGuardExpr 538 return None, None 539 540 specialCaseAccess = handleSpecialCases(vulkanType) 541 542 if specialCaseAccess != (None, None): 543 return specialCaseAccess 544 545 lenExpr = vulkanType.getLengthExpression() 546 547 if not lenExpr: 548 return None, None 549 550 lenExprInfo = api.getParameter(lenExpr) 551 552 if not lenExprInfo: 553 return self.makeRawLengthAccess(vulkanType) 554 555 if lenExpr == "null-terminated": 556 return "strlen(%s)" % vulkanType.paramName(), None 557 else: 558 deref = "*" if lenExprInfo.pointerIndirectionLevels > 0 else "" 559 lenAccessGuardExpr = "%s" % lenExpr if deref else None 560 return "(%s(%s))" % (deref, lenExpr), lenAccessGuardExpr 561 562 def accessParameter(self, param, asPtr=True): 563 if asPtr: 564 if param.pointerIndirectionLevels > 0: 565 return param.paramName 566 else: 567 return "&%s" % param.paramName 568 else: 569 return param.paramName 570 571 def sizeofExpr(self, vulkanType): 572 return "sizeof(%s)" % ( 573 self.makeCTypeDecl(vulkanType, useParamName=False)) 574 575 def generalAccess(self, 576 vulkanType, 577 parentVarName=None, 578 asPtr=True, 579 structAsPtr=True): 580 if vulkanType.parent is None: 581 if parentVarName is None: 582 return self.accessParameter(vulkanType, asPtr=asPtr) 583 else: 584 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 585 586 if isinstance(vulkanType.parent, VulkanCompoundType): 587 return self.makeStructAccess( 588 vulkanType, parentVarName, asPtr=asPtr, structAsPtr=structAsPtr) 589 590 if isinstance(vulkanType.parent, VulkanAPI): 591 if parentVarName is None: 592 return self.accessParameter(vulkanType, asPtr=asPtr) 593 else: 594 return self.accessParameter(vulkanType.withModifiedName(parentVarName), asPtr=asPtr) 595 596 os.abort("Could not find a way to access Vulkan type %s" % 597 vulkanType.name) 598 599 def makeLengthAccess(self, vulkanType, parentVarName="parent"): 600 if vulkanType.parent is None: 601 return self.makeRawLengthAccess(vulkanType) 602 603 if isinstance(vulkanType.parent, VulkanCompoundType): 604 return self.makeLengthAccessFromStruct( 605 vulkanType.parent, vulkanType, parentVarName, asPtr=True) 606 607 if isinstance(vulkanType.parent, VulkanAPI): 608 return self.makeLengthAccessFromApi(vulkanType.parent, vulkanType) 609 610 os.abort("Could not find a way to access length of Vulkan type %s" % 611 vulkanType.name) 612 613 def generalLengthAccess(self, vulkanType, parentVarName="parent"): 614 return self.makeLengthAccess(vulkanType, parentVarName)[0] 615 616 def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"): 617 return self.makeLengthAccess(vulkanType, parentVarName)[1] 618 619 def vkApiCall(self, api, customPrefix="", globalStatePrefix="", customParameters=None, checkForDeviceLost=False): 620 callLhs = None 621 622 retTypeName = api.getRetTypeExpr() 623 retVar = None 624 625 if retTypeName != "void": 626 retVar = api.getRetVarExpr() 627 self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName)) 628 callLhs = retVar 629 630 if customParameters is None: 631 self.funcCall( 632 callLhs, customPrefix + api.name, [p.paramName for p in api.parameters]) 633 else: 634 self.funcCall( 635 callLhs, customPrefix + api.name, customParameters) 636 637 if retTypeName == "VkResult" and checkForDeviceLost: 638 self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix)) 639 640 return (retTypeName, retVar) 641 642 def makeCheckVkSuccess(self, expr): 643 return "((%s) == VK_SUCCESS)" % expr 644 645 def makeReinterpretCast(self, varName, typeName, const=True): 646 return "reinterpret_cast<%s%s*>(%s)" % \ 647 ("const " if const else "", typeName, varName) 648 649 def validPrimitive(self, typeInfo, typeName): 650 size = typeInfo.getPrimitiveEncodingSize(typeName) 651 return size != None 652 653 def makePrimitiveStreamMethod(self, typeInfo, typeName, direction="write"): 654 if not self.validPrimitive(typeInfo, typeName): 655 return None 656 657 size = typeInfo.getPrimitiveEncodingSize(typeName) 658 prefix = "put" if direction == "write" else "get" 659 suffix = None 660 if size == 1: 661 suffix = "Byte" 662 elif size == 2: 663 suffix = "Be16" 664 elif size == 4: 665 suffix = "Be32" 666 elif size == 8: 667 suffix = "Be64" 668 669 if suffix: 670 return prefix + suffix 671 672 return None 673 674 def makePrimitiveStreamMethodInPlace(self, typeInfo, typeName, direction="write"): 675 if not self.validPrimitive(typeInfo, typeName): 676 return None 677 678 size = typeInfo.getPrimitiveEncodingSize(typeName) 679 prefix = "to" if direction == "write" else "from" 680 suffix = None 681 if size == 1: 682 suffix = "Byte" 683 elif size == 2: 684 suffix = "Be16" 685 elif size == 4: 686 suffix = "Be32" 687 elif size == 8: 688 suffix = "Be64" 689 690 if suffix: 691 return prefix + suffix 692 693 return None 694 695 def streamPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"): 696 accessTypeName = accessType.typeName 697 698 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 699 print("Tried to stream a non-primitive type: %s" % accessTypeName) 700 os.abort() 701 702 needPtrCast = False 703 704 if accessType.pointerIndirectionLevels > 0: 705 streamSize = 8 706 streamStorageVarType = "uint64_t" 707 needPtrCast = True 708 streamMethod = "putBe64" if direction == "write" else "getBe64" 709 else: 710 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 711 if streamSize == 1: 712 streamStorageVarType = "uint8_t" 713 elif streamSize == 2: 714 streamStorageVarType = "uint16_t" 715 elif streamSize == 4: 716 streamStorageVarType = "uint32_t" 717 elif streamSize == 8: 718 streamStorageVarType = "uint64_t" 719 streamMethod = self.makePrimitiveStreamMethod( 720 typeInfo, accessTypeName, direction=direction) 721 722 streamStorageVar = self.var() 723 724 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 725 726 ptrCast = "(uintptr_t)" if needPtrCast else "" 727 728 if direction == "read": 729 self.stmt("%s = (%s)%s%s->%s()" % 730 (accessExpr, 731 accessCast, 732 ptrCast, 733 streamVar, 734 streamMethod)) 735 else: 736 self.stmt("%s %s = (%s)%s%s" % 737 (streamStorageVarType, streamStorageVar, 738 streamStorageVarType, ptrCast, accessExpr)) 739 self.stmt("%s->%s(%s)" % 740 (streamVar, streamMethod, streamStorageVar)) 741 742 def memcpyPrimitive(self, typeInfo, streamVar, accessExpr, accessType, direction="write"): 743 accessTypeName = accessType.typeName 744 745 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 746 print("Tried to stream a non-primitive type: %s" % accessTypeName) 747 os.abort() 748 749 needPtrCast = False 750 751 streamSize = 8 752 753 if accessType.pointerIndirectionLevels > 0: 754 streamSize = 8 755 streamStorageVarType = "uint64_t" 756 needPtrCast = True 757 streamMethod = "toBe64" if direction == "write" else "fromBe64" 758 else: 759 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 760 if streamSize == 1: 761 streamStorageVarType = "uint8_t" 762 elif streamSize == 2: 763 streamStorageVarType = "uint16_t" 764 elif streamSize == 4: 765 streamStorageVarType = "uint32_t" 766 elif streamSize == 8: 767 streamStorageVarType = "uint64_t" 768 streamMethod = self.makePrimitiveStreamMethodInPlace( 769 typeInfo, accessTypeName, direction=direction) 770 771 streamStorageVar = self.var() 772 773 accessCast = self.makeRichCTypeDecl(accessType, useParamName=False) 774 775 if direction == "read": 776 accessCast = self.makeRichCTypeDecl( 777 accessType.getForNonConstAccess(), useParamName=False) 778 779 ptrCast = "(uintptr_t)" if needPtrCast else "" 780 781 if direction == "read": 782 self.stmt("memcpy((%s*)&%s, %s, %s)" % 783 (accessCast, 784 accessExpr, 785 streamVar, 786 str(streamSize))) 787 self.stmt("android::base::Stream::%s((uint8_t*)&%s)" % ( 788 streamMethod, 789 accessExpr)) 790 else: 791 self.stmt("%s %s = (%s)%s%s" % 792 (streamStorageVarType, streamStorageVar, 793 streamStorageVarType, ptrCast, accessExpr)) 794 self.stmt("memcpy(%s, &%s, %s)" % 795 (streamVar, streamStorageVar, str(streamSize))) 796 self.stmt("android::base::Stream::%s((uint8_t*)%s)" % ( 797 streamMethod, 798 streamVar)) 799 800 def countPrimitive(self, typeInfo, accessType): 801 accessTypeName = accessType.typeName 802 803 if accessType.pointerIndirectionLevels == 0 and not self.validPrimitive(typeInfo, accessTypeName): 804 print("Tried to count a non-primitive type: %s" % accessTypeName) 805 os.abort() 806 807 needPtrCast = False 808 809 if accessType.pointerIndirectionLevels > 0: 810 streamSize = 8 811 else: 812 streamSize = typeInfo.getPrimitiveEncodingSize(accessTypeName) 813 814 return streamSize 815 816# Class to wrap a Vulkan API call. 817# 818# The user gives a generic callback, |codegenDef|, 819# that takes a CodeGen object and a VulkanAPI object as arguments. 820# codegenDef uses CodeGen along with the VulkanAPI object 821# to generate the function body. 822class VulkanAPIWrapper(object): 823 824 def __init__(self, 825 customApiPrefix, 826 extraParameters=None, 827 returnTypeOverride=None, 828 codegenDef=None): 829 self.customApiPrefix = customApiPrefix 830 self.extraParameters = extraParameters 831 self.returnTypeOverride = returnTypeOverride 832 833 self.codegen = CodeGen() 834 835 self.definitionFunc = codegenDef 836 837 # Private function 838 839 def makeApiFunc(self, typeInfo, apiName): 840 customApi = copy(typeInfo.apis[apiName]) 841 customApi.name = self.customApiPrefix + customApi.name 842 if self.extraParameters is not None: 843 if isinstance(self.extraParameters, list): 844 customApi.parameters = \ 845 self.extraParameters + customApi.parameters 846 else: 847 os.abort( 848 "Type of extra parameters to custom API not valid. Expected list, got %s" % type( 849 self.extraParameters)) 850 851 if self.returnTypeOverride is not None: 852 customApi.retType = self.returnTypeOverride 853 return customApi 854 855 self.makeApi = makeApiFunc 856 857 def setCodegenDef(self, codegenDefFunc): 858 self.definitionFunc = codegenDefFunc 859 860 def makeDecl(self, typeInfo, apiName): 861 return self.codegen.makeFuncProto( 862 self.makeApi(self, typeInfo, apiName)) + ";\n\n" 863 864 def makeDefinition(self, typeInfo, apiName, isStatic=False): 865 vulkanApi = self.makeApi(self, typeInfo, apiName) 866 867 self.codegen.swapCode() 868 self.codegen.beginBlock() 869 870 if self.definitionFunc is None: 871 print("ERROR: No definition found for (%s, %s)" % 872 (vulkanApi.name, self.customApiPrefix)) 873 sys.exit(1) 874 875 self.definitionFunc(self.codegen, vulkanApi) 876 877 self.codegen.endBlock() 878 879 return ("static " if isStatic else "") + self.codegen.makeFuncProto( 880 vulkanApi) + "\n" + self.codegen.swapCode() + "\n" 881 882# Base class for wrapping all Vulkan API objects. These work with Vulkan 883# Registry generators and have gen* triggers. They tend to contain 884# VulkanAPIWrapper objects to make it easier to generate the code. 885class VulkanWrapperGenerator(object): 886 887 def __init__(self, module: Module, typeInfo: VulkanTypeInfo): 888 self.module: Module = module 889 self.typeInfo: VulkanTypeInfo = typeInfo 890 self.extensionStructTypes = OrderedDict() 891 892 def onBegin(self): 893 pass 894 895 def onEnd(self): 896 pass 897 898 def onBeginFeature(self, featureName, featureType): 899 pass 900 901 def onFeatureNewCmd(self, cmdName): 902 pass 903 904 def onEndFeature(self): 905 pass 906 907 def onGenType(self, typeInfo, name, alias): 908 category = self.typeInfo.categoryOf(name) 909 if category in ["struct", "union"] and not alias: 910 structInfo = self.typeInfo.structs[name] 911 if structInfo.structExtendsExpr: 912 self.extensionStructTypes[name] = structInfo 913 pass 914 915 def onGenStruct(self, typeInfo, name, alias): 916 pass 917 918 def onGenGroup(self, groupinfo, groupName, alias=None): 919 pass 920 921 def onGenEnum(self, enuminfo, name, alias): 922 pass 923 924 def onGenCmd(self, cmdinfo, name, alias): 925 pass 926 927 # Below Vulkan structure types may correspond to multiple Vulkan structs 928 # due to a conflict between different Vulkan registries. In order to get 929 # the correct Vulkan struct type, we need to check the type of its "root" 930 # struct as well. 931 ROOT_TYPE_MAPPING = { 932 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_FEATURES_EXT": { 933 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 934 "VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 935 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportColorBufferGOOGLE", 936 "default": "VkPhysicalDeviceFragmentDensityMapFeaturesEXT", 937 }, 938 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FRAGMENT_DENSITY_MAP_PROPERTIES_EXT": { 939 "VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 940 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportPhysicalAddressGOOGLE", 941 "default": "VkPhysicalDeviceFragmentDensityMapPropertiesEXT", 942 }, 943 "VK_STRUCTURE_TYPE_RENDER_PASS_FRAGMENT_DENSITY_MAP_CREATE_INFO_EXT": { 944 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO": "VkRenderPassFragmentDensityMapCreateInfoEXT", 945 "VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO_2": "VkRenderPassFragmentDensityMapCreateInfoEXT", 946 "VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO": "VkImportBufferGOOGLE", 947 "default": "VkRenderPassFragmentDensityMapCreateInfoEXT", 948 }, 949 } 950 951 def emitForEachStructExtension(self, cgen, retType, triggerVar, forEachFunc, autoBreak=True, defaultEmit=None, nullEmit=None, rootTypeVar=None): 952 def readStructType(structTypeName, structVarName, cgen): 953 cgen.stmt("uint32_t %s = (uint32_t)%s(%s)" % \ 954 (structTypeName, "goldfish_vk_struct_type", structVarName)) 955 956 def castAsStruct(varName, typeName, const=True): 957 return "reinterpret_cast<%s%s*>(%s)" % \ 958 ("const " if const else "", typeName, varName) 959 960 def doDefaultReturn(cgen): 961 if retType.typeName == "void": 962 cgen.stmt("return") 963 else: 964 cgen.stmt("return (%s)0" % retType.typeName) 965 966 cgen.beginIf("!%s" % triggerVar.paramName) 967 if nullEmit is None: 968 doDefaultReturn(cgen) 969 else: 970 nullEmit(cgen) 971 cgen.endIf() 972 973 readStructType("structType", triggerVar.paramName, cgen) 974 975 cgen.line("switch(structType)") 976 cgen.beginBlock() 977 978 currFeature = None 979 980 for ext in self.extensionStructTypes.values(): 981 if not currFeature: 982 cgen.leftline("#ifdef %s" % ext.feature) 983 currFeature = ext.feature 984 985 if currFeature and ext.feature != currFeature: 986 cgen.leftline("#endif") 987 cgen.leftline("#ifdef %s" % ext.feature) 988 currFeature = ext.feature 989 990 enum = ext.structEnumExpr 991 cgen.line("case %s:" % enum) 992 cgen.beginBlock() 993 994 if rootTypeVar is not None and enum in VulkanWrapperGenerator.ROOT_TYPE_MAPPING: 995 cgen.line("switch(%s)" % rootTypeVar.paramName) 996 cgen.beginBlock() 997 kv = VulkanWrapperGenerator.ROOT_TYPE_MAPPING[enum] 998 for k in kv: 999 v = self.extensionStructTypes[kv[k]] 1000 if k == "default": 1001 cgen.line("%s:" % k) 1002 else: 1003 cgen.line("case %s:" % k) 1004 cgen.beginBlock() 1005 castedAccess = castAsStruct( 1006 triggerVar.paramName, v.name, const=triggerVar.isConst) 1007 forEachFunc(v, castedAccess, cgen) 1008 cgen.line("break;") 1009 cgen.endBlock() 1010 cgen.endBlock() 1011 else: 1012 castedAccess = castAsStruct( 1013 triggerVar.paramName, ext.name, const=triggerVar.isConst) 1014 forEachFunc(ext, castedAccess, cgen) 1015 1016 if autoBreak: 1017 cgen.stmt("break") 1018 cgen.endBlock() 1019 1020 if currFeature: 1021 cgen.leftline("#endif") 1022 1023 cgen.line("default:") 1024 cgen.beginBlock() 1025 if defaultEmit is None: 1026 doDefaultReturn(cgen) 1027 else: 1028 defaultEmit(cgen) 1029 cgen.endBlock() 1030 1031 cgen.endBlock() 1032 1033 def emitForEachStructExtensionGeneral(self, cgen, forEachFunc, doFeatureIfdefs=False): 1034 currFeature = None 1035 1036 for (i, ext) in enumerate(self.extensionStructTypes.values()): 1037 if doFeatureIfdefs: 1038 if not currFeature: 1039 cgen.leftline("#ifdef %s" % ext.feature) 1040 currFeature = ext.feature 1041 1042 if currFeature and ext.feature != currFeature: 1043 cgen.leftline("#endif") 1044 cgen.leftline("#ifdef %s" % ext.feature) 1045 currFeature = ext.feature 1046 1047 forEachFunc(i, ext, cgen) 1048 1049 if doFeatureIfdefs: 1050 if currFeature: 1051 cgen.leftline("#endif") 1052