1 /*
2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "adjust_arefs.h"
17 #include "optimizer/ir/basicblock.h"
18 #include "optimizer/ir/graph.h"
19 #include "optimizer/analysis/loop_analyzer.h"
20
21 namespace ark::compiler {
AdjustRefs(Graph * graph)22 AdjustRefs::AdjustRefs(Graph *graph)
23 : Optimization {graph},
24 defs_ {graph->GetLocalAllocator()->Adapter()},
25 workset_ {graph->GetLocalAllocator()->Adapter()},
26 heads_ {graph->GetLocalAllocator()->Adapter()},
27 instsToReplace_ {graph->GetLocalAllocator()->Adapter()}
28 {
29 }
30
IsRefAdjustable(const Inst * inst)31 static bool IsRefAdjustable(const Inst *inst)
32 {
33 switch (inst->GetOpcode()) {
34 case Opcode::StoreArray:
35 return !inst->CastToStoreArray()->GetNeedBarrier();
36 case Opcode::LoadArray:
37 return !inst->CastToLoadArray()->GetNeedBarrier() && !inst->CastToLoadArray()->IsString();
38 default:
39 break;
40 }
41
42 return false;
43 }
44
RunImpl()45 bool AdjustRefs::RunImpl()
46 {
47 auto defMarker = GetGraph()->NewMarker();
48 for (const auto &bb : GetGraph()->GetBlocksRPO()) {
49 if (bb->GetLoop()->IsRoot()) {
50 continue;
51 }
52
53 for (auto inst : bb->Insts()) {
54 if (!IsRefAdjustable(inst)) {
55 continue;
56 }
57 auto def = inst->GetInput(0).GetInst();
58 if (!def->SetMarker(defMarker)) {
59 defs_.push_back(def);
60 }
61 }
62 }
63 GetGraph()->EraseMarker(defMarker);
64 for (auto def : defs_) {
65 workset_.clear();
66 auto markerHolder = MarkerHolder(GetGraph());
67 worksetMarker_ = markerHolder.GetMarker();
68 for (auto &user : def->GetUsers()) {
69 auto i = user.GetInst();
70 if (!IsRefAdjustable(i) || i->GetBasicBlock()->GetLoop()->IsRoot()) {
71 continue;
72 }
73 workset_.push_back(i);
74 i->SetMarker(worksetMarker_);
75 }
76 ProcessArrayUses();
77 }
78 // We make second pass, because some LoadArrays and StoreArrays can be removed
79 for (const auto &bb : GetGraph()->GetBlocksRPO()) {
80 for (auto inst : bb->Insts()) {
81 if (IsRefAdjustable(inst)) {
82 ProcessIndex(inst);
83 }
84 }
85 }
86
87 return added_;
88 }
89
ProcessArrayUses()90 void AdjustRefs::ProcessArrayUses()
91 {
92 ASSERT(heads_.empty());
93 auto enteredHolder = MarkerHolder(GetGraph());
94 blockEntered_ = enteredHolder.GetMarker();
95 GetHeads();
96 while (!heads_.empty()) {
97 auto head = heads_.back();
98 heads_.pop_back();
99 ASSERT(IsRefAdjustable(head));
100 ASSERT(head->IsMarked(worksetMarker_));
101 ASSERT(head->GetBasicBlock() != nullptr);
102 loop_ = head->GetBasicBlock()->GetLoop();
103 auto processedHolder = MarkerHolder(GetGraph());
104 blockProcessed_ = processedHolder.GetMarker();
105 instsToReplace_.clear();
106 ASSERT(!head->GetBasicBlock()->IsMarked(blockProcessed_));
107 WalkChainDown(head->GetBasicBlock(), head, head);
108 if (instsToReplace_.size() > 1) {
109 ProcessChain(head);
110 }
111 }
112 }
113
114 /* Create the list of "heads" - the instructions that are not dominated by any other
115 * instruction in the workset, or have a potential GC trigger after the dominating
116 * instruction and thus cannot be merged with it. In both cases "head" can potentially be
117 * the first instruction in a chain. */
GetHeads()118 void AdjustRefs::GetHeads()
119 {
120 for (const auto i : workset_) {
121 auto comp = [i](const Inst *i1) { return i1->IsDominate(i) && i != i1; };
122 if (workset_.end() == std::find_if(workset_.begin(), workset_.end(), comp)) {
123 heads_.push_back(i);
124 i->GetBasicBlock()->SetMarker(blockEntered_);
125 }
126 }
127 }
128
129 /* Add instructions which can be merged with head to insts_to_replace_
130 * Instructions which are visited but cannot be merged are added to heads_ */
WalkChainDown(BasicBlock * bb,Inst * startFrom,Inst * head)131 void AdjustRefs::WalkChainDown(BasicBlock *bb, Inst *startFrom, Inst *head)
132 {
133 bb->SetMarker(blockEntered_);
134 for (auto cur = startFrom; cur != nullptr; cur = cur->GetNext()) {
135 /* potential switch to VM, the chain breaks here */
136 if (cur->IsRuntimeCall() || cur->GetOpcode() == Opcode::SafePoint) {
137 head = nullptr;
138 } else if (cur->IsMarked(worksetMarker_)) {
139 if (head == nullptr) {
140 heads_.push_back(cur);
141 return;
142 }
143 ASSERT(head->IsDominate(cur));
144 instsToReplace_.push_back(cur);
145 }
146 }
147 if (head != nullptr) {
148 bb->SetMarker(blockProcessed_);
149 }
150 for (auto succ : bb->GetSuccsBlocks()) {
151 if (succ->GetLoop() != loop_ || succ->IsMarked(blockEntered_)) {
152 continue;
153 }
154
155 if (head != nullptr) {
156 auto blockNotProcessed = [this](BasicBlock *b) { return !b->IsMarked(blockProcessed_); };
157 auto it = std::find_if(succ->GetPredsBlocks().begin(), succ->GetPredsBlocks().end(), blockNotProcessed);
158 if (it != succ->GetPredsBlocks().end()) {
159 continue;
160 }
161 }
162 // If all predecessors of succ were walked with the current value of head,
163 // we can be sure that there are no SafePoints or runtime calls
164 // on any path from block with head to succ
165 WalkChainDown(succ, succ->GetFirstInst(), head);
166 }
167 }
168
ProcessChain(Inst * head)169 void AdjustRefs::ProcessChain(Inst *head)
170 {
171 Inst *def = head->GetInput(0).GetInst();
172 auto off = GetGraph()->GetRuntime()->GetArrayDataOffset(GetGraph()->GetArch());
173 auto arrData = InsertPointerArithmetic(def, off, head, def->GetPc(), true);
174 ASSERT(arrData != nullptr);
175
176 for (auto inst : instsToReplace_) {
177 auto scale = DataType::ShiftByType(inst->GetType(), GetGraph()->GetArch());
178 InsertMem(inst, arrData, inst->GetInput(1).GetInst(), scale);
179 }
180
181 added_ = true;
182 }
183
InsertPointerArithmetic(Inst * input,uint64_t imm,Inst * insertBefore,uint32_t pc,bool isAdd)184 Inst *AdjustRefs::InsertPointerArithmetic(Inst *input, uint64_t imm, Inst *insertBefore, uint32_t pc, bool isAdd)
185 {
186 uint32_t size = DataType::GetTypeSize(DataType::POINTER, GetGraph()->GetArch());
187 if (!GetGraph()->GetEncoder()->CanEncodeImmAddSubCmp(imm, size, false)) {
188 return nullptr;
189 }
190 Inst *newInst;
191 if (isAdd) {
192 newInst = GetGraph()->CreateInstAddI(DataType::POINTER, pc, input, imm);
193 } else {
194 newInst = GetGraph()->CreateInstSubI(DataType::POINTER, pc, input, imm);
195 }
196 insertBefore->InsertBefore(newInst);
197 return newInst;
198 }
199
InsertMem(Inst * org,Inst * base,Inst * index,uint8_t scale)200 void AdjustRefs::InsertMem(Inst *org, Inst *base, Inst *index, uint8_t scale)
201 {
202 Inst *ldst = nullptr;
203
204 ASSERT(base->IsDominate(org));
205
206 if (org->IsStore()) {
207 constexpr auto VALUE_IDX = 2;
208 ldst = GetGraph()->CreateInst(Opcode::Store);
209 ldst->SetInput(VALUE_IDX, org->GetInput(VALUE_IDX).GetInst());
210 ldst->CastToStore()->SetScale(scale);
211 } else if (org->IsLoad()) {
212 ldst = GetGraph()->CreateInst(Opcode::Load);
213 ldst->CastToLoad()->SetScale(scale);
214 } else {
215 UNREACHABLE();
216 }
217 ldst->SetInput(0, base);
218 ldst->SetInput(1, index);
219 ldst->SetType(org->GetType());
220 org->ReplaceUsers(ldst);
221 org->RemoveInputs();
222 org->GetBasicBlock()->ReplaceInst(org, ldst);
223 }
224
225 // from
226 // 3.i32 AddI v2, 0xN -> v4
227 // 4.i64 LoadArray v1, v3 -> ....
228 // to
229 // 5.ptr AddI v1, 0x10 + (0xN << 3) -> v6
230 // 6.i64 Load v5, v2 -> ....
ProcessIndex(Inst * mem)231 void AdjustRefs::ProcessIndex(Inst *mem)
232 {
233 Inst *index = mem->GetInput(1).GetInst();
234 bool isAdd;
235 uint64_t imm;
236 if (index->GetOpcode() == Opcode::AddI) {
237 isAdd = true;
238 imm = index->CastToAddI()->GetImm();
239 } else if (index->GetOpcode() == Opcode::SubI) {
240 isAdd = false;
241 imm = index->CastToSubI()->GetImm();
242 } else {
243 return;
244 }
245 auto scale = DataType::ShiftByType(mem->GetType(), GetGraph()->GetArch());
246 uint64_t off = GetGraph()->GetRuntime()->GetArrayDataOffset(GetGraph()->GetArch());
247 Inst *base = mem->GetInput(0).GetInst();
248
249 Inst *newBase;
250 if (!isAdd) {
251 if (off > (imm << scale)) {
252 uint64_t newOff = off - (imm << scale);
253 newBase = InsertPointerArithmetic(base, newOff, mem, mem->GetPc(), true);
254 } else if (off < (imm << scale)) {
255 uint64_t newOff = (imm << scale) - off;
256 newBase = InsertPointerArithmetic(base, newOff, mem, mem->GetPc(), false);
257 } else {
258 ASSERT(off == (imm << scale));
259 newBase = base;
260 }
261 } else {
262 uint64_t newOff = off + (imm << scale);
263 newBase = InsertPointerArithmetic(base, newOff, mem, mem->GetPc(), true);
264 }
265 if (newBase == nullptr) {
266 return;
267 }
268 InsertMem(mem, newBase, index->GetInput(0).GetInst(), scale);
269
270 added_ = true;
271 }
272
273 } // namespace ark::compiler
274