1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "reg_alloc_resolver.h"
17 #include "reg_type.h"
18 #include "compiler/optimizer/code_generator/codegen.h"
19 #include "compiler/optimizer/ir/analysis.h"
20 #include "compiler/optimizer/ir/inst.h"
21 #include "compiler/optimizer/ir/graph.h"
22 #include "compiler/optimizer/ir/basicblock.h"
23 #include "compiler/optimizer/analysis/dominators_tree.h"
24 #include "optimizer/analysis/loop_analyzer.h"
25
26 namespace ark::compiler {
27 /*
28 * For each instruction set destination register if it is assigned,
29 * Pop inputs from stack and push result on stack if stack slot is assigned.
30 */
Resolve()31 void RegAllocResolver::Resolve()
32 {
33 // We use RPO order because we need to calculate Caller's roots mask before its inlined callees
34 // (see PropagateCallerMasks).
35 for (auto block : GetGraph()->GetBlocksRPO()) {
36 for (auto inst : block->AllInstsSafe()) {
37 if (inst->IsSaveState()) {
38 ResolveSaveState(inst);
39 continue;
40 }
41 ResolveInputs(inst);
42 ResolveOutput(inst);
43 if (GetGraph()->IsInstThrowable(inst)) {
44 AddCatchPhiMoves(inst);
45 }
46 }
47 }
48 }
49
AddCatchPhiMoves(Inst * inst)50 void RegAllocResolver::AddCatchPhiMoves(Inst *inst)
51 {
52 auto spillFillInst = GetGraph()->CreateInstSpillFill(SpillFillType::INPUT_FILL);
53 auto handlers = GetGraph()->GetThrowableInstHandlers(inst);
54
55 for (auto catchHandler : handlers) {
56 for (auto catchInst : catchHandler->AllInsts()) {
57 if (!catchInst->IsCatchPhi() || catchInst->CastToCatchPhi()->IsAcc()) {
58 continue;
59 }
60 auto catchPhi = catchInst->CastToCatchPhi();
61 const auto &throwableInsts = catchPhi->GetThrowableInsts();
62 auto it = std::find(throwableInsts->begin(), throwableInsts->end(), inst);
63 if (it == throwableInsts->end()) {
64 continue;
65 }
66 int index = std::distance(throwableInsts->begin(), it);
67 auto catchInput = catchPhi->GetDataFlowInput(index);
68 auto inputInterval = liveness_->GetInstLifeIntervals(catchInput);
69 ASSERT(inputInterval->GetSibling() == nullptr);
70 auto catchPhiInterval = liveness_->GetInstLifeIntervals(catchPhi);
71 if (inputInterval->GetLocation() != catchPhiInterval->GetLocation()) {
72 ConnectIntervals(spillFillInst, inputInterval, catchPhiInterval);
73 }
74 }
75 }
76 if (!spillFillInst->GetSpillFills().empty()) {
77 inst->InsertBefore(spillFillInst);
78 }
79 }
80
ResolveInputs(Inst * inst)81 void RegAllocResolver::ResolveInputs(Inst *inst)
82 {
83 if (inst->IsPhi() || inst->IsCatchPhi() || IsPseudoUserOfMultiOutput(inst)) {
84 return;
85 }
86 // Life-position before instruction to analyze intervals, that were splitted directly before it
87 auto insLn = liveness_->GetInstLifeIntervals(inst)->GetBegin();
88 auto preInsLn = insLn - 1U;
89
90 inputLocations_.clear();
91 for (size_t i = 0; i < inst->GetInputsCount(); ++i) {
92 auto inputInterval = liveness_->GetInstLifeIntervals(inst->GetDataFlowInput(i));
93 inputLocations_.push_back(inputInterval->FindSiblingAt(insLn)->GetLocation());
94 }
95
96 for (size_t i = 0; i < inst->GetInputsCount(); ++i) {
97 auto location = inst->GetLocation(i);
98 auto inputInterval = liveness_->GetInstLifeIntervals(inst->GetDataFlowInput(i));
99 if (CanReadFromAccumulator(inst, i) || inputInterval->NoDest() || location.IsInvalid()) {
100 continue;
101 }
102
103 // Interval with fixed register can be splitted before `inst`: we don't need any extra moves in that case,
104 // since fixed register can't be overwrite
105 auto sibling = inputInterval->FindSiblingAt(preInsLn);
106 ASSERT(sibling != nullptr);
107 if (location.IsFixedRegister() && sibling->GetLocation() == location) {
108 auto it = std::find(inputLocations_.begin(), inputLocations_.end(), location);
109 // If some other instruction reside in the location then we can't reuse a split ending
110 // before the inst as a corresponding location will be overridden
111 if (it == inputLocations_.end() || static_cast<size_t>(it - inputLocations_.begin()) == i) {
112 continue;
113 }
114 }
115
116 // Otherwise use sibling covering `inst`
117 if (sibling->GetEnd() == preInsLn) {
118 sibling = sibling->GetSibling();
119 }
120
121 // Input's location required any register: specify the allocated one
122 if (location.IsUnallocatedRegister()) {
123 ASSERT(sibling->HasReg());
124 inst->SetLocation(i, sibling->GetLocation());
125 continue;
126 }
127
128 // Finally, if input's location is not equal to the required one, add spill-fill
129 if (sibling->GetLocation() != location) {
130 AddMoveToFixedLocation(inst, sibling->GetLocation(), i);
131 }
132 }
133
134 if (inst->RequireTmpReg()) {
135 auto interval = liveness_->GetTmpRegInterval(inst);
136 ASSERT(interval != nullptr);
137 ASSERT(interval->HasReg());
138 auto regLocation = Location::MakeRegister(interval->GetReg(), interval->GetType());
139 inst->SetTmpLocation(regLocation);
140 GetGraph()->SetRegUsage(regLocation);
141 }
142 }
143
AddMoveToFixedLocation(Inst * inst,Location inputLocation,size_t inputNum)144 void RegAllocResolver::AddMoveToFixedLocation(Inst *inst, Location inputLocation, size_t inputNum)
145 {
146 // Create or get existing SpillFillInst
147 SpillFillInst *sfInst {};
148 if (inst->GetPrev() != nullptr && inst->GetPrev()->IsSpillFill()) {
149 sfInst = inst->GetPrev()->CastToSpillFill();
150 } else {
151 sfInst = GetGraph()->CreateInstSpillFill(SpillFillType::INPUT_FILL);
152 inst->InsertBefore(sfInst);
153 }
154
155 // Add move from input to fixed location
156 auto type = ConvertRegType(GetGraph(), inst->GetInputType(inputNum));
157 auto fixedLocation = inst->GetLocation(inputNum);
158 if (fixedLocation.IsFixedRegister()) {
159 GetGraph()->SetRegUsage(fixedLocation.GetValue(), type);
160 }
161 sfInst->AddSpillFill(inputLocation, fixedLocation, type);
162 }
163
GetFirstUserOrInst(Inst * inst)164 Inst *GetFirstUserOrInst(Inst *inst)
165 {
166 for (auto &user : inst->GetUsers()) {
167 if (user.GetInst()->GetOpcode() != Opcode::ReturnInlined) {
168 return user.GetInst();
169 }
170 }
171 return inst;
172 }
173
174 // For implicit null check we need to find the first null check's user to
175 // correctly capture SaveState's input locations, because implicit null checks are fired
176 // when its input is accessed by its users (for example, when LoadArray instruction is loading
177 // value from null array reference). Some life intervals may change its location (due to spilling)
178 // between NullCheck and its users, so locations captured at implicit null check could be incorrect.
179 // While implicit NullCheck may have multiple users we can use only a user dominating all other users,
180 // because null check either will be fired at it, or won't be fired at all.
GetExplicitUser(Inst * inst)181 Inst *GetExplicitUser(Inst *inst)
182 {
183 if (!inst->IsNullCheck() || !inst->CastToNullCheck()->IsImplicit() || inst->GetUsers().Empty()) {
184 return inst;
185 }
186 if (inst->HasSingleUser()) {
187 return inst->GetUsers().Front().GetInst();
188 }
189
190 Inst *userInst {nullptr};
191 for (auto &user : inst->GetUsers()) {
192 auto currInst = user.GetInst();
193 if (!IsSuitableForImplicitNullCheck(currInst)) {
194 continue;
195 }
196 if (currInst->GetInput(0) != inst) {
197 continue;
198 }
199 if (!currInst->CanThrow()) {
200 continue;
201 }
202 userInst = currInst;
203 break;
204 }
205 #ifndef NDEBUG
206 for (auto &user : inst->GetUsers()) {
207 if (user.GetInst()->IsPhi()) {
208 continue;
209 }
210 ASSERT(userInst != nullptr && userInst->IsDominate(user.GetInst()));
211 }
212 #endif
213 return userInst;
214 }
215
PropagateCallerMasks(SaveStateInst * saveState)216 void RegAllocResolver::PropagateCallerMasks(SaveStateInst *saveState)
217 {
218 saveState->CreateRootsStackMask(GetGraph()->GetAllocator());
219 auto user = GetExplicitUser(GetFirstUserOrInst(saveState));
220 // Get location of save state inputs at the save state user (note that at this point
221 // all inputs will have the same location at all users (excluding ReturnInlined that should be skipped)).
222 FillSaveStateRootsMask(saveState, user, saveState);
223 for (auto callerInst = saveState->GetCallerInst(); callerInst != nullptr;
224 callerInst = callerInst->GetSaveState()->GetCallerInst()) {
225 auto callerSs = callerInst->GetSaveState();
226 FillSaveStateRootsMask(callerSs, user, saveState);
227 }
228 }
229
FillSaveStateRootsMask(SaveStateInst * saveState,Inst * user,SaveStateInst * targetSs)230 void RegAllocResolver::FillSaveStateRootsMask(SaveStateInst *saveState, Inst *user, SaveStateInst *targetSs)
231 {
232 auto dstLn = liveness_->GetInstLifeIntervals(user)->GetBegin();
233
234 for (size_t i = 0; i < saveState->GetInputsCount(); ++i) {
235 auto inputInst = saveState->GetDataFlowInput(i);
236 if (!inputInst->IsMovableObject()) {
237 continue;
238 }
239 auto inputInterval = liveness_->GetInstLifeIntervals(inputInst);
240 auto sibling = inputInterval->FindSiblingAt(dstLn);
241 ASSERT(sibling != nullptr);
242 bool isSplitCover;
243 if (user->IsPropagateLiveness()) {
244 isSplitCover = sibling->SplitCover<true>(dstLn);
245 } else {
246 isSplitCover = sibling->SplitCover(dstLn);
247 }
248 if (!isSplitCover) {
249 continue;
250 }
251 AddLocationToRoots(sibling->GetLocation(), targetSs, GetGraph());
252 #ifndef NDEBUG
253 for (auto &testUser : targetSs->GetUsers()) {
254 if (testUser.GetInst()->GetOpcode() == Opcode::ReturnInlined ||
255 testUser.GetInst()->GetId() == user->GetId()) {
256 continue;
257 }
258 auto explicitTestUser = GetExplicitUser(testUser.GetInst());
259 auto udstLn = liveness_->GetInstLifeIntervals(explicitTestUser)->GetBegin();
260 ASSERT(sibling->GetLocation() == inputInterval->FindSiblingAt(udstLn)->GetLocation());
261 }
262 #endif
263 }
264 }
265
266 namespace {
HasSameLocation(LifeIntervals * interval,LifeNumber pos1,LifeNumber pos2)267 bool HasSameLocation(LifeIntervals *interval, LifeNumber pos1, LifeNumber pos2)
268 {
269 auto sibling1 = interval->FindSiblingAt(pos1);
270 auto sibling2 = interval->FindSiblingAt(pos2);
271 ASSERT(sibling1 != nullptr);
272 ASSERT(sibling2 != nullptr);
273 return sibling1->SplitCover(pos1) && sibling1->SplitCover(pos2) &&
274 sibling1->GetLocation() == sibling2->GetLocation();
275 }
276
SaveStateCopyRequired(Inst * inst,User * currUser,User * prevUser,const LivenessAnalyzer * la)277 bool SaveStateCopyRequired(Inst *inst, User *currUser, User *prevUser, const LivenessAnalyzer *la)
278 {
279 ASSERT(inst->IsSaveState());
280 auto currUserLn = la->GetInstLifeIntervals(GetExplicitUser(currUser->GetInst()))->GetBegin();
281 auto prevUserLn = la->GetInstLifeIntervals(GetExplicitUser(prevUser->GetInst()))->GetBegin();
282 bool needCopy = false;
283 // If current save state is part of inlined method then we have to check location for all
284 // parent save states.
285 for (auto ss = static_cast<SaveStateInst *>(inst); ss != nullptr && !needCopy;) {
286 for (size_t inputIdx = 0; inputIdx < ss->GetInputsCount() && !needCopy; inputIdx++) {
287 auto inputInterval = la->GetInstLifeIntervals(ss->GetDataFlowInput(inputIdx));
288 needCopy = !HasSameLocation(inputInterval, currUserLn, prevUserLn);
289 }
290 auto caller = ss->GetCallerInst();
291 if (caller == nullptr) {
292 ss = nullptr;
293 } else {
294 ss = caller->GetSaveState();
295 }
296 }
297 return needCopy;
298 }
299 } // namespace
300
ResolveSaveState(Inst * inst)301 void RegAllocResolver::ResolveSaveState(Inst *inst)
302 {
303 if (GetGraph()->GetCallingConvention() == nullptr) {
304 return;
305 }
306 ASSERT(inst->IsSaveState());
307
308 bool handledAllUsers = inst->HasSingleUser() || !inst->HasUsers();
309 while (!handledAllUsers) {
310 size_t copyUsers = 0;
311 auto userIt = inst->GetUsers().begin();
312 User *prevUser = &*userIt;
313 ++userIt;
314 bool needCopy = false;
315
316 // Find first user having different location for some of the save state inputs and use SaveState's
317 // copy for all preceding users.
318 for (; userIt != inst->GetUsers().end() && !needCopy; ++userIt, copyUsers++) {
319 auto &currUser = *userIt;
320 // ReturnInline's SaveState is required only for SaveState's inputs life range propagation,
321 // so it does not actually matter which interval will be actually used.
322 if (prevUser->GetInst()->GetOpcode() == Opcode::ReturnInlined) {
323 prevUser = &*userIt;
324 continue;
325 }
326 if (currUser.GetInst()->GetOpcode() == Opcode::ReturnInlined) {
327 continue;
328 }
329 needCopy = SaveStateCopyRequired(inst, &currUser, prevUser, liveness_);
330 prevUser = &*userIt;
331 }
332 if (needCopy) {
333 auto copy = CopySaveState(GetGraph(), static_cast<SaveStateInst *>(inst));
334 // Replace original SaveState with the copy for first N users (N = `copy_users` ).
335 while (copyUsers > 0) {
336 auto userInst = inst->GetUsers().Front().GetInst();
337 userInst->ReplaceInput(inst, copy);
338 copyUsers--;
339 }
340 inst->GetBasicBlock()->InsertAfter(copy, inst);
341 PropagateCallerMasks(copy);
342 handledAllUsers = inst->HasSingleUser();
343 } else {
344 handledAllUsers = !(userIt != inst->GetUsers().end());
345 }
346 }
347 // At this point inst either has single user or all its inputs have the same location at all users.
348 PropagateCallerMasks(static_cast<SaveStateInst *>(inst));
349 }
350
351 /*
352 * Pop output on stack from reserved register
353 */
ResolveOutput(Inst * inst)354 void RegAllocResolver::ResolveOutput(Inst *inst)
355 {
356 // Don't process LiveOut, since it is instruction with pseudo destination
357 if (inst->GetOpcode() == Opcode::LiveOut) {
358 return;
359 }
360 // Multi-output instructions' dst registers will be filled after procecssing theirs pseudo users
361 if (inst->GetLinearNumber() == INVALID_LINEAR_NUM || inst->GetDstCount() > 1) {
362 return;
363 }
364
365 if (CanStoreToAccumulator(inst)) {
366 return;
367 }
368
369 auto instInterval = liveness_->GetInstLifeIntervals(inst);
370 if (instInterval->NoDest()) {
371 inst->SetDstReg(GetInvalidReg());
372 return;
373 }
374
375 if (inst->GetOpcode() == Opcode::Parameter) {
376 inst->CastToParameter()->GetLocationData().SetDst(instInterval->GetLocation());
377 }
378 // Process multi-output inst
379 size_t dstMum = inst->GetSrcRegIndex();
380 if (IsPseudoUserOfMultiOutput(inst)) {
381 inst = inst->GetInput(0).GetInst();
382 }
383 // Wrtie dst
384 auto regType = instInterval->GetType();
385 if (instInterval->HasReg()) {
386 auto reg = instInterval->GetReg();
387 inst->SetDstReg(dstMum, reg);
388 GetGraph()->SetRegUsage(reg, regType);
389 } else {
390 ASSERT(inst->IsConst() || inst->IsPhi() || inst->IsParameter());
391 }
392 }
393
ResolveCatchPhis()394 bool RegAllocResolver::ResolveCatchPhis()
395 {
396 for (auto block : GetGraph()->GetBlocksRPO()) {
397 if (!block->IsCatchBegin()) {
398 continue;
399 }
400 for (auto inst : block->AllInstsSafe()) {
401 if (!inst->IsCatchPhi()) {
402 break;
403 }
404 if (inst->CastToCatchPhi()->IsAcc()) {
405 continue;
406 }
407 // This is the case when all throwable instructions were removed from the try-block,
408 // so that catch-handler is unreachable
409 if (inst->GetInputs().Empty()) {
410 return false;
411 }
412 auto newCatchPhi = SqueezeCatchPhiInputs(inst->CastToCatchPhi());
413 if (newCatchPhi != nullptr) {
414 inst->ReplaceUsers(newCatchPhi);
415 block->RemoveInst(inst);
416 }
417 }
418 }
419 return true;
420 }
421
422 /**
423 * Try to remove catch phi's inputs:
424 * If the input's corresponding throwable instruction dominates other throwable inst, we can remove other equal catch
425 * phi's input
426 *
427 * CatchPhi(v1, v1, v1, v2, v2, v2) -> CatchPhi(v1, v2)
428 *
429 * Return nullptr if inputs count was not reduced.
430 */
SqueezeCatchPhiInputs(CatchPhiInst * catchPhi)431 Inst *RegAllocResolver::SqueezeCatchPhiInputs(CatchPhiInst *catchPhi)
432 {
433 bool inputsAreIdentical = true;
434 auto firstInput = catchPhi->GetInput(0).GetInst();
435 for (size_t i = 1; i < catchPhi->GetInputsCount(); ++i) {
436 if (catchPhi->GetInput(i).GetInst() != firstInput) {
437 inputsAreIdentical = false;
438 break;
439 }
440 }
441 if (inputsAreIdentical) {
442 return firstInput;
443 }
444
445 // Create a new one and fill it with the necessary inputs
446 auto newCatchPhi = GetGraph()->CreateInstCatchPhi(catchPhi->GetType(), catchPhi->GetPc());
447 ASSERT(catchPhi->GetBasicBlock()->GetFirstInst()->IsCatchPhi());
448 catchPhi->GetBasicBlock()->PrependInst(newCatchPhi);
449 for (size_t i = 0; i < catchPhi->GetInputsCount(); i++) {
450 auto inputInst = catchPhi->GetInput(i).GetInst();
451 auto currentThrowableInst = catchPhi->GetThrowableInst(i);
452 ASSERT(GetGraph()->IsInstThrowable(currentThrowableInst));
453 bool skip = false;
454 for (size_t j = 0; j < newCatchPhi->GetInputsCount(); j++) {
455 auto savedInst = newCatchPhi->GetInput(j).GetInst();
456 if (savedInst != inputInst) {
457 continue;
458 }
459 auto savedThrowableInst = newCatchPhi->GetThrowableInst(j);
460 if (savedThrowableInst->IsDominate(currentThrowableInst)) {
461 skip = true;
462 }
463 if (currentThrowableInst->IsDominate(savedThrowableInst)) {
464 newCatchPhi->ReplaceThrowableInst(savedThrowableInst, currentThrowableInst);
465 skip = true;
466 }
467 if (skip) {
468 break;
469 }
470 }
471 if (!skip) {
472 newCatchPhi->AppendInput(inputInst);
473 newCatchPhi->AppendThrowableInst(currentThrowableInst);
474 }
475 }
476 if (newCatchPhi->GetInputsCount() == catchPhi->GetInputsCount()) {
477 newCatchPhi->GetBasicBlock()->RemoveInst(newCatchPhi);
478 return nullptr;
479 }
480 return newCatchPhi;
481 }
482
483 } // namespace ark::compiler
484