• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from copy import copy
17
18from .common.codegen import CodeGen
19from .common.vulkantypes import \
20        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator, Atom, FuncExpr, FuncExprVal, FuncLambda
21
22from .wrapperdefs import VulkanWrapperGenerator
23from .wrapperdefs import ROOT_TYPE_VAR_NAME, ROOT_TYPE_PARAM
24from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME
25
26class VulkanCountingCodegen(VulkanTypeIterator):
27    def __init__(self, cgen, featureBitsVar, toCountVar, countVar, rootTypeVar, prefix, forApiOutput=False, mapHandles=True, handleMapOverwrites=False, doFiltering=True):
28        self.cgen = cgen
29        self.featureBitsVar = featureBitsVar
30        self.toCountVar = toCountVar
31        self.rootTypeVar = rootTypeVar
32        self.countVar = countVar
33        self.prefix = prefix
34        self.forApiOutput = forApiOutput
35
36        self.exprAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = True)
37        self.exprValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = False)
38        self.exprPrimitiveValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = False)
39
40        self.lenAccessor = lambda t: self.cgen.generalLengthAccess(t, parentVarName = self.toCountVar)
41        self.lenAccessorGuard = lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName = self.toCountVar)
42        self.filterVarAccessor = lambda t: self.cgen.filterVarAccess(t, parentVarName = self.toCountVar)
43
44        self.checked = False
45
46        self.mapHandles = mapHandles
47        self.handleMapOverwrites = handleMapOverwrites
48        self.doFiltering = doFiltering
49
50    def getTypeForStreaming(self, vulkanType):
51        res = copy(vulkanType)
52
53        if not vulkanType.accessibleAsPointer():
54            res = res.getForAddressAccess()
55
56        if vulkanType.staticArrExpr:
57            res = res.getForAddressAccess()
58
59        return res
60
61    def makeCastExpr(self, vulkanType):
62        return "(%s)" % (
63            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
64
65    def genCount(self, sizeExpr):
66        self.cgen.stmt("*%s += %s" % (self.countVar, sizeExpr))
67
68    def genPrimitiveStreamCall(self, vulkanType):
69        self.genCount(str(self.cgen.countPrimitive(
70            self.typeInfo,
71            vulkanType)))
72
73    def genHandleMappingCall(self, vulkanType, access, lenAccess):
74
75        if lenAccess is None:
76            lenAccess = "1"
77            handle64Bytes = "8"
78        else:
79            handle64Bytes = "%s * 8" % lenAccess
80
81        handle64Var = self.cgen.var()
82        if lenAccess != "1":
83            self.cgen.beginIf(lenAccess)
84            # self.cgen.stmt("uint64_t* %s" % handle64Var)
85            # self.cgen.stmt(
86                # "%s->alloc((void**)&%s, %s * 8)" % \
87                # (self.streamVarName, handle64Var, lenAccess))
88            handle64VarAccess = handle64Var
89            handle64VarType = \
90                makeVulkanTypeSimple(False, "uint64_t", 1, paramName=handle64Var)
91        else:
92            self.cgen.stmt("uint64_t %s" % handle64Var)
93            handle64VarAccess = "&%s" % handle64Var
94            handle64VarType = \
95                makeVulkanTypeSimple(False, "uint64_t", 0, paramName=handle64Var)
96
97        if self.handleMapOverwrites:
98            # self.cgen.stmt(
99                # "static_assert(8 == sizeof(%s), \"handle map overwrite requres %s to be 8 bytes long\")" % \
100                        # (vulkanType.typeName, vulkanType.typeName))
101            # self.cgen.stmt(
102                # "%s->handleMapping()->mapHandles_%s((%s*)%s, %s)" %
103                # (self.streamVarName, vulkanType.typeName, vulkanType.typeName,
104                # access, lenAccess))
105            self.genCount("8 * %s" % lenAccess)
106        else:
107            # self.cgen.stmt(
108                # "%s->handleMapping()->mapHandles_%s_u64(%s, %s, %s)" %
109                # (self.streamVarName, vulkanType.typeName,
110                # access,
111                # handle64VarAccess, lenAccess))
112            self.genCount(handle64Bytes)
113
114        if lenAccess != "1":
115            self.cgen.endIf()
116
117    def doAllocSpace(self, vulkanType):
118        pass
119
120    def getOptionalStringFeatureExpr(self, vulkanType):
121        if vulkanType.optionalStr is not None:
122            if vulkanType.optionalStr.startswith("streamFeature:"):
123                splitted = vulkanType.optionalStr.split(":")
124                featureExpr = "%s & %s" % (self.featureBitsVar, splitted[1])
125                return featureExpr
126        return None
127
128    def onCheck(self, vulkanType):
129
130        if self.forApiOutput:
131            return
132
133        featureExpr = self.getOptionalStringFeatureExpr(vulkanType);
134
135        self.checked = True
136
137        access = self.exprAccessor(vulkanType)
138
139        needConsistencyCheck = False
140
141        self.cgen.line("// WARNING PTR CHECK")
142        checkAccess = self.exprAccessor(vulkanType)
143        addrExpr = "&" + checkAccess
144        sizeExpr = self.cgen.sizeofExpr(vulkanType)
145
146        if featureExpr is not None:
147            self.cgen.beginIf(featureExpr)
148
149        self.genPrimitiveStreamCall(
150            vulkanType)
151
152        if featureExpr is not None:
153            self.cgen.endIf()
154
155        if featureExpr is not None:
156            self.cgen.beginIf("(!(%s) || %s)" % (featureExpr, access))
157        else:
158            self.cgen.beginIf(access)
159
160        if needConsistencyCheck and featureExpr is None:
161            self.cgen.beginIf("!(%s)" % checkName)
162            self.cgen.stmt(
163                "fprintf(stderr, \"fatal: %s inconsistent between guest and host\\n\")" % (access))
164            self.cgen.endIf()
165
166
167    def onCheckWithNullOptionalStringFeature(self, vulkanType):
168        self.cgen.beginIf("%s & VULKAN_STREAM_FEATURE_NULL_OPTIONAL_STRINGS_BIT" % self.featureBitsVar)
169        self.onCheck(vulkanType)
170
171    def endCheckWithNullOptionalStringFeature(self, vulkanType):
172        self.endCheck(vulkanType)
173        self.cgen.endIf()
174        self.cgen.beginElse()
175
176    def finalCheckWithNullOptionalStringFeature(self, vulkanType):
177        self.cgen.endElse()
178
179    def endCheck(self, vulkanType):
180
181        if self.checked:
182            self.cgen.endIf()
183            self.checked = False
184
185    def genFilterFunc(self, filterfunc, env):
186
187        def loop(expr, lambdaEnv={}):
188            def do_func(expr):
189                fnamestr = expr.name.name
190                if "not" == fnamestr:
191                    return "!(%s)" % (loop(expr.args[0], lambdaEnv))
192                if "eq" == fnamestr:
193                    return "(%s == %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
194                if "and" == fnamestr:
195                    return "(%s && %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
196                if "or" == fnamestr:
197                    return "(%s || %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
198                if "bitwise_and" == fnamestr:
199                    return "(%s & %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
200                if "getfield" == fnamestr:
201                    ptrlevels = get_ptrlevels(expr.args[0].val.name)
202                    if ptrlevels == 0:
203                        return "%s.%s" % (loop(expr.args[0], lambdaEnv), expr.args[1].val)
204                    else:
205                        return "(%s(%s)).%s" % ("*" * ptrlevels, loop(expr.args[0], lambdaEnv), expr.args[1].val)
206
207                if "if" == fnamestr:
208                    return "((%s) ? (%s) : (%s))" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv), loop(expr.args[2], lambdaEnv))
209
210                return "%s(%s)" % (fnamestr, ", ".join(map(lambda e: loop(e, lambdaEnv), expr.args)))
211
212            def do_expratom(atomname, lambdaEnv= {}):
213                if lambdaEnv.get(atomname, None) is not None:
214                    return atomname
215
216                enventry = env.get(atomname, None)
217                if None != enventry:
218                    return self.getEnvAccessExpr(atomname)
219                return atomname
220
221            def get_ptrlevels(atomname, lambdaEnv= {}):
222                if lambdaEnv.get(atomname, None) is not None:
223                    return 0
224
225                enventry = env.get(atomname, None)
226                if None != enventry:
227                    return self.getPointerIndirectionLevels(atomname)
228
229                return 0
230
231            def do_exprval(expr, lambdaEnv= {}):
232                expratom = expr.val
233
234                if Atom == type(expratom):
235                    return do_expratom(expratom.name, lambdaEnv)
236
237                return "%s" % expratom
238
239            def do_lambda(expr, lambdaEnv= {}):
240                params = expr.vs
241                body = expr.body
242                newEnv = {}
243
244                for (k, v) in lambdaEnv.items():
245                    newEnv[k] = v
246
247                for p in params:
248                    newEnv[p.name] = p.typ
249
250                return "[](%s) { return %s; }" % (", ".join(list(map(lambda p: "%s %s" % (p.typ, p.name), params))), loop(body, lambdaEnv=newEnv))
251
252            if FuncExpr == type(expr):
253                return do_func(expr)
254            if FuncLambda == type(expr):
255                return do_lambda(expr)
256            elif FuncExprVal == type(expr):
257                return do_exprval(expr)
258
259        return loop(filterfunc)
260
261    def beginFilterGuard(self, vulkanType):
262        if vulkanType.filterVar == None:
263            return
264
265        if self.doFiltering == False:
266            return
267
268        filterVarAccess = self.getEnvAccessExpr(vulkanType.filterVar)
269
270        filterValsExpr = None
271        filterFuncExpr = None
272        filterExpr = None
273
274        filterFeature = "%s & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.featureBitsVar
275
276        if None != vulkanType.filterVals:
277            filterValsExpr = " || ".join(map(lambda filterval: "(%s == %s)" % (filterval, filterVarAccess), vulkanType.filterVals))
278
279        if None != vulkanType.filterFunc:
280            filterFuncExpr = self.genFilterFunc(vulkanType.filterFunc, self.currentStructInfo.environment)
281
282        if None != filterValsExpr and None != filterFuncExpr:
283            filterExpr = "%s || %s" % (filterValsExpr, filterFuncExpr)
284        elif None == filterValsExpr and None == filterFuncExpr:
285            # Assume is bool
286            self.cgen.beginIf(filterVarAccess)
287        elif None != filterValsExpr:
288            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterValsExpr))
289        elif None != filterFuncExpr:
290            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterFuncExpr))
291
292    def endFilterGuard(self, vulkanType, cleanupExpr=None):
293        if vulkanType.filterVar == None:
294            return
295
296        if self.doFiltering == False:
297            return
298
299        if cleanupExpr == None:
300            self.cgen.endIf()
301        else:
302            self.cgen.endIf()
303            self.cgen.beginElse()
304            self.cgen.stmt(cleanupExpr)
305            self.cgen.endElse()
306
307    def getEnvAccessExpr(self, varName):
308        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
309
310        if parentEnvEntry != None:
311            isParentMember = parentEnvEntry["structmember"]
312
313            if isParentMember:
314                envAccess = self.exprValueAccessor(list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0])
315            else:
316                envAccess = varName
317            return envAccess
318
319        return None
320
321    def getPointerIndirectionLevels(self, varName):
322        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
323
324        if parentEnvEntry != None:
325            isParentMember = parentEnvEntry["structmember"]
326
327            if isParentMember:
328                return list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0].pointerIndirectionLevels
329            else:
330                return 0
331            return 0
332
333        return 0
334
335
336    def onCompoundType(self, vulkanType):
337
338        access = self.exprAccessor(vulkanType)
339        lenAccess = self.lenAccessor(vulkanType)
340        lenAccessGuard = self.lenAccessorGuard(vulkanType)
341
342        self.beginFilterGuard(vulkanType)
343
344        if vulkanType.pointerIndirectionLevels > 0:
345            self.doAllocSpace(vulkanType)
346
347        if lenAccess is not None:
348            if lenAccessGuard is not None:
349                self.cgen.beginIf(lenAccessGuard)
350            loopVar = "i"
351            access = "%s + %s" % (access, loopVar)
352            forInit = "uint32_t %s = 0" % loopVar
353            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess)
354            forIncr = "++%s" % loopVar
355            self.cgen.beginFor(forInit, forCond, forIncr)
356
357        accessWithCast = "%s(%s)" % (self.makeCastExpr(
358            self.getTypeForStreaming(vulkanType)), access)
359
360        callParams = [self.featureBitsVar,
361                      self.rootTypeVar, accessWithCast, self.countVar]
362
363        for (bindName, localName) in vulkanType.binds.items():
364            callParams.append(self.getEnvAccessExpr(localName))
365
366        self.cgen.funcCall(None, self.prefix + vulkanType.typeName,
367                           callParams)
368
369        if lenAccess is not None:
370            self.cgen.endFor()
371            if lenAccessGuard is not None:
372                self.cgen.endIf()
373
374        self.endFilterGuard(vulkanType)
375
376    def onString(self, vulkanType):
377        access = self.exprAccessor(vulkanType)
378        self.genCount("sizeof(uint32_t) + (%s ? strlen(%s) : 0)" % (access, access))
379
380    def onStringArray(self, vulkanType):
381        access = self.exprAccessor(vulkanType)
382        lenAccess = self.lenAccessor(vulkanType)
383        lenAccessGuard = self.lenAccessorGuard(vulkanType)
384
385        self.genCount("sizeof(uint32_t)")
386        if lenAccessGuard is not None:
387            self.cgen.beginIf(lenAccessGuard)
388        self.cgen.beginFor("uint32_t i = 0", "i < %s" % lenAccess, "++i")
389        self.cgen.stmt("size_t l = %s[i] ? strlen(%s[i]) : 0" % (access, access))
390        self.genCount("sizeof(uint32_t) + (%s[i] ? strlen(%s[i]) : 0)" % (access, access))
391        self.cgen.endFor()
392        if lenAccessGuard is not None:
393            self.cgen.endIf()
394
395    def onStaticArr(self, vulkanType):
396        access = self.exprValueAccessor(vulkanType)
397        lenAccess = self.lenAccessor(vulkanType)
398        lenAccessGuard = self.lenAccessorGuard(vulkanType)
399
400        if lenAccessGuard is not None:
401            self.cgen.beginIf(lenAccessGuard)
402        finalLenExpr = "%s * %s" % (lenAccess, self.cgen.sizeofExpr(vulkanType))
403        if lenAccessGuard is not None:
404            self.cgen.endIf()
405        self.genCount(finalLenExpr)
406
407    def onStructExtension(self, vulkanType):
408        sTypeParam = copy(vulkanType)
409        sTypeParam.paramName = "sType"
410
411        access = self.exprAccessor(vulkanType)
412        sizeVar = "%s_size" % vulkanType.paramName
413
414        castedAccessExpr = access
415
416        sTypeAccess = self.exprAccessor(sTypeParam)
417        self.cgen.beginIf("%s == VK_STRUCTURE_TYPE_MAX_ENUM" %
418                          self.rootTypeVar)
419        self.cgen.stmt("%s = %s" % (self.rootTypeVar, sTypeAccess))
420        self.cgen.endIf()
421
422        self.cgen.funcCall(None, self.prefix + "extension_struct",
423                           [self.featureBitsVar, self.rootTypeVar, castedAccessExpr, self.countVar])
424
425
426    def onPointer(self, vulkanType):
427        access = self.exprAccessor(vulkanType)
428
429        lenAccess = self.lenAccessor(vulkanType)
430        lenAccessGuard = self.lenAccessorGuard(vulkanType)
431
432        self.beginFilterGuard(vulkanType)
433        self.doAllocSpace(vulkanType)
434
435        if vulkanType.filterVar != None:
436            print("onPointer Needs filter: %s filterVar %s" % (access, vulkanType.filterVar))
437
438        if vulkanType.isHandleType() and self.mapHandles:
439            self.genHandleMappingCall(vulkanType, access, lenAccess)
440        else:
441            if self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
442                if lenAccess is not None:
443                    if lenAccessGuard is not None:
444                        self.cgen.beginIf(lenAccessGuard)
445                    self.cgen.beginFor("uint32_t i = 0", "i < (uint32_t)%s" % lenAccess, "++i")
446                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess())
447                    self.cgen.endFor()
448                    if lenAccessGuard is not None:
449                        self.cgen.endIf()
450                else:
451                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess())
452            else:
453                if lenAccess is not None:
454                    needLenAccessGuard = True
455                    finalLenExpr = "%s * %s" % (
456                        lenAccess, self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
457                else:
458                    needLenAccessGuard = False
459                    finalLenExpr = "%s" % (
460                        self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
461                if needLenAccessGuard and lenAccessGuard is not None:
462                    self.cgen.beginIf(lenAccessGuard)
463                self.genCount(finalLenExpr)
464                if needLenAccessGuard and lenAccessGuard is not None:
465                    self.cgen.endIf()
466
467        self.endFilterGuard(vulkanType)
468
469    def onValue(self, vulkanType):
470        self.beginFilterGuard(vulkanType)
471
472        if vulkanType.isHandleType() and self.mapHandles:
473            access = self.exprAccessor(vulkanType)
474            if vulkanType.filterVar != None:
475                print("onValue Needs filter: %s filterVar %s" % (access, vulkanType.filterVar))
476            self.genHandleMappingCall(
477                vulkanType.getForAddressAccess(), access, "1")
478        elif self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
479            access = self.exprPrimitiveValueAccessor(vulkanType)
480            self.genPrimitiveStreamCall(vulkanType)
481        else:
482            access = self.exprAccessor(vulkanType)
483            self.genCount(self.cgen.sizeofExpr(vulkanType))
484
485        self.endFilterGuard(vulkanType)
486
487    def streamLetParameter(self, structInfo, letParamInfo):
488        filterFeature = "%s & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % (self.featureBitsVar)
489        self.cgen.stmt("%s %s = 1" % (letParamInfo.typeName, letParamInfo.paramName))
490
491        self.cgen.beginIf(filterFeature)
492
493        bodyExpr = self.currentStructInfo.environment[letParamInfo.paramName]["body"]
494        self.cgen.stmt("%s = %s" % (letParamInfo.paramName, self.genFilterFunc(bodyExpr, self.currentStructInfo.environment)))
495
496        self.genPrimitiveStreamCall(letParamInfo)
497
498        self.cgen.endIf()
499
500class VulkanCounting(VulkanWrapperGenerator):
501
502    def __init__(self, module, typeInfo):
503        VulkanWrapperGenerator.__init__(self, module, typeInfo)
504
505        self.codegen = CodeGen()
506
507        self.featureBitsVar = "featureBits"
508        self.featureBitsVarType = makeVulkanTypeSimple(False, "uint32_t", 0, self.featureBitsVar)
509        self.countingPrefix = "count_"
510        self.countVars = ["toCount", "count"]
511        self.countVarType = makeVulkanTypeSimple(False, "size_t", 1, self.countVars[1])
512        self.voidType = makeVulkanTypeSimple(False, "void", 0)
513        self.rootTypeVar = ROOT_TYPE_VAR_NAME
514
515        self.countingCodegen = \
516            VulkanCountingCodegen(
517                self.codegen,
518                self.featureBitsVar,
519                self.countVars[0],
520                self.countVars[1],
521                self.rootTypeVar,
522                self.countingPrefix)
523
524        self.knownDefs = {}
525
526        self.extensionCountingPrototype = \
527            VulkanAPI(self.countingPrefix + "extension_struct",
528                      self.voidType,
529                      [self.featureBitsVarType,
530                       ROOT_TYPE_PARAM,
531                       STRUCT_EXTENSION_PARAM,
532                       self.countVarType])
533
534    def onBegin(self,):
535        VulkanWrapperGenerator.onBegin(self)
536        self.module.appendImpl(self.codegen.makeFuncDecl(
537            self.extensionCountingPrototype))
538
539    def onGenType(self, typeXml, name, alias):
540        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
541
542        if name in self.knownDefs:
543            return
544
545        category = self.typeInfo.categoryOf(name)
546
547        if category in ["struct", "union"] and alias:
548            # TODO(liyl): might not work if freeParams != []
549            self.module.appendHeader(
550                self.codegen.makeFuncAlias(self.countingPrefix + name,
551                                           self.countingPrefix + alias))
552
553        if category in ["struct", "union"] and not alias:
554
555            structInfo = self.typeInfo.structs[name]
556
557            freeParams = []
558            letParams = []
559
560            for (envname, bindingInfo) in list(sorted(structInfo.environment.items(), key = lambda kv: kv[0])):
561                if None == bindingInfo["binding"]:
562                    freeParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
563                else:
564                    if not bindingInfo["structmember"]:
565                        letParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
566
567            typeFromName = \
568                lambda varname: \
569                    makeVulkanTypeSimple(True, name, 1, varname)
570
571            countingParams = \
572                [makeVulkanTypeSimple(False, "uint32_t", 0, self.featureBitsVar),
573                 ROOT_TYPE_PARAM,
574                 typeFromName(self.countVars[0]),
575                 makeVulkanTypeSimple(False, "size_t", 1, self.countVars[1])]
576
577            countingPrototype = \
578                VulkanAPI(self.countingPrefix + name,
579                          self.voidType,
580                          countingParams + freeParams)
581
582            countingPrototypeNoFilter = \
583                VulkanAPI(self.countingPrefix + name,
584                          self.voidType,
585                          countingParams)
586
587            def structCountingDef(cgen):
588                self.countingCodegen.cgen = cgen
589                self.countingCodegen.currentStructInfo = structInfo
590                cgen.stmt("(void)%s" % self.featureBitsVar);
591                cgen.stmt("(void)%s" % self.rootTypeVar);
592                cgen.stmt("(void)%s" % self.countVars[0]);
593                cgen.stmt("(void)%s" % self.countVars[1]);
594
595                if category == "struct":
596                    # marshal 'let' parameters first
597                    for letp in letParams:
598                        self.countingCodegen.streamLetParameter(self.typeInfo, letp)
599
600                    for member in structInfo.members:
601                        iterateVulkanType(self.typeInfo, member, self.countingCodegen)
602                if category == "union":
603                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.countingCodegen)
604
605            def structCountingDefNoFilter(cgen):
606                self.countingCodegen.cgen = cgen
607                self.countingCodegen.currentStructInfo = structInfo
608                self.countingCodegen.doFiltering = False
609                cgen.stmt("(void)%s" % self.featureBitsVar);
610                cgen.stmt("(void)%s" % self.rootTypeVar);
611                cgen.stmt("(void)%s" % self.countVars[0]);
612                cgen.stmt("(void)%s" % self.countVars[1]);
613
614                if category == "struct":
615                    # marshal 'let' parameters first
616                    for letp in letParams:
617                        self.countingCodegen.streamLetParameter(self.typeInfo, letp)
618
619                    for member in structInfo.members:
620                        iterateVulkanType(self.typeInfo, member, self.countingCodegen)
621                if category == "union":
622                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.countingCodegen)
623
624                self.countingCodegen.doFiltering = True
625
626            self.module.appendHeader(
627                self.codegen.makeFuncDecl(countingPrototype))
628            self.module.appendImpl(
629                self.codegen.makeFuncImpl(countingPrototype, structCountingDef))
630
631            if freeParams != []:
632                self.module.appendHeader(
633                    self.cgenHeader.makeFuncDecl(countingPrototypeNoFilter))
634                self.module.appendImpl(
635                    self.cgenImpl.makeFuncImpl(
636                        countingPrototypeNoFilter, structCountingDefNoFilter))
637
638    def onGenCmd(self, cmdinfo, name, alias):
639        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
640
641    def doExtensionStructCountCodegen(self, cgen, extParam, forEach, funcproto):
642        accessVar = "structAccess"
643        sizeVar = "currExtSize"
644        cgen.stmt("VkInstanceCreateInfo* %s = (VkInstanceCreateInfo*)(%s)" % (accessVar, extParam.paramName))
645        cgen.stmt("size_t %s = %s(%s, %s, %s)" % (sizeVar, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME,
646                                                  self.featureBitsVar, ROOT_TYPE_VAR_NAME, extParam.paramName))
647
648        cgen.beginIf("!%s && %s" % (sizeVar, extParam.paramName))
649
650        cgen.line("// unknown struct extension; skip and call on its pNext field");
651        cgen.funcCall(None, funcproto.name, [
652                      self.featureBitsVar, ROOT_TYPE_VAR_NAME, "(void*)%s->pNext" % accessVar, self.countVars[1]])
653        cgen.stmt("return")
654
655        cgen.endIf()
656        cgen.beginElse()
657
658        cgen.line("// known or null extension struct")
659
660        cgen.stmt("*%s += sizeof(uint32_t)" % self.countVars[1])
661
662        cgen.beginIf("!%s" % (sizeVar))
663        cgen.line("// exit if this was a null extension struct (size == 0 in this branch)")
664        cgen.stmt("return")
665        cgen.endIf()
666
667        cgen.endIf()
668
669        cgen.stmt("*%s += sizeof(VkStructureType)" % self.countVars[1])
670
671        def fatalDefault(cgen):
672            cgen.line("// fatal; the switch is only taken if the extension struct is known");
673            cgen.stmt("abort()")
674            pass
675
676        self.emitForEachStructExtension(
677            cgen,
678            makeVulkanTypeSimple(False, "void", 0, "void"),
679            extParam,
680            forEach,
681            defaultEmit=fatalDefault,
682            rootTypeVar=ROOT_TYPE_PARAM)
683
684    def onEnd(self,):
685        VulkanWrapperGenerator.onEnd(self)
686
687        def forEachExtensionCounting(ext, castedAccess, cgen):
688            cgen.funcCall(None, self.countingPrefix + ext.name,
689                          [self.featureBitsVar, self.rootTypeVar, castedAccess, self.countVars[1]])
690
691        self.module.appendImpl(
692            self.codegen.makeFuncImpl(
693                self.extensionCountingPrototype,
694                lambda cgen: self.doExtensionStructCountCodegen(
695                    cgen,
696                    STRUCT_EXTENSION_PARAM,
697                    forEachExtensionCounting,
698                    self.extensionCountingPrototype)))
699