• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "reg_alloc_resolver.h"
17 #include "reg_type.h"
18 #include "compiler/optimizer/code_generator/codegen.h"
19 #include "compiler/optimizer/ir/analysis.h"
20 #include "compiler/optimizer/ir/inst.h"
21 #include "compiler/optimizer/ir/graph.h"
22 #include "compiler/optimizer/ir/basicblock.h"
23 #include "compiler/optimizer/analysis/dominators_tree.h"
24 #include "optimizer/analysis/loop_analyzer.h"
25 
26 namespace ark::compiler {
27 /*
28  * For each instruction set destination register if it is assigned,
29  * Pop inputs from stack and push result on stack if stack slot is assigned.
30  */
Resolve()31 void RegAllocResolver::Resolve()
32 {
33     // We use RPO order because we need to calculate Caller's roots mask before its inlined callees
34     // (see PropagateCallerMasks).
35     for (auto block : GetGraph()->GetBlocksRPO()) {
36         for (auto inst : block->AllInstsSafe()) {
37             if (inst->IsSaveState()) {
38                 ResolveSaveState(inst);
39                 continue;
40             }
41             ResolveInputs(inst);
42             ResolveOutput(inst);
43             if (GetGraph()->IsInstThrowable(inst)) {
44                 AddCatchPhiMoves(inst);
45             }
46         }
47     }
48 }
49 
AddCatchPhiMoves(Inst * inst)50 void RegAllocResolver::AddCatchPhiMoves(Inst *inst)
51 {
52     auto spillFillInst = GetGraph()->CreateInstSpillFill(SpillFillType::INPUT_FILL);
53     auto handlers = GetGraph()->GetThrowableInstHandlers(inst);
54 
55     for (auto catchHandler : handlers) {
56         for (auto catchInst : catchHandler->AllInsts()) {
57             if (!catchInst->IsCatchPhi() || catchInst->CastToCatchPhi()->IsAcc()) {
58                 continue;
59             }
60             auto catchPhi = catchInst->CastToCatchPhi();
61             const auto &throwableInsts = catchPhi->GetThrowableInsts();
62             auto it = std::find(throwableInsts->begin(), throwableInsts->end(), inst);
63             if (it == throwableInsts->end()) {
64                 continue;
65             }
66             int index = std::distance(throwableInsts->begin(), it);
67             auto catchInput = catchPhi->GetDataFlowInput(index);
68             auto inputInterval = liveness_->GetInstLifeIntervals(catchInput);
69             ASSERT(inputInterval->GetSibling() == nullptr);
70             auto catchPhiInterval = liveness_->GetInstLifeIntervals(catchPhi);
71             if (inputInterval->GetLocation() != catchPhiInterval->GetLocation()) {
72                 ConnectIntervals(spillFillInst, inputInterval, catchPhiInterval);
73             }
74         }
75     }
76     if (!spillFillInst->GetSpillFills().empty()) {
77         inst->InsertBefore(spillFillInst);
78     }
79 }
80 
ResolveInputs(Inst * inst)81 void RegAllocResolver::ResolveInputs(Inst *inst)
82 {
83     if (inst->IsPhi() || inst->IsCatchPhi() || IsPseudoUserOfMultiOutput(inst) || InstHasPseudoInputs(inst)) {
84         return;
85     }
86     // Life-position before instruction to analyze intervals, that were splitted directly before it
87     auto insLn = liveness_->GetInstLifeIntervals(inst)->GetBegin();
88     auto preInsLn = insLn - 1U;
89 
90     inputLocations_.clear();
91     for (size_t i = 0; i < inst->GetInputsCount(); ++i) {
92         auto inputInterval = liveness_->GetInstLifeIntervals(inst->GetDataFlowInput(i));
93         auto sibling = inputInterval->FindSiblingAt(insLn);
94         ASSERT(sibling != nullptr);
95         inputLocations_.push_back(sibling->GetLocation());
96     }
97 
98     for (size_t i = 0; i < inst->GetInputsCount(); ++i) {
99         auto location = inst->GetLocation(i);
100         auto inputInterval = liveness_->GetInstLifeIntervals(inst->GetDataFlowInput(i));
101         if (CanReadFromAccumulator(inst, i) || inputInterval->NoDest() || location.IsInvalid()) {
102             continue;
103         }
104 
105         // Interval with fixed register can be splitted before `inst`: we don't need any extra moves in that case,
106         // since fixed register can't be overwrite
107         auto sibling = inputInterval->FindSiblingAt(preInsLn);
108         ASSERT(sibling != nullptr);
109         if (location.IsFixedRegister() && sibling->GetLocation() == location) {
110             auto it = std::find(inputLocations_.begin(), inputLocations_.end(), location);
111             // If some other instruction reside in the location then we can't reuse a split ending
112             // before the inst as a corresponding location will be overridden
113             if (it == inputLocations_.end() || static_cast<size_t>(it - inputLocations_.begin()) == i) {
114                 continue;
115             }
116         }
117 
118         // Otherwise use sibling covering `inst`
119         if (sibling->GetEnd() == preInsLn) {
120             sibling = sibling->GetSibling();
121         }
122 
123         // Input's location required any register: specify the allocated one
124         if (location.IsUnallocatedRegister()) {
125             ASSERT(sibling->HasReg());
126             inst->SetLocation(i, sibling->GetLocation());
127             continue;
128         }
129 
130         // Finally, if input's location is not equal to the required one, add spill-fill
131         if (sibling->GetLocation() != location) {
132             AddMoveToFixedLocation(inst, sibling->GetLocation(), i);
133         }
134     }
135 
136     if (inst->RequireTmpReg()) {
137         auto interval = liveness_->GetTmpRegInterval(inst);
138         ASSERT(interval != nullptr);
139         ASSERT(interval->HasReg());
140         auto regLocation = Location::MakeRegister(interval->GetReg(), interval->GetType());
141         inst->SetTmpLocation(regLocation);
142         GetGraph()->SetRegUsage(regLocation);
143     }
144 }
145 
AddMoveToFixedLocation(Inst * inst,Location inputLocation,size_t inputNum)146 void RegAllocResolver::AddMoveToFixedLocation(Inst *inst, Location inputLocation, size_t inputNum)
147 {
148     // Create or get existing SpillFillInst
149     SpillFillInst *sfInst {};
150     if (inst->GetPrev() != nullptr && inst->GetPrev()->IsSpillFill()) {
151         sfInst = inst->GetPrev()->CastToSpillFill();
152     } else {
153         sfInst = GetGraph()->CreateInstSpillFill(SpillFillType::INPUT_FILL);
154         inst->InsertBefore(sfInst);
155     }
156 
157     // Add move from input to fixed location
158     auto type = ConvertRegType(GetGraph(), inst->GetInputType(inputNum));
159     auto fixedLocation = inst->GetLocation(inputNum);
160     if (fixedLocation.IsFixedRegister()) {
161         GetGraph()->SetRegUsage(fixedLocation.GetValue(), type);
162     }
163     sfInst->AddSpillFill(inputLocation, fixedLocation, type);
164 }
165 
GetFirstUserOrInst(Inst * inst)166 Inst *GetFirstUserOrInst(Inst *inst)
167 {
168     for (auto &user : inst->GetUsers()) {
169         if (user.GetInst()->GetOpcode() != Opcode::ReturnInlined) {
170             return user.GetInst();
171         }
172     }
173     return inst;
174 }
175 
176 // For implicit null check we need to find the first null check's user to
177 // correctly capture SaveState's input locations, because implicit null checks are fired
178 // when its input is accessed by its users (for example, when LoadArray instruction is loading
179 // value from null array reference). Some life intervals may change its location (due to spilling)
180 // between NullCheck and its users, so locations captured at implicit null check could be incorrect.
181 // While implicit NullCheck may have multiple users we can use only a user dominating all other users,
182 // because null check either will be fired at it, or won't be fired at all.
GetExplicitUser(Inst * inst)183 Inst *GetExplicitUser(Inst *inst)
184 {
185     if (!inst->IsNullCheck() || !inst->CastToNullCheck()->IsImplicit() || inst->GetUsers().Empty()) {
186         return inst;
187     }
188     if (inst->HasSingleUser()) {
189         return inst->GetUsers().Front().GetInst();
190     }
191 
192     Inst *userInst {nullptr};
193     for (auto &user : inst->GetUsers()) {
194         auto currInst = user.GetInst();
195         if (!IsSuitableForImplicitNullCheck(currInst)) {
196             continue;
197         }
198         if (currInst->GetInput(0) != inst) {
199             continue;
200         }
201         if (!currInst->CanThrow()) {
202             continue;
203         }
204         userInst = currInst;
205         break;
206     }
207 #ifndef NDEBUG
208     for (auto &user : inst->GetUsers()) {
209         if (user.GetInst()->IsPhi()) {
210             continue;
211         }
212         ASSERT(userInst != nullptr && userInst->IsDominate(user.GetInst()));
213     }
214 #endif
215     return userInst;
216 }
217 
PropagateCallerMasks(SaveStateInst * saveState)218 void RegAllocResolver::PropagateCallerMasks(SaveStateInst *saveState)
219 {
220     saveState->CreateRootsStackMask(GetGraph()->GetAllocator());
221     auto user = GetExplicitUser(GetFirstUserOrInst(saveState));
222     // Get location of save state inputs at the save state user (note that at this point
223     // all inputs will have the same location at all users (excluding ReturnInlined that should be skipped)).
224     FillSaveStateRootsMask(saveState, user, saveState);
225 
226     for (auto callerInst = saveState->GetCallerInst(); callerInst != nullptr;) {
227         auto callerSs = callerInst->GetSaveState();
228         FillSaveStateRootsMask(callerSs, user, saveState);
229         auto saveStateTmp = callerInst->GetSaveState();
230         ASSERT(saveStateTmp != nullptr);
231         callerInst = saveStateTmp->GetCallerInst();
232     }
233 }
234 
FillSaveStateRootsMask(SaveStateInst * saveState,Inst * user,SaveStateInst * targetSs)235 void RegAllocResolver::FillSaveStateRootsMask(SaveStateInst *saveState, Inst *user, SaveStateInst *targetSs)
236 {
237     auto dstLn = liveness_->GetInstLifeIntervals(user)->GetBegin();
238 
239     for (size_t i = 0; i < saveState->GetInputsCount(); ++i) {
240         auto inputInst = saveState->GetDataFlowInput(i);
241         if (!inputInst->IsMovableObject()) {
242             continue;
243         }
244         auto inputInterval = liveness_->GetInstLifeIntervals(inputInst);
245         auto sibling = inputInterval->FindSiblingAt(dstLn);
246         ASSERT(sibling != nullptr);
247         bool isSplitCover;
248         if (user->IsPropagateLiveness()) {
249             isSplitCover = sibling->SplitCover<true>(dstLn);
250         } else {
251             isSplitCover = sibling->SplitCover(dstLn);
252         }
253         if (!isSplitCover) {
254             continue;
255         }
256         AddLocationToRoots(sibling->GetLocation(), targetSs, GetGraph());
257 #ifndef NDEBUG
258         for (auto &testUser : targetSs->GetUsers()) {
259             if (testUser.GetInst()->GetOpcode() == Opcode::ReturnInlined ||
260                 testUser.GetInst()->GetId() == user->GetId()) {
261                 continue;
262             }
263             auto explicitTestUser = GetExplicitUser(testUser.GetInst());
264             auto udstLn = liveness_->GetInstLifeIntervals(explicitTestUser)->GetBegin();
265             ASSERT(sibling->GetLocation() == inputInterval->FindSiblingAt(udstLn)->GetLocation());
266         }
267 #endif
268     }
269 }
270 
271 namespace {
HasSameLocation(LifeIntervals * interval,LifeNumber pos1,LifeNumber pos2)272 bool HasSameLocation(LifeIntervals *interval, LifeNumber pos1, LifeNumber pos2)
273 {
274     auto sibling1 = interval->FindSiblingAt(pos1);
275     auto sibling2 = interval->FindSiblingAt(pos2);
276     ASSERT(sibling1 != nullptr);
277     ASSERT(sibling2 != nullptr);
278     return sibling1->SplitCover(pos1) && sibling1->SplitCover(pos2) &&
279            sibling1->GetLocation() == sibling2->GetLocation();
280 }
281 
SaveStateCopyRequired(Inst * inst,User * currUser,User * prevUser,const LivenessAnalyzer * la)282 bool SaveStateCopyRequired(Inst *inst, User *currUser, User *prevUser, const LivenessAnalyzer *la)
283 {
284     ASSERT(inst->IsSaveState());
285     auto currUserLn = la->GetInstLifeIntervals(GetExplicitUser(currUser->GetInst()))->GetBegin();
286     auto prevUserLn = la->GetInstLifeIntervals(GetExplicitUser(prevUser->GetInst()))->GetBegin();
287     bool needCopy = false;
288     // If current save state is part of inlined method then we have to check location for all
289     // parent save states.
290     for (auto ss = static_cast<SaveStateInst *>(inst); ss != nullptr && !needCopy;) {
291         for (size_t inputIdx = 0; inputIdx < ss->GetInputsCount() && !needCopy; inputIdx++) {
292             auto inputInterval = la->GetInstLifeIntervals(ss->GetDataFlowInput(inputIdx));
293             needCopy = !HasSameLocation(inputInterval, currUserLn, prevUserLn);
294         }
295         auto caller = ss->GetCallerInst();
296         if (caller == nullptr) {
297             ss = nullptr;
298         } else {
299             ss = caller->GetSaveState();
300         }
301     }
302     return needCopy;
303 }
304 }  // namespace
305 
ResolveSaveState(Inst * inst)306 void RegAllocResolver::ResolveSaveState(Inst *inst)
307 {
308     if (GetGraph()->GetCallingConvention() == nullptr) {
309         return;
310     }
311     ASSERT(inst->IsSaveState());
312 
313     bool handledAllUsers = inst->HasSingleUser() || !inst->HasUsers();
314     while (!handledAllUsers) {
315         size_t copyUsers = 0;
316         auto userIt = inst->GetUsers().begin();
317         User *prevUser = &*userIt;
318         ++userIt;
319         bool needCopy = false;
320 
321         // Find first user having different location for some of the save state inputs and use SaveState's
322         // copy for all preceding users.
323         for (; userIt != inst->GetUsers().end() && !needCopy; ++userIt, copyUsers++) {
324             auto &currUser = *userIt;
325             // ReturnInline's SaveState is required only for SaveState's inputs life range propagation,
326             // so it does not actually matter which interval will be actually used.
327             if (prevUser->GetInst()->GetOpcode() == Opcode::ReturnInlined) {
328                 prevUser = &*userIt;
329                 continue;
330             }
331             if (currUser.GetInst()->GetOpcode() == Opcode::ReturnInlined) {
332                 continue;
333             }
334             needCopy = SaveStateCopyRequired(inst, &currUser, prevUser, liveness_);
335             prevUser = &*userIt;
336         }
337         if (needCopy) {
338             auto copy = CopySaveState(GetGraph(), static_cast<SaveStateInst *>(inst));
339             // Replace original SaveState with the copy for first N users (N = `copy_users` ).
340             while (copyUsers > 0) {
341                 auto userInst = inst->GetUsers().Front().GetInst();
342                 userInst->ReplaceInput(inst, copy);
343                 copyUsers--;
344             }
345             inst->GetBasicBlock()->InsertAfter(copy, inst);
346             PropagateCallerMasks(copy);
347             handledAllUsers = inst->HasSingleUser();
348         } else {
349             handledAllUsers = !(userIt != inst->GetUsers().end());
350         }
351     }
352     // At this point inst either has single user or all its inputs have the same location at all users.
353     PropagateCallerMasks(static_cast<SaveStateInst *>(inst));
354 }
355 
356 /*
357  * Pop output on stack from reserved register
358  */
ResolveOutput(Inst * inst)359 void RegAllocResolver::ResolveOutput(Inst *inst)
360 {
361     // Don't process LiveOut, since it is instruction with pseudo destination
362     if (inst->GetOpcode() == Opcode::LiveOut) {
363         return;
364     }
365     // Multi-output instructions' dst registers will be filled after procecssing theirs pseudo users
366     if (inst->GetLinearNumber() == INVALID_LINEAR_NUM || inst->GetDstCount() > 1) {
367         return;
368     }
369 
370     if (CanStoreToAccumulator(inst)) {
371         return;
372     }
373 
374     auto instInterval = liveness_->GetInstLifeIntervals(inst);
375     if (instInterval->NoDest()) {
376         inst->SetDstReg(GetInvalidReg());
377         return;
378     }
379 
380     if (inst->GetOpcode() == Opcode::Parameter) {
381         inst->CastToParameter()->GetLocationData().SetDst(instInterval->GetLocation());
382     }
383     // Process multi-output inst
384     size_t dstMum = inst->GetSrcRegIndex();
385     if (IsPseudoUserOfMultiOutput(inst)) {
386         inst = inst->GetInput(0).GetInst();
387     }
388     // Wrtie dst
389     auto regType = instInterval->GetType();
390     if (instInterval->HasReg()) {
391         auto reg = instInterval->GetReg();
392         inst->SetDstReg(dstMum, reg);
393         GetGraph()->SetRegUsage(reg, regType);
394     } else {
395         ASSERT(inst->IsConst() || inst->IsPhi() || inst->IsParameter());
396     }
397 }
398 
ResolveCatchPhis()399 bool RegAllocResolver::ResolveCatchPhis()
400 {
401     for (auto block : GetGraph()->GetBlocksRPO()) {
402         if (!block->IsCatchBegin()) {
403             continue;
404         }
405         for (auto inst : block->AllInstsSafe()) {
406             if (!inst->IsCatchPhi()) {
407                 break;
408             }
409             if (inst->CastToCatchPhi()->IsAcc()) {
410                 continue;
411             }
412             // This is the case when all throwable instructions were removed from the try-block,
413             // so that catch-handler is unreachable
414             if (inst->GetInputs().Empty()) {
415                 return false;
416             }
417             auto newCatchPhi = SqueezeCatchPhiInputs(inst->CastToCatchPhi());
418             if (newCatchPhi != nullptr) {
419                 inst->ReplaceUsers(newCatchPhi);
420                 block->RemoveInst(inst);
421             }
422         }
423     }
424     return true;
425 }
426 
427 /**
428  * Try to remove catch phi's inputs:
429  * If the input's corresponding throwable instruction dominates other throwable inst, we can remove other equal catch
430  * phi's input
431  *
432  * CatchPhi(v1, v1, v1, v2, v2, v2) -> CatchPhi(v1, v2)
433  *
434  * Return nullptr if inputs count was not reduced.
435  */
SqueezeCatchPhiInputs(CatchPhiInst * catchPhi)436 Inst *RegAllocResolver::SqueezeCatchPhiInputs(CatchPhiInst *catchPhi)
437 {
438     bool inputsAreIdentical = true;
439     auto firstInput = catchPhi->GetInput(0).GetInst();
440     for (size_t i = 1; i < catchPhi->GetInputsCount(); ++i) {
441         if (catchPhi->GetInput(i).GetInst() != firstInput) {
442             inputsAreIdentical = false;
443             break;
444         }
445     }
446     if (inputsAreIdentical) {
447         return firstInput;
448     }
449 
450     // Create a new one and fill it with the necessary inputs
451     auto newCatchPhi = GetGraph()->CreateInstCatchPhi(catchPhi->GetType(), catchPhi->GetPc());
452     ASSERT(catchPhi->GetBasicBlock()->GetFirstInst()->IsCatchPhi());
453     catchPhi->GetBasicBlock()->PrependInst(newCatchPhi);
454     for (size_t i = 0; i < catchPhi->GetInputsCount(); i++) {
455         auto inputInst = catchPhi->GetInput(i).GetInst();
456         auto currentThrowableInst = catchPhi->GetThrowableInst(i);
457         ASSERT(GetGraph()->IsInstThrowable(currentThrowableInst));
458         bool skip = false;
459         for (size_t j = 0; j < newCatchPhi->GetInputsCount(); j++) {
460             auto savedInst = newCatchPhi->GetInput(j).GetInst();
461             if (savedInst != inputInst) {
462                 continue;
463             }
464             auto savedThrowableInst = newCatchPhi->GetThrowableInst(j);
465             if (savedThrowableInst->IsDominate(currentThrowableInst)) {
466                 skip = true;
467             }
468             if (currentThrowableInst->IsDominate(savedThrowableInst)) {
469                 newCatchPhi->ReplaceThrowableInst(savedThrowableInst, currentThrowableInst);
470                 skip = true;
471             }
472             if (skip) {
473                 break;
474             }
475         }
476         if (!skip) {
477             newCatchPhi->AppendInput(inputInst);
478             newCatchPhi->AppendThrowableInst(currentThrowableInst);
479         }
480     }
481     if (newCatchPhi->GetInputsCount() == catchPhi->GetInputsCount()) {
482         newCatchPhi->GetBasicBlock()->RemoveInst(newCatchPhi);
483         return nullptr;
484     }
485     return newCatchPhi;
486 }
487 
488 }  // namespace ark::compiler
489