• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (c) 2021-2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_TESTS_GRAPH_COMPARATOR_H
17 #define COMPILER_TESTS_GRAPH_COMPARATOR_H
18 
19 #include <algorithm>
20 #include <iostream>
21 #include "optimizer/analysis/rpo.h"
22 #include "optimizer/ir/basicblock.h"
23 #include "optimizer/ir/graph.h"
24 #include "optimizer/ir/inst.h"
25 
26 namespace panda::compiler {
27 
28 class GraphComparator {
29 public:
Compare(Graph * graph1,Graph * graph2)30     bool Compare(Graph *graph1, Graph *graph2)
31     {
32         ASSERT(graph1 != nullptr);
33         ASSERT(graph2 != nullptr);
34         graph1->InvalidateAnalysis<Rpo>();
35         graph2->InvalidateAnalysis<Rpo>();
36         if (graph1->GetBlocksRPO().size() != graph2->GetBlocksRPO().size()) {
37             std::cerr << "Different number of blocks\n";
38             return false;
39         }
40         for (auto it1 = graph1->GetBlocksRPO().begin(), it2 = graph2->GetBlocksRPO().begin();
41              it1 != graph1->GetBlocksRPO().end(); it1++, it2++) {
42             auto it = bb_map_.insert({*it1, *it2});
43             if (!it.second) {
44                 return false;
45             }
46         }
47         return std::equal(graph1->GetBlocksRPO().begin(), graph1->GetBlocksRPO().end(), graph2->GetBlocksRPO().begin(),
48                           graph2->GetBlocksRPO().end(), [this](auto bb1, auto bb2) { return Compare(bb1, bb2); });
49     }
50 
Compare(BasicBlock * block1,BasicBlock * block2)51     bool Compare(BasicBlock *block1, BasicBlock *block2)
52     {
53         ASSERT(block1 != nullptr);
54         ASSERT(block2 != nullptr);
55         if (block1->GetPredsBlocks().size() != block2->GetPredsBlocks().size()) {
56             std::cerr << "Different number of preds blocks\n";
57             block1->Dump(&std::cerr);
58             block2->Dump(&std::cerr);
59             return false;
60         }
61         if (block1->GetSuccsBlocks().size() != block2->GetSuccsBlocks().size()) {
62             std::cerr << "Different number of succs blocks\n";
63             block1->Dump(&std::cerr);
64             block2->Dump(&std::cerr);
65             return false;
66         }
67         auto inst_cmp = [this](auto inst1, auto inst2) {
68             ASSERT(inst2 != nullptr);
69             bool t = Compare(inst1, inst2);
70             if (!t) {
71                 std::cerr << "Different instructions:\n";
72                 inst1->Dump(&std::cerr);
73                 inst2->Dump(&std::cerr);
74             }
75             return t;
76         };
77         return std::equal(block1->AllInsts().begin(), block1->AllInsts().end(), block2->AllInsts().begin(),
78                           block2->AllInsts().end(), inst_cmp);
79     }
80 
Compare(Inst * inst1,Inst * inst2)81     bool Compare(Inst *inst1, Inst *inst2)
82     {
83         ASSERT(inst1 != nullptr);
84         ASSERT(inst2 != nullptr);
85         if (auto it = inst_compare_map_.insert({inst1, inst2}); !it.second) {
86             if (inst2 == it.first->second) {
87                 return true;
88             }
89             inst_compare_map_.erase(inst1);
90             return false;
91         }
92 
93         if (inst1->GetOpcode() != inst2->GetOpcode() || inst1->GetType() != inst2->GetType() ||
94             inst1->GetInputsCount() != inst2->GetInputsCount()) {
95                 inst_compare_map_.erase(inst1);
96                 return false;
97         }
98 
99         bool result = (inst1->GetOpcode() != Opcode::Phi) ?
100                       CompareNonPhiInputs(inst1, inst2) : ComparePhiInputs(inst1, inst2);
101         if (!result) {
102             inst_compare_map_.erase(inst1);
103             return false;
104         }
105 
106 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
107 #define CAST(Opc) CastTo##Opc()
108 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage
109 #define CHECK_INST(Opc, Getter)                                                                               \
110     if (inst1->GetOpcode() == Opcode::Opc && inst1->CAST(Opc)->Getter() != inst2->CAST(Opc)->Getter()) { \
111         inst_compare_map_.erase(inst1);                                                                  \
112         return false;                                                                                    \
113     }
114         CHECK_INST(CastAnyTypeValue, GetDeducedType)
115 
116         CHECK_INST(Cmp, GetOperandsType)
117 
118         CHECK_INST(Compare, GetCc)
119         CHECK_INST(Compare, GetOperandsType)
120 
121         CHECK_INST(If, GetCc)
122         CHECK_INST(If, GetOperandsType)
123 
124         CHECK_INST(IfImm, GetCc)
125         CHECK_INST(IfImm, GetImm)
126         CHECK_INST(IfImm, GetOperandsType)
127 
128         CHECK_INST(LoadString, GetNeedBarrier)
129 
130         CHECK_INST(CompareAnyType, GetAnyType)
131         CHECK_INST(CastValueToAnyType, GetAnyType)
132         CHECK_INST(CastAnyTypeValue, GetAnyType)
133 
134         // Those below can fail because unit test Graph don't have proper Runtime links
135         // CHECK_INST(Intrinsic, GetEntrypointId)
136         // CHECK_INST(CallStatic, GetCallMethodId)
137         // CHECK_INST(CallVirtual, GetCallMethodId)
138 
139         // CHECK_INST(InitClass, GetTypeId)
140         // CHECK_INST(LoadAndInitClass, GetTypeId)
141         // CHECK_INST(LoadStatic, GetTypeId)
142         // CHECK_INST(StoreStatic, GetTypeId)
143         // CHECK_INST(LoadObject, GetTypeId)
144         // CHECK_INST(StoreObject, GetTypeId)
145         // CHECK_INST(NewObject, GetTypeId)
146         // CHECK_INST(InitObject, GetTypeId)
147         // CHECK_INST(NewArray, GetTypeId)
148         // CHECK_INST(LoadConstArray, GetTypeId)
149         // CHECK_INST(CHECK_INSTCast, GetTypeId)
150         // CHECK_INST(IsInstance, GetTypeId)
151         // CHECK_INST(LoadString, GetTypeId)
152         // CHECK_INST(LoadType, GetTypeId)
153 #undef CHECK_INST
154 #undef CAST
155         if (!CompareInputTypes(inst1, inst2)
156             || !CompareIntrinsicInst(inst1, inst2) || !CompareConstantInst(inst1, inst2)
157             || !CompareFcmpgInst(inst1, inst2) || !CompareSaveStateInst(inst1, inst2)) {
158             inst_compare_map_.erase(inst1);
159             return false;
160         }
161 
162         return true;
163     }
164 private:
165     std::unordered_map<Inst *, Inst *> inst_compare_map_;
166     std::unordered_map<BasicBlock *, BasicBlock *> bb_map_;
167 
CompareNonPhiInputs(Inst * inst1,Inst * inst2)168     bool CompareNonPhiInputs(Inst *inst1, Inst *inst2)
169     {
170         auto inst1_begin = inst1->GetInputs().begin();
171         auto inst1_end = inst1->GetInputs().end();
172         auto inst2_begin = inst2->GetInputs().begin();
173         auto eq_lambda = [this](Input input1, Input input2) {
174             return Compare(input1.GetInst(), input2.GetInst());
175         };
176         return std::equal(inst1_begin, inst1_end, inst2_begin, eq_lambda);
177     }
178 
ComparePhiInputs(Inst * inst1,Inst * inst2)179     bool ComparePhiInputs(Inst *inst1, Inst *inst2)
180     {
181         if (inst1->GetInputsCount() != inst2->GetInputsCount()) {
182             return false;
183         }
184 
185         for (size_t index1 = 0; index1 < inst1->GetInputsCount(); index1++) {
186             auto input1 = inst1->GetInput(index1).GetInst();
187             auto bb1 = inst1->CastToPhi()->GetPhiInputBb(index1);
188             if (bb_map_.count(bb1) == 0) {
189                 return false;
190             }
191             auto bb2 = bb_map_.at(bb1);
192             auto input2 = inst2->CastToPhi()->GetPhiInput(bb2);
193             if (!Compare(input1, input2)) {
194                 return false;
195             }
196         }
197         return true;
198     }
199 
CompareIntrinsicInst(Inst * inst1,Inst * inst2)200     bool CompareIntrinsicInst(Inst *inst1, Inst *inst2)
201     {
202         if (inst1->GetOpcode() != Opcode::Intrinsic) {
203             return true;
204         }
205 
206         auto intrinsic1 = inst1->CastToIntrinsic();
207         auto intrinsic2 = inst2->CastToIntrinsic();
208         auto same = intrinsic1->GetIntrinsicId() == intrinsic2->GetIntrinsicId();
209         if (intrinsic1->HasImms()) {
210             auto imms1 = intrinsic1->GetImms();
211             auto imms2 = intrinsic2->GetImms();
212             same = same && std::equal(imms1.begin(), imms1.end(), imms2.begin(), imms2.end());
213         }
214         return same;
215     }
216 
CompareConstantInst(Inst * inst1,Inst * inst2)217     bool CompareConstantInst(Inst *inst1, Inst *inst2)
218     {
219         if (inst1->GetOpcode() != Opcode::Constant) {
220             return true;
221         }
222 
223         auto c1 = inst1->CastToConstant();
224         auto c2 = inst2->CastToConstant();
225         bool same = false;
226         switch (inst1->GetType()) {
227             case DataType::FLOAT32:
228             case DataType::INT32:
229                 same = static_cast<uint32_t>(c1->GetRawValue()) == static_cast<uint32_t>(c2->GetRawValue());
230                 break;
231             default:
232                 same = c1->GetRawValue() == c2->GetRawValue();
233                 break;
234         }
235         return same;
236     }
237 
CompareFcmpgInst(Inst * inst1,Inst * inst2)238     bool CompareFcmpgInst(Inst *inst1, Inst *inst2)
239     {
240         if (inst1->GetOpcode() != Opcode::Cmp || !IsFloatType(inst1->GetInput(0).GetInst()->GetType())) {
241             return true;
242         }
243 
244         auto cmp1 = static_cast<CmpInst *>(inst1);
245         auto cmp2 = static_cast<CmpInst *>(inst2);
246         return cmp1->IsFcmpg() == cmp2->IsFcmpg();
247     }
248 
CompareInputTypes(Inst * inst1,Inst * inst2)249     bool CompareInputTypes(Inst *inst1, Inst *inst2)
250     {
251         for (size_t i = 0; i < inst2->GetInputsCount(); i++) {
252             if (inst1->GetInputType(i) != inst2->GetInputType(i)) {
253                 return false;
254             }
255         }
256         return true;
257     }
258 
CompareSaveStateInst(Inst * inst1,Inst * inst2)259     bool CompareSaveStateInst(Inst *inst1, Inst *inst2)
260     {
261         if (!inst1->IsSaveState()) {
262             return true;
263         }
264 
265         auto *sv_st1 = static_cast<SaveStateInst *>(inst1);
266         auto *sv_st2 = static_cast<SaveStateInst *>(inst2);
267         if (sv_st1->GetImmediatesCount() != sv_st2->GetImmediatesCount()) {
268             return false;
269         }
270 
271         std::vector<VirtualRegister::ValueType> regs1;
272         std::vector<VirtualRegister::ValueType> regs2;
273         regs1.reserve(sv_st1->GetInputsCount());
274         regs2.reserve(sv_st2->GetInputsCount());
275         for (size_t i {0}; i < sv_st1->GetInputsCount(); ++i) {
276             regs1.emplace_back(sv_st1->GetVirtualRegister(i).Value());
277             regs2.emplace_back(sv_st2->GetVirtualRegister(i).Value());
278         }
279         std::sort(regs1.begin(), regs1.end());
280         std::sort(regs2.begin(), regs2.end());
281         if (regs1 != regs2) {
282             return false;
283         }
284         if (sv_st1->GetImmediatesCount() != 0) {
285             auto eq_lambda = [](SaveStateImm i1, SaveStateImm i2) {
286                 return i1.value == i2.value && i1.vreg == i2.vreg && i1.is_acc == i2.is_acc;
287             };
288             if (!std::equal(sv_st1->GetImmediates()->begin(), sv_st1->GetImmediates()->end(),
289                             sv_st2->GetImmediates()->begin(), eq_lambda)) {
290                 return false;
291             }
292         }
293         return true;
294     }
295 };
296 }  // namespace panda::compiler
297 
298 #endif  // COMPILER_TESTS_GRAPH_COMPARATOR_H
299