• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from copy import copy
17
18from .common.codegen import CodeGen
19from .common.vulkantypes import \
20        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator
21
22from .wrapperdefs import VulkanWrapperGenerator
23from .wrapperdefs import EQUALITY_VAR_NAMES
24from .wrapperdefs import EQUALITY_ON_FAIL_VAR
25from .wrapperdefs import EQUALITY_ON_FAIL_VAR_TYPE
26from .wrapperdefs import EQUALITY_RET_TYPE
27from .wrapperdefs import API_PREFIX_EQUALITY
28from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM2
29
30class VulkanEqualityCodegen(VulkanTypeIterator):
31
32    def __init__(self, cgen, inputVars, onFailCompareVar, prefix):
33        self.cgen = cgen
34        self.inputVars = inputVars
35        self.onFailCompareVar = onFailCompareVar
36        self.prefix = prefix
37
38        def makeAccess(varName, asPtr = True):
39            return lambda t: self.cgen.generalAccess(t, parentVarName = varName, asPtr = asPtr)
40
41        def makeLengthAccess(varName):
42            return lambda t: self.cgen.generalLengthAccess(t, parentVarName = varName)
43
44        def makeLengthAccessGuard(varName):
45            return lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName=varName)
46
47        self.exprAccessorLhs = makeAccess(self.inputVars[0])
48        self.exprAccessorRhs = makeAccess(self.inputVars[1])
49
50        self.exprAccessorValueLhs = makeAccess(self.inputVars[0], asPtr = False)
51        self.exprAccessorValueRhs = makeAccess(self.inputVars[1], asPtr = False)
52
53        self.lenAccessorLhs = makeLengthAccess(self.inputVars[0])
54        self.lenAccessorRhs = makeLengthAccess(self.inputVars[1])
55
56        self.lenAccessGuardLhs = makeLengthAccessGuard(self.inputVars[0])
57        self.lenAccessGuardRhs = makeLengthAccessGuard(self.inputVars[1])
58
59        self.checked = False
60
61    def getTypeForCompare(self, vulkanType):
62        res = copy(vulkanType)
63
64        if not vulkanType.accessibleAsPointer():
65            res = res.getForAddressAccess()
66
67        if vulkanType.staticArrExpr:
68            res = res.getForAddressAccess()
69
70        return res
71
72    def makeCastExpr(self, vulkanType):
73        return "(%s)" % (
74            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
75
76    def makeEqualExpr(self, lhs, rhs):
77        return "(%s) == (%s)" % (lhs, rhs)
78
79    def makeEqualBufExpr(self, lhs, rhs, size):
80        return "(memcmp(%s, %s, %s) == 0)" % (lhs, rhs, size)
81
82    def makeEqualStringExpr(self, lhs, rhs):
83        return "(strcmp(%s, %s) == 0)" % (lhs, rhs)
84
85    def makeBothNotNullExpr(self, lhs, rhs):
86        return "(%s) && (%s)" % (lhs, rhs)
87
88    def makeBothNullExpr(self, lhs, rhs):
89        return "!(%s) && !(%s)" % (lhs, rhs)
90
91    def compareWithConsequence(self, compareExpr, vulkanType, errMsg=""):
92        self.cgen.stmt("if (!(%s)) { %s(\"%s (Error: %s)\"); }" %
93                       (compareExpr, self.onFailCompareVar,
94                        self.exprAccessorValueLhs(vulkanType), errMsg))
95
96    def onCheck(self, vulkanType):
97
98        self.checked = True
99
100        accessLhs = self.exprAccessorLhs(vulkanType)
101        accessRhs = self.exprAccessorRhs(vulkanType)
102
103        bothNull = self.makeBothNullExpr(accessLhs, accessRhs)
104        bothNotNull = self.makeBothNotNullExpr(accessLhs, accessRhs)
105        nullMatchExpr = "(%s) || (%s)" % (bothNull, bothNotNull)
106
107        self.compareWithConsequence( \
108            nullMatchExpr,
109            vulkanType,
110            "Mismatch in optional field")
111
112        skipStreamInternal = vulkanType.typeName == "void"
113
114        if skipStreamInternal:
115            return
116
117        self.cgen.beginIf("%s && %s" % (accessLhs, accessRhs))
118
119    def endCheck(self, vulkanType):
120
121        skipStreamInternal = vulkanType.typeName == "void"
122        if skipStreamInternal:
123            return
124
125        if self.checked:
126            self.cgen.endIf()
127            self.checked = False
128
129    def onCompoundType(self, vulkanType):
130        accessLhs = self.exprAccessorLhs(vulkanType)
131        accessRhs = self.exprAccessorRhs(vulkanType)
132
133        lenAccessLhs = self.lenAccessorLhs(vulkanType)
134        lenAccessRhs = self.lenAccessorRhs(vulkanType)
135
136        lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType)
137        lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType)
138
139        needNullCheck = vulkanType.pointerIndirectionLevels > 0
140
141        if needNullCheck:
142            bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs)
143            self.cgen.beginIf(bothNotNullExpr)
144
145        if lenAccessLhs is not None:
146            equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs)
147
148            self.compareWithConsequence( \
149                equalLenExpr,
150                vulkanType, "Lengths not equal")
151
152            loopVar = "i"
153            accessLhs = "%s + %s" % (accessLhs, loopVar)
154            accessRhs = "%s + %s" % (accessRhs, loopVar)
155            forInit = "uint32_t %s = 0" % loopVar
156            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs)
157            forIncr = "++%s" % loopVar
158
159            if needNullCheck:
160                self.cgen.beginIf(equalLenExpr)
161
162            if lenAccessGuardLhs is not None:
163                self.cgen.beginIf(lenAccessGuardLhs)
164
165            self.cgen.beginFor(forInit, forCond, forIncr)
166
167        self.cgen.funcCall(None, self.prefix + vulkanType.typeName,
168                           [accessLhs, accessRhs, self.onFailCompareVar])
169
170        if lenAccessLhs is not None:
171            self.cgen.endFor()
172            if lenAccessGuardLhs is not None:
173                self.cgen.endIf()
174            if needNullCheck:
175                self.cgen.endIf()
176
177        if needNullCheck:
178            self.cgen.endIf()
179
180    def onString(self, vulkanType):
181        accessLhs = self.exprAccessorLhs(vulkanType)
182        accessRhs = self.exprAccessorRhs(vulkanType)
183
184        bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs)
185        bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs)
186        nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr)
187
188        self.compareWithConsequence( \
189            nullMatchExpr,
190            vulkanType,
191            "Mismatch in string pointer nullness")
192
193        self.cgen.beginIf(bothNotNullExpr)
194
195        self.compareWithConsequence(
196            self.makeEqualStringExpr(accessLhs, accessRhs),
197            vulkanType, "Unequal strings")
198
199        self.cgen.endIf()
200
201    def onStringArray(self, vulkanType):
202        accessLhs = self.exprAccessorLhs(vulkanType)
203        accessRhs = self.exprAccessorRhs(vulkanType)
204
205        lenAccessLhs = self.lenAccessorLhs(vulkanType)
206        lenAccessRhs = self.lenAccessorRhs(vulkanType)
207
208        lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType)
209        lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType)
210
211        bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs)
212        bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs)
213        nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr)
214
215        self.compareWithConsequence( \
216            nullMatchExpr,
217            vulkanType,
218            "Mismatch in string array pointer nullness")
219
220        equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs)
221
222        self.compareWithConsequence( \
223            equalLenExpr,
224            vulkanType, "Lengths not equal in string array")
225
226        self.compareWithConsequence( \
227            equalLenExpr,
228            vulkanType, "Lengths not equal in string array")
229
230        self.cgen.beginIf("%s && %s" % (equalLenExpr, bothNotNullExpr))
231
232        loopVar = "i"
233        accessLhs = "*(%s + %s)" % (accessLhs, loopVar)
234        accessRhs = "*(%s + %s)" % (accessRhs, loopVar)
235        forInit = "uint32_t %s = 0" % loopVar
236        forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs)
237        forIncr = "++%s" % loopVar
238
239        if lenAccessGuardLhs is not None:
240            self.cgen.beginIf(lenAccessGuardLhs)
241
242        self.cgen.beginFor(forInit, forCond, forIncr)
243
244        self.compareWithConsequence(
245            self.makeEqualStringExpr(accessLhs, accessRhs),
246            vulkanType, "Unequal string in string array")
247
248        self.cgen.endFor()
249
250        if lenAccessGuardLhs is not None:
251            self.cgen.endIf()
252
253        self.cgen.endIf()
254
255    def onStaticArr(self, vulkanType):
256        accessLhs = self.exprAccessorLhs(vulkanType)
257        accessRhs = self.exprAccessorRhs(vulkanType)
258
259        lenAccessLhs = self.lenAccessorLhs(vulkanType)
260
261        finalLenExpr = "%s * %s" % (lenAccessLhs,
262                                    self.cgen.sizeofExpr(vulkanType))
263
264        self.compareWithConsequence(
265            self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr),
266            vulkanType, "Unequal static array")
267
268    def onStructExtension(self, vulkanType):
269        lhs = self.exprAccessorLhs(vulkanType)
270        rhs = self.exprAccessorRhs(vulkanType)
271
272        self.cgen.beginIf(lhs)
273        self.cgen.funcCall(None, self.prefix + "extension_struct",
274                           [lhs, rhs, self.onFailCompareVar])
275        self.cgen.endIf()
276
277    def onPointer(self, vulkanType):
278        accessLhs = self.exprAccessorLhs(vulkanType)
279        accessRhs = self.exprAccessorRhs(vulkanType)
280
281        skipStreamInternal = vulkanType.typeName == "void"
282        if skipStreamInternal:
283            return
284
285        lenAccessLhs = self.lenAccessorLhs(vulkanType)
286        lenAccessRhs = self.lenAccessorRhs(vulkanType)
287
288        if lenAccessLhs is not None:
289            self.compareWithConsequence( \
290                self.makeEqualExpr(lenAccessLhs, lenAccessRhs),
291                vulkanType, "Lengths not equal")
292
293            finalLenExpr = "%s * %s" % (lenAccessLhs,
294                                        self.cgen.sizeofExpr(
295                                            vulkanType.getForValueAccess()))
296        else:
297            finalLenExpr = self.cgen.sizeofExpr(vulkanType.getForValueAccess())
298
299        self.compareWithConsequence(
300            self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr),
301            vulkanType, "Unequal dyn array")
302
303    def onValue(self, vulkanType):
304        accessLhs = self.exprAccessorValueLhs(vulkanType)
305        accessRhs = self.exprAccessorValueRhs(vulkanType)
306        self.compareWithConsequence(
307            self.makeEqualExpr(accessLhs, accessRhs), vulkanType,
308            "Value not equal")
309
310
311class VulkanTesting(VulkanWrapperGenerator):
312
313    def __init__(self, module, typeInfo):
314        VulkanWrapperGenerator.__init__(self, module, typeInfo)
315
316        self.codegen = CodeGen()
317
318        self.equalityCodegen = \
319            VulkanEqualityCodegen(
320                None,
321                EQUALITY_VAR_NAMES,
322                EQUALITY_ON_FAIL_VAR,
323                API_PREFIX_EQUALITY)
324
325        self.knownDefs = {}
326
327        self.extensionTestingPrototype = \
328            VulkanAPI(API_PREFIX_EQUALITY + "extension_struct",
329                      EQUALITY_RET_TYPE,
330                      [STRUCT_EXTENSION_PARAM,
331                       STRUCT_EXTENSION_PARAM2,
332                       EQUALITY_ON_FAIL_VAR_TYPE])
333
334    def onBegin(self,):
335        VulkanWrapperGenerator.onBegin(self)
336        self.module.appendImpl(self.codegen.makeFuncDecl(
337            self.extensionTestingPrototype))
338
339    def onGenType(self, typeXml, name, alias):
340        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
341
342        if name in self.knownDefs:
343            return
344
345        category = self.typeInfo.categoryOf(name)
346
347        if category in ["struct", "union"] and alias:
348            self.module.appendHeader(
349                self.codegen.makeFuncAlias(API_PREFIX_EQUALITY + name,
350                                           API_PREFIX_EQUALITY + alias))
351
352        if category in ["struct", "union"] and not alias:
353
354            structInfo = self.typeInfo.structs[name]
355
356            typeFromName = \
357                lambda varname: makeVulkanTypeSimple(True, name, 1, varname)
358
359            compareParams = \
360                list(map(typeFromName, EQUALITY_VAR_NAMES)) + \
361                [EQUALITY_ON_FAIL_VAR_TYPE]
362
363            comparePrototype = \
364                VulkanAPI(API_PREFIX_EQUALITY + name,
365                          EQUALITY_RET_TYPE,
366                          compareParams)
367
368            def structCompareDef(cgen):
369                self.equalityCodegen.cgen = cgen
370                for member in structInfo.members:
371                    iterateVulkanType(self.typeInfo, member,
372                                      self.equalityCodegen)
373
374            self.module.appendHeader(
375                self.codegen.makeFuncDecl(comparePrototype))
376            self.module.appendImpl(
377                self.codegen.makeFuncImpl(comparePrototype, structCompareDef))
378
379    def onGenCmd(self, cmdinfo, name, alias):
380        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
381
382    def onEnd(self,):
383        VulkanWrapperGenerator.onEnd(self)
384
385        def forEachExtensionCompare(ext, castedAccess, cgen):
386            cgen.funcCall(None, API_PREFIX_EQUALITY + ext.name,
387                          [castedAccess,
388                           cgen.makeReinterpretCast(
389                               STRUCT_EXTENSION_PARAM2.paramName, ext.name),
390                           EQUALITY_ON_FAIL_VAR])
391
392        self.module.appendImpl(
393            self.codegen.makeFuncImpl(
394                self.extensionTestingPrototype,
395                lambda cgen: self.emitForEachStructExtension(
396                    cgen,
397                    EQUALITY_RET_TYPE,
398                    STRUCT_EXTENSION_PARAM,
399                    forEachExtensionCompare)))
400