• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (c) 2021-2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_TESTS_GRAPH_COMPARATOR_H
17 #define COMPILER_TESTS_GRAPH_COMPARATOR_H
18 
19 #include "optimizer/ir/ir_constructor.h"
20 #include "optimizer/analysis/rpo.h"
21 
22 namespace panda::compiler {
23 
24 class GraphComparator {
25 public:
Compare(Graph * graph1,Graph * graph2)26     bool Compare(Graph *graph1, Graph *graph2)
27     {
28         graph1->InvalidateAnalysis<Rpo>();
29         graph2->InvalidateAnalysis<Rpo>();
30         if (graph1->GetBlocksRPO().size() != graph2->GetBlocksRPO().size()) {
31             std::cerr << "Different number of blocks\n";
32             return false;
33         }
34         for (auto it1 = graph1->GetBlocksRPO().begin(), it2 = graph2->GetBlocksRPO().begin();
35              it1 != graph1->GetBlocksRPO().end(); it1++, it2++) {
36             auto it = bb_map_.insert({*it1, *it2});
37             if (!it.second) {
38                 return false;
39             }
40         }
41         return std::equal(graph1->GetBlocksRPO().begin(), graph1->GetBlocksRPO().end(), graph2->GetBlocksRPO().begin(),
42                           graph2->GetBlocksRPO().end(), [this](auto bb1, auto bb2) { return Compare(bb1, bb2); });
43     }
44 
Compare(BasicBlock * block1,BasicBlock * block2)45     bool Compare(BasicBlock *block1, BasicBlock *block2)
46     {
47         if (block1->GetPredsBlocks().size() != block2->GetPredsBlocks().size()) {
48             std::cerr << "Different number of preds blocks\n";
49             block1->Dump(&std::cerr);
50             block2->Dump(&std::cerr);
51             return false;
52         }
53         if (block1->GetSuccsBlocks().size() != block2->GetSuccsBlocks().size()) {
54             std::cerr << "Different number of succs blocks\n";
55             block1->Dump(&std::cerr);
56             block2->Dump(&std::cerr);
57             return false;
58         }
59         auto inst_cmp = [this](auto inst1, auto inst2) {
60             ASSERT(inst2 != nullptr);
61             bool t = Compare(inst1, inst2);
62             if (!t) {
63                 std::cerr << "Different instructions:\n";
64                 inst1->Dump(&std::cerr);
65                 inst2->Dump(&std::cerr);
66             }
67             return t;
68         };
69         return std::equal(block1->AllInsts().begin(), block1->AllInsts().end(), block2->AllInsts().begin(),
70                           block2->AllInsts().end(), inst_cmp);
71     }
72 
Compare(Inst * inst1,Inst * inst2)73     bool Compare(Inst *inst1, Inst *inst2)
74     {
75         if (auto it = inst_compare_map_.insert({inst1, inst2}); !it.second) {
76             if (inst2 == it.first->second) {
77                 return true;
78             }
79             inst_compare_map_.erase(inst1);
80             return false;
81         }
82 
83         if (inst1->GetOpcode() != inst2->GetOpcode() || inst1->GetType() != inst2->GetType() ||
84             inst1->GetInputsCount() != inst2->GetInputsCount()) {
85             inst_compare_map_.erase(inst1);
86             return false;
87         }
88         if (inst1->GetOpcode() != Opcode::Phi) {
89             auto inst1_begin = inst1->GetInputs().begin();
90             auto inst1_end = inst1->GetInputs().end();
91             auto inst2_begin = inst2->GetInputs().begin();
92             auto eq_lambda = [this](Input input1, Input input2) { return Compare(input1.GetInst(), input2.GetInst()); };
93             if (!std::equal(inst1_begin, inst1_end, inst2_begin, eq_lambda)) {
94                 inst_compare_map_.erase(inst1);
95                 return false;
96             }
97         } else {
98             if (inst1->GetInputsCount() != inst2->GetInputsCount()) {
99                 inst_compare_map_.erase(inst1);
100                 return false;
101             }
102             for (size_t index1 = 0; index1 < inst1->GetInputsCount(); index1++) {
103                 auto input1 = inst1->GetInput(index1).GetInst();
104                 auto bb1 = inst1->CastToPhi()->GetPhiInputBb(index1);
105                 if (bb_map_.count(bb1) == 0) {
106                     inst_compare_map_.erase(inst1);
107                     return false;
108                 }
109                 auto bb2 = bb_map_.at(bb1);
110                 auto input2 = inst2->CastToPhi()->GetPhiInput(bb2);
111                 if (!Compare(input1, input2)) {
112                     inst_compare_map_.erase(inst1);
113                     return false;
114                 }
115             }
116         }
117 
118 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
119 #define CAST(Opc) CastTo##Opc()
120 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
121 #define CHECK(Opc, Getter)                                                                               \
122     if (inst1->GetOpcode() == Opcode::Opc && inst1->CAST(Opc)->Getter() != inst2->CAST(Opc)->Getter()) { \
123         inst_compare_map_.erase(inst1);                                                                  \
124         return false;                                                                                    \
125     }
126         CHECK(Cast, GetOperandsType)
127         CHECK(Cmp, GetOperandsType)
128 
129         CHECK(Compare, GetCc)
130         CHECK(Compare, GetOperandsType)
131 
132         CHECK(If, GetCc)
133         CHECK(If, GetOperandsType)
134 
135         CHECK(IfImm, GetCc)
136         CHECK(IfImm, GetImm)
137         CHECK(IfImm, GetOperandsType)
138 
139         CHECK(Select, GetCc)
140         CHECK(Select, GetOperandsType)
141 
142         CHECK(SelectImm, GetCc)
143         CHECK(SelectImm, GetImm)
144         CHECK(SelectImm, GetOperandsType)
145 
146         CHECK(LoadArrayI, GetImm)
147         CHECK(LoadArrayPairI, GetImm)
148         CHECK(LoadPairPart, GetImm)
149         CHECK(StoreArrayI, GetImm)
150         CHECK(StoreArrayPairI, GetImm)
151         CHECK(BoundsCheckI, GetImm)
152         CHECK(ReturnI, GetImm)
153         CHECK(AddI, GetImm)
154         CHECK(SubI, GetImm)
155         CHECK(ShlI, GetImm)
156         CHECK(ShrI, GetImm)
157         CHECK(AShrI, GetImm)
158         CHECK(AndI, GetImm)
159         CHECK(OrI, GetImm)
160         CHECK(XorI, GetImm)
161 
162         CHECK(LoadArray, GetNeedBarrier)
163         CHECK(LoadArrayPair, GetNeedBarrier)
164         CHECK(StoreArray, GetNeedBarrier)
165         CHECK(StoreArrayPair, GetNeedBarrier)
166         CHECK(LoadArrayI, GetNeedBarrier)
167         CHECK(LoadArrayPairI, GetNeedBarrier)
168         CHECK(StoreArrayI, GetNeedBarrier)
169         CHECK(StoreArrayPairI, GetNeedBarrier)
170         CHECK(LoadStatic, GetNeedBarrier)
171         CHECK(StoreStatic, GetNeedBarrier)
172         CHECK(LoadObject, GetNeedBarrier)
173         CHECK(StoreObject, GetNeedBarrier)
174         CHECK(LoadStatic, GetVolatile)
175         CHECK(StoreStatic, GetVolatile)
176         CHECK(LoadObject, GetVolatile)
177         CHECK(StoreObject, GetVolatile)
178         CHECK(NewObject, GetNeedBarrier)
179         CHECK(NewArray, GetNeedBarrier)
180         CHECK(CheckCast, GetNeedBarrier)
181         CHECK(IsInstance, GetNeedBarrier)
182         CHECK(LoadString, GetNeedBarrier)
183         CHECK(LoadConstArray, GetNeedBarrier)
184         CHECK(LoadType, GetNeedBarrier)
185 
186         CHECK(CallStatic, IsInlined)
187         CHECK(CallVirtual, IsInlined)
188 
189         CHECK(LoadArray, IsArray)
190         CHECK(LenArray, IsArray)
191 
192         CHECK(Deoptimize, GetDeoptimizeType)
193         CHECK(DeoptimizeIf, GetDeoptimizeType)
194 
195         CHECK(CompareAnyType, GetAnyType)
196         CHECK(CastValueToAnyType, GetAnyType)
197         CHECK(CastAnyTypeValue, GetAnyType)
198         CHECK(AnyTypeCheck, GetAnyType)
199 
200         // Those below can fail because unit test Graph don't have proper Runtime links
201         // CHECK(Intrinsic, GetEntrypointId)
202         // CHECK(CallStatic, GetCallMethodId)
203         // CHECK(CallVirtual, GetCallMethodId)
204 
205         // CHECK(InitClass, GetTypeId)
206         // CHECK(LoadAndInitClass, GetTypeId)
207         // CHECK(LoadStatic, GetTypeId)
208         // CHECK(StoreStatic, GetTypeId)
209         // CHECK(LoadObject, GetTypeId)
210         // CHECK(StoreObject, GetTypeId)
211         // CHECK(NewObject, GetTypeId)
212         // CHECK(InitObject, GetTypeId)
213         // CHECK(NewArray, GetTypeId)
214         // CHECK(LoadConstArray, GetTypeId)
215         // CHECK(CheckCast, GetTypeId)
216         // CHECK(IsInstance, GetTypeId)
217         // CHECK(LoadString, GetTypeId)
218         // CHECK(LoadType, GetTypeId)
219 #undef CHECK
220 #undef CAST
221         if (inst1->GetOpcode() == Opcode::Constant) {
222             auto c1 = inst1->CastToConstant();
223             auto c2 = inst2->CastToConstant();
224             bool same = false;
225             switch (inst1->GetType()) {
226                 case DataType::FLOAT32:
227                 case DataType::INT32:
228                     same = static_cast<uint32_t>(c1->GetRawValue()) == static_cast<uint32_t>(c2->GetRawValue());
229                     break;
230                 default:
231                     same = c1->GetRawValue() == c2->GetRawValue();
232                     break;
233             }
234             if (!same) {
235                 inst_compare_map_.erase(inst1);
236                 return false;
237             }
238         }
239         if (inst1->GetOpcode() == Opcode::Cmp && IsFloatType(inst1->GetInput(0).GetInst()->GetType())) {
240             auto cmp1 = static_cast<CmpInst *>(inst1);
241             auto cmp2 = static_cast<CmpInst *>(inst2);
242             if (cmp1->IsFcmpg() != cmp2->IsFcmpg()) {
243                 inst_compare_map_.erase(inst1);
244                 return false;
245             }
246         }
247         for (size_t i = 0; i < inst2->GetInputsCount(); i++) {
248             if (inst1->GetInputType(i) != inst2->GetInputType(i)) {
249                 inst_compare_map_.erase(inst1);
250                 return false;
251             }
252         }
253         if (inst1->IsSaveState()) {
254             auto *sv_st1 = static_cast<SaveStateInst *>(inst1);
255             auto *sv_st2 = static_cast<SaveStateInst *>(inst2);
256             if (sv_st1->GetImmediatesCount() != sv_st2->GetImmediatesCount()) {
257                 inst_compare_map_.erase(inst1);
258                 return false;
259             }
260 
261             std::vector<VirtualRegister::ValueType> regs1;
262             std::vector<VirtualRegister::ValueType> regs2;
263             regs1.reserve(sv_st1->GetInputsCount());
264             regs2.reserve(sv_st2->GetInputsCount());
265             for (size_t i {0}; i < sv_st1->GetInputsCount(); ++i) {
266                 regs1.emplace_back(sv_st1->GetVirtualRegister(i).Value());
267                 regs2.emplace_back(sv_st2->GetVirtualRegister(i).Value());
268             }
269             std::sort(regs1.begin(), regs1.end());
270             std::sort(regs2.begin(), regs2.end());
271             if (regs1 != regs2) {
272                 inst_compare_map_.erase(inst1);
273                 return false;
274             }
275             if (sv_st1->GetImmediatesCount() != 0) {
276                 auto eq_lambda = [](SaveStateImm i1, SaveStateImm i2) {
277                     return i1.value == i2.value && i1.vreg == i2.vreg && i1.is_acc == i2.is_acc;
278                 };
279                 if (!std::equal(sv_st1->GetImmediates()->begin(), sv_st1->GetImmediates()->end(),
280                                 sv_st2->GetImmediates()->begin(), eq_lambda)) {
281                     inst_compare_map_.erase(inst1);
282                     return false;
283                 }
284             }
285         }
286         return true;
287     }
288 
289 private:
290     std::unordered_map<Inst *, Inst *> inst_compare_map_;
291     std::unordered_map<BasicBlock *, BasicBlock *> bb_map_;
292 };
293 }  // namespace panda::compiler
294 
295 #endif  // COMPILER_TESTS_GRAPH_COMPARATOR_H
296