1# Copyright 2023 Google LLC 2# SPDX-License-Identifier: MIT 3 4from copy import copy 5 6from .common.codegen import CodeGen 7from .common.vulkantypes import \ 8 VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator, Atom, FuncExpr, FuncExprVal, FuncLambda 9 10from .wrapperdefs import VulkanWrapperGenerator 11from .wrapperdefs import ROOT_TYPE_VAR_NAME, ROOT_TYPE_PARAM 12from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME 13 14class VulkanCountingCodegen(VulkanTypeIterator): 15 def __init__(self, cgen, featureBitsVar, toCountVar, countVar, rootTypeVar, prefix, forApiOutput=False, mapHandles=True, handleMapOverwrites=False, doFiltering=True): 16 self.cgen = cgen 17 self.featureBitsVar = featureBitsVar 18 self.toCountVar = toCountVar 19 self.rootTypeVar = rootTypeVar 20 self.countVar = countVar 21 self.prefix = prefix 22 self.forApiOutput = forApiOutput 23 24 self.exprAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = True) 25 self.exprValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = False) 26 self.exprPrimitiveValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = False) 27 28 self.lenAccessor = lambda t: self.cgen.generalLengthAccess(t, parentVarName = self.toCountVar) 29 self.lenAccessorGuard = lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName = self.toCountVar) 30 self.filterVarAccessor = lambda t: self.cgen.filterVarAccess(t, parentVarName = self.toCountVar) 31 32 self.checked = False 33 34 self.mapHandles = mapHandles 35 self.handleMapOverwrites = handleMapOverwrites 36 self.doFiltering = doFiltering 37 38 def getTypeForStreaming(self, vulkanType): 39 res = copy(vulkanType) 40 41 if not vulkanType.accessibleAsPointer(): 42 res = res.getForAddressAccess() 43 44 if vulkanType.staticArrExpr: 45 res = res.getForAddressAccess() 46 47 return res 48 49 def makeCastExpr(self, vulkanType): 50 return "(%s)" % ( 51 self.cgen.makeCTypeDecl(vulkanType, useParamName=False)) 52 53 def genCount(self, sizeExpr): 54 self.cgen.stmt("*%s += %s" % (self.countVar, sizeExpr)) 55 56 def genPrimitiveStreamCall(self, vulkanType): 57 self.genCount(str(self.cgen.countPrimitive( 58 self.typeInfo, 59 vulkanType))) 60 61 def genHandleMappingCall(self, vulkanType, access, lenAccess): 62 63 if lenAccess is None: 64 lenAccess = "1" 65 handle64Bytes = "8" 66 else: 67 handle64Bytes = "%s * 8" % lenAccess 68 69 handle64Var = self.cgen.var() 70 if lenAccess != "1": 71 self.cgen.beginIf(lenAccess) 72 # self.cgen.stmt("uint64_t* %s" % handle64Var) 73 # self.cgen.stmt( 74 # "%s->alloc((void**)&%s, %s * 8)" % \ 75 # (self.streamVarName, handle64Var, lenAccess)) 76 handle64VarAccess = handle64Var 77 handle64VarType = \ 78 makeVulkanTypeSimple(False, "uint64_t", 1, paramName=handle64Var) 79 else: 80 self.cgen.stmt("uint64_t %s" % handle64Var) 81 self.cgen.stmt("(void)%s" % handle64Var) 82 handle64VarAccess = "&%s" % handle64Var 83 handle64VarType = \ 84 makeVulkanTypeSimple(False, "uint64_t", 0, paramName=handle64Var) 85 86 if self.handleMapOverwrites: 87 # self.cgen.stmt( 88 # "static_assert(8 == sizeof(%s), \"handle map overwrite requres %s to be 8 bytes long\")" % \ 89 # (vulkanType.typeName, vulkanType.typeName)) 90 # self.cgen.stmt( 91 # "%s->handleMapping()->mapHandles_%s((%s*)%s, %s)" % 92 # (self.streamVarName, vulkanType.typeName, vulkanType.typeName, 93 # access, lenAccess)) 94 self.genCount("8 * %s" % lenAccess) 95 else: 96 # self.cgen.stmt( 97 # "%s->handleMapping()->mapHandles_%s_u64(%s, %s, %s)" % 98 # (self.streamVarName, vulkanType.typeName, 99 # access, 100 # handle64VarAccess, lenAccess)) 101 self.genCount(handle64Bytes) 102 103 if lenAccess != "1": 104 self.cgen.endIf() 105 106 def doAllocSpace(self, vulkanType): 107 pass 108 109 def getOptionalStringFeatureExpr(self, vulkanType): 110 feature = vulkanType.getProtectStreamFeature() 111 if feature is None: 112 return None 113 return "%s & %s" % (self.featureBitsVar, feature) 114 115 def onCheck(self, vulkanType): 116 117 if self.forApiOutput: 118 return 119 120 featureExpr = self.getOptionalStringFeatureExpr(vulkanType); 121 122 self.checked = True 123 124 access = self.exprAccessor(vulkanType) 125 126 needConsistencyCheck = False 127 128 self.cgen.line("// WARNING PTR CHECK") 129 checkAccess = self.exprAccessor(vulkanType) 130 addrExpr = "&" + checkAccess 131 sizeExpr = self.cgen.sizeofExpr(vulkanType) 132 133 if featureExpr is not None: 134 self.cgen.beginIf(featureExpr) 135 136 self.genPrimitiveStreamCall( 137 vulkanType) 138 139 if featureExpr is not None: 140 self.cgen.endIf() 141 142 if featureExpr is not None: 143 self.cgen.beginIf("(!(%s) || %s)" % (featureExpr, access)) 144 else: 145 self.cgen.beginIf(access) 146 147 if needConsistencyCheck and featureExpr is None: 148 self.cgen.beginIf("!(%s)" % checkName) 149 self.cgen.stmt( 150 "fprintf(stderr, \"fatal: %s inconsistent between guest and host\\n\")" % (access)) 151 self.cgen.endIf() 152 153 154 def onCheckWithNullOptionalStringFeature(self, vulkanType): 155 self.cgen.beginIf("%s & VULKAN_STREAM_FEATURE_NULL_OPTIONAL_STRINGS_BIT" % self.featureBitsVar) 156 self.onCheck(vulkanType) 157 158 def endCheckWithNullOptionalStringFeature(self, vulkanType): 159 self.endCheck(vulkanType) 160 self.cgen.endIf() 161 self.cgen.beginElse() 162 163 def finalCheckWithNullOptionalStringFeature(self, vulkanType): 164 self.cgen.endElse() 165 166 def endCheck(self, vulkanType): 167 168 if self.checked: 169 self.cgen.endIf() 170 self.checked = False 171 172 def genFilterFunc(self, filterfunc, env): 173 174 def loop(expr, lambdaEnv={}): 175 def do_func(expr): 176 fnamestr = expr.name.name 177 if "not" == fnamestr: 178 return "!(%s)" % (loop(expr.args[0], lambdaEnv)) 179 if "eq" == fnamestr: 180 return "(%s == %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv)) 181 if "and" == fnamestr: 182 return "(%s && %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv)) 183 if "or" == fnamestr: 184 return "(%s || %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv)) 185 if "bitwise_and" == fnamestr: 186 return "(%s & %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv)) 187 if "getfield" == fnamestr: 188 ptrlevels = get_ptrlevels(expr.args[0].val.name) 189 if ptrlevels == 0: 190 return "%s.%s" % (loop(expr.args[0], lambdaEnv), expr.args[1].val) 191 else: 192 return "(%s(%s)).%s" % ("*" * ptrlevels, loop(expr.args[0], lambdaEnv), expr.args[1].val) 193 194 if "if" == fnamestr: 195 return "((%s) ? (%s) : (%s))" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv), loop(expr.args[2], lambdaEnv)) 196 197 return "%s(%s)" % (fnamestr, ", ".join(map(lambda e: loop(e, lambdaEnv), expr.args))) 198 199 def do_expratom(atomname, lambdaEnv= {}): 200 if lambdaEnv.get(atomname, None) is not None: 201 return atomname 202 203 enventry = env.get(atomname, None) 204 if None != enventry: 205 return self.getEnvAccessExpr(atomname) 206 return atomname 207 208 def get_ptrlevels(atomname, lambdaEnv= {}): 209 if lambdaEnv.get(atomname, None) is not None: 210 return 0 211 212 enventry = env.get(atomname, None) 213 if None != enventry: 214 return self.getPointerIndirectionLevels(atomname) 215 216 return 0 217 218 def do_exprval(expr, lambdaEnv= {}): 219 expratom = expr.val 220 221 if Atom == type(expratom): 222 return do_expratom(expratom.name, lambdaEnv) 223 224 return "%s" % expratom 225 226 def do_lambda(expr, lambdaEnv= {}): 227 params = expr.vs 228 body = expr.body 229 newEnv = {} 230 231 for (k, v) in lambdaEnv.items(): 232 newEnv[k] = v 233 234 for p in params: 235 newEnv[p.name] = p.typ 236 237 return "[](%s) { return %s; }" % (", ".join(list(map(lambda p: "%s %s" % (p.typ, p.name), params))), loop(body, lambdaEnv=newEnv)) 238 239 if FuncExpr == type(expr): 240 return do_func(expr) 241 if FuncLambda == type(expr): 242 return do_lambda(expr) 243 elif FuncExprVal == type(expr): 244 return do_exprval(expr) 245 246 return loop(filterfunc) 247 248 def beginFilterGuard(self, vulkanType): 249 if vulkanType.filterVar == None: 250 return 251 252 if self.doFiltering == False: 253 return 254 255 filterVarAccess = self.getEnvAccessExpr(vulkanType.filterVar) 256 257 filterValsExpr = None 258 filterFuncExpr = None 259 filterExpr = None 260 261 filterFeature = "%s & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.featureBitsVar 262 263 if None != vulkanType.filterVals: 264 filterValsExpr = " || ".join(map(lambda filterval: "(%s == %s)" % (filterval, filterVarAccess), vulkanType.filterVals)) 265 266 if None != vulkanType.filterFunc: 267 filterFuncExpr = self.genFilterFunc(vulkanType.filterFunc, self.currentStructInfo.environment) 268 269 if None != filterValsExpr and None != filterFuncExpr: 270 filterExpr = "%s || %s" % (filterValsExpr, filterFuncExpr) 271 elif None == filterValsExpr and None == filterFuncExpr: 272 # Assume is bool 273 self.cgen.beginIf(filterVarAccess) 274 elif None != filterValsExpr: 275 self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterValsExpr)) 276 elif None != filterFuncExpr: 277 self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterFuncExpr)) 278 279 def endFilterGuard(self, vulkanType, cleanupExpr=None): 280 if vulkanType.filterVar == None: 281 return 282 283 if self.doFiltering == False: 284 return 285 286 if cleanupExpr == None: 287 self.cgen.endIf() 288 else: 289 self.cgen.endIf() 290 self.cgen.beginElse() 291 self.cgen.stmt(cleanupExpr) 292 self.cgen.endElse() 293 294 def getEnvAccessExpr(self, varName): 295 parentEnvEntry = self.currentStructInfo.environment.get(varName, None) 296 297 if parentEnvEntry != None: 298 isParentMember = parentEnvEntry["structmember"] 299 300 if isParentMember: 301 envAccess = self.exprValueAccessor(list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0]) 302 else: 303 envAccess = varName 304 return envAccess 305 306 return None 307 308 def getPointerIndirectionLevels(self, varName): 309 parentEnvEntry = self.currentStructInfo.environment.get(varName, None) 310 311 if parentEnvEntry != None: 312 isParentMember = parentEnvEntry["structmember"] 313 314 if isParentMember: 315 return list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0].pointerIndirectionLevels 316 else: 317 return 0 318 return 0 319 320 return 0 321 322 323 def onCompoundType(self, vulkanType): 324 325 access = self.exprAccessor(vulkanType) 326 lenAccess = self.lenAccessor(vulkanType) 327 lenAccessGuard = self.lenAccessorGuard(vulkanType) 328 329 self.beginFilterGuard(vulkanType) 330 331 if vulkanType.pointerIndirectionLevels > 0: 332 self.doAllocSpace(vulkanType) 333 334 if lenAccess is not None: 335 if lenAccessGuard is not None: 336 self.cgen.beginIf(lenAccessGuard) 337 loopVar = "i" 338 access = "%s + %s" % (access, loopVar) 339 forInit = "uint32_t %s = 0" % loopVar 340 forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess) 341 forIncr = "++%s" % loopVar 342 self.cgen.beginFor(forInit, forCond, forIncr) 343 344 accessWithCast = "%s(%s)" % (self.makeCastExpr( 345 self.getTypeForStreaming(vulkanType)), access) 346 347 callParams = [self.featureBitsVar, 348 self.rootTypeVar, accessWithCast, self.countVar] 349 350 for (bindName, localName) in vulkanType.binds.items(): 351 callParams.append(self.getEnvAccessExpr(localName)) 352 353 self.cgen.funcCall(None, self.prefix + vulkanType.typeName, 354 callParams) 355 356 if lenAccess is not None: 357 self.cgen.endFor() 358 if lenAccessGuard is not None: 359 self.cgen.endIf() 360 361 self.endFilterGuard(vulkanType) 362 363 def onString(self, vulkanType): 364 access = self.exprAccessor(vulkanType) 365 self.genCount("sizeof(uint32_t) + (%s ? strlen(%s) : 0)" % (access, access)) 366 367 def onStringArray(self, vulkanType): 368 access = self.exprAccessor(vulkanType) 369 lenAccess = self.lenAccessor(vulkanType) 370 lenAccessGuard = self.lenAccessorGuard(vulkanType) 371 372 self.genCount("sizeof(uint32_t)") 373 if lenAccessGuard is not None: 374 self.cgen.beginIf(lenAccessGuard) 375 self.cgen.beginFor("uint32_t i = 0", "i < %s" % lenAccess, "++i") 376 self.genCount("sizeof(uint32_t) + (%s[i] ? strlen(%s[i]) : 0)" % (access, access)) 377 self.cgen.endFor() 378 if lenAccessGuard is not None: 379 self.cgen.endIf() 380 381 def onStaticArr(self, vulkanType): 382 access = self.exprValueAccessor(vulkanType) 383 lenAccess = self.lenAccessor(vulkanType) 384 lenAccessGuard = self.lenAccessorGuard(vulkanType) 385 386 if lenAccessGuard is not None: 387 self.cgen.beginIf(lenAccessGuard) 388 finalLenExpr = "%s * %s" % (lenAccess, self.cgen.sizeofExpr(vulkanType)) 389 if lenAccessGuard is not None: 390 self.cgen.endIf() 391 self.genCount(finalLenExpr) 392 393 def onStructExtension(self, vulkanType): 394 sTypeParam = copy(vulkanType) 395 sTypeParam.paramName = "sType" 396 397 access = self.exprAccessor(vulkanType) 398 sizeVar = "%s_size" % vulkanType.paramName 399 400 castedAccessExpr = access 401 402 sTypeAccess = self.exprAccessor(sTypeParam) 403 self.cgen.beginIf("%s == VK_STRUCTURE_TYPE_MAX_ENUM" % 404 self.rootTypeVar) 405 self.cgen.stmt("%s = %s" % (self.rootTypeVar, sTypeAccess)) 406 self.cgen.endIf() 407 408 self.cgen.funcCall(None, self.prefix + "extension_struct", 409 [self.featureBitsVar, self.rootTypeVar, castedAccessExpr, self.countVar]) 410 411 412 def onPointer(self, vulkanType): 413 access = self.exprAccessor(vulkanType) 414 415 lenAccess = self.lenAccessor(vulkanType) 416 lenAccessGuard = self.lenAccessorGuard(vulkanType) 417 418 self.beginFilterGuard(vulkanType) 419 self.doAllocSpace(vulkanType) 420 421 if vulkanType.isHandleType() and self.mapHandles: 422 self.genHandleMappingCall(vulkanType, access, lenAccess) 423 else: 424 if self.typeInfo.isNonAbiPortableType(vulkanType.typeName): 425 if lenAccess is not None: 426 if lenAccessGuard is not None: 427 self.cgen.beginIf(lenAccessGuard) 428 self.cgen.beginFor("uint32_t i = 0", "i < (uint32_t)%s" % lenAccess, "++i") 429 self.genPrimitiveStreamCall(vulkanType.getForValueAccess()) 430 self.cgen.endFor() 431 if lenAccessGuard is not None: 432 self.cgen.endIf() 433 else: 434 self.genPrimitiveStreamCall(vulkanType.getForValueAccess()) 435 else: 436 if lenAccess is not None: 437 needLenAccessGuard = True 438 finalLenExpr = "%s * %s" % ( 439 lenAccess, self.cgen.sizeofExpr(vulkanType.getForValueAccess())) 440 else: 441 needLenAccessGuard = False 442 finalLenExpr = "%s" % ( 443 self.cgen.sizeofExpr(vulkanType.getForValueAccess())) 444 if needLenAccessGuard and lenAccessGuard is not None: 445 self.cgen.beginIf(lenAccessGuard) 446 self.genCount(finalLenExpr) 447 if needLenAccessGuard and lenAccessGuard is not None: 448 self.cgen.endIf() 449 450 self.endFilterGuard(vulkanType) 451 452 def onValue(self, vulkanType): 453 self.beginFilterGuard(vulkanType) 454 455 if vulkanType.isHandleType() and self.mapHandles: 456 access = self.exprAccessor(vulkanType) 457 self.genHandleMappingCall( 458 vulkanType.getForAddressAccess(), access, "1") 459 elif self.typeInfo.isNonAbiPortableType(vulkanType.typeName): 460 access = self.exprPrimitiveValueAccessor(vulkanType) 461 self.genPrimitiveStreamCall(vulkanType) 462 else: 463 access = self.exprAccessor(vulkanType) 464 self.genCount(self.cgen.sizeofExpr(vulkanType)) 465 466 self.endFilterGuard(vulkanType) 467 468 def streamLetParameter(self, structInfo, letParamInfo): 469 filterFeature = "%s & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % (self.featureBitsVar) 470 self.cgen.stmt("%s %s = 1" % (letParamInfo.typeName, letParamInfo.paramName)) 471 472 self.cgen.beginIf(filterFeature) 473 474 bodyExpr = self.currentStructInfo.environment[letParamInfo.paramName]["body"] 475 self.cgen.stmt("%s = %s" % (letParamInfo.paramName, self.genFilterFunc(bodyExpr, self.currentStructInfo.environment))) 476 477 self.genPrimitiveStreamCall(letParamInfo) 478 479 self.cgen.endIf() 480 481class VulkanCounting(VulkanWrapperGenerator): 482 483 def __init__(self, module, typeInfo): 484 VulkanWrapperGenerator.__init__(self, module, typeInfo) 485 486 self.codegen = CodeGen() 487 488 self.featureBitsVar = "featureBits" 489 self.featureBitsVarType = makeVulkanTypeSimple(False, "uint32_t", 0, self.featureBitsVar) 490 self.countingPrefix = "count_" 491 self.countVars = ["toCount", "count"] 492 self.countVarType = makeVulkanTypeSimple(False, "size_t", 1, self.countVars[1]) 493 self.voidType = makeVulkanTypeSimple(False, "void", 0) 494 self.rootTypeVar = ROOT_TYPE_VAR_NAME 495 496 self.countingCodegen = \ 497 VulkanCountingCodegen( 498 self.codegen, 499 self.featureBitsVar, 500 self.countVars[0], 501 self.countVars[1], 502 self.rootTypeVar, 503 self.countingPrefix) 504 505 self.knownDefs = {} 506 507 self.extensionCountingPrototype = \ 508 VulkanAPI(self.countingPrefix + "extension_struct", 509 self.voidType, 510 [self.featureBitsVarType, 511 ROOT_TYPE_PARAM, 512 STRUCT_EXTENSION_PARAM, 513 self.countVarType]) 514 515 def onBegin(self,): 516 VulkanWrapperGenerator.onBegin(self) 517 self.module.appendImpl(self.codegen.makeFuncDecl( 518 self.extensionCountingPrototype)) 519 520 def onGenType(self, typeXml, name, alias): 521 VulkanWrapperGenerator.onGenType(self, typeXml, name, alias) 522 523 if name in self.knownDefs: 524 return 525 526 category = self.typeInfo.categoryOf(name) 527 528 if category in ["struct", "union"] and alias: 529 # TODO(liyl): might not work if freeParams != [] 530 self.module.appendHeader( 531 self.codegen.makeFuncAlias(self.countingPrefix + name, 532 self.countingPrefix + alias)) 533 534 if category in ["struct", "union"] and not alias: 535 536 structInfo = self.typeInfo.structs[name] 537 538 freeParams = [] 539 letParams = [] 540 541 for (envname, bindingInfo) in list(sorted(structInfo.environment.items(), key = lambda kv: kv[0])): 542 if None == bindingInfo["binding"]: 543 freeParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname)) 544 else: 545 if not bindingInfo["structmember"]: 546 letParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname)) 547 548 typeFromName = \ 549 lambda varname: \ 550 makeVulkanTypeSimple(True, name, 1, varname) 551 552 countingParams = \ 553 [makeVulkanTypeSimple(False, "uint32_t", 0, self.featureBitsVar), 554 ROOT_TYPE_PARAM, 555 typeFromName(self.countVars[0]), 556 makeVulkanTypeSimple(False, "size_t", 1, self.countVars[1])] 557 558 countingPrototype = \ 559 VulkanAPI(self.countingPrefix + name, 560 self.voidType, 561 countingParams + freeParams) 562 563 countingPrototypeNoFilter = \ 564 VulkanAPI(self.countingPrefix + name, 565 self.voidType, 566 countingParams) 567 568 def structCountingDef(cgen): 569 self.countingCodegen.cgen = cgen 570 self.countingCodegen.currentStructInfo = structInfo 571 cgen.stmt("(void)%s" % self.featureBitsVar); 572 cgen.stmt("(void)%s" % self.rootTypeVar); 573 cgen.stmt("(void)%s" % self.countVars[0]); 574 cgen.stmt("(void)%s" % self.countVars[1]); 575 576 if category == "struct": 577 # marshal 'let' parameters first 578 for letp in letParams: 579 self.countingCodegen.streamLetParameter(self.typeInfo, letp) 580 581 for member in structInfo.members: 582 iterateVulkanType(self.typeInfo, member, self.countingCodegen) 583 if category == "union": 584 iterateVulkanType(self.typeInfo, structInfo.members[0], self.countingCodegen) 585 586 def structCountingDefNoFilter(cgen): 587 self.countingCodegen.cgen = cgen 588 self.countingCodegen.currentStructInfo = structInfo 589 self.countingCodegen.doFiltering = False 590 cgen.stmt("(void)%s" % self.featureBitsVar); 591 cgen.stmt("(void)%s" % self.rootTypeVar); 592 cgen.stmt("(void)%s" % self.countVars[0]); 593 cgen.stmt("(void)%s" % self.countVars[1]); 594 595 if category == "struct": 596 # marshal 'let' parameters first 597 for letp in letParams: 598 self.countingCodegen.streamLetParameter(self.typeInfo, letp) 599 600 for member in structInfo.members: 601 iterateVulkanType(self.typeInfo, member, self.countingCodegen) 602 if category == "union": 603 iterateVulkanType(self.typeInfo, structInfo.members[0], self.countingCodegen) 604 605 self.countingCodegen.doFiltering = True 606 607 self.module.appendHeader( 608 self.codegen.makeFuncDecl(countingPrototype)) 609 self.module.appendImpl( 610 self.codegen.makeFuncImpl(countingPrototype, structCountingDef)) 611 612 if freeParams != []: 613 self.module.appendHeader( 614 self.cgenHeader.makeFuncDecl(countingPrototypeNoFilter)) 615 self.module.appendImpl( 616 self.cgenImpl.makeFuncImpl( 617 countingPrototypeNoFilter, structCountingDefNoFilter)) 618 619 def onGenCmd(self, cmdinfo, name, alias): 620 VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias) 621 622 def doExtensionStructCountCodegen(self, cgen, extParam, forEach, funcproto): 623 accessVar = "structAccess" 624 sizeVar = "currExtSize" 625 cgen.stmt("VkInstanceCreateInfo* %s = (VkInstanceCreateInfo*)(%s)" % (accessVar, extParam.paramName)) 626 cgen.stmt("size_t %s = %s(%s, %s, %s)" % (sizeVar, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME, 627 self.featureBitsVar, ROOT_TYPE_VAR_NAME, extParam.paramName)) 628 629 cgen.beginIf("!%s && %s" % (sizeVar, extParam.paramName)) 630 631 cgen.line("// unknown struct extension; skip and call on its pNext field"); 632 cgen.funcCall(None, funcproto.name, [ 633 self.featureBitsVar, ROOT_TYPE_VAR_NAME, "(void*)%s->pNext" % accessVar, self.countVars[1]]) 634 cgen.stmt("return") 635 636 cgen.endIf() 637 cgen.beginElse() 638 639 cgen.line("// known or null extension struct") 640 641 cgen.stmt("*%s += sizeof(uint32_t)" % self.countVars[1]) 642 643 cgen.beginIf("!%s" % (sizeVar)) 644 cgen.line("// exit if this was a null extension struct (size == 0 in this branch)") 645 cgen.stmt("return") 646 cgen.endIf() 647 648 cgen.endIf() 649 650 cgen.stmt("*%s += sizeof(VkStructureType)" % self.countVars[1]) 651 652 def fatalDefault(cgen): 653 cgen.line("// fatal; the switch is only taken if the extension struct is known"); 654 cgen.stmt("abort()") 655 pass 656 657 self.emitForEachStructExtension( 658 cgen, 659 makeVulkanTypeSimple(False, "void", 0, "void"), 660 extParam, 661 forEach, 662 defaultEmit=fatalDefault, 663 rootTypeVar=ROOT_TYPE_PARAM) 664 665 def onEnd(self,): 666 VulkanWrapperGenerator.onEnd(self) 667 668 def forEachExtensionCounting(ext, castedAccess, cgen): 669 cgen.funcCall(None, self.countingPrefix + ext.name, 670 [self.featureBitsVar, self.rootTypeVar, castedAccess, self.countVars[1]]) 671 672 self.module.appendImpl( 673 self.codegen.makeFuncImpl( 674 self.extensionCountingPrototype, 675 lambda cgen: self.doExtensionStructCountCodegen( 676 cgen, 677 STRUCT_EXTENSION_PARAM, 678 forEachExtensionCounting, 679 self.extensionCountingPrototype))) 680