1 /* 2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMPILER_TESTS_GRAPH_COMPARATOR_H 17 #define COMPILER_TESTS_GRAPH_COMPARATOR_H 18 19 #include "optimizer/ir/ir_constructor.h" 20 #include "optimizer/analysis/rpo.h" 21 22 namespace ark::compiler { 23 24 class GraphComparator { 25 public: Compare(Graph * graph1,Graph * graph2)26 bool Compare(Graph *graph1, Graph *graph2) 27 { 28 graph1->InvalidateAnalysis<Rpo>(); 29 graph2->InvalidateAnalysis<Rpo>(); 30 if (graph1->GetBlocksRPO().size() != graph2->GetBlocksRPO().size()) { 31 std::cerr << "Different number of blocks\n"; 32 return false; 33 } 34 for (auto it1 = graph1->GetBlocksRPO().begin(), it2 = graph2->GetBlocksRPO().begin(); 35 it1 != graph1->GetBlocksRPO().end(); it1++, it2++) { 36 auto it = bbMap_.insert({*it1, *it2}); 37 if (!it.second) { 38 return false; 39 } 40 } 41 return std::equal(graph1->GetBlocksRPO().begin(), graph1->GetBlocksRPO().end(), graph2->GetBlocksRPO().begin(), 42 graph2->GetBlocksRPO().end(), [this](auto bb1, auto bb2) { return Compare(bb1, bb2); }); 43 } 44 Compare(BasicBlock * block1,BasicBlock * block2)45 bool Compare(BasicBlock *block1, BasicBlock *block2) 46 { 47 if (block1->GetPredsBlocks().size() != block2->GetPredsBlocks().size()) { 48 std::cerr << "Different number of preds blocks\n"; 49 block1->Dump(&std::cerr); 50 block2->Dump(&std::cerr); 51 return false; 52 } 53 if (block1->GetSuccsBlocks().size() != block2->GetSuccsBlocks().size()) { 54 std::cerr << "Different number of succs blocks\n"; 55 block1->Dump(&std::cerr); 56 block2->Dump(&std::cerr); 57 return false; 58 } 59 auto instCmp = [this](auto inst1, auto inst2) { 60 ASSERT(inst2 != nullptr); 61 bool t = Compare(inst1, inst2); 62 if (!t) { 63 std::cerr << "Different instructions:\n"; 64 inst1->Dump(&std::cerr); 65 inst2->Dump(&std::cerr); 66 } 67 return t; 68 }; 69 return std::equal(block1->AllInsts().begin(), block1->AllInsts().end(), block2->AllInsts().begin(), 70 block2->AllInsts().end(), instCmp); 71 } 72 InstInitialCompare(Inst * inst1,Inst * inst2)73 bool InstInitialCompare(Inst *inst1, Inst *inst2) 74 { 75 if (inst1->GetOpcode() != inst2->GetOpcode() || inst1->GetType() != inst2->GetType() || 76 inst1->GetInputsCount() != inst2->GetInputsCount()) { 77 instCompareMap_.erase(inst1); 78 return false; 79 } 80 if (inst1->GetFlagsMask() != inst2->GetFlagsMask()) { 81 instCompareMap_.erase(inst1); 82 return false; 83 } 84 if (inst1->GetOpcode() != Opcode::Phi) { 85 auto inst1Begin = inst1->GetInputs().begin(); 86 auto inst1End = inst1->GetInputs().end(); 87 auto inst2Begin = inst2->GetInputs().begin(); 88 auto eqLambda = [this](Input input1, Input input2) { return Compare(input1.GetInst(), input2.GetInst()); }; 89 if (!std::equal(inst1Begin, inst1End, inst2Begin, eqLambda)) { 90 instCompareMap_.erase(inst1); 91 return false; 92 } 93 } else { 94 if (inst1->GetInputsCount() != inst2->GetInputsCount()) { 95 instCompareMap_.erase(inst1); 96 return false; 97 } 98 for (size_t index1 = 0; index1 < inst1->GetInputsCount(); index1++) { 99 auto input1 = inst1->GetInput(index1).GetInst(); 100 auto bb1 = inst1->CastToPhi()->GetPhiInputBb(index1); 101 if (bbMap_.count(bb1) == 0) { 102 instCompareMap_.erase(inst1); 103 return false; 104 } 105 auto bb2 = bbMap_.at(bb1); 106 auto input2 = inst2->CastToPhi()->GetPhiInput(bb2); 107 if (!Compare(input1, input2)) { 108 instCompareMap_.erase(inst1); 109 return false; 110 } 111 } 112 } 113 return true; 114 } 115 116 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage 117 #define CAST(Opc) CastTo##Opc() 118 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage 119 #define CHECK_OR_RETURN(Opc, Getter) /* CC-OFFNXT(G.PRE.02) namespace member */ \ 120 if (inst1->GetOpcode() == Opcode::Opc && inst1->CAST(Opc)->Getter() != inst2->CAST(Opc)->Getter()) { \ 121 instCompareMap_.erase(inst1); \ 122 return false; \ 123 } 124 InstPropertiesCompare(Inst * inst1,Inst * inst2)125 bool InstPropertiesCompare(Inst *inst1, Inst *inst2) 126 { 127 CHECK_OR_RETURN(Cast, GetOperandsType) 128 CHECK_OR_RETURN(Cmp, GetOperandsType) 129 130 CHECK_OR_RETURN(Compare, GetCc) 131 CHECK_OR_RETURN(Compare, GetOperandsType) 132 133 CHECK_OR_RETURN(If, GetCc) 134 CHECK_OR_RETURN(If, GetOperandsType) 135 136 CHECK_OR_RETURN(IfImm, GetCc) 137 CHECK_OR_RETURN(IfImm, GetImm) 138 CHECK_OR_RETURN(IfImm, GetOperandsType) 139 140 CHECK_OR_RETURN(Select, GetCc) 141 CHECK_OR_RETURN(Select, GetOperandsType) 142 143 CHECK_OR_RETURN(SelectImm, GetCc) 144 CHECK_OR_RETURN(SelectImm, GetImm) 145 CHECK_OR_RETURN(SelectImm, GetOperandsType) 146 147 CHECK_OR_RETURN(LoadArrayI, GetImm) 148 CHECK_OR_RETURN(LoadArrayPairI, GetImm) 149 CHECK_OR_RETURN(LoadPairPart, GetImm) 150 CHECK_OR_RETURN(StoreArrayI, GetImm) 151 CHECK_OR_RETURN(StoreArrayPairI, GetImm) 152 CHECK_OR_RETURN(LoadArrayPair, GetImm) 153 CHECK_OR_RETURN(StoreArrayPair, GetImm) 154 CHECK_OR_RETURN(BoundsCheckI, GetImm) 155 CHECK_OR_RETURN(ReturnI, GetImm) 156 CHECK_OR_RETURN(AddI, GetImm) 157 CHECK_OR_RETURN(SubI, GetImm) 158 CHECK_OR_RETURN(ShlI, GetImm) 159 CHECK_OR_RETURN(ShrI, GetImm) 160 CHECK_OR_RETURN(AShrI, GetImm) 161 CHECK_OR_RETURN(AndI, GetImm) 162 CHECK_OR_RETURN(OrI, GetImm) 163 CHECK_OR_RETURN(XorI, GetImm) 164 165 return true; 166 } 167 InstAdditionalPropertiesCompare(Inst * inst1,Inst * inst2)168 bool InstAdditionalPropertiesCompare(Inst *inst1, Inst *inst2) 169 { 170 CHECK_OR_RETURN(LoadArray, GetNeedBarrier) 171 CHECK_OR_RETURN(LoadArrayPair, GetNeedBarrier) 172 CHECK_OR_RETURN(StoreArray, GetNeedBarrier) 173 CHECK_OR_RETURN(StoreArrayPair, GetNeedBarrier) 174 CHECK_OR_RETURN(LoadArrayI, GetNeedBarrier) 175 CHECK_OR_RETURN(LoadArrayPairI, GetNeedBarrier) 176 CHECK_OR_RETURN(StoreArrayI, GetNeedBarrier) 177 CHECK_OR_RETURN(StoreArrayPairI, GetNeedBarrier) 178 CHECK_OR_RETURN(LoadStatic, GetNeedBarrier) 179 CHECK_OR_RETURN(StoreStatic, GetNeedBarrier) 180 CHECK_OR_RETURN(LoadObject, GetNeedBarrier) 181 CHECK_OR_RETURN(StoreObject, GetNeedBarrier) 182 CHECK_OR_RETURN(LoadStatic, GetVolatile) 183 CHECK_OR_RETURN(StoreStatic, GetVolatile) 184 CHECK_OR_RETURN(LoadObject, GetVolatile) 185 CHECK_OR_RETURN(StoreObject, GetVolatile) 186 CHECK_OR_RETURN(NewObject, GetNeedBarrier) 187 CHECK_OR_RETURN(NewArray, GetNeedBarrier) 188 CHECK_OR_RETURN(CheckCast, GetNeedBarrier) 189 CHECK_OR_RETURN(IsInstance, GetNeedBarrier) 190 CHECK_OR_RETURN(LoadString, GetNeedBarrier) 191 CHECK_OR_RETURN(LoadConstArray, GetNeedBarrier) 192 CHECK_OR_RETURN(LoadType, GetNeedBarrier) 193 194 CHECK_OR_RETURN(CallStatic, IsInlined) 195 CHECK_OR_RETURN(CallVirtual, IsInlined) 196 197 CHECK_OR_RETURN(LoadArray, IsArray) 198 CHECK_OR_RETURN(LenArray, IsArray) 199 200 CHECK_OR_RETURN(Deoptimize, GetDeoptimizeType) 201 CHECK_OR_RETURN(DeoptimizeIf, GetDeoptimizeType) 202 203 CHECK_OR_RETURN(CompareAnyType, GetAnyType) 204 CHECK_OR_RETURN(CastValueToAnyType, GetAnyType) 205 CHECK_OR_RETURN(CastAnyTypeValue, GetAnyType) 206 CHECK_OR_RETURN(AnyTypeCheck, GetAnyType) 207 208 CHECK_OR_RETURN(HclassCheck, GetCheckIsFunction) 209 CHECK_OR_RETURN(HclassCheck, GetCheckFunctionIsNotClassConstructor) 210 211 // Those below can fail because unit test Graph don't have proper Runtime links 212 // CHECK_OR_RETURN(Intrinsic, GetEntrypointId) 213 // CHECK_OR_RETURN(CallStatic, GetCallMethodId) 214 // CHECK_OR_RETURN(CallVirtual, GetCallMethodId) 215 216 // CHECK_OR_RETURN(InitClass, GetTypeId) 217 // CHECK_OR_RETURN(LoadAndInitClass, GetTypeId) 218 // CHECK_OR_RETURN(LoadStatic, GetTypeId) 219 // CHECK_OR_RETURN(StoreStatic, GetTypeId) 220 // CHECK_OR_RETURN(LoadObject, GetTypeId) 221 // CHECK_OR_RETURN(StoreObject, GetTypeId) 222 // CHECK_OR_RETURN(NewObject, GetTypeId) 223 // CHECK_OR_RETURN(InitObject, GetTypeId) 224 // CHECK_OR_RETURN(NewArray, GetTypeId) 225 // CHECK_OR_RETURN(LoadConstArray, GetTypeId) 226 // CHECK_OR_RETURN(CheckCast, GetTypeId) 227 // CHECK_OR_RETURN(IsInstance, GetTypeId) 228 // CHECK_OR_RETURN(LoadString, GetTypeId) 229 // CHECK_OR_RETURN(LoadType, GetTypeId) 230 231 return true; 232 } 233 #undef CHECK_OR_RETURN 234 #undef CAST 235 InstSaveStateCompare(Inst * inst1,Inst * inst2)236 bool InstSaveStateCompare(Inst *inst1, Inst *inst2) 237 { 238 auto *svSt1 = static_cast<SaveStateInst *>(inst1); 239 auto *svSt2 = static_cast<SaveStateInst *>(inst2); 240 if (svSt1->GetImmediatesCount() != svSt2->GetImmediatesCount()) { 241 instCompareMap_.erase(inst1); 242 return false; 243 } 244 245 std::vector<VirtualRegister::ValueType> regs1; 246 std::vector<VirtualRegister::ValueType> regs2; 247 regs1.reserve(svSt1->GetInputsCount()); 248 regs2.reserve(svSt2->GetInputsCount()); 249 for (size_t i {0}; i < svSt1->GetInputsCount(); ++i) { 250 regs1.emplace_back(svSt1->GetVirtualRegister(i).Value()); 251 regs2.emplace_back(svSt2->GetVirtualRegister(i).Value()); 252 } 253 std::sort(regs1.begin(), regs1.end()); 254 std::sort(regs2.begin(), regs2.end()); 255 if (regs1 != regs2) { 256 instCompareMap_.erase(inst1); 257 return false; 258 } 259 if (svSt1->GetImmediatesCount() != 0) { 260 auto eqLambda = [](SaveStateImm i1, SaveStateImm i2) { 261 return i1.value == i2.value && i1.vreg == i2.vreg && i1.vregType == i2.vregType && i1.type == i2.type; 262 }; 263 if (!std::equal(svSt1->GetImmediates()->begin(), svSt1->GetImmediates()->end(), 264 svSt2->GetImmediates()->begin(), eqLambda)) { 265 instCompareMap_.erase(inst1); 266 return false; 267 } 268 } 269 return true; 270 } 271 272 private: CompareCommon(Inst * inst1,Inst * inst2)273 bool CompareCommon(Inst *inst1, Inst *inst2) 274 { 275 if (auto it = instCompareMap_.insert({inst1, inst2}); !it.second) { 276 if (inst2 == it.first->second) { 277 return true; 278 } 279 instCompareMap_.erase(inst1); 280 return false; 281 } 282 283 if (!InstInitialCompare(inst1, inst2)) { 284 return false; 285 } 286 287 if (!InstPropertiesCompare(inst1, inst2)) { 288 return false; 289 } 290 291 if (!InstAdditionalPropertiesCompare(inst1, inst2)) { 292 return false; 293 } 294 return true; 295 } 296 297 public: Compare(Inst * inst1,Inst * inst2)298 bool Compare(Inst *inst1, Inst *inst2) 299 { 300 if (!CompareCommon(inst1, inst2)) { 301 return false; 302 } 303 304 if (inst1->GetOpcode() == Opcode::Constant) { 305 auto c1 = inst1->CastToConstant(); 306 auto c2 = inst2->CastToConstant(); 307 bool same = false; 308 switch (inst1->GetType()) { 309 case DataType::FLOAT32: 310 case DataType::INT32: 311 same = static_cast<uint32_t>(c1->GetRawValue()) == static_cast<uint32_t>(c2->GetRawValue()); 312 break; 313 default: 314 same = c1->GetRawValue() == c2->GetRawValue(); 315 break; 316 } 317 if (!same) { 318 instCompareMap_.erase(inst1); 319 return false; 320 } 321 } 322 if (inst1->GetOpcode() == Opcode::Cmp && IsFloatType(inst1->GetInput(0).GetInst()->GetType())) { 323 auto cmp1 = static_cast<CmpInst *>(inst1); 324 auto cmp2 = static_cast<CmpInst *>(inst2); 325 if (cmp1->IsFcmpg() != cmp2->IsFcmpg()) { 326 instCompareMap_.erase(inst1); 327 return false; 328 } 329 } 330 for (size_t i = 0; i < inst2->GetInputsCount(); i++) { 331 if (inst1->GetInputType(i) != inst2->GetInputType(i)) { 332 instCompareMap_.erase(inst1); 333 return false; 334 } 335 } 336 if (inst1->IsSaveState()) { 337 return InstSaveStateCompare(inst1, inst2); 338 } 339 return true; 340 } 341 342 private: 343 std::unordered_map<Inst *, Inst *> instCompareMap_; 344 std::unordered_map<BasicBlock *, BasicBlock *> bbMap_; 345 }; 346 } // namespace ark::compiler 347 348 #endif // COMPILER_TESTS_GRAPH_COMPARATOR_H 349