# Copyright (c) 2018 The Android Open Source Project # Copyright (c) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import copy from .common.codegen import CodeGen from .common.vulkantypes import \ VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator from .wrapperdefs import VulkanWrapperGenerator from .wrapperdefs import EQUALITY_VAR_NAMES from .wrapperdefs import EQUALITY_ON_FAIL_VAR from .wrapperdefs import EQUALITY_ON_FAIL_VAR_TYPE from .wrapperdefs import EQUALITY_RET_TYPE from .wrapperdefs import API_PREFIX_EQUALITY from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM2 class VulkanEqualityCodegen(VulkanTypeIterator): def __init__(self, cgen, inputVars, onFailCompareVar, prefix): self.cgen = cgen self.inputVars = inputVars self.onFailCompareVar = onFailCompareVar self.prefix = prefix def makeAccess(varName, asPtr = True): return lambda t: self.cgen.generalAccess(t, parentVarName = varName, asPtr = asPtr) def makeLengthAccess(varName): return lambda t: self.cgen.generalLengthAccess(t, parentVarName = varName) def makeLengthAccessGuard(varName): return lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName=varName) self.exprAccessorLhs = makeAccess(self.inputVars[0]) self.exprAccessorRhs = makeAccess(self.inputVars[1]) self.exprAccessorValueLhs = makeAccess(self.inputVars[0], asPtr = False) self.exprAccessorValueRhs = makeAccess(self.inputVars[1], asPtr = False) self.lenAccessorLhs = makeLengthAccess(self.inputVars[0]) self.lenAccessorRhs = makeLengthAccess(self.inputVars[1]) self.lenAccessGuardLhs = makeLengthAccessGuard(self.inputVars[0]) self.lenAccessGuardRhs = makeLengthAccessGuard(self.inputVars[1]) self.checked = False def getTypeForCompare(self, vulkanType): res = copy(vulkanType) if not vulkanType.accessibleAsPointer(): res = res.getForAddressAccess() if vulkanType.staticArrExpr: res = res.getForAddressAccess() return res def makeCastExpr(self, vulkanType): return "(%s)" % ( self.cgen.makeCTypeDecl(vulkanType, useParamName=False)) def makeEqualExpr(self, lhs, rhs): return "(%s) == (%s)" % (lhs, rhs) def makeEqualBufExpr(self, lhs, rhs, size): return "(memcmp(%s, %s, %s) == 0)" % (lhs, rhs, size) def makeEqualStringExpr(self, lhs, rhs): return "(strcmp(%s, %s) == 0)" % (lhs, rhs) def makeBothNotNullExpr(self, lhs, rhs): return "(%s) && (%s)" % (lhs, rhs) def makeBothNullExpr(self, lhs, rhs): return "!(%s) && !(%s)" % (lhs, rhs) def compareWithConsequence(self, compareExpr, vulkanType, errMsg=""): self.cgen.stmt("if (!(%s)) { %s(\"%s (Error: %s)\"); }" % (compareExpr, self.onFailCompareVar, self.exprAccessorValueLhs(vulkanType), errMsg)) def onCheck(self, vulkanType): self.checked = True accessLhs = self.exprAccessorLhs(vulkanType) accessRhs = self.exprAccessorRhs(vulkanType) bothNull = self.makeBothNullExpr(accessLhs, accessRhs) bothNotNull = self.makeBothNotNullExpr(accessLhs, accessRhs) nullMatchExpr = "(%s) || (%s)" % (bothNull, bothNotNull) self.compareWithConsequence( \ nullMatchExpr, vulkanType, "Mismatch in optional field") skipStreamInternal = vulkanType.typeName == "void" if skipStreamInternal: return self.cgen.beginIf("%s && %s" % (accessLhs, accessRhs)) def endCheck(self, vulkanType): skipStreamInternal = vulkanType.typeName == "void" if skipStreamInternal: return if self.checked: self.cgen.endIf() self.checked = False def onCompoundType(self, vulkanType): accessLhs = self.exprAccessorLhs(vulkanType) accessRhs = self.exprAccessorRhs(vulkanType) lenAccessLhs = self.lenAccessorLhs(vulkanType) lenAccessRhs = self.lenAccessorRhs(vulkanType) lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType) lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType) needNullCheck = vulkanType.pointerIndirectionLevels > 0 if needNullCheck: bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) self.cgen.beginIf(bothNotNullExpr) if lenAccessLhs is not None: equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs) self.compareWithConsequence( \ equalLenExpr, vulkanType, "Lengths not equal") loopVar = "i" accessLhs = "%s + %s" % (accessLhs, loopVar) accessRhs = "%s + %s" % (accessRhs, loopVar) forInit = "uint32_t %s = 0" % loopVar forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs) forIncr = "++%s" % loopVar if needNullCheck: self.cgen.beginIf(equalLenExpr) if lenAccessGuardLhs is not None: self.cgen.beginIf(lenAccessGuardLhs) self.cgen.beginFor(forInit, forCond, forIncr) self.cgen.funcCall(None, self.prefix + vulkanType.typeName, [accessLhs, accessRhs, self.onFailCompareVar]) if lenAccessLhs is not None: self.cgen.endFor() if lenAccessGuardLhs is not None: self.cgen.endIf() if needNullCheck: self.cgen.endIf() if needNullCheck: self.cgen.endIf() def onString(self, vulkanType): accessLhs = self.exprAccessorLhs(vulkanType) accessRhs = self.exprAccessorRhs(vulkanType) bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs) bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr) self.compareWithConsequence( \ nullMatchExpr, vulkanType, "Mismatch in string pointer nullness") self.cgen.beginIf(bothNotNullExpr) self.compareWithConsequence( self.makeEqualStringExpr(accessLhs, accessRhs), vulkanType, "Unequal strings") self.cgen.endIf() def onStringArray(self, vulkanType): accessLhs = self.exprAccessorLhs(vulkanType) accessRhs = self.exprAccessorRhs(vulkanType) lenAccessLhs = self.lenAccessorLhs(vulkanType) lenAccessRhs = self.lenAccessorRhs(vulkanType) lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType) lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType) bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs) bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr) self.compareWithConsequence( \ nullMatchExpr, vulkanType, "Mismatch in string array pointer nullness") equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs) self.compareWithConsequence( \ equalLenExpr, vulkanType, "Lengths not equal in string array") self.compareWithConsequence( \ equalLenExpr, vulkanType, "Lengths not equal in string array") self.cgen.beginIf("%s && %s" % (equalLenExpr, bothNotNullExpr)) loopVar = "i" accessLhs = "*(%s + %s)" % (accessLhs, loopVar) accessRhs = "*(%s + %s)" % (accessRhs, loopVar) forInit = "uint32_t %s = 0" % loopVar forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs) forIncr = "++%s" % loopVar if lenAccessGuardLhs is not None: self.cgen.beginIf(lenAccessGuardLhs) self.cgen.beginFor(forInit, forCond, forIncr) self.compareWithConsequence( self.makeEqualStringExpr(accessLhs, accessRhs), vulkanType, "Unequal string in string array") self.cgen.endFor() if lenAccessGuardLhs is not None: self.cgen.endIf() self.cgen.endIf() def onStaticArr(self, vulkanType): accessLhs = self.exprAccessorLhs(vulkanType) accessRhs = self.exprAccessorRhs(vulkanType) lenAccessLhs = self.lenAccessorLhs(vulkanType) finalLenExpr = "%s * %s" % (lenAccessLhs, self.cgen.sizeofExpr(vulkanType)) self.compareWithConsequence( self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr), vulkanType, "Unequal static array") def onStructExtension(self, vulkanType): lhs = self.exprAccessorLhs(vulkanType) rhs = self.exprAccessorRhs(vulkanType) self.cgen.beginIf(lhs) self.cgen.funcCall(None, self.prefix + "extension_struct", [lhs, rhs, self.onFailCompareVar]) self.cgen.endIf() def onPointer(self, vulkanType): accessLhs = self.exprAccessorLhs(vulkanType) accessRhs = self.exprAccessorRhs(vulkanType) skipStreamInternal = vulkanType.typeName == "void" if skipStreamInternal: return lenAccessLhs = self.lenAccessorLhs(vulkanType) lenAccessRhs = self.lenAccessorRhs(vulkanType) if lenAccessLhs is not None: self.compareWithConsequence( \ self.makeEqualExpr(lenAccessLhs, lenAccessRhs), vulkanType, "Lengths not equal") finalLenExpr = "%s * %s" % (lenAccessLhs, self.cgen.sizeofExpr( vulkanType.getForValueAccess())) else: finalLenExpr = self.cgen.sizeofExpr(vulkanType.getForValueAccess()) self.compareWithConsequence( self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr), vulkanType, "Unequal dyn array") def onValue(self, vulkanType): accessLhs = self.exprAccessorValueLhs(vulkanType) accessRhs = self.exprAccessorValueRhs(vulkanType) self.compareWithConsequence( self.makeEqualExpr(accessLhs, accessRhs), vulkanType, "Value not equal") class VulkanTesting(VulkanWrapperGenerator): def __init__(self, module, typeInfo): VulkanWrapperGenerator.__init__(self, module, typeInfo) self.codegen = CodeGen() self.equalityCodegen = \ VulkanEqualityCodegen( None, EQUALITY_VAR_NAMES, EQUALITY_ON_FAIL_VAR, API_PREFIX_EQUALITY) self.knownDefs = {} self.extensionTestingPrototype = \ VulkanAPI(API_PREFIX_EQUALITY + "extension_struct", EQUALITY_RET_TYPE, [STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM2, EQUALITY_ON_FAIL_VAR_TYPE]) def onBegin(self,): VulkanWrapperGenerator.onBegin(self) self.module.appendImpl(self.codegen.makeFuncDecl( self.extensionTestingPrototype)) def onGenType(self, typeXml, name, alias): VulkanWrapperGenerator.onGenType(self, typeXml, name, alias) if name in self.knownDefs: return category = self.typeInfo.categoryOf(name) if category in ["struct", "union"] and alias: self.module.appendHeader( self.codegen.makeFuncAlias(API_PREFIX_EQUALITY + name, API_PREFIX_EQUALITY + alias)) if category in ["struct", "union"] and not alias: structInfo = self.typeInfo.structs[name] typeFromName = \ lambda varname: makeVulkanTypeSimple(True, name, 1, varname) compareParams = \ list(map(typeFromName, EQUALITY_VAR_NAMES)) + \ [EQUALITY_ON_FAIL_VAR_TYPE] comparePrototype = \ VulkanAPI(API_PREFIX_EQUALITY + name, EQUALITY_RET_TYPE, compareParams) def structCompareDef(cgen): self.equalityCodegen.cgen = cgen for member in structInfo.members: iterateVulkanType(self.typeInfo, member, self.equalityCodegen) self.module.appendHeader( self.codegen.makeFuncDecl(comparePrototype)) self.module.appendImpl( self.codegen.makeFuncImpl(comparePrototype, structCompareDef)) def onGenCmd(self, cmdinfo, name, alias): VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias) def onEnd(self,): VulkanWrapperGenerator.onEnd(self) def forEachExtensionCompare(ext, castedAccess, cgen): cgen.funcCall(None, API_PREFIX_EQUALITY + ext.name, [castedAccess, cgen.makeReinterpretCast( STRUCT_EXTENSION_PARAM2.paramName, ext.name), EQUALITY_ON_FAIL_VAR]) self.module.appendImpl( self.codegen.makeFuncImpl( self.extensionTestingPrototype, lambda cgen: self.emitForEachStructExtension( cgen, EQUALITY_RET_TYPE, STRUCT_EXTENSION_PARAM, forEachExtensionCompare)))