1 /**
2 * Copyright (c) 2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "lower_boxed_boolean.h"
17 #include "compiler_logger.h"
18 #include <optional>
19
20 #include "optimizer/ir/analysis.h"
21 #include "optimizer/ir/basicblock.h"
22 #include "optimizer/ir/datatype.h"
23 #include "optimizer/ir/graph.h"
24 #include "optimizer/ir/graph_visitor.h"
25 #include "optimizer/ir/inst.h"
26 #include "utils/arena_containers.h"
27
28 namespace ark::compiler {
29
RunImpl()30 bool LowerBoxedBoolean::RunImpl()
31 {
32 COMPILER_LOG(DEBUG, LOWER_BOXED_BOOLEAN) << "Run " << GetPassName();
33 isApplied_ = false;
34 visitedMarker_ = GetGraph()->NewMarker();
35 VisitGraph();
36 instReplacements_.clear();
37 GetGraph()->EraseMarker(visitedMarker_);
38 COMPILER_LOG(DEBUG, LOWER_BOXED_BOOLEAN) << "LowerBoxedBoolean " << (isApplied_ ? "is" : "isn't") << " applied";
39 COMPILER_LOG(DEBUG, LOWER_BOXED_BOOLEAN) << "Finish LowerBoxedBoolean";
40 return isApplied_;
41 }
42
VisitCompare(GraphVisitor * v,Inst * inst)43 void LowerBoxedBoolean::VisitCompare(GraphVisitor *v, Inst *inst)
44 {
45 COMPILER_LOG(DEBUG, LOWER_BOXED_BOOLEAN) << "Start visit Compare with id = " << inst->GetId();
46 auto visitor = static_cast<LowerBoxedBoolean *>(v);
47 auto input = inst->GetInput(0).GetInst();
48
49 if (!IsCompareWithNullPtr(inst)) {
50 return;
51 }
52
53 ProcessInput(v, input);
54
55 if (!visitor->HasReplacement(input)) {
56 return;
57 }
58
59 COMPILER_LOG(DEBUG, LOWER_BOXED_BOOLEAN)
60 << "Applied LowerBoxedBoolean optimization to Compare with id = " << inst->GetId();
61
62 auto graph = inst->GetBasicBlock()->GetGraph();
63 inst->ReplaceUsers(graph->FindOrCreateConstant(0));
64 visitor->SetApplied();
65 }
66
VisitLoadObject(GraphVisitor * v,Inst * inst)67 void LowerBoxedBoolean::VisitLoadObject(GraphVisitor *v, Inst *inst)
68 {
69 COMPILER_LOG(DEBUG, LOWER_BOXED_BOOLEAN) << "Start visit LoadObject with id = " << inst->GetId();
70 auto runtime = inst->GetBasicBlock()->GetGraph()->GetRuntime();
71 auto fieldPtr = inst->CastToLoadObject()->GetObjField();
72 auto visitor = static_cast<LowerBoxedBoolean *>(v);
73 if (!runtime->IsFieldBooleanValue(fieldPtr)) {
74 return;
75 }
76
77 auto input = inst->GetDataFlowInput(0);
78 if (!IsValidLoadObjectInput(input)) {
79 return;
80 }
81
82 ProcessInput(v, input);
83
84 if (visitor->HasReplacement(input)) {
85 COMPILER_LOG(DEBUG, LOWER_BOXED_BOOLEAN)
86 << "Applied LowerBoxedBoolean optimization to LoadObject with id = " << inst->GetId();
87 inst->ReplaceUsers(visitor->GetReplacement(input));
88 visitor->SetApplied();
89 }
90 }
91
IsValidLoadObjectInput(Inst * input)92 bool LowerBoxedBoolean::IsValidLoadObjectInput(Inst *input)
93 {
94 // For LoadObject std.core.Boolean, we only support cases where its input is either:
95 // - LoadStatic of a known Boolean constant (TRUE or FALSE), or
96 // - Phi node merging such values.
97 return input->IsPhi() || GetBooleanFieldValue(input).has_value();
98 }
99
ProcessInput(GraphVisitor * v,Inst * inst)100 void LowerBoxedBoolean::ProcessInput(GraphVisitor *v, Inst *inst)
101 {
102 auto visitor = static_cast<LowerBoxedBoolean *>(v);
103 if (visitor->IsVisited(inst)) {
104 return;
105 }
106 visitor->SetVisited(inst);
107
108 switch (inst->GetOpcode()) {
109 case Opcode::LoadStatic:
110 ProcessLoadStatic(v, inst);
111 break;
112 case Opcode::Phi:
113 ProcessPhi(v, inst);
114 break;
115 default:
116 break;
117 }
118 }
119
ProcessPhi(GraphVisitor * v,Inst * inst)120 void LowerBoxedBoolean::ProcessPhi(GraphVisitor *v, Inst *inst)
121 {
122 auto visitor = static_cast<LowerBoxedBoolean *>(v);
123 if (visitor->HasReplacement(inst)) {
124 return;
125 }
126
127 if (!HasOnlyKnownUsers(inst)) {
128 return;
129 }
130
131 COMPILER_LOG(DEBUG, LOWER_BOXED_BOOLEAN) << "Process Phi with id = " << inst->GetId();
132
133 // We must be sure that all Phi inputs are reducible to constant Boolean values
134 // before replacing the Phi itself.
135 for (auto input : inst->GetInputs()) {
136 ProcessInput(v, input.GetInst());
137 if (!visitor->HasReplacement(input.GetInst())) {
138 return;
139 }
140 }
141
142 // Clone Phi instruction, replace its inputs with optimized versions, and set its type to BOOL.
143 auto graph = inst->GetBasicBlock()->GetGraph();
144 auto clone = inst->Clone(graph);
145 clone->SetType(DataType::BOOL);
146 inst->InsertBefore(clone);
147
148 visitor->SetInstReplacement(inst, clone);
149 for (auto input : inst->GetInputs()) {
150 clone->AppendInput(visitor->GetReplacement(input.GetInst()));
151 }
152
153 inst->SetFlag(inst_flags::NO_NULLPTR);
154 }
155
ProcessLoadStatic(GraphVisitor * v,Inst * inst)156 void LowerBoxedBoolean::ProcessLoadStatic(GraphVisitor *v, Inst *inst)
157 {
158 auto visitor = static_cast<LowerBoxedBoolean *>(v);
159 if (visitor->HasReplacement(inst)) {
160 return;
161 }
162
163 if (!HasOnlyKnownUsers(inst)) {
164 return;
165 }
166
167 COMPILER_LOG(DEBUG, LOWER_BOXED_BOOLEAN) << "Process LoadStatic with id = " << inst->GetId();
168
169 auto graph = inst->GetBasicBlock()->GetGraph();
170 if (auto fieldValue = GetBooleanFieldValue(inst)) {
171 inst->SetFlag(inst_flags::NO_NULLPTR);
172 auto constInst = graph->FindOrCreateConstant(*fieldValue);
173 visitor->SetInstReplacement(inst, constInst);
174 }
175 }
176
HasOnlyKnownUsers(Inst * inst)177 bool LowerBoxedBoolean::HasOnlyKnownUsers(Inst *inst)
178 {
179 for (auto &user : inst->GetUsers()) {
180 auto userInst = user.GetInst();
181 auto opcode = userInst->GetOpcode();
182 if (opcode == Opcode::SaveState) {
183 if (ProcessSaveState(userInst, inst)) {
184 continue;
185 }
186 }
187 switch (opcode) {
188 case Opcode::CallStatic:
189 case Opcode::CheckCast:
190 case Opcode::Compare:
191 case Opcode::IfImm:
192 case Opcode::Intrinsic:
193 case Opcode::LoadObject:
194 case Opcode::NullCheck:
195 case Opcode::Phi:
196 continue;
197 default:
198 return false;
199 }
200 }
201 return true;
202 }
203
ProcessSaveState(Inst * saveState,Inst * inst)204 bool LowerBoxedBoolean::ProcessSaveState(Inst *saveState, Inst *inst)
205 {
206 auto saveStateInst = saveState->CastToSaveState();
207 auto callerInst = saveStateInst->GetCallerInst();
208 if (callerInst != nullptr && callerInst->IsInlined()) {
209 return true;
210 }
211
212 if (!saveStateInst->HasUsers()) {
213 return false;
214 }
215
216 auto firstUser = saveStateInst->GetFirstUser()->GetInst();
217 if (firstUser->GetOpcode() == Opcode::ReturnInlined) {
218 return true;
219 }
220
221 return CheckSaveStateUsers(saveState, inst);
222 }
223
CheckSaveStateUsers(Inst * saveStateInst,Inst * inst)224 bool LowerBoxedBoolean::CheckSaveStateUsers(Inst *saveStateInst, Inst *inst)
225 {
226 for (auto &user : saveStateInst->GetUsers()) {
227 if (!IsNullCheckUsingInput(user.GetInst(), inst)) {
228 continue;
229 }
230 return true;
231 }
232 return false;
233 }
234
IsNullCheckUsingInput(Inst * inst,Inst * input)235 bool LowerBoxedBoolean::IsNullCheckUsingInput(Inst *inst, Inst *input)
236 {
237 if (inst->GetOpcode() != Opcode::NullCheck) {
238 return false;
239 }
240
241 for (auto &in : inst->GetInputs()) {
242 if (in.GetInst() == input) {
243 return true;
244 }
245 }
246 return false;
247 }
248
GetBooleanFieldValue(Inst * inst)249 std::optional<uint32_t> LowerBoxedBoolean::GetBooleanFieldValue(Inst *inst)
250 {
251 if (inst->GetOpcode() != Opcode::LoadStatic) {
252 return std::nullopt;
253 }
254
255 auto graph = inst->GetBasicBlock()->GetGraph();
256 auto runtime = graph->GetRuntime();
257 auto fieldPtr = inst->CastToLoadStatic()->GetObjField();
258
259 bool isTrue = runtime->IsFieldBooleanTrue(fieldPtr);
260 bool isFalse = runtime->IsFieldBooleanFalse(fieldPtr);
261 if (isTrue == isFalse) {
262 return std::nullopt;
263 }
264
265 return isTrue ? 1 : 0;
266 }
267
IsCompareWithNullPtr(Inst * inst)268 bool LowerBoxedBoolean::IsCompareWithNullPtr(Inst *inst)
269 {
270 ASSERT(inst->GetOpcode() == Opcode::Compare);
271 auto input = inst->GetInput(1).GetInst();
272 return input->GetOpcode() == Opcode::NullPtr;
273 }
274
SetInstReplacement(Inst * oldInst,Inst * newInst)275 void LowerBoxedBoolean::SetInstReplacement(Inst *oldInst, Inst *newInst)
276 {
277 instReplacements_[oldInst] = newInst;
278 }
279
GetReplacement(Inst * inst)280 Inst *LowerBoxedBoolean::GetReplacement(Inst *inst)
281 {
282 if (HasReplacement(inst)) {
283 return instReplacements_[inst];
284 }
285
286 return nullptr;
287 }
288
HasReplacement(Inst * inst) const289 bool LowerBoxedBoolean::HasReplacement(Inst *inst) const
290 {
291 return instReplacements_.find(inst) != instReplacements_.end();
292 }
293
SetVisited(Inst * inst)294 void LowerBoxedBoolean::SetVisited(Inst *inst)
295 {
296 inst->SetMarker(visitedMarker_);
297 }
298
IsVisited(Inst * inst) const299 bool LowerBoxedBoolean::IsVisited(Inst *inst) const
300 {
301 return inst->IsMarked(visitedMarker_);
302 }
303
SetApplied()304 void LowerBoxedBoolean::SetApplied()
305 {
306 isApplied_ = true;
307 }
308 } // namespace ark::compiler
309