• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_TESTS_GRAPH_COMPARATOR_H
17 #define COMPILER_TESTS_GRAPH_COMPARATOR_H
18 
19 #include "optimizer/ir/ir_constructor.h"
20 #include "optimizer/analysis/rpo.h"
21 
22 namespace panda::compiler {
23 
24 class GraphComparator {
25 public:
Compare(Graph * graph1,Graph * graph2)26     bool Compare(Graph *graph1, Graph *graph2)
27     {
28         graph1->InvalidateAnalysis<Rpo>();
29         graph2->InvalidateAnalysis<Rpo>();
30         if (graph1->GetBlocksRPO().size() != graph2->GetBlocksRPO().size()) {
31             std::cerr << "Different number of blocks\n";
32             return false;
33         }
34         for (auto it1 = graph1->GetBlocksRPO().begin(), it2 = graph2->GetBlocksRPO().begin();
35              it1 != graph1->GetBlocksRPO().end(); it1++, it2++) {
36             auto it = bbMap_.insert({*it1, *it2});
37             if (!it.second) {
38                 return false;
39             }
40         }
41         return std::equal(graph1->GetBlocksRPO().begin(), graph1->GetBlocksRPO().end(), graph2->GetBlocksRPO().begin(),
42                           graph2->GetBlocksRPO().end(), [this](auto bb1, auto bb2) { return Compare(bb1, bb2); });
43     }
44 
Compare(BasicBlock * block1,BasicBlock * block2)45     bool Compare(BasicBlock *block1, BasicBlock *block2)
46     {
47         if (block1->GetPredsBlocks().size() != block2->GetPredsBlocks().size()) {
48             std::cerr << "Different number of preds blocks\n";
49             block1->Dump(&std::cerr);
50             block2->Dump(&std::cerr);
51             return false;
52         }
53         if (block1->GetSuccsBlocks().size() != block2->GetSuccsBlocks().size()) {
54             std::cerr << "Different number of succs blocks\n";
55             block1->Dump(&std::cerr);
56             block2->Dump(&std::cerr);
57             return false;
58         }
59         auto instCmp = [this](auto inst1, auto inst2) {
60             ASSERT(inst2 != nullptr);
61             bool t = Compare(inst1, inst2);
62             if (!t) {
63                 std::cerr << "Different instructions:\n";
64                 inst1->Dump(&std::cerr);
65                 inst2->Dump(&std::cerr);
66             }
67             return t;
68         };
69         return std::equal(block1->AllInsts().begin(), block1->AllInsts().end(), block2->AllInsts().begin(),
70                           block2->AllInsts().end(), instCmp);
71     }
72 
InstInitialCompare(Inst * inst1,Inst * inst2)73     bool InstInitialCompare(Inst *inst1, Inst *inst2)
74     {
75         if (inst1->GetOpcode() != inst2->GetOpcode() || inst1->GetType() != inst2->GetType() ||
76             inst1->GetInputsCount() != inst2->GetInputsCount()) {
77             instCompareMap_.erase(inst1);
78             return false;
79         }
80         if (inst1->GetFlagsMask() != inst2->GetFlagsMask()) {
81             instCompareMap_.erase(inst1);
82             return false;
83         }
84         if (inst1->GetOpcode() != Opcode::Phi) {
85             auto inst1Begin = inst1->GetInputs().begin();
86             auto inst1End = inst1->GetInputs().end();
87             auto inst2Begin = inst2->GetInputs().begin();
88             auto eqLambda = [this](Input input1, Input input2) { return Compare(input1.GetInst(), input2.GetInst()); };
89             if (!std::equal(inst1Begin, inst1End, inst2Begin, eqLambda)) {
90                 instCompareMap_.erase(inst1);
91                 return false;
92             }
93         } else {
94             if (inst1->GetInputsCount() != inst2->GetInputsCount()) {
95                 instCompareMap_.erase(inst1);
96                 return false;
97             }
98             for (size_t index1 = 0; index1 < inst1->GetInputsCount(); index1++) {
99                 auto input1 = inst1->GetInput(index1).GetInst();
100                 auto bb1 = inst1->CastToPhi()->GetPhiInputBb(index1);
101                 if (bbMap_.count(bb1) == 0) {
102                     instCompareMap_.erase(inst1);
103                     return false;
104                 }
105                 auto bb2 = bbMap_.at(bb1);
106                 auto input2 = inst2->CastToPhi()->GetPhiInput(bb2);
107                 if (!Compare(input1, input2)) {
108                     instCompareMap_.erase(inst1);
109                     return false;
110                 }
111             }
112         }
113         return true;
114     }
115 
116 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
117 #define CAST(Opc) CastTo##Opc()
118 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
119 #define CHECK(Opc, Getter)                                                                               \
120     if (inst1->GetOpcode() == Opcode::Opc && inst1->CAST(Opc)->Getter() != inst2->CAST(Opc)->Getter()) { \
121         instCompareMap_.erase(inst1);                                                                    \
122         return false;                                                                                    \
123     }
124 
InstPropertiesCompare(Inst * inst1,Inst * inst2)125     bool InstPropertiesCompare(Inst *inst1, Inst *inst2)
126     {
127         CHECK(Cast, GetOperandsType)
128         CHECK(Cmp, GetOperandsType)
129 
130         CHECK(Compare, GetCc)
131         CHECK(Compare, GetOperandsType)
132 
133         CHECK(If, GetCc)
134         CHECK(If, GetOperandsType)
135 
136         CHECK(IfImm, GetCc)
137         CHECK(IfImm, GetImm)
138         CHECK(IfImm, GetOperandsType)
139 
140         CHECK(Select, GetCc)
141         CHECK(Select, GetOperandsType)
142 
143         CHECK(SelectImm, GetCc)
144         CHECK(SelectImm, GetImm)
145         CHECK(SelectImm, GetOperandsType)
146 
147         CHECK(LoadArrayI, GetImm)
148         CHECK(LoadArrayPairI, GetImm)
149         CHECK(LoadPairPart, GetImm)
150         CHECK(StoreArrayI, GetImm)
151         CHECK(StoreArrayPairI, GetImm)
152         CHECK(LoadArrayPair, GetImm)
153         CHECK(StoreArrayPair, GetImm)
154         CHECK(BoundsCheckI, GetImm)
155         CHECK(ReturnI, GetImm)
156         CHECK(AddI, GetImm)
157         CHECK(SubI, GetImm)
158         CHECK(ShlI, GetImm)
159         CHECK(ShrI, GetImm)
160         CHECK(AShrI, GetImm)
161         CHECK(AndI, GetImm)
162         CHECK(OrI, GetImm)
163         CHECK(XorI, GetImm)
164 
165         return true;
166     }
167 
InstAdditionalPropertiesCompare(Inst * inst1,Inst * inst2)168     bool InstAdditionalPropertiesCompare(Inst *inst1, Inst *inst2)
169     {
170         CHECK(LoadArray, GetNeedBarrier)
171         CHECK(LoadArrayPair, GetNeedBarrier)
172         CHECK(StoreArray, GetNeedBarrier)
173         CHECK(StoreArrayPair, GetNeedBarrier)
174         CHECK(LoadArrayI, GetNeedBarrier)
175         CHECK(LoadArrayPairI, GetNeedBarrier)
176         CHECK(StoreArrayI, GetNeedBarrier)
177         CHECK(StoreArrayPairI, GetNeedBarrier)
178         CHECK(LoadStatic, GetNeedBarrier)
179         CHECK(StoreStatic, GetNeedBarrier)
180         CHECK(LoadObject, GetNeedBarrier)
181         CHECK(StoreObject, GetNeedBarrier)
182         CHECK(LoadStatic, GetVolatile)
183         CHECK(StoreStatic, GetVolatile)
184         CHECK(LoadObject, GetVolatile)
185         CHECK(StoreObject, GetVolatile)
186         CHECK(NewObject, GetNeedBarrier)
187         CHECK(NewArray, GetNeedBarrier)
188         CHECK(CheckCast, GetNeedBarrier)
189         CHECK(IsInstance, GetNeedBarrier)
190         CHECK(LoadString, GetNeedBarrier)
191         CHECK(LoadConstArray, GetNeedBarrier)
192         CHECK(LoadType, GetNeedBarrier)
193 
194         CHECK(CallStatic, IsInlined)
195         CHECK(CallVirtual, IsInlined)
196 
197         CHECK(LoadArray, IsArray)
198         CHECK(LenArray, IsArray)
199 
200         CHECK(Deoptimize, GetDeoptimizeType)
201         CHECK(DeoptimizeIf, GetDeoptimizeType)
202 
203         CHECK(CompareAnyType, GetAnyType)
204         CHECK(CastValueToAnyType, GetAnyType)
205         CHECK(CastAnyTypeValue, GetAnyType)
206         CHECK(AnyTypeCheck, GetAnyType)
207 
208         CHECK(HclassCheck, GetCheckIsFunction)
209         CHECK(HclassCheck, GetCheckFunctionIsNotClassConstructor)
210 
211         // Those below can fail because unit test Graph don't have proper Runtime links
212         // CHECK(Intrinsic, GetEntrypointId)
213         // CHECK(CallStatic, GetCallMethodId)
214         // CHECK(CallVirtual, GetCallMethodId)
215 
216         // CHECK(InitClass, GetTypeId)
217         // CHECK(LoadAndInitClass, GetTypeId)
218         // CHECK(LoadStatic, GetTypeId)
219         // CHECK(StoreStatic, GetTypeId)
220         // CHECK(LoadObject, GetTypeId)
221         // CHECK(StoreObject, GetTypeId)
222         // CHECK(NewObject, GetTypeId)
223         // CHECK(InitObject, GetTypeId)
224         // CHECK(NewArray, GetTypeId)
225         // CHECK(LoadConstArray, GetTypeId)
226         // CHECK(CheckCast, GetTypeId)
227         // CHECK(IsInstance, GetTypeId)
228         // CHECK(LoadString, GetTypeId)
229         // CHECK(LoadType, GetTypeId)
230 
231         return true;
232     }
233 #undef CHECK
234 #undef CAST
235 
InstSaveStateCompare(Inst * inst1,Inst * inst2)236     bool InstSaveStateCompare(Inst *inst1, Inst *inst2)
237     {
238         auto *svSt1 = static_cast<SaveStateInst *>(inst1);
239         auto *svSt2 = static_cast<SaveStateInst *>(inst2);
240         if (svSt1->GetImmediatesCount() != svSt2->GetImmediatesCount()) {
241             instCompareMap_.erase(inst1);
242             return false;
243         }
244 
245         std::vector<VirtualRegister::ValueType> regs1;
246         std::vector<VirtualRegister::ValueType> regs2;
247         regs1.reserve(svSt1->GetInputsCount());
248         regs2.reserve(svSt2->GetInputsCount());
249         for (size_t i {0}; i < svSt1->GetInputsCount(); ++i) {
250             regs1.emplace_back(svSt1->GetVirtualRegister(i).Value());
251             regs2.emplace_back(svSt2->GetVirtualRegister(i).Value());
252         }
253         std::sort(regs1.begin(), regs1.end());
254         std::sort(regs2.begin(), regs2.end());
255         if (regs1 != regs2) {
256             instCompareMap_.erase(inst1);
257             return false;
258         }
259         if (svSt1->GetImmediatesCount() != 0) {
260             auto eqLambda = [](SaveStateImm i1, SaveStateImm i2) {
261                 return i1.value == i2.value && i1.vreg == i2.vreg && i1.vregType == i2.vregType && i1.type == i2.type;
262             };
263             if (!std::equal(svSt1->GetImmediates()->begin(), svSt1->GetImmediates()->end(),
264                             svSt2->GetImmediates()->begin(), eqLambda)) {
265                 instCompareMap_.erase(inst1);
266                 return false;
267             }
268         }
269         return true;
270     }
271 
Compare(Inst * inst1,Inst * inst2)272     bool Compare(Inst *inst1, Inst *inst2)
273     {
274         if (auto it = instCompareMap_.insert({inst1, inst2}); !it.second) {
275             if (inst2 == it.first->second) {
276                 return true;
277             }
278             instCompareMap_.erase(inst1);
279             return false;
280         }
281 
282         if (!InstInitialCompare(inst1, inst2)) {
283             return false;
284         }
285 
286         if (!InstPropertiesCompare(inst1, inst2)) {
287             return false;
288         }
289 
290         if (!InstAdditionalPropertiesCompare(inst1, inst2)) {
291             return false;
292         }
293 
294         if (inst1->GetOpcode() == Opcode::Constant) {
295             auto c1 = inst1->CastToConstant();
296             auto c2 = inst2->CastToConstant();
297             bool same = false;
298             switch (inst1->GetType()) {
299                 case DataType::FLOAT32:
300                 case DataType::INT32:
301                     same = static_cast<uint32_t>(c1->GetRawValue()) == static_cast<uint32_t>(c2->GetRawValue());
302                     break;
303                 default:
304                     same = c1->GetRawValue() == c2->GetRawValue();
305                     break;
306             }
307             if (!same) {
308                 instCompareMap_.erase(inst1);
309                 return false;
310             }
311         }
312         if (inst1->GetOpcode() == Opcode::Cmp && IsFloatType(inst1->GetInput(0).GetInst()->GetType())) {
313             auto cmp1 = static_cast<CmpInst *>(inst1);
314             auto cmp2 = static_cast<CmpInst *>(inst2);
315             if (cmp1->IsFcmpg() != cmp2->IsFcmpg()) {
316                 instCompareMap_.erase(inst1);
317                 return false;
318             }
319         }
320         for (size_t i = 0; i < inst2->GetInputsCount(); i++) {
321             if (inst1->GetInputType(i) != inst2->GetInputType(i)) {
322                 instCompareMap_.erase(inst1);
323                 return false;
324             }
325         }
326         if (inst1->IsSaveState()) {
327             return InstSaveStateCompare(inst1, inst2);
328         }
329         return true;
330     }
331 
332 private:
333     std::unordered_map<Inst *, Inst *> instCompareMap_;
334     std::unordered_map<BasicBlock *, BasicBlock *> bbMap_;
335 };
336 }  // namespace panda::compiler
337 
338 #endif  // COMPILER_TESTS_GRAPH_COMPARATOR_H
339