1# Copyright (c) 2018 The Android Open Source Project 2# Copyright (c) 2018 Google Inc. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16from copy import copy 17 18from .common.codegen import CodeGen 19from .common.vulkantypes import \ 20 VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator 21 22from .wrapperdefs import VulkanWrapperGenerator 23from .wrapperdefs import EQUALITY_VAR_NAMES 24from .wrapperdefs import EQUALITY_ON_FAIL_VAR 25from .wrapperdefs import EQUALITY_ON_FAIL_VAR_TYPE 26from .wrapperdefs import EQUALITY_RET_TYPE 27from .wrapperdefs import API_PREFIX_EQUALITY 28from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM2 29 30class VulkanEqualityCodegen(VulkanTypeIterator): 31 32 def __init__(self, cgen, inputVars, onFailCompareVar, prefix): 33 self.cgen = cgen 34 self.inputVars = inputVars 35 self.onFailCompareVar = onFailCompareVar 36 self.prefix = prefix 37 38 def makeAccess(varName, asPtr = True): 39 return lambda t: self.cgen.generalAccess(t, parentVarName = varName, asPtr = asPtr) 40 41 def makeLengthAccess(varName): 42 return lambda t: self.cgen.generalLengthAccess(t, parentVarName = varName) 43 44 def makeLengthAccessGuard(varName): 45 return lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName=varName) 46 47 self.exprAccessorLhs = makeAccess(self.inputVars[0]) 48 self.exprAccessorRhs = makeAccess(self.inputVars[1]) 49 50 self.exprAccessorValueLhs = makeAccess(self.inputVars[0], asPtr = False) 51 self.exprAccessorValueRhs = makeAccess(self.inputVars[1], asPtr = False) 52 53 self.lenAccessorLhs = makeLengthAccess(self.inputVars[0]) 54 self.lenAccessorRhs = makeLengthAccess(self.inputVars[1]) 55 56 self.lenAccessGuardLhs = makeLengthAccessGuard(self.inputVars[0]) 57 self.lenAccessGuardRhs = makeLengthAccessGuard(self.inputVars[1]) 58 59 self.checked = False 60 61 def getTypeForCompare(self, vulkanType): 62 res = copy(vulkanType) 63 64 if not vulkanType.accessibleAsPointer(): 65 res = res.getForAddressAccess() 66 67 if vulkanType.staticArrExpr: 68 res = res.getForAddressAccess() 69 70 return res 71 72 def makeCastExpr(self, vulkanType): 73 return "(%s)" % ( 74 self.cgen.makeCTypeDecl(vulkanType, useParamName=False)) 75 76 def makeEqualExpr(self, lhs, rhs): 77 return "(%s) == (%s)" % (lhs, rhs) 78 79 def makeEqualBufExpr(self, lhs, rhs, size): 80 return "(memcmp(%s, %s, %s) == 0)" % (lhs, rhs, size) 81 82 def makeEqualStringExpr(self, lhs, rhs): 83 return "(strcmp(%s, %s) == 0)" % (lhs, rhs) 84 85 def makeBothNotNullExpr(self, lhs, rhs): 86 return "(%s) && (%s)" % (lhs, rhs) 87 88 def makeBothNullExpr(self, lhs, rhs): 89 return "!(%s) && !(%s)" % (lhs, rhs) 90 91 def compareWithConsequence(self, compareExpr, vulkanType, errMsg=""): 92 self.cgen.stmt("if (!(%s)) { %s(\"%s (Error: %s)\"); }" % 93 (compareExpr, self.onFailCompareVar, 94 self.exprAccessorValueLhs(vulkanType), errMsg)) 95 96 def onCheck(self, vulkanType): 97 98 self.checked = True 99 100 accessLhs = self.exprAccessorLhs(vulkanType) 101 accessRhs = self.exprAccessorRhs(vulkanType) 102 103 bothNull = self.makeBothNullExpr(accessLhs, accessRhs) 104 bothNotNull = self.makeBothNotNullExpr(accessLhs, accessRhs) 105 nullMatchExpr = "(%s) || (%s)" % (bothNull, bothNotNull) 106 107 self.compareWithConsequence( \ 108 nullMatchExpr, 109 vulkanType, 110 "Mismatch in optional field") 111 112 skipStreamInternal = vulkanType.typeName == "void" 113 114 if skipStreamInternal: 115 return 116 117 self.cgen.beginIf("%s && %s" % (accessLhs, accessRhs)) 118 119 def endCheck(self, vulkanType): 120 121 skipStreamInternal = vulkanType.typeName == "void" 122 if skipStreamInternal: 123 return 124 125 if self.checked: 126 self.cgen.endIf() 127 self.checked = False 128 129 def onCompoundType(self, vulkanType): 130 accessLhs = self.exprAccessorLhs(vulkanType) 131 accessRhs = self.exprAccessorRhs(vulkanType) 132 133 lenAccessLhs = self.lenAccessorLhs(vulkanType) 134 lenAccessRhs = self.lenAccessorRhs(vulkanType) 135 136 lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType) 137 lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType) 138 139 needNullCheck = vulkanType.pointerIndirectionLevels > 0 140 141 if needNullCheck: 142 bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) 143 self.cgen.beginIf(bothNotNullExpr) 144 145 if lenAccessLhs is not None: 146 equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs) 147 148 self.compareWithConsequence( \ 149 equalLenExpr, 150 vulkanType, "Lengths not equal") 151 152 loopVar = "i" 153 accessLhs = "%s + %s" % (accessLhs, loopVar) 154 accessRhs = "%s + %s" % (accessRhs, loopVar) 155 forInit = "uint32_t %s = 0" % loopVar 156 forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs) 157 forIncr = "++%s" % loopVar 158 159 if needNullCheck: 160 self.cgen.beginIf(equalLenExpr) 161 162 if lenAccessGuardLhs is not None: 163 self.cgen.beginIf(lenAccessGuardLhs) 164 165 self.cgen.beginFor(forInit, forCond, forIncr) 166 167 self.cgen.funcCall(None, self.prefix + vulkanType.typeName, 168 [accessLhs, accessRhs, self.onFailCompareVar]) 169 170 if lenAccessLhs is not None: 171 self.cgen.endFor() 172 if lenAccessGuardLhs is not None: 173 self.cgen.endIf() 174 if needNullCheck: 175 self.cgen.endIf() 176 177 if needNullCheck: 178 self.cgen.endIf() 179 180 def onString(self, vulkanType): 181 accessLhs = self.exprAccessorLhs(vulkanType) 182 accessRhs = self.exprAccessorRhs(vulkanType) 183 184 bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs) 185 bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) 186 nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr) 187 188 self.compareWithConsequence( \ 189 nullMatchExpr, 190 vulkanType, 191 "Mismatch in string pointer nullness") 192 193 self.cgen.beginIf(bothNotNullExpr) 194 195 self.compareWithConsequence( 196 self.makeEqualStringExpr(accessLhs, accessRhs), 197 vulkanType, "Unequal strings") 198 199 self.cgen.endIf() 200 201 def onStringArray(self, vulkanType): 202 accessLhs = self.exprAccessorLhs(vulkanType) 203 accessRhs = self.exprAccessorRhs(vulkanType) 204 205 lenAccessLhs = self.lenAccessorLhs(vulkanType) 206 lenAccessRhs = self.lenAccessorRhs(vulkanType) 207 208 lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType) 209 lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType) 210 211 bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs) 212 bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) 213 nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr) 214 215 self.compareWithConsequence( \ 216 nullMatchExpr, 217 vulkanType, 218 "Mismatch in string array pointer nullness") 219 220 equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs) 221 222 self.compareWithConsequence( \ 223 equalLenExpr, 224 vulkanType, "Lengths not equal in string array") 225 226 self.compareWithConsequence( \ 227 equalLenExpr, 228 vulkanType, "Lengths not equal in string array") 229 230 self.cgen.beginIf("%s && %s" % (equalLenExpr, bothNotNullExpr)) 231 232 loopVar = "i" 233 accessLhs = "*(%s + %s)" % (accessLhs, loopVar) 234 accessRhs = "*(%s + %s)" % (accessRhs, loopVar) 235 forInit = "uint32_t %s = 0" % loopVar 236 forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs) 237 forIncr = "++%s" % loopVar 238 239 if lenAccessGuardLhs is not None: 240 self.cgen.beginIf(lenAccessGuardLhs) 241 242 self.cgen.beginFor(forInit, forCond, forIncr) 243 244 self.compareWithConsequence( 245 self.makeEqualStringExpr(accessLhs, accessRhs), 246 vulkanType, "Unequal string in string array") 247 248 self.cgen.endFor() 249 250 if lenAccessGuardLhs is not None: 251 self.cgen.endIf() 252 253 self.cgen.endIf() 254 255 def onStaticArr(self, vulkanType): 256 accessLhs = self.exprAccessorLhs(vulkanType) 257 accessRhs = self.exprAccessorRhs(vulkanType) 258 259 lenAccessLhs = self.lenAccessorLhs(vulkanType) 260 261 finalLenExpr = "%s * %s" % (lenAccessLhs, 262 self.cgen.sizeofExpr(vulkanType)) 263 264 self.compareWithConsequence( 265 self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr), 266 vulkanType, "Unequal static array") 267 268 def onStructExtension(self, vulkanType): 269 lhs = self.exprAccessorLhs(vulkanType) 270 rhs = self.exprAccessorRhs(vulkanType) 271 272 self.cgen.beginIf(lhs) 273 self.cgen.funcCall(None, self.prefix + "extension_struct", 274 [lhs, rhs, self.onFailCompareVar]) 275 self.cgen.endIf() 276 277 def onPointer(self, vulkanType): 278 accessLhs = self.exprAccessorLhs(vulkanType) 279 accessRhs = self.exprAccessorRhs(vulkanType) 280 281 skipStreamInternal = vulkanType.typeName == "void" 282 if skipStreamInternal: 283 return 284 285 lenAccessLhs = self.lenAccessorLhs(vulkanType) 286 lenAccessRhs = self.lenAccessorRhs(vulkanType) 287 288 if lenAccessLhs is not None: 289 self.compareWithConsequence( \ 290 self.makeEqualExpr(lenAccessLhs, lenAccessRhs), 291 vulkanType, "Lengths not equal") 292 293 finalLenExpr = "%s * %s" % (lenAccessLhs, 294 self.cgen.sizeofExpr( 295 vulkanType.getForValueAccess())) 296 else: 297 finalLenExpr = self.cgen.sizeofExpr(vulkanType.getForValueAccess()) 298 299 self.compareWithConsequence( 300 self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr), 301 vulkanType, "Unequal dyn array") 302 303 def onValue(self, vulkanType): 304 accessLhs = self.exprAccessorValueLhs(vulkanType) 305 accessRhs = self.exprAccessorValueRhs(vulkanType) 306 self.compareWithConsequence( 307 self.makeEqualExpr(accessLhs, accessRhs), vulkanType, 308 "Value not equal") 309 310 311class VulkanTesting(VulkanWrapperGenerator): 312 313 def __init__(self, module, typeInfo): 314 VulkanWrapperGenerator.__init__(self, module, typeInfo) 315 316 self.codegen = CodeGen() 317 318 self.equalityCodegen = \ 319 VulkanEqualityCodegen( 320 None, 321 EQUALITY_VAR_NAMES, 322 EQUALITY_ON_FAIL_VAR, 323 API_PREFIX_EQUALITY) 324 325 self.knownDefs = {} 326 327 self.extensionTestingPrototype = \ 328 VulkanAPI(API_PREFIX_EQUALITY + "extension_struct", 329 EQUALITY_RET_TYPE, 330 [STRUCT_EXTENSION_PARAM, 331 STRUCT_EXTENSION_PARAM2, 332 EQUALITY_ON_FAIL_VAR_TYPE]) 333 334 def onBegin(self,): 335 VulkanWrapperGenerator.onBegin(self) 336 self.module.appendImpl(self.codegen.makeFuncDecl( 337 self.extensionTestingPrototype)) 338 339 def onGenType(self, typeXml, name, alias): 340 VulkanWrapperGenerator.onGenType(self, typeXml, name, alias) 341 342 if name in self.knownDefs: 343 return 344 345 category = self.typeInfo.categoryOf(name) 346 347 if category in ["struct", "union"] and alias: 348 self.module.appendHeader( 349 self.codegen.makeFuncAlias(API_PREFIX_EQUALITY + name, 350 API_PREFIX_EQUALITY + alias)) 351 352 if category in ["struct", "union"] and not alias: 353 354 structInfo = self.typeInfo.structs[name] 355 356 typeFromName = \ 357 lambda varname: makeVulkanTypeSimple(True, name, 1, varname) 358 359 compareParams = \ 360 list(map(typeFromName, EQUALITY_VAR_NAMES)) + \ 361 [EQUALITY_ON_FAIL_VAR_TYPE] 362 363 comparePrototype = \ 364 VulkanAPI(API_PREFIX_EQUALITY + name, 365 EQUALITY_RET_TYPE, 366 compareParams) 367 368 def structCompareDef(cgen): 369 self.equalityCodegen.cgen = cgen 370 for member in structInfo.members: 371 iterateVulkanType(self.typeInfo, member, 372 self.equalityCodegen) 373 374 self.module.appendHeader( 375 self.codegen.makeFuncDecl(comparePrototype)) 376 self.module.appendImpl( 377 self.codegen.makeFuncImpl(comparePrototype, structCompareDef)) 378 379 def onGenCmd(self, cmdinfo, name, alias): 380 VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias) 381 382 def onEnd(self,): 383 VulkanWrapperGenerator.onEnd(self) 384 385 def forEachExtensionCompare(ext, castedAccess, cgen): 386 cgen.funcCall(None, API_PREFIX_EQUALITY + ext.name, 387 [castedAccess, 388 cgen.makeReinterpretCast( 389 STRUCT_EXTENSION_PARAM2.paramName, ext.name), 390 EQUALITY_ON_FAIL_VAR]) 391 392 self.module.appendImpl( 393 self.codegen.makeFuncImpl( 394 self.extensionTestingPrototype, 395 lambda cgen: self.emitForEachStructExtension( 396 cgen, 397 EQUALITY_RET_TYPE, 398 STRUCT_EXTENSION_PARAM, 399 forEachExtensionCompare))) 400