• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Shenzhen Kaihong Digital Industry Development Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  *
15  * Copyright (c) 2024 Huawei Device Co., Ltd.
16  * Licensed under the Apache License, Version 2.0 (the "License");
17  * you may not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19 
20  * http://www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an "AS IS" BASIS,
24  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include <cstdint>
30 #include <set>
31 #include "bytecode_optimizer/constant_propagation/constant_propagation.h"
32 #include "bytecode_optimizer/ir_interface.h"
33 #include "compiler/tests/graph_test.h"
34 #include "gtest/gtest.h"
35 #include "optimizer/ir/basicblock.h"
36 #include "optimizer/ir/graph.h"
37 #include "optimizer/ir/inst.h"
38 #include "optimizer/optimizations/branch_elimination.h"
39 #include "optimizer/optimizations/cleanup.h"
40 
41 using namespace testing::ext;
42 
43 namespace panda::compiler {
44 class BranchEliminationTest : public testing::Test {
45 public:
SetUpTestCase(void)46     static void SetUpTestCase(void) {}
TearDownTestCase(void)47     static void TearDownTestCase(void) {}
SetUp()48     void SetUp() {}
TearDown()49     void TearDown() {}
50 
IsIntrinsicConstInst(Inst * inst)51     static bool IsIntrinsicConstInst(Inst *inst)
52     {
53         ASSERT(inst != nullptr);
54         if (!inst->IsIntrinsic()) {
55             return false;
56         }
57 
58         auto intrinsic_id = inst->CastToIntrinsic()->GetIntrinsicId();
59         return intrinsic_id == IntrinsicInst::IntrinsicId::LDTRUE ||
60                intrinsic_id == IntrinsicInst::IntrinsicId::LDFALSE;
61     }
62 
IsIfWithConstInputs(Inst * inst)63     static bool IsIfWithConstInputs(Inst *inst)
64     {
65         ASSERT(inst != nullptr);
66         if (inst->GetOpcode() != Opcode::IfImm) {
67             return false;
68         }
69 
70         auto input_inst = inst->GetInput(0).GetInst();
71         return input_inst->IsConst() || IsIntrinsicConstInst(input_inst);
72     }
73 
GetConstValue(Inst * inst)74     static bool GetConstValue(Inst *inst)
75     {
76         ASSERT(inst != nullptr);
77         if (inst->IsConst()) {
78             auto const_value = inst->CastToConstant()->GetIntValue();
79             ASSERT(const_value <= 1);
80             return static_cast<bool>(const_value);
81         } else if (IsIntrinsicConstInst(inst)) {
82             return inst->CastToIntrinsic()->GetIntrinsicId() == IntrinsicInst::IntrinsicId::LDTRUE;
83         } else {
84             UNREACHABLE();
85         }
86     }
87 
GetDeadBranch(IfImmInst * inst)88     static BasicBlock* GetDeadBranch(IfImmInst *inst)
89     {
90         auto input_inst = inst->GetInput(0).GetInst();
91         ASSERT(input_inst->IsConst() || IsIntrinsicConstInst(input_inst));
92 
93         bool const_value = GetConstValue(input_inst);
94         bool cond_result = (const_value == inst->GetImm());
95         if (inst->GetCc() == CC_NE) {
96             cond_result = !cond_result;
97         } else {
98             ASSERT(inst->GetCc() == CC_EQ);
99         }
100 
101         if (cond_result) {
102             return inst->GetEdgeIfInputFalse();
103         } else {
104             return inst->GetEdgeIfInputTrue();
105         }
106     }
107 
CollectDominatedDeadBlocks(Graph * graph,std::set<uint32_t> & dead_blocks,IfImmInst * dead_if_inst)108     static void CollectDominatedDeadBlocks(Graph *graph, std::set<uint32_t> &dead_blocks, IfImmInst *dead_if_inst)
109     {
110         ASSERT(graph != nullptr);
111         ASSERT(dead_if_inst != nullptr);
112 
113         // Collect dead blocks that need to be eliminated.
114         auto dead_bb = GetDeadBranch(dead_if_inst);
115         dead_blocks.insert(dead_bb->GetId());
116         for (auto dom_bb : graph->GetBlocksRPO()) {
117             if (!dead_bb->IsDominate(dom_bb)) {
118                 continue;
119             }
120 
121             dead_blocks.insert(dom_bb->GetId());
122         }
123     }
124 
CollectDeadBlocksWithIfInst(Graph * graph,std::set<uint32_t> & dead_if_insts,std::set<uint32_t> & dead_blocks)125     static void CollectDeadBlocksWithIfInst(Graph *graph, std::set<uint32_t> &dead_if_insts,
126                                             std::set<uint32_t> &dead_blocks)
127     {
128         ASSERT(graph != nullptr);
129         for (auto bb : graph->GetBlocksRPO()) {
130             for (auto inst : bb->AllInsts()) {
131                 if (!IsIfWithConstInputs(inst)) {
132                     continue;
133                 }
134 
135                 dead_if_insts.insert(inst->GetId());
136                 CollectDominatedDeadBlocks(graph, dead_blocks, inst->CastToIfImm());
137             }
138         }
139     }
140 
141     GraphTest graph_test_;
142 };
143 
144 /**
145  * @tc.name: branch_elimination_test_001
146  * @tc.desc: Verify branch elimination.
147  * @tc.type: FUNC
148  * @tc.require:
149  */
150 HWTEST_F(BranchEliminationTest, branch_elimination_test_001, TestSize.Level1)
151 {
152     std::string pfile_unopt = GRAPH_TEST_ABC_DIR "branchElimination.abc";
153     options.SetCompilerUseSafepoint(false);
__anonacd6c1790102(Graph* graph, std::string &method_name) 154     graph_test_.TestBuildGraphFromFile(pfile_unopt, [](Graph* graph, std::string &method_name) {
155         if (method_name == "func_main_0") {
156             return;
157         }
158         EXPECT_NE(graph, nullptr);
159         pandasm::AsmEmitter::PandaFileToPandaAsmMaps maps;
160         pandasm::Program *prog = nullptr;
161         bytecodeopt::BytecodeOptIrInterface interface(&maps, prog);
162         graph->RunPass<bytecodeopt::ConstantPropagation>(&interface);
163 
164         std::set<uint32_t> dead_if_insts;
165         std::set<uint32_t> dead_blocks;
166         CollectDeadBlocksWithIfInst(graph, dead_if_insts, dead_blocks);
167 
168         EXPECT_FALSE(dead_if_insts.empty());
169         EXPECT_FALSE(dead_blocks.empty());
170 
171         graph->RunPass<BranchElimination>();
172         graph->RunPass<Cleanup>();
173         for (auto bb : graph->GetBlocksRPO()) {
174             EXPECT_TRUE(dead_blocks.count(bb->GetId()) == 0);
175             for (auto inst : bb->AllInsts()) {
176                 EXPECT_TRUE(dead_if_insts.count(inst->GetId()) == 0);
177             }
178         }
179     });
180 }
181 }  // namespace panda::compiler