1 /*
2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "compiler_logger.h"
17 #include "optimizer/analysis/dominators_tree.h"
18 #include "optimizer/analysis/countable_loop_parser.h"
19 #include "optimizer/optimizations/loop_idioms.h"
20
21 namespace ark::compiler {
RunImpl()22 bool LoopIdioms::RunImpl()
23 {
24 if (GetGraph()->GetArch() == Arch::AARCH32) {
25 // There is only one supported idiom and intrinsic
26 // emitted for it could not be encoded on Arm32.
27 return false;
28 }
29 GetGraph()->RunPass<LoopAnalyzer>();
30 RunLoopsVisitor();
31 return isApplied_;
32 }
33
InvalidateAnalyses()34 void LoopIdioms::InvalidateAnalyses()
35 {
36 GetGraph()->InvalidateAnalysis<LoopAnalyzer>();
37 InvalidateBlocksOrderAnalyzes(GetGraph());
38 }
39
TransformLoop(Loop * loop)40 bool LoopIdioms::TransformLoop(Loop *loop)
41 {
42 if (TryTransformArrayInitIdiom(loop)) {
43 isApplied_ = true;
44 return true;
45 }
46 return false;
47 }
48
FindStoreForArrayInit(BasicBlock * block)49 StoreInst *FindStoreForArrayInit(BasicBlock *block)
50 {
51 StoreInst *store {nullptr};
52 for (auto inst : block->Insts()) {
53 if (inst->GetOpcode() != Opcode::StoreArray) {
54 continue;
55 }
56 if (store != nullptr) {
57 return nullptr;
58 }
59 store = inst->CastToStoreArray();
60 }
61 // should be a loop invariant
62 if (store != nullptr && store->GetStoredValue()->GetBasicBlock()->GetLoop() == block->GetLoop()) {
63 return nullptr;
64 }
65 return store;
66 }
67
ExtractArrayInitInitialIndexValue(PhiInst * index)68 Inst *ExtractArrayInitInitialIndexValue(PhiInst *index)
69 {
70 auto block = index->GetBasicBlock();
71 BasicBlock *pred = block->GetPredsBlocks().front();
72 if (pred == block) {
73 pred = block->GetPredsBlocks().back();
74 }
75 return index->GetPhiInput(pred);
76 }
77
AllUsesWithinLoop(Inst * inst,Loop * loop)78 bool AllUsesWithinLoop(Inst *inst, Loop *loop)
79 {
80 for (auto &user : inst->GetUsers()) {
81 if (user.GetInst()->GetBasicBlock()->GetLoop() != loop) {
82 return false;
83 }
84 }
85 return true;
86 }
87
CanReplaceLoop(Loop * loop,Marker marker)88 bool CanReplaceLoop(Loop *loop, Marker marker)
89 {
90 ASSERT(loop->GetBlocks().size() == 1);
91 auto block = loop->GetHeader();
92 for (auto inst : block->AllInsts()) {
93 if (inst->IsMarked(marker)) {
94 continue;
95 }
96 auto opcode = inst->GetOpcode();
97 if (opcode != Opcode::NOP && opcode != Opcode::SaveState && opcode != Opcode::SafePoint) {
98 return false;
99 }
100 }
101 return true;
102 }
103
IsLoopContainsArrayInitIdiom(StoreInst * store,Loop * loop,CountableLoopInfo & loopInfo)104 bool IsLoopContainsArrayInitIdiom(StoreInst *store, Loop *loop, CountableLoopInfo &loopInfo)
105 {
106 auto storeIdx = store->GetIndex();
107
108 return loopInfo.constStep == 1UL && loopInfo.index == storeIdx && loopInfo.normalizedCc == ConditionCode::CC_LT &&
109 AllUsesWithinLoop(storeIdx, loop) && AllUsesWithinLoop(loopInfo.update, loop) &&
110 AllUsesWithinLoop(loopInfo.ifImm->GetInput(0).GetInst(), loop);
111 }
112
TryTransformArrayInitIdiom(Loop * loop)113 bool LoopIdioms::TryTransformArrayInitIdiom(Loop *loop)
114 {
115 ASSERT(loop->GetInnerLoops().empty());
116 if (loop->GetBlocks().size() != 1) {
117 return false;
118 }
119
120 auto store = FindStoreForArrayInit(loop->GetHeader());
121 if (store == nullptr) {
122 return false;
123 }
124
125 auto loopInfoOpt = CountableLoopParser {*loop}.Parse();
126 if (!loopInfoOpt.has_value()) {
127 return false;
128 }
129 auto loopInfo = *loopInfoOpt;
130 if (!IsLoopContainsArrayInitIdiom(store, loop, loopInfo)) {
131 return false;
132 }
133 ASSERT(loopInfo.isInc);
134
135 MarkerHolder holder {GetGraph()};
136 Marker marker = holder.GetMarker();
137 store->SetMarker(marker);
138 loopInfo.update->SetMarker(marker);
139 loopInfo.index->SetMarker(marker);
140 loopInfo.ifImm->SetMarker(marker);
141 loopInfo.ifImm->GetInput(0).GetInst()->SetMarker(marker);
142
143 if (!CanReplaceLoop(loop, marker)) {
144 return false;
145 }
146
147 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Array init idiom found in loop: " << loop->GetId()
148 << "\n\tarray: " << *store->GetArray()
149 << "\n\tvalue: " << *store->GetStoredValue()
150 << "\n\tinitial index: " << *loopInfo.init << "\n\ttest: " << *loopInfo.test
151 << "\n\tupdate: " << *loopInfo.update << "\n\tstep: " << loopInfo.constStep
152 << "\n\tindex: " << *loopInfo.index;
153
154 bool alwaysJump = false;
155 if (loopInfo.init->IsConst() && loopInfo.test->IsConst()) {
156 auto iterations =
157 loopInfo.test->CastToConstant()->GetIntValue() - loopInfo.init->CastToConstant()->GetIntValue();
158 if (iterations <= ITERATIONS_THRESHOLD) {
159 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
160 << "Loop will have " << iterations << " iterations, so intrinsics will not be generated";
161 return false;
162 }
163 alwaysJump = true;
164 }
165
166 return ReplaceArrayInitLoop(loop, &loopInfo, store, alwaysJump);
167 }
168
CreateArrayInitIntrinsic(StoreInst * store,CountableLoopInfo * info)169 Inst *LoopIdioms::CreateArrayInitIntrinsic(StoreInst *store, CountableLoopInfo *info)
170 {
171 auto type = store->GetType();
172 RuntimeInterface::IntrinsicId intrinsicId;
173 switch (type) {
174 case DataType::BOOL:
175 case DataType::INT8:
176 case DataType::UINT8:
177 intrinsicId = RuntimeInterface::IntrinsicId::LIB_CALL_MEMSET_8;
178 break;
179 case DataType::INT16:
180 case DataType::UINT16:
181 intrinsicId = RuntimeInterface::IntrinsicId::LIB_CALL_MEMSET_16;
182 break;
183 case DataType::INT32:
184 case DataType::UINT32:
185 intrinsicId = RuntimeInterface::IntrinsicId::LIB_CALL_MEMSET_32;
186 break;
187 case DataType::INT64:
188 case DataType::UINT64:
189 intrinsicId = RuntimeInterface::IntrinsicId::LIB_CALL_MEMSET_64;
190 break;
191 case DataType::FLOAT32:
192 intrinsicId = RuntimeInterface::IntrinsicId::LIB_CALL_MEMSET_F32;
193 break;
194 case DataType::FLOAT64:
195 intrinsicId = RuntimeInterface::IntrinsicId::LIB_CALL_MEMSET_F64;
196 break;
197 default:
198 return nullptr;
199 }
200
201 auto fillArray = GetGraph()->CreateInstIntrinsic(DataType::VOID, store->GetPc(), intrinsicId);
202 fillArray->ClearFlag(inst_flags::Flags::REQUIRE_STATE);
203 fillArray->ClearFlag(inst_flags::Flags::RUNTIME_CALL);
204 fillArray->ClearFlag(inst_flags::Flags::CAN_THROW);
205 fillArray->SetInputs(GetGraph()->GetAllocator(), {{store->GetArray(), DataType::REFERENCE},
206 {store->GetStoredValue(), type},
207 {info->init, DataType::INT32},
208 {info->test, DataType::INT32}});
209 return fillArray;
210 }
211
ReplaceArrayInitLoop(Loop * loop,CountableLoopInfo * loopInfo,StoreInst * store,bool alwaysJump)212 bool LoopIdioms::ReplaceArrayInitLoop(Loop *loop, CountableLoopInfo *loopInfo, StoreInst *store, bool alwaysJump)
213 {
214 auto inst = CreateArrayInitIntrinsic(store, loopInfo);
215 if (inst == nullptr) {
216 return false;
217 }
218 auto header = loop->GetHeader();
219 auto preHeader = loop->GetPreHeader();
220
221 auto loopSucc = header->GetSuccessor(0) == header ? header->GetSuccessor(1) : header->GetSuccessor(0);
222 if (alwaysJump) {
223 ASSERT(loop->GetBlocks().size() == 1);
224 // insert block before disconnecting header to properly handle Phi in loop_succ
225 auto block = header->InsertNewBlockToSuccEdge(loopSucc);
226 preHeader->ReplaceSucc(header, block, true);
227 GetGraph()->DisconnectBlock(header, false, false);
228 block->AppendInst(inst);
229
230 COMPILER_LOG(INFO, LOOP_TRANSFORM) << "Replaced loop " << loop->GetId() << " with the instruction " << *inst
231 << " inserted into the new block " << block->GetId();
232 } else {
233 auto guardBlock = preHeader->InsertNewBlockToSuccEdge(header);
234 auto sub = GetGraph()->CreateInstSub(DataType::INT32, inst->GetPc(), loopInfo->test, loopInfo->init);
235 auto cmp = GetGraph()->CreateInstCompare(DataType::BOOL, inst->GetPc(), sub,
236 GetGraph()->FindOrCreateConstant(ITERATIONS_THRESHOLD),
237 DataType::INT32, ConditionCode::CC_LE);
238 auto ifImm =
239 GetGraph()->CreateInstIfImm(DataType::NO_TYPE, inst->GetPc(), cmp, 0, DataType::BOOL, ConditionCode::CC_NE);
240 guardBlock->AppendInst(sub);
241 guardBlock->AppendInst(cmp);
242 guardBlock->AppendInst(ifImm);
243
244 auto mergeBlock = header->InsertNewBlockToSuccEdge(loopSucc);
245 auto intrinsicBlock = GetGraph()->CreateEmptyBlock();
246
247 guardBlock->AddSucc(intrinsicBlock);
248 intrinsicBlock->AddSucc(mergeBlock);
249 intrinsicBlock->AppendInst(inst);
250
251 COMPILER_LOG(INFO, LOOP_TRANSFORM) << "Inserted conditional jump into intinsic " << *inst << " before loop "
252 << loop->GetId() << ", inserted blocks: " << intrinsicBlock->GetId() << ", "
253 << guardBlock->GetId() << ", " << mergeBlock->GetId();
254 }
255 return true;
256 }
257
258 } // namespace ark::compiler
259