• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 Google LLC
2# SPDX-License-Identifier: MIT
3
4from copy import copy
5import hashlib, sys
6
7from .common.codegen import CodeGen, VulkanAPIWrapper
8from .common.vulkantypes import \
9        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator, Atom, FuncExpr, FuncExprVal, FuncLambda
10
11from .wrapperdefs import VulkanWrapperGenerator
12from .wrapperdefs import VULKAN_STREAM_VAR_NAME
13from .wrapperdefs import ROOT_TYPE_VAR_NAME, ROOT_TYPE_PARAM
14from .wrapperdefs import STREAM_RET_TYPE
15from .wrapperdefs import MARSHAL_INPUT_VAR_NAME
16from .wrapperdefs import UNMARSHAL_INPUT_VAR_NAME
17from .wrapperdefs import PARAMETERS_MARSHALING
18from .wrapperdefs import PARAMETERS_MARSHALING_GUEST
19from .wrapperdefs import STYPE_OVERRIDE
20from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME
21from .wrapperdefs import API_PREFIX_MARSHAL
22from .wrapperdefs import API_PREFIX_UNMARSHAL
23
24from .marshalingdefs import KNOWN_FUNCTION_OPCODES, CUSTOM_MARSHAL_TYPES
25
26class VulkanMarshalingCodegen(VulkanTypeIterator):
27
28    def __init__(self,
29                 cgen,
30                 streamVarName,
31                 rootTypeVarName,
32                 inputVarName,
33                 marshalPrefix,
34                 direction = "write",
35                 forApiOutput = False,
36                 dynAlloc = False,
37                 mapHandles = True,
38                 handleMapOverwrites = False,
39                 doFiltering = True):
40        self.cgen = cgen
41        self.direction = direction
42        self.processSimple = "write" if self.direction == "write" else "read"
43        self.forApiOutput = forApiOutput
44
45        self.checked = False
46
47        self.streamVarName = streamVarName
48        self.rootTypeVarName = rootTypeVarName
49        self.inputVarName = inputVarName
50        self.marshalPrefix = marshalPrefix
51
52        self.exprAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = True)
53        self.exprValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = False)
54        self.exprPrimitiveValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = False)
55        self.lenAccessor = lambda t: self.cgen.generalLengthAccess(t, parentVarName = self.inputVarName)
56        self.lenAccessorGuard = lambda t: self.cgen.generalLengthAccessGuard(
57            t, parentVarName=self.inputVarName)
58        self.filterVarAccessor = lambda t: self.cgen.filterVarAccess(t, parentVarName = self.inputVarName)
59
60        self.dynAlloc = dynAlloc
61        self.mapHandles = mapHandles
62        self.handleMapOverwrites = handleMapOverwrites
63        self.doFiltering = doFiltering
64
65    def getTypeForStreaming(self, vulkanType):
66        res = copy(vulkanType)
67
68        if not vulkanType.accessibleAsPointer():
69            res = res.getForAddressAccess()
70
71        if vulkanType.staticArrExpr:
72            res = res.getForAddressAccess()
73
74        if self.direction == "write":
75            return res
76        else:
77            return res.getForNonConstAccess()
78
79    def makeCastExpr(self, vulkanType):
80        return "(%s)" % (
81            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
82
83    def genStreamCall(self, vulkanType, toStreamExpr, sizeExpr):
84        varname = self.streamVarName
85        func = self.processSimple
86        cast = self.makeCastExpr(self.getTypeForStreaming(vulkanType))
87
88        self.cgen.stmt(
89            "%s->%s(%s%s, %s)" % (varname, func, cast, toStreamExpr, sizeExpr))
90
91    def genPrimitiveStreamCall(self, vulkanType, access):
92        varname = self.streamVarName
93
94        self.cgen.streamPrimitive(
95            self.typeInfo,
96            varname,
97            access,
98            vulkanType,
99            direction=self.direction)
100
101    def genHandleMappingCall(self, vulkanType, access, lenAccess):
102
103        if lenAccess is None:
104            lenAccess = "1"
105            handle64Bytes = "8"
106        else:
107            handle64Bytes = "%s * 8" % lenAccess
108
109        handle64Var = self.cgen.var()
110        if lenAccess != "1":
111            self.cgen.beginIf(lenAccess)
112            self.cgen.stmt("uint64_t* %s" % handle64Var)
113            self.cgen.stmt(
114                "%s->alloc((void**)&%s, %s * 8)" % \
115                (self.streamVarName, handle64Var, lenAccess))
116            handle64VarAccess = handle64Var
117            handle64VarType = \
118                makeVulkanTypeSimple(False, "uint64_t", 1, paramName=handle64Var)
119        else:
120            self.cgen.stmt("uint64_t %s" % handle64Var)
121            handle64VarAccess = "&%s" % handle64Var
122            handle64VarType = \
123                makeVulkanTypeSimple(False, "uint64_t", 0, paramName=handle64Var)
124
125        if self.direction == "write":
126            if self.handleMapOverwrites:
127                self.cgen.stmt(
128                    "static_assert(8 == sizeof(%s), \"handle map overwrite requires %s to be 8 bytes long\")" % \
129                            (vulkanType.typeName, vulkanType.typeName))
130                self.cgen.stmt(
131                    "%s->handleMapping()->mapHandles_%s((%s*)%s, %s)" %
132                    (self.streamVarName, vulkanType.typeName, vulkanType.typeName,
133                    access, lenAccess))
134                self.genStreamCall(vulkanType, access, "8 * %s" % lenAccess)
135            else:
136                self.cgen.stmt(
137                    "%s->handleMapping()->mapHandles_%s_u64(%s, %s, %s)" %
138                    (self.streamVarName, vulkanType.typeName,
139                    access,
140                    handle64VarAccess, lenAccess))
141                self.genStreamCall(handle64VarType, handle64VarAccess, handle64Bytes)
142        else:
143            self.genStreamCall(handle64VarType, handle64VarAccess, handle64Bytes)
144            self.cgen.stmt(
145                "%s->handleMapping()->mapHandles_u64_%s(%s, %s%s, %s)" %
146                (self.streamVarName, vulkanType.typeName,
147                handle64VarAccess,
148                self.makeCastExpr(vulkanType.getForNonConstAccess()), access,
149                lenAccess))
150
151        if lenAccess != "1":
152            self.cgen.endIf()
153
154    def doAllocSpace(self, vulkanType):
155        if self.dynAlloc and self.direction == "read":
156            access = self.exprAccessor(vulkanType)
157            lenAccess = self.lenAccessor(vulkanType)
158            sizeof = self.cgen.sizeofExpr( \
159                         vulkanType.getForValueAccess())
160            if lenAccess:
161                bytesExpr = "%s * %s" % (lenAccess, sizeof)
162            else:
163                bytesExpr = sizeof
164
165            self.cgen.stmt( \
166                "%s->alloc((void**)&%s, %s)" %
167                    (self.streamVarName,
168                     access, bytesExpr))
169
170    def getOptionalStringFeatureExpr(self, vulkanType):
171        streamFeature = vulkanType.getProtectStreamFeature()
172        if streamFeature is None:
173            return None
174        return "%s->getFeatureBits() & %s" % (self.streamVarName, streamFeature)
175
176    def onCheck(self, vulkanType):
177
178        if self.forApiOutput:
179            return
180
181        featureExpr = self.getOptionalStringFeatureExpr(vulkanType);
182
183        self.checked = True
184
185        access = self.exprAccessor(vulkanType)
186
187        needConsistencyCheck = False
188
189        self.cgen.line("// WARNING PTR CHECK")
190        if (self.dynAlloc and self.direction == "read") or self.direction == "write":
191            checkAccess = self.exprAccessor(vulkanType)
192            addrExpr = "&" + checkAccess
193            sizeExpr = self.cgen.sizeofExpr(vulkanType)
194        else:
195            checkName = "check_%s" % vulkanType.paramName
196            self.cgen.stmt("%s %s" % (
197                self.cgen.makeCTypeDecl(vulkanType, useParamName = False), checkName))
198            self.cgen.stmt("(void)%s" % checkName)
199            checkAccess = checkName
200            addrExpr = "&" + checkAccess
201            sizeExpr = self.cgen.sizeofExpr(vulkanType)
202            needConsistencyCheck = True
203
204        if featureExpr is not None:
205            self.cgen.beginIf(featureExpr)
206
207        self.genPrimitiveStreamCall(
208            vulkanType,
209            checkAccess)
210
211        if featureExpr is not None:
212            self.cgen.endIf()
213
214        if featureExpr is not None:
215            self.cgen.beginIf("(!(%s) || %s)" % (featureExpr, access))
216        else:
217            self.cgen.beginIf(access)
218
219        if needConsistencyCheck and featureExpr is None:
220            self.cgen.beginIf("!(%s)" % checkName)
221            self.cgen.stmt(
222                "fprintf(stderr, \"fatal: %s inconsistent between guest and host\\n\")" % (access))
223            self.cgen.endIf()
224
225
226    def onCheckWithNullOptionalStringFeature(self, vulkanType):
227        self.cgen.beginIf("%s->getFeatureBits() & VULKAN_STREAM_FEATURE_NULL_OPTIONAL_STRINGS_BIT" % self.streamVarName)
228        self.onCheck(vulkanType)
229
230    def endCheckWithNullOptionalStringFeature(self, vulkanType):
231        self.endCheck(vulkanType)
232        self.cgen.endIf()
233        self.cgen.beginElse()
234
235    def finalCheckWithNullOptionalStringFeature(self, vulkanType):
236        self.cgen.endElse()
237
238    def endCheck(self, vulkanType):
239
240        if self.checked:
241            self.cgen.endIf()
242            self.checked = False
243
244    def genFilterFunc(self, filterfunc, env):
245
246        def loop(expr, lambdaEnv={}):
247            def do_func(expr):
248                fnamestr = expr.name.name
249                if "not" == fnamestr:
250                    return "!(%s)" % (loop(expr.args[0], lambdaEnv))
251                if "eq" == fnamestr:
252                    return "(%s == %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
253                if "and" == fnamestr:
254                    return "(%s && %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
255                if "or" == fnamestr:
256                    return "(%s || %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
257                if "bitwise_and" == fnamestr:
258                    return "(%s & %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
259                if "getfield" == fnamestr:
260                    ptrlevels = get_ptrlevels(expr.args[0].val.name)
261                    if ptrlevels == 0:
262                        return "%s.%s" % (loop(expr.args[0], lambdaEnv), expr.args[1].val)
263                    else:
264                        return "(%s(%s)).%s" % ("*" * ptrlevels, loop(expr.args[0], lambdaEnv), expr.args[1].val)
265
266                if "if" == fnamestr:
267                    return "((%s) ? (%s) : (%s))" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv), loop(expr.args[2], lambdaEnv))
268
269                return "%s(%s)" % (fnamestr, ", ".join(map(lambda e: loop(e, lambdaEnv), expr.args)))
270
271            def do_expratom(atomname, lambdaEnv= {}):
272                if lambdaEnv.get(atomname, None) is not None:
273                    return atomname
274
275                enventry = env.get(atomname, None)
276                if None != enventry:
277                    return self.getEnvAccessExpr(atomname)
278                return atomname
279
280            def get_ptrlevels(atomname, lambdaEnv= {}):
281                if lambdaEnv.get(atomname, None) is not None:
282                    return 0
283
284                enventry = env.get(atomname, None)
285                if None != enventry:
286                    return self.getPointerIndirectionLevels(atomname)
287
288                return 0
289
290            def do_exprval(expr, lambdaEnv= {}):
291                expratom = expr.val
292
293                if Atom == type(expratom):
294                    return do_expratom(expratom.name, lambdaEnv)
295
296                return "%s" % expratom
297
298            def do_lambda(expr, lambdaEnv= {}):
299                params = expr.vs
300                body = expr.body
301                newEnv = {}
302
303                for (k, v) in lambdaEnv.items():
304                    newEnv[k] = v
305
306                for p in params:
307                    newEnv[p.name] = p.typ
308
309                return "[](%s) { return %s; }" % (", ".join(list(map(lambda p: "%s %s" % (p.typ, p.name), params))), loop(body, lambdaEnv=newEnv))
310
311            if FuncExpr == type(expr):
312                return do_func(expr)
313            if FuncLambda == type(expr):
314                return do_lambda(expr)
315            elif FuncExprVal == type(expr):
316                return do_exprval(expr)
317
318        return loop(filterfunc)
319
320    def beginFilterGuard(self, vulkanType):
321        if vulkanType.filterVar == None:
322            return
323
324        if self.doFiltering == False:
325            return
326
327        filterVarAccess = self.getEnvAccessExpr(vulkanType.filterVar)
328
329        filterValsExpr = None
330        filterFuncExpr = None
331        filterExpr = None
332
333        filterFeature = "%s->getFeatureBits() & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.streamVarName
334
335        if None != vulkanType.filterVals:
336            filterValsExpr = " || ".join(map(lambda filterval: "(%s == %s)" % (filterval, filterVarAccess), vulkanType.filterVals))
337
338        if None != vulkanType.filterFunc:
339            filterFuncExpr = self.genFilterFunc(vulkanType.filterFunc, self.currentStructInfo.environment)
340
341        if None != filterValsExpr and None != filterFuncExpr:
342            filterExpr = "%s || %s" % (filterValsExpr, filterFuncExpr)
343        elif None == filterValsExpr and None == filterFuncExpr:
344            # Assume is bool
345            self.cgen.beginIf(filterVarAccess)
346        elif None != filterValsExpr:
347            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterValsExpr))
348        elif None != filterFuncExpr:
349            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterFuncExpr))
350
351    def endFilterGuard(self, vulkanType, cleanupExpr=None):
352        if vulkanType.filterVar == None:
353            return
354
355        if self.doFiltering == False:
356            return
357
358        if cleanupExpr == None:
359            self.cgen.endIf()
360        else:
361            self.cgen.endIf()
362            self.cgen.beginElse()
363            self.cgen.stmt(cleanupExpr)
364            self.cgen.endElse()
365
366    def getEnvAccessExpr(self, varName):
367        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
368
369        if parentEnvEntry != None:
370            isParentMember = parentEnvEntry["structmember"]
371
372            if isParentMember:
373                envAccess = self.exprValueAccessor(list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0])
374            else:
375                envAccess = varName
376            return envAccess
377
378        return None
379
380    def getPointerIndirectionLevels(self, varName):
381        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
382
383        if parentEnvEntry != None:
384            isParentMember = parentEnvEntry["structmember"]
385
386            if isParentMember:
387                return list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0].pointerIndirectionLevels
388            else:
389                return 0
390            return 0
391
392        return 0
393
394
395    def onCompoundType(self, vulkanType):
396
397        access = self.exprAccessor(vulkanType)
398        lenAccess = self.lenAccessor(vulkanType)
399        lenAccessGuard = self.lenAccessorGuard(vulkanType)
400
401        self.beginFilterGuard(vulkanType)
402
403        if vulkanType.pointerIndirectionLevels > 0:
404            self.doAllocSpace(vulkanType)
405
406        if lenAccess is not None:
407            if lenAccessGuard is not None:
408                self.cgen.beginIf(lenAccessGuard)
409            loopVar = "i"
410            access = "%s + %s" % (access, loopVar)
411            forInit = "uint32_t %s = 0" % loopVar
412            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess)
413            forIncr = "++%s" % loopVar
414            self.cgen.beginFor(forInit, forCond, forIncr)
415
416        accessWithCast = "%s(%s)" % (self.makeCastExpr(
417            self.getTypeForStreaming(vulkanType)), access)
418
419        callParams = [self.streamVarName, self.rootTypeVarName, accessWithCast]
420
421        for (bindName, localName) in vulkanType.binds.items():
422            callParams.append(self.getEnvAccessExpr(localName))
423
424        self.cgen.funcCall(None, self.marshalPrefix + vulkanType.typeName,
425                           callParams)
426
427        if lenAccess is not None:
428            self.cgen.endFor()
429            if lenAccessGuard is not None:
430                self.cgen.endIf()
431
432        if self.direction == "read":
433            self.endFilterGuard(vulkanType, "%s = 0" % self.exprAccessor(vulkanType))
434        else:
435            self.endFilterGuard(vulkanType)
436
437    def onString(self, vulkanType):
438
439        access = self.exprAccessor(vulkanType)
440
441        if self.direction == "write":
442            self.cgen.stmt("%s->putString(%s)" % (self.streamVarName, access))
443        else:
444            castExpr = \
445                self.makeCastExpr( \
446                    self.getTypeForStreaming( \
447                        vulkanType.getForAddressAccess()))
448
449            self.cgen.stmt( \
450                "%s->loadStringInPlace(%s&%s)" % (self.streamVarName, castExpr, access))
451
452    def onStringArray(self, vulkanType):
453
454        access = self.exprAccessor(vulkanType)
455        lenAccess = self.lenAccessor(vulkanType)
456
457        if self.direction == "write":
458            self.cgen.stmt("saveStringArray(%s, %s, %s)" % (self.streamVarName,
459                                                            access, lenAccess))
460        else:
461            castExpr = \
462                self.makeCastExpr( \
463                    self.getTypeForStreaming( \
464                        vulkanType.getForAddressAccess()))
465
466            self.cgen.stmt("%s->loadStringArrayInPlace(%s&%s)" % (self.streamVarName, castExpr, access))
467
468    def onStaticArr(self, vulkanType):
469        access = self.exprValueAccessor(vulkanType)
470        lenAccess = self.lenAccessor(vulkanType)
471        finalLenExpr = "%s * %s" % (lenAccess, self.cgen.sizeofExpr(vulkanType))
472        self.genStreamCall(vulkanType, access, finalLenExpr)
473
474    # Old version VkEncoder may have some sType values conflict with VkDecoder
475    # of new versions. For host decoder, it should not carry the incorrect old
476    # sType values to the |forUnmarshaling| struct. Instead it should overwrite
477    # the sType value.
478    def overwriteSType(self, vulkanType):
479        if self.direction == "read":
480            sTypeParam = copy(vulkanType)
481            sTypeParam.paramName = "sType"
482            sTypeAccess = self.exprAccessor(sTypeParam)
483
484            typeName = vulkanType.parent.typeName
485            if typeName in STYPE_OVERRIDE:
486                self.cgen.stmt("%s = %s" %
487                               (sTypeAccess, STYPE_OVERRIDE[typeName]))
488
489    def onStructExtension(self, vulkanType):
490        self.overwriteSType(vulkanType)
491
492        sTypeParam = copy(vulkanType)
493        sTypeParam.paramName = "sType"
494
495        access = self.exprAccessor(vulkanType)
496        sizeVar = "%s_size" % vulkanType.paramName
497
498        if self.direction == "read":
499            castedAccessExpr = "(%s)(%s)" % ("void*", access)
500        else:
501            castedAccessExpr = access
502
503        sTypeAccess = self.exprAccessor(sTypeParam)
504        self.cgen.beginIf("%s == VK_STRUCTURE_TYPE_MAX_ENUM" %
505                          self.rootTypeVarName)
506        self.cgen.stmt("%s = %s" % (self.rootTypeVarName, sTypeAccess))
507        self.cgen.endIf()
508
509        if self.direction == "read" and self.dynAlloc:
510            self.cgen.stmt("size_t %s" % sizeVar)
511            self.cgen.stmt("%s = %s->getBe32()" % \
512                (sizeVar, self.streamVarName))
513            self.cgen.stmt("%s = nullptr" % access)
514            self.cgen.beginIf(sizeVar)
515            self.cgen.stmt( \
516                    "%s->alloc((void**)&%s, sizeof(VkStructureType))" %
517                    (self.streamVarName, access))
518
519            self.genStreamCall(vulkanType, access, "sizeof(VkStructureType)")
520            self.cgen.stmt("VkStructureType extType = *(VkStructureType*)(%s)" % access)
521            self.cgen.stmt( \
522                "%s->alloc((void**)&%s, %s(%s->getFeatureBits(), %s, %s))" %
523                (self.streamVarName, access, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME, self.streamVarName, self.rootTypeVarName, access))
524            self.cgen.stmt("*(VkStructureType*)%s = extType" % access)
525
526            self.cgen.funcCall(None, self.marshalPrefix + "extension_struct",
527                               [self.streamVarName, self.rootTypeVarName, castedAccessExpr])
528            self.cgen.endIf()
529        else:
530
531            self.cgen.funcCall(None, self.marshalPrefix + "extension_struct",
532                               [self.streamVarName, self.rootTypeVarName, castedAccessExpr])
533
534
535    def onPointer(self, vulkanType):
536        access = self.exprAccessor(vulkanType)
537
538        lenAccess = self.lenAccessor(vulkanType)
539        lenAccessGuard = self.lenAccessorGuard(vulkanType)
540
541        self.beginFilterGuard(vulkanType)
542        self.doAllocSpace(vulkanType)
543
544        if vulkanType.isHandleType() and self.mapHandles:
545            self.genHandleMappingCall(vulkanType, access, lenAccess)
546        else:
547            if self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
548                if lenAccess is not None:
549                    if lenAccessGuard is not None:
550                        self.cgen.beginIf(lenAccessGuard)
551                    self.cgen.beginFor("uint32_t i = 0", "i < (uint32_t)%s" % lenAccess, "++i")
552                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess(), "%s[i]" % access)
553                    self.cgen.endFor()
554                    if lenAccessGuard is not None:
555                        self.cgen.endIf()
556                else:
557                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess(), "(*%s)" % access)
558            else:
559                if lenAccess is not None:
560                    finalLenExpr = "%s * %s" % (
561                        lenAccess, self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
562                else:
563                    finalLenExpr = "%s" % (
564                        self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
565                self.genStreamCall(vulkanType, access, finalLenExpr)
566
567        if self.direction == "read":
568            self.endFilterGuard(vulkanType, "%s = 0" % access)
569        else:
570            self.endFilterGuard(vulkanType)
571
572    def onValue(self, vulkanType):
573        self.beginFilterGuard(vulkanType)
574
575        if vulkanType.isHandleType() and self.mapHandles:
576            access = self.exprAccessor(vulkanType)
577            self.genHandleMappingCall(
578                vulkanType.getForAddressAccess(), access, "1")
579        elif self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
580            access = self.exprPrimitiveValueAccessor(vulkanType)
581            self.genPrimitiveStreamCall(vulkanType, access)
582        else:
583            access = self.exprAccessor(vulkanType)
584            self.genStreamCall(vulkanType, access, self.cgen.sizeofExpr(vulkanType))
585
586        self.endFilterGuard(vulkanType)
587
588    def streamLetParameter(self, structInfo, letParamInfo):
589        filterFeature = "%s->getFeatureBits() & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.streamVarName
590        self.cgen.stmt("%s %s = 1" % (letParamInfo.typeName, letParamInfo.paramName))
591
592        self.cgen.beginIf(filterFeature)
593
594        if self.direction == "write":
595            bodyExpr = self.currentStructInfo.environment[letParamInfo.paramName]["body"]
596            self.cgen.stmt("%s = %s" % (letParamInfo.paramName, self.genFilterFunc(bodyExpr, self.currentStructInfo.environment)))
597
598        self.genPrimitiveStreamCall(letParamInfo, letParamInfo.paramName)
599
600        self.cgen.endIf()
601
602
603class VulkanMarshaling(VulkanWrapperGenerator):
604
605    def __init__(self, module, typeInfo, variant="host"):
606        VulkanWrapperGenerator.__init__(self, module, typeInfo)
607
608        self.cgenHeader = CodeGen()
609        self.cgenImpl = CodeGen()
610
611        self.variant = variant
612
613        self.currentFeature = None
614        self.apiOpcodes = {}
615        self.dynAlloc = self.variant != "guest"
616
617        if self.variant == "guest":
618            self.marshalingParams = PARAMETERS_MARSHALING_GUEST
619        else:
620            self.marshalingParams = PARAMETERS_MARSHALING
621
622        self.writeCodegen = \
623            VulkanMarshalingCodegen(
624                None,
625                VULKAN_STREAM_VAR_NAME,
626                ROOT_TYPE_VAR_NAME,
627                MARSHAL_INPUT_VAR_NAME,
628                API_PREFIX_MARSHAL,
629                direction = "write")
630
631        self.readCodegen = \
632            VulkanMarshalingCodegen(
633                None,
634                VULKAN_STREAM_VAR_NAME,
635                ROOT_TYPE_VAR_NAME,
636                UNMARSHAL_INPUT_VAR_NAME,
637                API_PREFIX_UNMARSHAL,
638                direction = "read",
639                dynAlloc=self.dynAlloc)
640
641        self.knownDefs = {}
642
643        # Begin Vulkan API opcodes from something high
644        # that is not going to interfere with renderControl
645        # opcodes
646        self.beginOpcodeOld = 20000
647        self.endOpcodeOld = 30000
648
649        self.beginOpcode = 200000000
650        self.endOpcode = 300000000
651        self.knownOpcodes = set()
652
653        self.extensionMarshalPrototype = \
654            VulkanAPI(API_PREFIX_MARSHAL + "extension_struct",
655                      STREAM_RET_TYPE,
656                      self.marshalingParams +
657                      [STRUCT_EXTENSION_PARAM])
658
659        self.extensionUnmarshalPrototype = \
660            VulkanAPI(API_PREFIX_UNMARSHAL + "extension_struct",
661                      STREAM_RET_TYPE,
662                      self.marshalingParams +
663                      [STRUCT_EXTENSION_PARAM_FOR_WRITE])
664
665    def onBegin(self,):
666        VulkanWrapperGenerator.onBegin(self)
667        self.module.appendImpl(self.cgenImpl.makeFuncDecl(self.extensionMarshalPrototype))
668        self.module.appendImpl(self.cgenImpl.makeFuncDecl(self.extensionUnmarshalPrototype))
669
670    def onBeginFeature(self, featureName, featureType):
671        VulkanWrapperGenerator.onBeginFeature(self, featureName, featureType)
672        self.currentFeature = featureName
673
674    def onGenType(self, typeXml, name, alias):
675        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
676
677        if name in self.knownDefs:
678            return
679
680        category = self.typeInfo.categoryOf(name)
681
682        if category in ["struct", "union"] and alias:
683            self.module.appendHeader(
684                self.cgenHeader.makeFuncAlias(API_PREFIX_MARSHAL + name,
685                                              API_PREFIX_MARSHAL + alias))
686            self.module.appendHeader(
687                self.cgenHeader.makeFuncAlias(API_PREFIX_UNMARSHAL + name,
688                                              API_PREFIX_UNMARSHAL + alias))
689
690        if category in ["struct", "union"] and not alias:
691
692            structInfo = self.typeInfo.structs[name]
693
694            marshalParams = self.marshalingParams + \
695                [makeVulkanTypeSimple(True, name, 1, MARSHAL_INPUT_VAR_NAME)]
696
697            freeParams = []
698            letParams = []
699
700            for (envname, bindingInfo) in list(sorted(structInfo.environment.items(), key = lambda kv: kv[0])):
701                if None == bindingInfo["binding"]:
702                    freeParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
703                else:
704                    if not bindingInfo["structmember"]:
705                        letParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
706
707            marshalPrototype = \
708                VulkanAPI(API_PREFIX_MARSHAL + name,
709                          STREAM_RET_TYPE,
710                          marshalParams + freeParams)
711
712            marshalPrototypeNoFilter = \
713                VulkanAPI(API_PREFIX_MARSHAL + name,
714                          STREAM_RET_TYPE,
715                          marshalParams)
716
717            def structMarshalingCustom(cgen):
718                self.writeCodegen.cgen = cgen
719                self.writeCodegen.currentStructInfo = structInfo
720                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
721
722                marshalingCode = \
723                    CUSTOM_MARSHAL_TYPES[name]["common"] + \
724                    CUSTOM_MARSHAL_TYPES[name]["marshaling"].format(
725                        streamVarName=self.writeCodegen.streamVarName,
726                        rootTypeVarName=self.writeCodegen.rootTypeVarName,
727                        inputVarName=self.writeCodegen.inputVarName,
728                        newInputVarName=self.writeCodegen.inputVarName + "_new")
729                for line in marshalingCode.split('\n'):
730                    cgen.line(line)
731
732            def structMarshalingDef(cgen):
733                self.writeCodegen.cgen = cgen
734                self.writeCodegen.currentStructInfo = structInfo
735                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
736
737                if category == "struct":
738                    # marshal 'let' parameters first
739                    for letp in letParams:
740                        self.writeCodegen.streamLetParameter(self.typeInfo, letp)
741
742                    for member in structInfo.members:
743                        iterateVulkanType(self.typeInfo, member, self.writeCodegen)
744                if category == "union":
745                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.writeCodegen)
746
747            def structMarshalingDefNoFilter(cgen):
748                self.writeCodegen.cgen = cgen
749                self.writeCodegen.currentStructInfo = structInfo
750                self.writeCodegen.doFiltering = False
751                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
752
753                if category == "struct":
754                    # marshal 'let' parameters first
755                    for letp in letParams:
756                        self.writeCodegen.streamLetParameter(self.typeInfo, letp)
757
758                    for member in structInfo.members:
759                        iterateVulkanType(self.typeInfo, member, self.writeCodegen)
760                if category == "union":
761                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.writeCodegen)
762                self.writeCodegen.doFiltering = True
763
764            self.module.appendHeader(
765                self.cgenHeader.makeFuncDecl(marshalPrototype))
766
767            if name in CUSTOM_MARSHAL_TYPES and CUSTOM_MARSHAL_TYPES[name].get("marshaling"):
768                self.module.appendImpl(
769                    self.cgenImpl.makeFuncImpl(
770                        marshalPrototype, structMarshalingCustom))
771            else:
772                self.module.appendImpl(
773                    self.cgenImpl.makeFuncImpl(
774                        marshalPrototype, structMarshalingDef))
775
776            if freeParams != []:
777                self.module.appendHeader(
778                    self.cgenHeader.makeFuncDecl(marshalPrototypeNoFilter))
779                self.module.appendImpl(
780                    self.cgenImpl.makeFuncImpl(
781                        marshalPrototypeNoFilter, structMarshalingDefNoFilter))
782
783            unmarshalPrototype = \
784                VulkanAPI(API_PREFIX_UNMARSHAL + name,
785                          STREAM_RET_TYPE,
786                          self.marshalingParams + [makeVulkanTypeSimple(False, name, 1, UNMARSHAL_INPUT_VAR_NAME)] + freeParams)
787
788            unmarshalPrototypeNoFilter = \
789                VulkanAPI(API_PREFIX_UNMARSHAL + name,
790                          STREAM_RET_TYPE,
791                          self.marshalingParams + [makeVulkanTypeSimple(False, name, 1, UNMARSHAL_INPUT_VAR_NAME)])
792
793            def structUnmarshalingCustom(cgen):
794                self.readCodegen.cgen = cgen
795                self.readCodegen.currentStructInfo = structInfo
796                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
797
798                unmarshalingCode = \
799                    CUSTOM_MARSHAL_TYPES[name]["common"] + \
800                    CUSTOM_MARSHAL_TYPES[name]["unmarshaling"].format(
801                        streamVarName=self.readCodegen.streamVarName,
802                        rootTypeVarName=self.readCodegen.rootTypeVarName,
803                        inputVarName=self.readCodegen.inputVarName,
804                        newInputVarName=self.readCodegen.inputVarName + "_new")
805                for line in unmarshalingCode.split('\n'):
806                    cgen.line(line)
807
808            def structUnmarshalingDef(cgen):
809                self.readCodegen.cgen = cgen
810                self.readCodegen.currentStructInfo = structInfo
811                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
812
813                if category == "struct":
814                    # unmarshal 'let' parameters first
815                    for letp in letParams:
816                        self.readCodegen.streamLetParameter(self.typeInfo, letp)
817
818                    for member in structInfo.members:
819                        iterateVulkanType(self.typeInfo, member, self.readCodegen)
820                if category == "union":
821                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.readCodegen)
822
823            def structUnmarshalingDefNoFilter(cgen):
824                self.readCodegen.cgen = cgen
825                self.readCodegen.currentStructInfo = structInfo
826                self.readCodegen.doFiltering = False
827                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
828
829                if category == "struct":
830                    # unmarshal 'let' parameters first
831                    for letp in letParams:
832                        iterateVulkanType(self.typeInfo, letp, self.readCodegen)
833                    for member in structInfo.members:
834                        iterateVulkanType(self.typeInfo, member, self.readCodegen)
835                if category == "union":
836                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.readCodegen)
837                self.readCodegen.doFiltering = True
838
839            self.module.appendHeader(
840                self.cgenHeader.makeFuncDecl(unmarshalPrototype))
841
842            if name in CUSTOM_MARSHAL_TYPES and CUSTOM_MARSHAL_TYPES[name].get("unmarshaling"):
843                self.module.appendImpl(
844                    self.cgenImpl.makeFuncImpl(
845                        unmarshalPrototype, structUnmarshalingCustom))
846            else:
847                self.module.appendImpl(
848                    self.cgenImpl.makeFuncImpl(
849                        unmarshalPrototype, structUnmarshalingDef))
850
851            if freeParams != []:
852                self.module.appendHeader(
853                    self.cgenHeader.makeFuncDecl(unmarshalPrototypeNoFilter))
854                self.module.appendImpl(
855                    self.cgenImpl.makeFuncImpl(
856                        unmarshalPrototypeNoFilter, structUnmarshalingDefNoFilter))
857
858    def onGenCmd(self, cmdinfo, name, alias):
859        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
860        if name in KNOWN_FUNCTION_OPCODES:
861            opcode = KNOWN_FUNCTION_OPCODES[name]
862        else:
863            hashCode = hashlib.sha256(name.encode()).hexdigest()[:8]
864            hashInt = int(hashCode, 16)
865            opcode = self.beginOpcode + hashInt % (self.endOpcode - self.beginOpcode)
866            hasHashCollision = False
867            while opcode in self.knownOpcodes:
868                hasHashCollision = True
869                opcode += 1
870            if hasHashCollision:
871                print("Hash collision occurred on function '{}'. "
872                      "Please add the following line to marshalingdefs.py:".format(name), file=sys.stderr)
873                print("----------------------", file=sys.stderr)
874                print("    \"{}\": {},".format(name, opcode), file=sys.stderr)
875                print("----------------------", file=sys.stderr)
876
877        self.module.appendHeader(
878            "#define OP_%s %d\n" % (name, opcode))
879        self.apiOpcodes[name] = (opcode, self.currentFeature)
880        self.knownOpcodes.add(opcode)
881
882    def doExtensionStructMarshalingCodegen(self, cgen, retType, extParam, forEach, funcproto, direction):
883        accessVar = "structAccess"
884        sizeVar = "currExtSize"
885        cgen.stmt("VkInstanceCreateInfo* %s = (VkInstanceCreateInfo*)(%s)" % (accessVar, extParam.paramName))
886        cgen.stmt("size_t %s = %s(%s->getFeatureBits(), %s, %s)" % (sizeVar,
887                                                                    EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME, VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, extParam.paramName))
888
889        cgen.beginIf("!%s && %s" % (sizeVar, extParam.paramName))
890
891        cgen.line("// unknown struct extension; skip and call on its pNext field");
892        cgen.funcCall(None, funcproto.name, [
893                      "vkStream", ROOT_TYPE_VAR_NAME, "(void*)%s->pNext" % accessVar])
894        cgen.stmt("return")
895
896        cgen.endIf()
897        cgen.beginElse()
898
899        cgen.line("// known or null extension struct")
900
901        if direction == "write":
902            cgen.stmt("vkStream->putBe32(%s)" % sizeVar)
903        elif not self.dynAlloc:
904            cgen.stmt("vkStream->getBe32()");
905
906        cgen.beginIf("!%s" % (sizeVar))
907        cgen.line("// exit if this was a null extension struct (size == 0 in this branch)")
908        cgen.stmt("return")
909        cgen.endIf()
910
911        cgen.endIf()
912
913        # Now we can do stream stuff
914        if direction == "write":
915            cgen.stmt("vkStream->write(%s, sizeof(VkStructureType))" % extParam.paramName)
916        elif not self.dynAlloc:
917            cgen.stmt("uint64_t pNext_placeholder")
918            placeholderAccess = "(&pNext_placeholder)"
919            cgen.stmt("vkStream->read((void*)(&pNext_placeholder), sizeof(VkStructureType))")
920            cgen.stmt("(void)pNext_placeholder")
921
922        def fatalDefault(cgen):
923            cgen.line("// fatal; the switch is only taken if the extension struct is known")
924            if self.variant != "guest":
925                cgen.stmt("fprintf(stderr, \" %s, Unhandled Vulkan structure type %s [%d], aborting.\\n\", __func__, string_VkStructureType(VkStructureType(structType)), structType)")
926            cgen.stmt("abort()")
927            pass
928
929        self.emitForEachStructExtension(
930            cgen,
931            retType,
932            extParam,
933            forEach,
934            defaultEmit=fatalDefault,
935            rootTypeVar=ROOT_TYPE_PARAM)
936
937    def onEnd(self,):
938        VulkanWrapperGenerator.onEnd(self)
939
940        def forEachExtensionMarshal(ext, castedAccess, cgen):
941            cgen.funcCall(None, API_PREFIX_MARSHAL + ext.name,
942                          [VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, castedAccess])
943
944        def forEachExtensionUnmarshal(ext, castedAccess, cgen):
945            cgen.funcCall(None, API_PREFIX_UNMARSHAL + ext.name,
946                          [VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, castedAccess])
947
948        self.module.appendImpl(
949            self.cgenImpl.makeFuncImpl(
950                self.extensionMarshalPrototype,
951                lambda cgen: self.doExtensionStructMarshalingCodegen(
952                    cgen,
953                    STREAM_RET_TYPE,
954                    STRUCT_EXTENSION_PARAM,
955                    forEachExtensionMarshal,
956                    self.extensionMarshalPrototype,
957                    "write")))
958
959        self.module.appendImpl(
960            self.cgenImpl.makeFuncImpl(
961                self.extensionUnmarshalPrototype,
962                lambda cgen: self.doExtensionStructMarshalingCodegen(
963                    cgen,
964                    STREAM_RET_TYPE,
965                    STRUCT_EXTENSION_PARAM_FOR_WRITE,
966                    forEachExtensionUnmarshal,
967                    self.extensionUnmarshalPrototype,
968                    "read")))
969
970        opcode2stringPrototype = \
971            VulkanAPI("api_opcode_to_string",
972                          makeVulkanTypeSimple(True, "char", 1, "none"),
973                          [ makeVulkanTypeSimple(True, "uint32_t", 0, "opcode") ])
974
975        self.module.appendHeader(
976            self.cgenHeader.makeFuncDecl(opcode2stringPrototype))
977
978        def emitOpcode2StringImpl(apiOpcodes, cgen):
979            cgen.line("switch(opcode)")
980            cgen.beginBlock()
981
982            currFeature = None
983
984            for (name, (opcodeNum, feature)) in sorted(apiOpcodes.items(), key = lambda x : x[1][0]):
985                if not currFeature:
986                    cgen.leftline("#ifdef %s" % feature)
987                    currFeature = feature
988
989                if currFeature and feature != currFeature:
990                    cgen.leftline("#endif")
991                    cgen.leftline("#ifdef %s" % feature)
992                    currFeature = feature
993
994                cgen.line("case OP_%s:" % name)
995                cgen.beginBlock()
996                cgen.stmt("return \"OP_%s\"" % name)
997                cgen.endBlock()
998
999            if currFeature:
1000                cgen.leftline("#endif")
1001
1002            cgen.line("default:")
1003            cgen.beginBlock()
1004            cgen.stmt("return \"OP_UNKNOWN_API_CALL\"")
1005            cgen.endBlock()
1006
1007            cgen.endBlock()
1008
1009        self.module.appendImpl(
1010            self.cgenImpl.makeFuncImpl(
1011                opcode2stringPrototype,
1012                lambda cgen: emitOpcode2StringImpl(self.apiOpcodes, cgen)))
1013
1014        self.module.appendHeader(
1015            "#define OP_vkFirst_old %d\n" % (self.beginOpcodeOld))
1016        self.module.appendHeader(
1017            "#define OP_vkLast_old %d\n" % (self.endOpcodeOld))
1018        self.module.appendHeader(
1019            "#define OP_vkFirst %d\n" % (self.beginOpcode))
1020        self.module.appendHeader(
1021            "#define OP_vkLast %d\n" % (self.endOpcode))
1022