1 /*
2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include "canonicalization.h"
16
17 #include "compiler/optimizer/ir/basicblock.h"
18 #include "compiler/optimizer/ir/inst.h"
19
20 namespace ark::bytecodeopt {
21
RunImpl()22 bool Canonicalization::RunImpl()
23 {
24 Canonicalization visitor(GetGraph());
25 for (auto bb : GetGraph()->GetBlocksRPO()) {
26 for (auto inst : bb->AllInsts()) {
27 if (inst->IsCommutative()) {
28 visitor.VisitCommutative(inst);
29 } else {
30 visitor.VisitInstruction(inst);
31 }
32 }
33 }
34 return visitor.GetStatus();
35 }
36
IsDominateReverseInputs(const compiler::Inst * inst)37 static bool IsDominateReverseInputs(const compiler::Inst *inst)
38 {
39 auto input0 = inst->GetInput(0U).GetInst();
40 auto input1 = inst->GetInput(1U).GetInst();
41 return input0->IsDominate(input1);
42 }
43
ConstantFitsCompareImm(const Inst * cst,uint32_t size)44 static bool ConstantFitsCompareImm(const Inst *cst, uint32_t size)
45 {
46 ASSERT(cst->GetOpcode() == Opcode::Constant);
47 if (compiler::DataType::IsFloatType(cst->GetType())) {
48 return false;
49 }
50 auto val = cst->CastToConstant()->GetIntValue();
51 return (size == compiler::HALF_SIZE) && (val == 0U);
52 }
53
BetterToSwapCompareInputs(const compiler::Inst * inst,const compiler::Inst * input0,const compiler::Inst * input1)54 static bool BetterToSwapCompareInputs(const compiler::Inst *inst, const compiler::Inst *input0,
55 const compiler::Inst *input1)
56 {
57 if (!input0->IsConst()) {
58 return false;
59 }
60 if (!input1->IsConst()) {
61 return true;
62 }
63
64 compiler::DataType::Type type = inst->CastToCompare()->GetOperandsType();
65 uint32_t size = (type == compiler::DataType::UINT64 || type == compiler::DataType::INT64) ? compiler::WORD_SIZE
66 : compiler::HALF_SIZE;
67 return ConstantFitsCompareImm(input0, size) && !ConstantFitsCompareImm(input1, size);
68 }
69
SwapInputsIfNecessary(compiler::Inst * inst,const bool necessary)70 static bool SwapInputsIfNecessary(compiler::Inst *inst, const bool necessary)
71 {
72 if (!necessary) {
73 return false;
74 }
75 auto input0 = inst->GetInput(0U).GetInst();
76 auto input1 = inst->GetInput(1U).GetInst();
77 if ((inst->GetOpcode() == compiler::Opcode::Compare) && !BetterToSwapCompareInputs(inst, input0, input1)) {
78 return false;
79 }
80
81 inst->SwapInputs();
82 return true;
83 }
84
TrySwapConstantInput(Inst * inst)85 bool Canonicalization::TrySwapConstantInput(Inst *inst)
86 {
87 return SwapInputsIfNecessary(inst, inst->GetInput(0U).GetInst()->IsConst());
88 }
89
TrySwapReverseInput(Inst * inst)90 bool Canonicalization::TrySwapReverseInput(Inst *inst)
91 {
92 return SwapInputsIfNecessary(inst, IsDominateReverseInputs(inst));
93 }
94
VisitCommutative(Inst * inst)95 void Canonicalization::VisitCommutative(Inst *inst)
96 {
97 ASSERT(inst->IsCommutative());
98 ASSERT(inst->GetInputsCount() == 2U); // 2 is COMMUTATIVE_INPUT_COUNT
99 if (g_options.GetOptLevel() > 1) {
100 result_ = TrySwapReverseInput(inst);
101 }
102 result_ = TrySwapConstantInput(inst) || result_;
103 }
104
105 // It is not allowed to move a constant input1 with a single user (it's processed Compare instruction).
106 // This is necessary for further merging of the constant and the If instrution in the Lowering pass
AllowSwap(const compiler::Inst * inst)107 bool AllowSwap(const compiler::Inst *inst)
108 {
109 auto input1 = inst->GetInput(1U).GetInst();
110 if (!input1->IsConst()) {
111 return true;
112 }
113 for (const auto &user : input1->GetUsers()) {
114 if (user.GetInst() != inst) {
115 return true;
116 }
117 }
118 return false;
119 }
120
VisitCompare(GraphVisitor * v,Inst * instBase)121 void Canonicalization::VisitCompare([[maybe_unused]] GraphVisitor *v, Inst *instBase)
122 {
123 auto inst = instBase->CastToCompare();
124 if (AllowSwap(inst) && SwapInputsIfNecessary(inst, IsDominateReverseInputs(inst))) {
125 auto revertCc = SwapOperandsConditionCode(inst->GetCc());
126 inst->SetCc(revertCc);
127 }
128 }
129
130 } // namespace ark::bytecodeopt
131