• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (c) 2021-2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_TESTS_GRAPH_COMPARATOR_H
17 #define COMPILER_TESTS_GRAPH_COMPARATOR_H
18 
19 #include "optimizer/ir/ir_constructor.h"
20 #include "optimizer/analysis/rpo.h"
21 
22 namespace panda::compiler {
23 
24 class GraphComparator {
25 public:
Compare(Graph * graph1,Graph * graph2)26     bool Compare(Graph *graph1, Graph *graph2)
27     {
28         graph1->InvalidateAnalysis<Rpo>();
29         graph2->InvalidateAnalysis<Rpo>();
30         if (graph1->GetBlocksRPO().size() != graph2->GetBlocksRPO().size()) {
31             std::cerr << "Different number of blocks\n";
32             return false;
33         }
34         for (auto it1 = graph1->GetBlocksRPO().begin(), it2 = graph2->GetBlocksRPO().begin();
35              it1 != graph1->GetBlocksRPO().end(); it1++, it2++) {
36             auto it = bb_map_.insert({*it1, *it2});
37             if (!it.second) {
38                 return false;
39             }
40         }
41         return std::equal(graph1->GetBlocksRPO().begin(), graph1->GetBlocksRPO().end(), graph2->GetBlocksRPO().begin(),
42                           graph2->GetBlocksRPO().end(), [this](auto bb1, auto bb2) { return Compare(bb1, bb2); });
43     }
44 
Compare(BasicBlock * block1,BasicBlock * block2)45     bool Compare(BasicBlock *block1, BasicBlock *block2)
46     {
47         if (block1->GetPredsBlocks().size() != block2->GetPredsBlocks().size()) {
48             std::cerr << "Different number of preds blocks\n";
49             block1->Dump(&std::cerr);
50             block2->Dump(&std::cerr);
51             return false;
52         }
53         if (block1->GetSuccsBlocks().size() != block2->GetSuccsBlocks().size()) {
54             std::cerr << "Different number of succs blocks\n";
55             block1->Dump(&std::cerr);
56             block2->Dump(&std::cerr);
57             return false;
58         }
59         auto inst_cmp = [this](auto inst1, auto inst2) {
60             ASSERT(inst2 != nullptr);
61             bool t = Compare(inst1, inst2);
62             if (!t) {
63                 std::cerr << "Different instructions:\n";
64                 inst1->Dump(&std::cerr);
65                 inst2->Dump(&std::cerr);
66             }
67             return t;
68         };
69         return std::equal(block1->AllInsts().begin(), block1->AllInsts().end(), block2->AllInsts().begin(),
70                           block2->AllInsts().end(), inst_cmp);
71     }
72 
Compare(Inst * inst1,Inst * inst2)73     bool Compare(Inst *inst1, Inst *inst2)
74     {
75         if (auto it = inst_compare_map_.insert({inst1, inst2}); !it.second) {
76             if (inst2 == it.first->second) {
77                 return true;
78             }
79             inst_compare_map_.erase(inst1);
80             return false;
81         }
82 
83         if (inst1->GetOpcode() != inst2->GetOpcode() || inst1->GetType() != inst2->GetType() ||
84             inst1->GetInputsCount() != inst2->GetInputsCount()) {
85                 inst_compare_map_.erase(inst1);
86                 return false;
87         }
88 
89         bool result = (inst1->GetOpcode() != Opcode::Phi) ?
90                       CompareNonPhiInputs(inst1, inst2) : ComparePhiInputs(inst1, inst2);
91         if (!result) {
92             inst_compare_map_.erase(inst1);
93             return false;
94         }
95 
96 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
97 #define CAST(Opc) CastTo##Opc()
98 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
99 #define CHECK(Opc, Getter)                                                                               \
100     if (inst1->GetOpcode() == Opcode::Opc && inst1->CAST(Opc)->Getter() != inst2->CAST(Opc)->Getter()) { \
101         inst_compare_map_.erase(inst1);                                                                  \
102         return false;                                                                                    \
103     }
104         CHECK(Cast, GetOperandsType)
105         CHECK(Cmp, GetOperandsType)
106 
107         CHECK(Compare, GetCc)
108         CHECK(Compare, GetOperandsType)
109 
110         CHECK(If, GetCc)
111         CHECK(If, GetOperandsType)
112 
113         CHECK(IfImm, GetCc)
114         CHECK(IfImm, GetImm)
115         CHECK(IfImm, GetOperandsType)
116 
117         CHECK(Select, GetCc)
118         CHECK(Select, GetOperandsType)
119 
120         CHECK(SelectImm, GetCc)
121         CHECK(SelectImm, GetImm)
122         CHECK(SelectImm, GetOperandsType)
123 
124         CHECK(LoadArrayI, GetImm)
125         CHECK(LoadArrayPairI, GetImm)
126         CHECK(LoadPairPart, GetImm)
127         CHECK(StoreArrayI, GetImm)
128         CHECK(StoreArrayPairI, GetImm)
129         CHECK(BoundsCheckI, GetImm)
130         CHECK(ReturnI, GetImm)
131         CHECK(AddI, GetImm)
132         CHECK(SubI, GetImm)
133         CHECK(ShlI, GetImm)
134         CHECK(ShrI, GetImm)
135         CHECK(AShrI, GetImm)
136         CHECK(AndI, GetImm)
137         CHECK(OrI, GetImm)
138         CHECK(XorI, GetImm)
139 
140         CHECK(LoadArray, GetNeedBarrier)
141         CHECK(LoadArrayPair, GetNeedBarrier)
142         CHECK(StoreArray, GetNeedBarrier)
143         CHECK(StoreArrayPair, GetNeedBarrier)
144         CHECK(LoadArrayI, GetNeedBarrier)
145         CHECK(LoadArrayPairI, GetNeedBarrier)
146         CHECK(StoreArrayI, GetNeedBarrier)
147         CHECK(StoreArrayPairI, GetNeedBarrier)
148         CHECK(LoadStatic, GetNeedBarrier)
149         CHECK(StoreStatic, GetNeedBarrier)
150         CHECK(LoadObject, GetNeedBarrier)
151         CHECK(StoreObject, GetNeedBarrier)
152         CHECK(LoadStatic, GetVolatile)
153         CHECK(StoreStatic, GetVolatile)
154         CHECK(LoadObject, GetVolatile)
155         CHECK(StoreObject, GetVolatile)
156         CHECK(NewObject, GetNeedBarrier)
157         CHECK(NewArray, GetNeedBarrier)
158         CHECK(CheckCast, GetNeedBarrier)
159         CHECK(IsInstance, GetNeedBarrier)
160         CHECK(LoadString, GetNeedBarrier)
161         CHECK(LoadConstArray, GetNeedBarrier)
162         CHECK(LoadType, GetNeedBarrier)
163 
164         CHECK(CallStatic, IsInlined)
165         CHECK(CallVirtual, IsInlined)
166 
167         CHECK(LoadArray, IsArray)
168         CHECK(LenArray, IsArray)
169 
170         CHECK(Deoptimize, GetDeoptimizeType)
171         CHECK(DeoptimizeIf, GetDeoptimizeType)
172 
173         CHECK(CompareAnyType, GetAnyType)
174         CHECK(CastValueToAnyType, GetAnyType)
175         CHECK(CastAnyTypeValue, GetAnyType)
176         CHECK(AnyTypeCheck, GetAnyType)
177 
178         // Those below can fail because unit test Graph don't have proper Runtime links
179         // CHECK(Intrinsic, GetEntrypointId)
180         // CHECK(CallStatic, GetCallMethodId)
181         // CHECK(CallVirtual, GetCallMethodId)
182 
183         // CHECK(InitClass, GetTypeId)
184         // CHECK(LoadAndInitClass, GetTypeId)
185         // CHECK(LoadStatic, GetTypeId)
186         // CHECK(StoreStatic, GetTypeId)
187         // CHECK(LoadObject, GetTypeId)
188         // CHECK(StoreObject, GetTypeId)
189         // CHECK(NewObject, GetTypeId)
190         // CHECK(InitObject, GetTypeId)
191         // CHECK(NewArray, GetTypeId)
192         // CHECK(LoadConstArray, GetTypeId)
193         // CHECK(CheckCast, GetTypeId)
194         // CHECK(IsInstance, GetTypeId)
195         // CHECK(LoadString, GetTypeId)
196         // CHECK(LoadType, GetTypeId)
197 #undef CHECK
198 #undef CAST
199 
200         if (inst1->GetOpcode() == Opcode::Constant) {
201             if (!CompareConstantInst(inst1, inst2)) {
202                 inst_compare_map_.erase(inst1);
203                 return false;
204             }
205         }
206 
207         if (inst1->GetOpcode() == Opcode::Cmp && IsFloatType(inst1->GetInput(0).GetInst()->GetType())) {
208             auto cmp1 = static_cast<CmpInst *>(inst1);
209             auto cmp2 = static_cast<CmpInst *>(inst2);
210             if (cmp1->IsFcmpg() != cmp2->IsFcmpg()) {
211                 inst_compare_map_.erase(inst1);
212                 return false;
213             }
214         }
215 
216         if (!CompareInputTypes(inst1, inst2) || !CompareSaveStateInst(inst1, inst2)) {
217                 inst_compare_map_.erase(inst1);
218                 return false;
219         }
220 
221         return true;
222     }
223 private:
224     std::unordered_map<Inst *, Inst *> inst_compare_map_;
225     std::unordered_map<BasicBlock *, BasicBlock *> bb_map_;
226 
CompareNonPhiInputs(Inst * inst1,Inst * inst2)227     bool CompareNonPhiInputs(Inst *inst1, Inst *inst2)
228     {
229         auto inst1_begin = inst1->GetInputs().begin();
230         auto inst1_end = inst1->GetInputs().end();
231         auto inst2_begin = inst2->GetInputs().begin();
232         auto eq_lambda = [this](Input input1, Input input2) {
233             return Compare(input1.GetInst(), input2.GetInst());
234         };
235         return std::equal(inst1_begin, inst1_end, inst2_begin, eq_lambda);
236     }
237 
ComparePhiInputs(Inst * inst1,Inst * inst2)238     bool ComparePhiInputs(Inst *inst1, Inst *inst2)
239     {
240         if (inst1->GetInputsCount() != inst2->GetInputsCount()) {
241             return false;
242         }
243 
244         for (size_t index1 = 0; index1 < inst1->GetInputsCount(); index1++) {
245             auto input1 = inst1->GetInput(index1).GetInst();
246             auto bb1 = inst1->CastToPhi()->GetPhiInputBb(index1);
247             if (bb_map_.count(bb1) == 0) {
248                 return false;
249             }
250             auto bb2 = bb_map_.at(bb1);
251             auto input2 = inst2->CastToPhi()->GetPhiInput(bb2);
252             if (!Compare(input1, input2)) {
253                 return false;
254             }
255         }
256         return true;
257     }
258 
CompareConstantInst(Inst * inst1,Inst * inst2)259     bool CompareConstantInst(Inst *inst1, Inst *inst2)
260     {
261         auto c1 = inst1->CastToConstant();
262         auto c2 = inst2->CastToConstant();
263         bool same = false;
264         switch (inst1->GetType()) {
265             case DataType::FLOAT32:
266             case DataType::INT32:
267                 same = static_cast<uint32_t>(c1->GetRawValue()) == static_cast<uint32_t>(c2->GetRawValue());
268                 break;
269             default:
270                 same = c1->GetRawValue() == c2->GetRawValue();
271                 break;
272         }
273         return same;
274     }
275 
CompareInputTypes(Inst * inst1,Inst * inst2)276     bool CompareInputTypes(Inst *inst1, Inst *inst2)
277     {
278         for (size_t i = 0; i < inst2->GetInputsCount(); i++) {
279             if (inst1->GetInputType(i) != inst2->GetInputType(i)) {
280                 return false;
281             }
282         }
283         return true;
284     }
285 
CompareSaveStateInst(Inst * inst1,Inst * inst2)286     bool CompareSaveStateInst(Inst *inst1, Inst *inst2)
287     {
288         if (!inst1->IsSaveState()) {
289             return true;
290         }
291 
292         auto *sv_st1 = static_cast<SaveStateInst *>(inst1);
293         auto *sv_st2 = static_cast<SaveStateInst *>(inst2);
294         if (sv_st1->GetImmediatesCount() != sv_st2->GetImmediatesCount()) {
295             return false;
296         }
297 
298         std::vector<VirtualRegister::ValueType> regs1;
299         std::vector<VirtualRegister::ValueType> regs2;
300         regs1.reserve(sv_st1->GetInputsCount());
301         regs2.reserve(sv_st2->GetInputsCount());
302         for (size_t i {0}; i < sv_st1->GetInputsCount(); ++i) {
303             regs1.emplace_back(sv_st1->GetVirtualRegister(i).Value());
304             regs2.emplace_back(sv_st2->GetVirtualRegister(i).Value());
305         }
306         std::sort(regs1.begin(), regs1.end());
307         std::sort(regs2.begin(), regs2.end());
308         if (regs1 != regs2) {
309             return false;
310         }
311         if (sv_st1->GetImmediatesCount() != 0) {
312             auto eq_lambda = [](SaveStateImm i1, SaveStateImm i2) {
313                 return i1.value == i2.value && i1.vreg == i2.vreg && i1.is_acc == i2.is_acc;
314             };
315             if (!std::equal(sv_st1->GetImmediates()->begin(), sv_st1->GetImmediates()->end(),
316                             sv_st2->GetImmediates()->begin(), eq_lambda)) {
317                 return false;
318             }
319         }
320         return true;
321     }
322 };
323 }  // namespace panda::compiler
324 
325 #endif  // COMPILER_TESTS_GRAPH_COMPARATOR_H
326