• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_TESTS_GRAPH_COMPARATOR_H
17 #define COMPILER_TESTS_GRAPH_COMPARATOR_H
18 
19 #include "optimizer/ir/ir_constructor.h"
20 #include "optimizer/analysis/rpo.h"
21 
22 namespace ark::compiler {
23 
24 class GraphComparator {
25 public:
Compare(Graph * graph1,Graph * graph2)26     bool Compare(Graph *graph1, Graph *graph2)
27     {
28         graph1->InvalidateAnalysis<Rpo>();
29         graph2->InvalidateAnalysis<Rpo>();
30         if (graph1->GetBlocksRPO().size() != graph2->GetBlocksRPO().size()) {
31             std::cerr << "Different number of blocks\n";
32             return false;
33         }
34         for (auto it1 = graph1->GetBlocksRPO().begin(), it2 = graph2->GetBlocksRPO().begin();
35              it1 != graph1->GetBlocksRPO().end(); it1++, it2++) {
36             auto it = bbMap_.insert({*it1, *it2});
37             if (!it.second) {
38                 return false;
39             }
40         }
41         return std::equal(graph1->GetBlocksRPO().begin(), graph1->GetBlocksRPO().end(), graph2->GetBlocksRPO().begin(),
42                           graph2->GetBlocksRPO().end(), [this](auto bb1, auto bb2) { return Compare(bb1, bb2); });
43     }
44 
Compare(BasicBlock * block1,BasicBlock * block2)45     bool Compare(BasicBlock *block1, BasicBlock *block2)
46     {
47         if (block1->GetPredsBlocks().size() != block2->GetPredsBlocks().size()) {
48             std::cerr << "Different number of preds blocks\n";
49             block1->Dump(&std::cerr);
50             block2->Dump(&std::cerr);
51             return false;
52         }
53         if (block1->GetSuccsBlocks().size() != block2->GetSuccsBlocks().size()) {
54             std::cerr << "Different number of succs blocks\n";
55             block1->Dump(&std::cerr);
56             block2->Dump(&std::cerr);
57             return false;
58         }
59         auto instCmp = [this](auto inst1, auto inst2) {
60             ASSERT(inst2 != nullptr);
61             bool t = Compare(inst1, inst2);
62             if (!t) {
63                 std::cerr << "Different instructions:\n";
64                 inst1->Dump(&std::cerr);
65                 inst2->Dump(&std::cerr);
66             }
67             return t;
68         };
69         return std::equal(block1->AllInsts().begin(), block1->AllInsts().end(), block2->AllInsts().begin(),
70                           block2->AllInsts().end(), instCmp);
71     }
72 
InstInitialCompare(Inst * inst1,Inst * inst2)73     bool InstInitialCompare(Inst *inst1, Inst *inst2)
74     {
75         if (inst1->GetOpcode() != inst2->GetOpcode() || inst1->GetType() != inst2->GetType() ||
76             inst1->GetInputsCount() != inst2->GetInputsCount()) {
77             instCompareMap_.erase(inst1);
78             return false;
79         }
80         if (inst1->GetFlagsMask() != inst2->GetFlagsMask()) {
81             instCompareMap_.erase(inst1);
82             return false;
83         }
84         if (inst1->GetOpcode() != Opcode::Phi) {
85             auto inst1Begin = inst1->GetInputs().begin();
86             auto inst1End = inst1->GetInputs().end();
87             auto inst2Begin = inst2->GetInputs().begin();
88             auto eqLambda = [this](Input input1, Input input2) { return Compare(input1.GetInst(), input2.GetInst()); };
89             if (!std::equal(inst1Begin, inst1End, inst2Begin, eqLambda)) {
90                 instCompareMap_.erase(inst1);
91                 return false;
92             }
93         } else {
94             if (inst1->GetInputsCount() != inst2->GetInputsCount()) {
95                 instCompareMap_.erase(inst1);
96                 return false;
97             }
98             for (size_t index1 = 0; index1 < inst1->GetInputsCount(); index1++) {
99                 auto input1 = inst1->GetInput(index1).GetInst();
100                 auto bb1 = inst1->CastToPhi()->GetPhiInputBb(index1);
101                 if (bbMap_.count(bb1) == 0) {
102                     instCompareMap_.erase(inst1);
103                     return false;
104                 }
105                 auto bb2 = bbMap_.at(bb1);
106                 auto input2 = inst2->CastToPhi()->GetPhiInput(bb2);
107                 if (!Compare(input1, input2)) {
108                     instCompareMap_.erase(inst1);
109                     return false;
110                 }
111             }
112         }
113         return true;
114     }
115 
116 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
117 #define CAST(Opc) CastTo##Opc()
118 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
119 #define CHECK_OR_RETURN(Opc, Getter) /* CC-OFFNXT(G.PRE.02) namespace member */                          \
120     if (inst1->GetOpcode() == Opcode::Opc && inst1->CAST(Opc)->Getter() != inst2->CAST(Opc)->Getter()) { \
121         instCompareMap_.erase(inst1);                                                                    \
122         return false;                                                                                    \
123     }
124 
InstPropertiesCompare(Inst * inst1,Inst * inst2)125     bool InstPropertiesCompare(Inst *inst1, Inst *inst2)
126     {
127         CHECK_OR_RETURN(Cast, GetOperandsType)
128         CHECK_OR_RETURN(Cmp, GetOperandsType)
129 
130         CHECK_OR_RETURN(Compare, GetCc)
131         CHECK_OR_RETURN(Compare, GetOperandsType)
132 
133         CHECK_OR_RETURN(If, GetCc)
134         CHECK_OR_RETURN(If, GetOperandsType)
135 
136         CHECK_OR_RETURN(IfImm, GetCc)
137         CHECK_OR_RETURN(IfImm, GetImm)
138         CHECK_OR_RETURN(IfImm, GetOperandsType)
139 
140         CHECK_OR_RETURN(Select, GetCc)
141         CHECK_OR_RETURN(Select, GetOperandsType)
142 
143         CHECK_OR_RETURN(SelectImm, GetCc)
144         CHECK_OR_RETURN(SelectImm, GetImm)
145         CHECK_OR_RETURN(SelectImm, GetOperandsType)
146 
147         CHECK_OR_RETURN(LoadArrayI, GetImm)
148         CHECK_OR_RETURN(LoadArrayPairI, GetImm)
149         CHECK_OR_RETURN(LoadPairPart, GetImm)
150         CHECK_OR_RETURN(StoreArrayI, GetImm)
151         CHECK_OR_RETURN(StoreArrayPairI, GetImm)
152         CHECK_OR_RETURN(LoadArrayPair, GetImm)
153         CHECK_OR_RETURN(StoreArrayPair, GetImm)
154         CHECK_OR_RETURN(BoundsCheckI, GetImm)
155         CHECK_OR_RETURN(ReturnI, GetImm)
156         CHECK_OR_RETURN(AddI, GetImm)
157         CHECK_OR_RETURN(SubI, GetImm)
158         CHECK_OR_RETURN(ShlI, GetImm)
159         CHECK_OR_RETURN(ShrI, GetImm)
160         CHECK_OR_RETURN(AShrI, GetImm)
161         CHECK_OR_RETURN(AndI, GetImm)
162         CHECK_OR_RETURN(OrI, GetImm)
163         CHECK_OR_RETURN(XorI, GetImm)
164 
165         return true;
166     }
167 
InstAdditionalPropertiesCompare(Inst * inst1,Inst * inst2)168     bool InstAdditionalPropertiesCompare(Inst *inst1, Inst *inst2)
169     {
170         CHECK_OR_RETURN(LoadArray, GetNeedBarrier)
171         CHECK_OR_RETURN(LoadArrayPair, GetNeedBarrier)
172         CHECK_OR_RETURN(StoreArray, GetNeedBarrier)
173         CHECK_OR_RETURN(StoreArrayPair, GetNeedBarrier)
174         CHECK_OR_RETURN(LoadArrayI, GetNeedBarrier)
175         CHECK_OR_RETURN(LoadArrayPairI, GetNeedBarrier)
176         CHECK_OR_RETURN(StoreArrayI, GetNeedBarrier)
177         CHECK_OR_RETURN(StoreArrayPairI, GetNeedBarrier)
178         CHECK_OR_RETURN(LoadStatic, GetNeedBarrier)
179         CHECK_OR_RETURN(StoreStatic, GetNeedBarrier)
180         CHECK_OR_RETURN(LoadObject, GetNeedBarrier)
181         CHECK_OR_RETURN(StoreObject, GetNeedBarrier)
182         CHECK_OR_RETURN(LoadStatic, GetVolatile)
183         CHECK_OR_RETURN(StoreStatic, GetVolatile)
184         CHECK_OR_RETURN(LoadObject, GetVolatile)
185         CHECK_OR_RETURN(StoreObject, GetVolatile)
186         CHECK_OR_RETURN(NewObject, GetNeedBarrier)
187         CHECK_OR_RETURN(NewArray, GetNeedBarrier)
188         CHECK_OR_RETURN(CheckCast, GetNeedBarrier)
189         CHECK_OR_RETURN(IsInstance, GetNeedBarrier)
190         CHECK_OR_RETURN(LoadString, GetNeedBarrier)
191         CHECK_OR_RETURN(LoadConstArray, GetNeedBarrier)
192         CHECK_OR_RETURN(LoadType, GetNeedBarrier)
193 
194         CHECK_OR_RETURN(CallStatic, IsInlined)
195         CHECK_OR_RETURN(CallVirtual, IsInlined)
196 
197         CHECK_OR_RETURN(LoadArray, IsArray)
198         CHECK_OR_RETURN(LenArray, IsArray)
199 
200         CHECK_OR_RETURN(Deoptimize, GetDeoptimizeType)
201         CHECK_OR_RETURN(DeoptimizeIf, GetDeoptimizeType)
202 
203         CHECK_OR_RETURN(CompareAnyType, GetAnyType)
204         CHECK_OR_RETURN(CastValueToAnyType, GetAnyType)
205         CHECK_OR_RETURN(CastAnyTypeValue, GetAnyType)
206         CHECK_OR_RETURN(AnyTypeCheck, GetAnyType)
207 
208         CHECK_OR_RETURN(HclassCheck, GetCheckIsFunction)
209         CHECK_OR_RETURN(HclassCheck, GetCheckFunctionIsNotClassConstructor)
210 
211         // Those below can fail because unit test Graph don't have proper Runtime links
212         // CHECK_OR_RETURN(Intrinsic, GetEntrypointId)
213         // CHECK_OR_RETURN(CallStatic, GetCallMethodId)
214         // CHECK_OR_RETURN(CallVirtual, GetCallMethodId)
215 
216         // CHECK_OR_RETURN(InitClass, GetTypeId)
217         // CHECK_OR_RETURN(LoadAndInitClass, GetTypeId)
218         // CHECK_OR_RETURN(LoadStatic, GetTypeId)
219         // CHECK_OR_RETURN(StoreStatic, GetTypeId)
220         // CHECK_OR_RETURN(LoadObject, GetTypeId)
221         // CHECK_OR_RETURN(StoreObject, GetTypeId)
222         // CHECK_OR_RETURN(NewObject, GetTypeId)
223         // CHECK_OR_RETURN(InitObject, GetTypeId)
224         // CHECK_OR_RETURN(NewArray, GetTypeId)
225         // CHECK_OR_RETURN(LoadConstArray, GetTypeId)
226         // CHECK_OR_RETURN(CheckCast, GetTypeId)
227         // CHECK_OR_RETURN(IsInstance, GetTypeId)
228         // CHECK_OR_RETURN(LoadString, GetTypeId)
229         // CHECK_OR_RETURN(LoadType, GetTypeId)
230 
231         return true;
232     }
233 #undef CHECK_OR_RETURN
234 #undef CAST
235 
InstSaveStateCompare(Inst * inst1,Inst * inst2)236     bool InstSaveStateCompare(Inst *inst1, Inst *inst2)
237     {
238         auto *svSt1 = static_cast<SaveStateInst *>(inst1);
239         auto *svSt2 = static_cast<SaveStateInst *>(inst2);
240         if (svSt1->GetImmediatesCount() != svSt2->GetImmediatesCount()) {
241             instCompareMap_.erase(inst1);
242             return false;
243         }
244 
245         std::vector<VirtualRegister::ValueType> regs1;
246         std::vector<VirtualRegister::ValueType> regs2;
247         regs1.reserve(svSt1->GetInputsCount());
248         regs2.reserve(svSt2->GetInputsCount());
249         for (size_t i {0}; i < svSt1->GetInputsCount(); ++i) {
250             regs1.emplace_back(svSt1->GetVirtualRegister(i).Value());
251             regs2.emplace_back(svSt2->GetVirtualRegister(i).Value());
252         }
253         std::sort(regs1.begin(), regs1.end());
254         std::sort(regs2.begin(), regs2.end());
255         if (regs1 != regs2) {
256             instCompareMap_.erase(inst1);
257             return false;
258         }
259         if (svSt1->GetImmediatesCount() != 0) {
260             auto eqLambda = [](SaveStateImm i1, SaveStateImm i2) {
261                 return i1.value == i2.value && i1.vreg == i2.vreg && i1.vregType == i2.vregType && i1.type == i2.type;
262             };
263             if (!std::equal(svSt1->GetImmediates()->begin(), svSt1->GetImmediates()->end(),
264                             svSt2->GetImmediates()->begin(), eqLambda)) {
265                 instCompareMap_.erase(inst1);
266                 return false;
267             }
268         }
269         return true;
270     }
271 
272 private:
CompareCommon(Inst * inst1,Inst * inst2)273     bool CompareCommon(Inst *inst1, Inst *inst2)
274     {
275         if (auto it = instCompareMap_.insert({inst1, inst2}); !it.second) {
276             if (inst2 == it.first->second) {
277                 return true;
278             }
279             instCompareMap_.erase(inst1);
280             return false;
281         }
282 
283         if (!InstInitialCompare(inst1, inst2)) {
284             return false;
285         }
286 
287         if (!InstPropertiesCompare(inst1, inst2)) {
288             return false;
289         }
290 
291         if (!InstAdditionalPropertiesCompare(inst1, inst2)) {
292             return false;
293         }
294         return true;
295     }
296 
297 public:
Compare(Inst * inst1,Inst * inst2)298     bool Compare(Inst *inst1, Inst *inst2)
299     {
300         if (!CompareCommon(inst1, inst2)) {
301             return false;
302         }
303 
304         if (inst1->GetOpcode() == Opcode::Constant) {
305             auto c1 = inst1->CastToConstant();
306             auto c2 = inst2->CastToConstant();
307             bool same = false;
308             switch (inst1->GetType()) {
309                 case DataType::FLOAT32:
310                 case DataType::INT32:
311                     same = static_cast<uint32_t>(c1->GetRawValue()) == static_cast<uint32_t>(c2->GetRawValue());
312                     break;
313                 default:
314                     same = c1->GetRawValue() == c2->GetRawValue();
315                     break;
316             }
317             if (!same) {
318                 instCompareMap_.erase(inst1);
319                 return false;
320             }
321         }
322         if (inst1->GetOpcode() == Opcode::Cmp && IsFloatType(inst1->GetInput(0).GetInst()->GetType())) {
323             auto cmp1 = static_cast<CmpInst *>(inst1);
324             auto cmp2 = static_cast<CmpInst *>(inst2);
325             if (cmp1->IsFcmpg() != cmp2->IsFcmpg()) {
326                 instCompareMap_.erase(inst1);
327                 return false;
328             }
329         }
330         for (size_t i = 0; i < inst2->GetInputsCount(); i++) {
331             if (inst1->GetInputType(i) != inst2->GetInputType(i)) {
332                 instCompareMap_.erase(inst1);
333                 return false;
334             }
335         }
336         if (inst1->IsSaveState()) {
337             return InstSaveStateCompare(inst1, inst2);
338         }
339         return true;
340     }
341 
342 private:
343     std::unordered_map<Inst *, Inst *> instCompareMap_;
344     std::unordered_map<BasicBlock *, BasicBlock *> bbMap_;
345 };
346 }  // namespace ark::compiler
347 
348 #endif  // COMPILER_TESTS_GRAPH_COMPARATOR_H
349