• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from copy import copy
17
18from .common.codegen import CodeGen
19from .common.vulkantypes import \
20        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator, Atom, FuncExpr, FuncExprVal, FuncLambda
21
22from .wrapperdefs import VulkanWrapperGenerator
23from .wrapperdefs import ROOT_TYPE_VAR_NAME, ROOT_TYPE_PARAM
24from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME
25
26class VulkanCountingCodegen(VulkanTypeIterator):
27    def __init__(self, cgen, featureBitsVar, toCountVar, countVar, rootTypeVar, prefix, forApiOutput=False, mapHandles=True, handleMapOverwrites=False, doFiltering=True):
28        self.cgen = cgen
29        self.featureBitsVar = featureBitsVar
30        self.toCountVar = toCountVar
31        self.rootTypeVar = rootTypeVar
32        self.countVar = countVar
33        self.prefix = prefix
34        self.forApiOutput = forApiOutput
35
36        self.exprAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = True)
37        self.exprValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = False)
38        self.exprPrimitiveValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = False)
39
40        self.lenAccessor = lambda t: self.cgen.generalLengthAccess(t, parentVarName = self.toCountVar)
41        self.lenAccessorGuard = lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName = self.toCountVar)
42        self.filterVarAccessor = lambda t: self.cgen.filterVarAccess(t, parentVarName = self.toCountVar)
43
44        self.checked = False
45
46        self.mapHandles = mapHandles
47        self.handleMapOverwrites = handleMapOverwrites
48        self.doFiltering = doFiltering
49
50    def getTypeForStreaming(self, vulkanType):
51        res = copy(vulkanType)
52
53        if not vulkanType.accessibleAsPointer():
54            res = res.getForAddressAccess()
55
56        if vulkanType.staticArrExpr:
57            res = res.getForAddressAccess()
58
59        return res
60
61    def makeCastExpr(self, vulkanType):
62        return "(%s)" % (
63            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
64
65    def genCount(self, sizeExpr):
66        self.cgen.stmt("*%s += %s" % (self.countVar, sizeExpr))
67
68    def genPrimitiveStreamCall(self, vulkanType):
69        self.genCount(str(self.cgen.countPrimitive(
70            self.typeInfo,
71            vulkanType)))
72
73    def genHandleMappingCall(self, vulkanType, access, lenAccess):
74
75        if lenAccess is None:
76            lenAccess = "1"
77            handle64Bytes = "8"
78        else:
79            handle64Bytes = "%s * 8" % lenAccess
80
81        handle64Var = self.cgen.var()
82        if lenAccess != "1":
83            self.cgen.beginIf(lenAccess)
84            # self.cgen.stmt("uint64_t* %s" % handle64Var)
85            # self.cgen.stmt(
86                # "%s->alloc((void**)&%s, %s * 8)" % \
87                # (self.streamVarName, handle64Var, lenAccess))
88            handle64VarAccess = handle64Var
89            handle64VarType = \
90                makeVulkanTypeSimple(False, "uint64_t", 1, paramName=handle64Var)
91        else:
92            self.cgen.stmt("uint64_t %s" % handle64Var)
93            handle64VarAccess = "&%s" % handle64Var
94            handle64VarType = \
95                makeVulkanTypeSimple(False, "uint64_t", 0, paramName=handle64Var)
96
97        if self.handleMapOverwrites:
98            # self.cgen.stmt(
99                # "static_assert(8 == sizeof(%s), \"handle map overwrite requres %s to be 8 bytes long\")" % \
100                        # (vulkanType.typeName, vulkanType.typeName))
101            # self.cgen.stmt(
102                # "%s->handleMapping()->mapHandles_%s((%s*)%s, %s)" %
103                # (self.streamVarName, vulkanType.typeName, vulkanType.typeName,
104                # access, lenAccess))
105            self.genCount("8 * %s" % lenAccess)
106        else:
107            # self.cgen.stmt(
108                # "%s->handleMapping()->mapHandles_%s_u64(%s, %s, %s)" %
109                # (self.streamVarName, vulkanType.typeName,
110                # access,
111                # handle64VarAccess, lenAccess))
112            self.genCount(handle64Bytes)
113
114        if lenAccess != "1":
115            self.cgen.endIf()
116
117    def doAllocSpace(self, vulkanType):
118        pass
119
120    def getOptionalStringFeatureExpr(self, vulkanType):
121        feature = vulkanType.getProtectStreamFeature()
122        if feature is None:
123            return None
124        return "%s & %s" % (self.featureBitsVar, feature)
125
126    def onCheck(self, vulkanType):
127
128        if self.forApiOutput:
129            return
130
131        featureExpr = self.getOptionalStringFeatureExpr(vulkanType);
132
133        self.checked = True
134
135        access = self.exprAccessor(vulkanType)
136
137        needConsistencyCheck = False
138
139        self.cgen.line("// WARNING PTR CHECK")
140        checkAccess = self.exprAccessor(vulkanType)
141        addrExpr = "&" + checkAccess
142        sizeExpr = self.cgen.sizeofExpr(vulkanType)
143
144        if featureExpr is not None:
145            self.cgen.beginIf(featureExpr)
146
147        self.genPrimitiveStreamCall(
148            vulkanType)
149
150        if featureExpr is not None:
151            self.cgen.endIf()
152
153        if featureExpr is not None:
154            self.cgen.beginIf("(!(%s) || %s)" % (featureExpr, access))
155        else:
156            self.cgen.beginIf(access)
157
158        if needConsistencyCheck and featureExpr is None:
159            self.cgen.beginIf("!(%s)" % checkName)
160            self.cgen.stmt(
161                "fprintf(stderr, \"fatal: %s inconsistent between guest and host\\n\")" % (access))
162            self.cgen.endIf()
163
164
165    def onCheckWithNullOptionalStringFeature(self, vulkanType):
166        self.cgen.beginIf("%s & VULKAN_STREAM_FEATURE_NULL_OPTIONAL_STRINGS_BIT" % self.featureBitsVar)
167        self.onCheck(vulkanType)
168
169    def endCheckWithNullOptionalStringFeature(self, vulkanType):
170        self.endCheck(vulkanType)
171        self.cgen.endIf()
172        self.cgen.beginElse()
173
174    def finalCheckWithNullOptionalStringFeature(self, vulkanType):
175        self.cgen.endElse()
176
177    def endCheck(self, vulkanType):
178
179        if self.checked:
180            self.cgen.endIf()
181            self.checked = False
182
183    def genFilterFunc(self, filterfunc, env):
184
185        def loop(expr, lambdaEnv={}):
186            def do_func(expr):
187                fnamestr = expr.name.name
188                if "not" == fnamestr:
189                    return "!(%s)" % (loop(expr.args[0], lambdaEnv))
190                if "eq" == fnamestr:
191                    return "(%s == %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
192                if "and" == fnamestr:
193                    return "(%s && %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
194                if "or" == fnamestr:
195                    return "(%s || %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
196                if "bitwise_and" == fnamestr:
197                    return "(%s & %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
198                if "getfield" == fnamestr:
199                    ptrlevels = get_ptrlevels(expr.args[0].val.name)
200                    if ptrlevels == 0:
201                        return "%s.%s" % (loop(expr.args[0], lambdaEnv), expr.args[1].val)
202                    else:
203                        return "(%s(%s)).%s" % ("*" * ptrlevels, loop(expr.args[0], lambdaEnv), expr.args[1].val)
204
205                if "if" == fnamestr:
206                    return "((%s) ? (%s) : (%s))" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv), loop(expr.args[2], lambdaEnv))
207
208                return "%s(%s)" % (fnamestr, ", ".join(map(lambda e: loop(e, lambdaEnv), expr.args)))
209
210            def do_expratom(atomname, lambdaEnv= {}):
211                if lambdaEnv.get(atomname, None) is not None:
212                    return atomname
213
214                enventry = env.get(atomname, None)
215                if None != enventry:
216                    return self.getEnvAccessExpr(atomname)
217                return atomname
218
219            def get_ptrlevels(atomname, lambdaEnv= {}):
220                if lambdaEnv.get(atomname, None) is not None:
221                    return 0
222
223                enventry = env.get(atomname, None)
224                if None != enventry:
225                    return self.getPointerIndirectionLevels(atomname)
226
227                return 0
228
229            def do_exprval(expr, lambdaEnv= {}):
230                expratom = expr.val
231
232                if Atom == type(expratom):
233                    return do_expratom(expratom.name, lambdaEnv)
234
235                return "%s" % expratom
236
237            def do_lambda(expr, lambdaEnv= {}):
238                params = expr.vs
239                body = expr.body
240                newEnv = {}
241
242                for (k, v) in lambdaEnv.items():
243                    newEnv[k] = v
244
245                for p in params:
246                    newEnv[p.name] = p.typ
247
248                return "[](%s) { return %s; }" % (", ".join(list(map(lambda p: "%s %s" % (p.typ, p.name), params))), loop(body, lambdaEnv=newEnv))
249
250            if FuncExpr == type(expr):
251                return do_func(expr)
252            if FuncLambda == type(expr):
253                return do_lambda(expr)
254            elif FuncExprVal == type(expr):
255                return do_exprval(expr)
256
257        return loop(filterfunc)
258
259    def beginFilterGuard(self, vulkanType):
260        if vulkanType.filterVar == None:
261            return
262
263        if self.doFiltering == False:
264            return
265
266        filterVarAccess = self.getEnvAccessExpr(vulkanType.filterVar)
267
268        filterValsExpr = None
269        filterFuncExpr = None
270        filterExpr = None
271
272        filterFeature = "%s & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.featureBitsVar
273
274        if None != vulkanType.filterVals:
275            filterValsExpr = " || ".join(map(lambda filterval: "(%s == %s)" % (filterval, filterVarAccess), vulkanType.filterVals))
276
277        if None != vulkanType.filterFunc:
278            filterFuncExpr = self.genFilterFunc(vulkanType.filterFunc, self.currentStructInfo.environment)
279
280        if None != filterValsExpr and None != filterFuncExpr:
281            filterExpr = "%s || %s" % (filterValsExpr, filterFuncExpr)
282        elif None == filterValsExpr and None == filterFuncExpr:
283            # Assume is bool
284            self.cgen.beginIf(filterVarAccess)
285        elif None != filterValsExpr:
286            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterValsExpr))
287        elif None != filterFuncExpr:
288            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterFuncExpr))
289
290    def endFilterGuard(self, vulkanType, cleanupExpr=None):
291        if vulkanType.filterVar == None:
292            return
293
294        if self.doFiltering == False:
295            return
296
297        if cleanupExpr == None:
298            self.cgen.endIf()
299        else:
300            self.cgen.endIf()
301            self.cgen.beginElse()
302            self.cgen.stmt(cleanupExpr)
303            self.cgen.endElse()
304
305    def getEnvAccessExpr(self, varName):
306        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
307
308        if parentEnvEntry != None:
309            isParentMember = parentEnvEntry["structmember"]
310
311            if isParentMember:
312                envAccess = self.exprValueAccessor(list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0])
313            else:
314                envAccess = varName
315            return envAccess
316
317        return None
318
319    def getPointerIndirectionLevels(self, varName):
320        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
321
322        if parentEnvEntry != None:
323            isParentMember = parentEnvEntry["structmember"]
324
325            if isParentMember:
326                return list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0].pointerIndirectionLevels
327            else:
328                return 0
329            return 0
330
331        return 0
332
333
334    def onCompoundType(self, vulkanType):
335
336        access = self.exprAccessor(vulkanType)
337        lenAccess = self.lenAccessor(vulkanType)
338        lenAccessGuard = self.lenAccessorGuard(vulkanType)
339
340        self.beginFilterGuard(vulkanType)
341
342        if vulkanType.pointerIndirectionLevels > 0:
343            self.doAllocSpace(vulkanType)
344
345        if lenAccess is not None:
346            if lenAccessGuard is not None:
347                self.cgen.beginIf(lenAccessGuard)
348            loopVar = "i"
349            access = "%s + %s" % (access, loopVar)
350            forInit = "uint32_t %s = 0" % loopVar
351            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess)
352            forIncr = "++%s" % loopVar
353            self.cgen.beginFor(forInit, forCond, forIncr)
354
355        accessWithCast = "%s(%s)" % (self.makeCastExpr(
356            self.getTypeForStreaming(vulkanType)), access)
357
358        callParams = [self.featureBitsVar,
359                      self.rootTypeVar, accessWithCast, self.countVar]
360
361        for (bindName, localName) in vulkanType.binds.items():
362            callParams.append(self.getEnvAccessExpr(localName))
363
364        self.cgen.funcCall(None, self.prefix + vulkanType.typeName,
365                           callParams)
366
367        if lenAccess is not None:
368            self.cgen.endFor()
369            if lenAccessGuard is not None:
370                self.cgen.endIf()
371
372        self.endFilterGuard(vulkanType)
373
374    def onString(self, vulkanType):
375        access = self.exprAccessor(vulkanType)
376        self.genCount("sizeof(uint32_t) + (%s ? strlen(%s) : 0)" % (access, access))
377
378    def onStringArray(self, vulkanType):
379        access = self.exprAccessor(vulkanType)
380        lenAccess = self.lenAccessor(vulkanType)
381        lenAccessGuard = self.lenAccessorGuard(vulkanType)
382
383        self.genCount("sizeof(uint32_t)")
384        if lenAccessGuard is not None:
385            self.cgen.beginIf(lenAccessGuard)
386        self.cgen.beginFor("uint32_t i = 0", "i < %s" % lenAccess, "++i")
387        self.cgen.stmt("size_t l = %s[i] ? strlen(%s[i]) : 0" % (access, access))
388        self.genCount("sizeof(uint32_t) + (%s[i] ? strlen(%s[i]) : 0)" % (access, access))
389        self.cgen.endFor()
390        if lenAccessGuard is not None:
391            self.cgen.endIf()
392
393    def onStaticArr(self, vulkanType):
394        access = self.exprValueAccessor(vulkanType)
395        lenAccess = self.lenAccessor(vulkanType)
396        lenAccessGuard = self.lenAccessorGuard(vulkanType)
397
398        if lenAccessGuard is not None:
399            self.cgen.beginIf(lenAccessGuard)
400        finalLenExpr = "%s * %s" % (lenAccess, self.cgen.sizeofExpr(vulkanType))
401        if lenAccessGuard is not None:
402            self.cgen.endIf()
403        self.genCount(finalLenExpr)
404
405    def onStructExtension(self, vulkanType):
406        sTypeParam = copy(vulkanType)
407        sTypeParam.paramName = "sType"
408
409        access = self.exprAccessor(vulkanType)
410        sizeVar = "%s_size" % vulkanType.paramName
411
412        castedAccessExpr = access
413
414        sTypeAccess = self.exprAccessor(sTypeParam)
415        self.cgen.beginIf("%s == VK_STRUCTURE_TYPE_MAX_ENUM" %
416                          self.rootTypeVar)
417        self.cgen.stmt("%s = %s" % (self.rootTypeVar, sTypeAccess))
418        self.cgen.endIf()
419
420        self.cgen.funcCall(None, self.prefix + "extension_struct",
421                           [self.featureBitsVar, self.rootTypeVar, castedAccessExpr, self.countVar])
422
423
424    def onPointer(self, vulkanType):
425        access = self.exprAccessor(vulkanType)
426
427        lenAccess = self.lenAccessor(vulkanType)
428        lenAccessGuard = self.lenAccessorGuard(vulkanType)
429
430        self.beginFilterGuard(vulkanType)
431        self.doAllocSpace(vulkanType)
432
433        if vulkanType.filterVar != None:
434            print("onPointer Needs filter: %s filterVar %s" % (access, vulkanType.filterVar))
435
436        if vulkanType.isHandleType() and self.mapHandles:
437            self.genHandleMappingCall(vulkanType, access, lenAccess)
438        else:
439            if self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
440                if lenAccess is not None:
441                    if lenAccessGuard is not None:
442                        self.cgen.beginIf(lenAccessGuard)
443                    self.cgen.beginFor("uint32_t i = 0", "i < (uint32_t)%s" % lenAccess, "++i")
444                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess())
445                    self.cgen.endFor()
446                    if lenAccessGuard is not None:
447                        self.cgen.endIf()
448                else:
449                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess())
450            else:
451                if lenAccess is not None:
452                    needLenAccessGuard = True
453                    finalLenExpr = "%s * %s" % (
454                        lenAccess, self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
455                else:
456                    needLenAccessGuard = False
457                    finalLenExpr = "%s" % (
458                        self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
459                if needLenAccessGuard and lenAccessGuard is not None:
460                    self.cgen.beginIf(lenAccessGuard)
461                self.genCount(finalLenExpr)
462                if needLenAccessGuard and lenAccessGuard is not None:
463                    self.cgen.endIf()
464
465        self.endFilterGuard(vulkanType)
466
467    def onValue(self, vulkanType):
468        self.beginFilterGuard(vulkanType)
469
470        if vulkanType.isHandleType() and self.mapHandles:
471            access = self.exprAccessor(vulkanType)
472            if vulkanType.filterVar != None:
473                print("onValue Needs filter: %s filterVar %s" % (access, vulkanType.filterVar))
474            self.genHandleMappingCall(
475                vulkanType.getForAddressAccess(), access, "1")
476        elif self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
477            access = self.exprPrimitiveValueAccessor(vulkanType)
478            self.genPrimitiveStreamCall(vulkanType)
479        else:
480            access = self.exprAccessor(vulkanType)
481            self.genCount(self.cgen.sizeofExpr(vulkanType))
482
483        self.endFilterGuard(vulkanType)
484
485    def streamLetParameter(self, structInfo, letParamInfo):
486        filterFeature = "%s & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % (self.featureBitsVar)
487        self.cgen.stmt("%s %s = 1" % (letParamInfo.typeName, letParamInfo.paramName))
488
489        self.cgen.beginIf(filterFeature)
490
491        bodyExpr = self.currentStructInfo.environment[letParamInfo.paramName]["body"]
492        self.cgen.stmt("%s = %s" % (letParamInfo.paramName, self.genFilterFunc(bodyExpr, self.currentStructInfo.environment)))
493
494        self.genPrimitiveStreamCall(letParamInfo)
495
496        self.cgen.endIf()
497
498class VulkanCounting(VulkanWrapperGenerator):
499
500    def __init__(self, module, typeInfo):
501        VulkanWrapperGenerator.__init__(self, module, typeInfo)
502
503        self.codegen = CodeGen()
504
505        self.featureBitsVar = "featureBits"
506        self.featureBitsVarType = makeVulkanTypeSimple(False, "uint32_t", 0, self.featureBitsVar)
507        self.countingPrefix = "count_"
508        self.countVars = ["toCount", "count"]
509        self.countVarType = makeVulkanTypeSimple(False, "size_t", 1, self.countVars[1])
510        self.voidType = makeVulkanTypeSimple(False, "void", 0)
511        self.rootTypeVar = ROOT_TYPE_VAR_NAME
512
513        self.countingCodegen = \
514            VulkanCountingCodegen(
515                self.codegen,
516                self.featureBitsVar,
517                self.countVars[0],
518                self.countVars[1],
519                self.rootTypeVar,
520                self.countingPrefix)
521
522        self.knownDefs = {}
523
524        self.extensionCountingPrototype = \
525            VulkanAPI(self.countingPrefix + "extension_struct",
526                      self.voidType,
527                      [self.featureBitsVarType,
528                       ROOT_TYPE_PARAM,
529                       STRUCT_EXTENSION_PARAM,
530                       self.countVarType])
531
532    def onBegin(self,):
533        VulkanWrapperGenerator.onBegin(self)
534        self.module.appendImpl(self.codegen.makeFuncDecl(
535            self.extensionCountingPrototype))
536
537    def onGenType(self, typeXml, name, alias):
538        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
539
540        if name in self.knownDefs:
541            return
542
543        category = self.typeInfo.categoryOf(name)
544
545        if category in ["struct", "union"] and alias:
546            # TODO(liyl): might not work if freeParams != []
547            self.module.appendHeader(
548                self.codegen.makeFuncAlias(self.countingPrefix + name,
549                                           self.countingPrefix + alias))
550
551        if category in ["struct", "union"] and not alias:
552
553            structInfo = self.typeInfo.structs[name]
554
555            freeParams = []
556            letParams = []
557
558            for (envname, bindingInfo) in list(sorted(structInfo.environment.items(), key = lambda kv: kv[0])):
559                if None == bindingInfo["binding"]:
560                    freeParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
561                else:
562                    if not bindingInfo["structmember"]:
563                        letParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
564
565            typeFromName = \
566                lambda varname: \
567                    makeVulkanTypeSimple(True, name, 1, varname)
568
569            countingParams = \
570                [makeVulkanTypeSimple(False, "uint32_t", 0, self.featureBitsVar),
571                 ROOT_TYPE_PARAM,
572                 typeFromName(self.countVars[0]),
573                 makeVulkanTypeSimple(False, "size_t", 1, self.countVars[1])]
574
575            countingPrototype = \
576                VulkanAPI(self.countingPrefix + name,
577                          self.voidType,
578                          countingParams + freeParams)
579
580            countingPrototypeNoFilter = \
581                VulkanAPI(self.countingPrefix + name,
582                          self.voidType,
583                          countingParams)
584
585            def structCountingDef(cgen):
586                self.countingCodegen.cgen = cgen
587                self.countingCodegen.currentStructInfo = structInfo
588                cgen.stmt("(void)%s" % self.featureBitsVar);
589                cgen.stmt("(void)%s" % self.rootTypeVar);
590                cgen.stmt("(void)%s" % self.countVars[0]);
591                cgen.stmt("(void)%s" % self.countVars[1]);
592
593                if category == "struct":
594                    # marshal 'let' parameters first
595                    for letp in letParams:
596                        self.countingCodegen.streamLetParameter(self.typeInfo, letp)
597
598                    for member in structInfo.members:
599                        iterateVulkanType(self.typeInfo, member, self.countingCodegen)
600                if category == "union":
601                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.countingCodegen)
602
603            def structCountingDefNoFilter(cgen):
604                self.countingCodegen.cgen = cgen
605                self.countingCodegen.currentStructInfo = structInfo
606                self.countingCodegen.doFiltering = False
607                cgen.stmt("(void)%s" % self.featureBitsVar);
608                cgen.stmt("(void)%s" % self.rootTypeVar);
609                cgen.stmt("(void)%s" % self.countVars[0]);
610                cgen.stmt("(void)%s" % self.countVars[1]);
611
612                if category == "struct":
613                    # marshal 'let' parameters first
614                    for letp in letParams:
615                        self.countingCodegen.streamLetParameter(self.typeInfo, letp)
616
617                    for member in structInfo.members:
618                        iterateVulkanType(self.typeInfo, member, self.countingCodegen)
619                if category == "union":
620                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.countingCodegen)
621
622                self.countingCodegen.doFiltering = True
623
624            self.module.appendHeader(
625                self.codegen.makeFuncDecl(countingPrototype))
626            self.module.appendImpl(
627                self.codegen.makeFuncImpl(countingPrototype, structCountingDef))
628
629            if freeParams != []:
630                self.module.appendHeader(
631                    self.cgenHeader.makeFuncDecl(countingPrototypeNoFilter))
632                self.module.appendImpl(
633                    self.cgenImpl.makeFuncImpl(
634                        countingPrototypeNoFilter, structCountingDefNoFilter))
635
636    def onGenCmd(self, cmdinfo, name, alias):
637        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
638
639    def doExtensionStructCountCodegen(self, cgen, extParam, forEach, funcproto):
640        accessVar = "structAccess"
641        sizeVar = "currExtSize"
642        cgen.stmt("VkInstanceCreateInfo* %s = (VkInstanceCreateInfo*)(%s)" % (accessVar, extParam.paramName))
643        cgen.stmt("size_t %s = %s(%s, %s, %s)" % (sizeVar, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME,
644                                                  self.featureBitsVar, ROOT_TYPE_VAR_NAME, extParam.paramName))
645
646        cgen.beginIf("!%s && %s" % (sizeVar, extParam.paramName))
647
648        cgen.line("// unknown struct extension; skip and call on its pNext field");
649        cgen.funcCall(None, funcproto.name, [
650                      self.featureBitsVar, ROOT_TYPE_VAR_NAME, "(void*)%s->pNext" % accessVar, self.countVars[1]])
651        cgen.stmt("return")
652
653        cgen.endIf()
654        cgen.beginElse()
655
656        cgen.line("// known or null extension struct")
657
658        cgen.stmt("*%s += sizeof(uint32_t)" % self.countVars[1])
659
660        cgen.beginIf("!%s" % (sizeVar))
661        cgen.line("// exit if this was a null extension struct (size == 0 in this branch)")
662        cgen.stmt("return")
663        cgen.endIf()
664
665        cgen.endIf()
666
667        cgen.stmt("*%s += sizeof(VkStructureType)" % self.countVars[1])
668
669        def fatalDefault(cgen):
670            cgen.line("// fatal; the switch is only taken if the extension struct is known");
671            cgen.stmt("abort()")
672            pass
673
674        self.emitForEachStructExtension(
675            cgen,
676            makeVulkanTypeSimple(False, "void", 0, "void"),
677            extParam,
678            forEach,
679            defaultEmit=fatalDefault,
680            rootTypeVar=ROOT_TYPE_PARAM)
681
682    def onEnd(self,):
683        VulkanWrapperGenerator.onEnd(self)
684
685        def forEachExtensionCounting(ext, castedAccess, cgen):
686            cgen.funcCall(None, self.countingPrefix + ext.name,
687                          [self.featureBitsVar, self.rootTypeVar, castedAccess, self.countVars[1]])
688
689        self.module.appendImpl(
690            self.codegen.makeFuncImpl(
691                self.extensionCountingPrototype,
692                lambda cgen: self.doExtensionStructCountCodegen(
693                    cgen,
694                    STRUCT_EXTENSION_PARAM,
695                    forEachExtensionCounting,
696                    self.extensionCountingPrototype)))
697