1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 /*
17 * This module analyzes the tag distribution in a switch statement and decides
18 * the best strategy in terms of runtime performance to generate code for it.
19 * The generated code makes use of 3 code generation techniques:
20 *
21 * 1. cascade of if-then-else based on equality test
22 * 2. rangegoto
23 * 3. binary search
24 *
25 * 1 is applied only if the number of possibilities is <= 6.
26 * 2 corresponds to indexed jump, but it requires allocating an array
27 * initialized with the jump targets. Since it causes memory usage overhead,
28 * rangegoto is used only if the density is higher than 0.7.
29 * If neither 1 nor 2 is applicable, 3 is applied in the form of a decision
30 * tree. In this case, each test would split the tags into 2 halves. For
31 * each half, the above algorithm is then applied recursively until the
32 * algorithm terminates.
33 *
34 * But we don't want to apply 3 right from the beginning if both 1 and 2 do not
35 * apply, because there may be regions that have density > 0.7. Thus, the
36 * switch lowerer begins by finding clusters. A cluster is defined to be a
37 * maximal range of tags whose density is > 0.7.
38 *
39 * In finding clusters, the original switch table is sorted and then each dense
40 * region is condensed into 1 switch item; in the switch_items table, each item // either corresponds to an original
41 * entry in the original switch table (pair's // second is 0), or to a dense region (pair's second gives the upper limit
42 * of the dense range). The output code is generated based on the switch_items. See BuildCodeForSwitchItems() which is
43 * recursive.
44 */
45 #include "switch_lowerer.h"
46 #include "mir_nodes.h"
47 #include "mir_builder.h"
48 #include "mir_lower.h" /* "../../../maple_ir/include/mir_lower.h" */
49
50 namespace maplebe {
51 using namespace maple;
52
CasePairKeyLessThan(const CasePair & left,const CasePair & right)53 static bool CasePairKeyLessThan(const CasePair &left, const CasePair &right)
54 {
55 return left.first < right.first;
56 }
57
FindClusters(MapleVector<Cluster> & clusters) const58 void SwitchLowerer::FindClusters(MapleVector<Cluster> &clusters) const
59 {
60 int32 length = static_cast<int>(stmt->GetSwitchTable().size());
61 int32 i = 0;
62 while (i < length - kClusterSwitchCutoff) {
63 for (int32 j = length - 1; j > i; --j) {
64 float tmp1 = static_cast<float>(j - i);
65 float tmp2 = static_cast<float>(stmt->GetCasePair(static_cast<size_t>(static_cast<uint32>(j))).first) -
66 static_cast<float>(stmt->GetCasePair(static_cast<size_t>(static_cast<uint32>(i))).first);
67 float currDensity = tmp1 / tmp2;
68 if (((j - i) >= kClusterSwitchCutoff) &&
69 ((currDensity >= kClusterSwitchDensityHigh) ||
70 ((currDensity >= kClusterSwitchDensityLow) && (tmp2 < kMaxRangeGotoTableSize)))) {
71 clusters.emplace_back(Cluster(i, j));
72 i = j;
73 break;
74 }
75 }
76 ++i;
77 }
78 }
79
InitSwitchItems(MapleVector<Cluster> & clusters)80 void SwitchLowerer::InitSwitchItems(MapleVector<Cluster> &clusters)
81 {
82 if (clusters.empty()) {
83 for (int32 i = 0; i < static_cast<int>(stmt->GetSwitchTable().size()); ++i) {
84 switchItems.emplace_back(SwitchItem(i, 0));
85 }
86 } else {
87 int32 j = 0;
88 Cluster front = clusters[j];
89 for (int32 i = 0; i < static_cast<int>(stmt->GetSwitchTable().size()); ++i) {
90 if (i == front.first) {
91 switchItems.emplace_back(SwitchItem(i, front.second));
92 i = front.second;
93 ++j;
94 if (static_cast<int>(clusters.size()) > j) {
95 front = clusters[j];
96 }
97 } else {
98 switchItems.emplace_back(SwitchItem(i, 0));
99 }
100 }
101 }
102 }
103
BuildRangeGotoNode(int32 startIdx,int32 endIdx)104 RangeGotoNode *SwitchLowerer::BuildRangeGotoNode(int32 startIdx, int32 endIdx)
105 {
106 RangeGotoNode *node = mirModule.CurFuncCodeMemPool()->New<RangeGotoNode>(mirModule);
107 node->SetOpnd(stmt->GetSwitchOpnd(), 0);
108
109 node->SetRangeGotoTable(SmallCaseVector(mirModule.CurFuncCodeMemPoolAllocator()->Adapter()));
110 node->SetTagOffset(static_cast<int32>(stmt->GetCasePair(static_cast<size_t>(startIdx)).first));
111 uint32 curTag = 0;
112 node->AddRangeGoto(curTag, stmt->GetCasePair(startIdx).second);
113 int64 lastCaseTag = stmt->GetSwitchTable().at(startIdx).first;
114 for (int32 i = startIdx + 1; i <= endIdx; ++i) {
115 /*
116 * The second condition is to solve the problem that compilation falls into a dead loop,
117 * because in some cases the two will fall into a dead loop if they are equal.
118 */
119 while ((stmt->GetCasePair(i).first != (lastCaseTag + 1)) && (stmt->GetCasePair(i).first != lastCaseTag)) {
120 /* fill in a gap in the case tags */
121 curTag = static_cast<uint32>((++lastCaseTag) - node->GetTagOffset());
122 if (stmt->GetDefaultLabel() != 0) {
123 node->AddRangeGoto(curTag, stmt->GetDefaultLabel());
124 }
125 }
126 curTag = static_cast<uint32>(stmt->GetCasePair(static_cast<size_t>(i)).first - node->GetTagOffset());
127 node->AddRangeGoto(curTag, stmt->GetCasePair(i).second);
128 lastCaseTag = stmt->GetCasePair(i).first;
129 }
130 /* If the density is high enough, the range is allowed to be large */
131 DEBUG_ASSERT(node->GetNumOpnds() == 1, "RangeGotoNode is a UnaryOpnd and numOpnds must be 1");
132 return node;
133 }
134
BuildCmpNode(Opcode opCode,uint32 idx)135 CompareNode *SwitchLowerer::BuildCmpNode(Opcode opCode, uint32 idx)
136 {
137 CompareNode *binaryExpr = mirModule.CurFuncCodeMemPool()->New<CompareNode>(opCode);
138 binaryExpr->SetPrimType(PTY_u32);
139 binaryExpr->SetOpndType(stmt->GetSwitchOpnd()->GetPrimType());
140
141 MIRType &type = *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(stmt->GetSwitchOpnd()->GetPrimType()));
142 MIRConst *constVal = GlobalTables::GetIntConstTable().GetOrCreateIntConst(stmt->GetCasePair(idx).first, type);
143 ConstvalNode *exprConst = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>();
144 exprConst->SetPrimType(stmt->GetSwitchOpnd()->GetPrimType());
145 exprConst->SetConstVal(constVal);
146
147 binaryExpr->SetBOpnd(stmt->GetSwitchOpnd(), 0);
148 binaryExpr->SetBOpnd(exprConst, 1);
149 return binaryExpr;
150 }
151
BuildGotoNode(int32 idx)152 GotoNode *SwitchLowerer::BuildGotoNode(int32 idx)
153 {
154 if (idx == -1 && stmt->GetDefaultLabel() == 0) {
155 return nullptr;
156 }
157 GotoNode *gotoStmt = mirModule.CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
158 if (idx == -1) {
159 gotoStmt->SetOffset(stmt->GetDefaultLabel());
160 } else {
161 gotoStmt->SetOffset(stmt->GetCasePair(idx).second);
162 }
163 return gotoStmt;
164 }
165
BuildCondGotoNode(int32 idx,Opcode opCode,BaseNode & cond)166 CondGotoNode *SwitchLowerer::BuildCondGotoNode(int32 idx, Opcode opCode, BaseNode &cond)
167 {
168 if (idx == -1 && stmt->GetDefaultLabel() == 0) {
169 return nullptr;
170 }
171 CondGotoNode *cGotoStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(opCode);
172 cGotoStmt->SetOpnd(&cond, 0);
173 if (idx == -1) {
174 cGotoStmt->SetOffset(stmt->GetDefaultLabel());
175 } else {
176 cGotoStmt->SetOffset(stmt->GetCasePair(idx).second);
177 }
178 return cGotoStmt;
179 }
180
181 /* start and end is with respect to switchItems */
BuildCodeForSwitchItems(int32 start,int32 end,bool lowBlockNodeChecked,bool highBlockNodeChecked)182 BlockNode *SwitchLowerer::BuildCodeForSwitchItems(int32 start, int32 end, bool lowBlockNodeChecked,
183 bool highBlockNodeChecked)
184 {
185 DEBUG_ASSERT(start >= 0, "invalid args start");
186 DEBUG_ASSERT(end >= 0, "invalid args end");
187 BlockNode *localBlk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
188 if (start > end) {
189 return localBlk;
190 }
191 CondGotoNode *cGoto = nullptr;
192 RangeGotoNode *rangeGoto = nullptr;
193 IfStmtNode *ifStmt = nullptr;
194 CompareNode *cmpNode = nullptr;
195 MIRLower mirLowerer(mirModule, mirModule.CurFunction());
196 mirLowerer.Init();
197 /* if low side starts with a dense item, handle it first */
198 while ((start <= end) && (switchItems[start].second != 0)) {
199 if (!lowBlockNodeChecked) {
200 lowBlockNodeChecked = true;
201 if (!(IsUnsignedInteger(stmt->GetSwitchOpnd()->GetPrimType()) &&
202 (stmt->GetCasePair(static_cast<size_t>(switchItems[static_cast<uint64>(start)].first)).first == 0))) {
203 cGoto = BuildCondGotoNode(-1, OP_brtrue, *BuildCmpNode(OP_lt, switchItems[start].first));
204 if (cGoto != nullptr) {
205 localBlk->AddStatement(cGoto);
206 }
207 }
208 }
209 rangeGoto = BuildRangeGotoNode(switchItems[start].first, switchItems[start].second);
210 if (stmt->GetDefaultLabel() == 0) {
211 localBlk->AddStatement(rangeGoto);
212 } else {
213 cmpNode = BuildCmpNode(OP_le, switchItems[start].second);
214 ifStmt = static_cast<IfStmtNode *>(mirModule.GetMIRBuilder()->CreateStmtIf(cmpNode));
215 ifStmt->GetThenPart()->AddStatement(rangeGoto);
216 localBlk->AppendStatementsFromBlock(*mirLowerer.LowerIfStmt(*ifStmt, false));
217 }
218 if (start < end) {
219 lowBlockNodeChecked = (stmt->GetCasePair(switchItems[start].second).first + 1 ==
220 stmt->GetCasePair(switchItems[start + 1].first).first);
221 }
222 ++start;
223 }
224 /* if high side starts with a dense item, handle it also */
225 while ((start <= end) && (switchItems[end].second != 0)) {
226 if (!highBlockNodeChecked) {
227 cGoto = BuildCondGotoNode(-1, OP_brtrue, *BuildCmpNode(OP_gt, switchItems[end].second));
228 if (cGoto != nullptr) {
229 localBlk->AddStatement(cGoto);
230 }
231 highBlockNodeChecked = true;
232 }
233 rangeGoto = BuildRangeGotoNode(switchItems[end].first, switchItems[end].second);
234 if (stmt->GetDefaultLabel() == 0) {
235 localBlk->AddStatement(rangeGoto);
236 } else {
237 cmpNode = BuildCmpNode(OP_ge, switchItems[end].first);
238 ifStmt = static_cast<IfStmtNode *>(mirModule.GetMIRBuilder()->CreateStmtIf(cmpNode));
239 ifStmt->GetThenPart()->AddStatement(rangeGoto);
240 localBlk->AppendStatementsFromBlock(*mirLowerer.LowerIfStmt(*ifStmt, false));
241 }
242 if (start < end) {
243 highBlockNodeChecked = (stmt->GetCasePair(switchItems[end].first).first - 1 ==
244 stmt->GetCasePair(switchItems[end - 1].first).first) ||
245 (stmt->GetCasePair(switchItems[end].first).first - 1 ==
246 stmt->GetCasePair(switchItems[end - 1].second).first);
247 }
248 --end;
249 }
250 if (start > end) {
251 if (!lowBlockNodeChecked || !highBlockNodeChecked) {
252 GotoNode *gotoDft = BuildGotoNode(-1);
253 if (gotoDft != nullptr) {
254 localBlk->AddStatement(gotoDft);
255 jumpToDefaultBlockGenerated = true;
256 }
257 }
258 return localBlk;
259 }
260 if ((start == end) && lowBlockNodeChecked && highBlockNodeChecked) {
261 /* only 1 case with 1 tag remains */
262 auto *gotoStmt = BuildGotoNode(switchItems[static_cast<size_t>(start)].first);
263 if (gotoStmt != nullptr) {
264 localBlk->AddStatement(gotoStmt);
265 }
266 return localBlk;
267 }
268 if (end < (start + kClusterSwitchCutoff)) {
269 /* generate equality checks for what remains */
270 while ((start <= end) && (switchItems[start].second == 0)) {
271 if ((start == end) && lowBlockNodeChecked && highBlockNodeChecked) {
272 cGoto = reinterpret_cast<CondGotoNode *>(
273 BuildGotoNode(switchItems[start].first)); /* can omit the condition */
274 } else {
275 cGoto = BuildCondGotoNode(switchItems[start].first, OP_brtrue,
276 *BuildCmpNode(OP_eq, switchItems[start].first));
277 }
278 if (cGoto != nullptr) {
279 localBlk->AddStatement(cGoto);
280 }
281 if (lowBlockNodeChecked && (start < end)) {
282 lowBlockNodeChecked = (stmt->GetCasePair(switchItems[start].first).first + 1 ==
283 stmt->GetCasePair(switchItems[start + 1].first).first);
284 }
285 ++start;
286 }
287 if (start <= end) { /* recursive call */
288 BlockNode *tmp = BuildCodeForSwitchItems(start, end, lowBlockNodeChecked, highBlockNodeChecked);
289 CHECK_FATAL(tmp != nullptr, "tmp should not be nullptr");
290 localBlk->AppendStatementsFromBlock(*tmp);
291 } else if (!lowBlockNodeChecked || !highBlockNodeChecked) {
292 GotoNode *gotoDft = BuildGotoNode(-1);
293 if (gotoDft != nullptr) {
294 localBlk->AddStatement(gotoDft);
295 jumpToDefaultBlockGenerated = true;
296 }
297 }
298 return localBlk;
299 }
300
301 int64 lowestTag = stmt->GetCasePair(switchItems[start].first).first;
302 int64 highestTag = stmt->GetCasePair(switchItems[end].first).first;
303
304 /*
305 * if lowestTag and higesttag have the same sign, use difference
306 * if lowestTag and higesttag have the diefferent sign, use sum
307 * 1LL << 63 judge lowestTag ^ highestTag operate result highest
308 * bit is 1 or not, the result highest bit is 1 express lowestTag
309 * and highestTag have same sign , otherwise diefferent sign.highestTag
310 * add or subtract lowestTag divide 2 to get middle tag.
311 */
312 int64 middleTag = ((((static_cast<uint64>(lowestTag)) ^ (static_cast<uint64>(highestTag))) & (1ULL << 63)) == 0)
313 ? (highestTag - lowestTag) / 2 + lowestTag
314 : (highestTag + lowestTag) / 2;
315 /* find the mid-point in switch_items between start and end */
316 int32 mid = start;
317 while (stmt->GetCasePair(switchItems[mid].first).first < middleTag) {
318 ++mid;
319 }
320 DEBUG_ASSERT(mid >= start, "switch lowering logic mid should greater than or equal start");
321 DEBUG_ASSERT(mid <= end, "switch lowering logic mid should less than or equal end");
322 /* generate test for binary search */
323 if (stmt->GetDefaultLabel() != 0) {
324 cmpNode = BuildCmpNode(OP_lt, static_cast<uint32>(switchItems[static_cast<uint64>(mid)].first));
325 ifStmt = static_cast<IfStmtNode *>(mirModule.GetMIRBuilder()->CreateStmtIf(cmpNode));
326 bool leftHighBNdChecked = (stmt->GetCasePair(switchItems.at(mid - 1).first).first + 1 ==
327 stmt->GetCasePair(switchItems.at(mid).first).first) ||
328 (stmt->GetCasePair(switchItems.at(mid - 1).second).first + 1 ==
329 stmt->GetCasePair(switchItems.at(mid).first).first);
330 ifStmt->SetThenPart(BuildCodeForSwitchItems(start, mid - 1, lowBlockNodeChecked, leftHighBNdChecked));
331 ifStmt->SetElsePart(BuildCodeForSwitchItems(mid, end, true, highBlockNodeChecked));
332 if (ifStmt->GetElsePart()) {
333 ifStmt->SetNumOpnds(kOperandNumTernary);
334 }
335 localBlk->AppendStatementsFromBlock(*mirLowerer.LowerIfStmt(*ifStmt, false));
336 }
337 return localBlk;
338 }
339
LowerSwitch()340 BlockNode *SwitchLowerer::LowerSwitch()
341 {
342 if (stmt->GetSwitchTable().empty()) { /* change to goto */
343 BlockNode *localBlk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
344 GotoNode *gotoDft = BuildGotoNode(-1);
345 if (gotoDft != nullptr) {
346 localBlk->AddStatement(gotoDft);
347 }
348 return localBlk;
349 }
350
351 // add case labels to label table's caseLabelSet
352 DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "curFunction should not be nullptr");
353 MIRLabelTable *labelTab = mirModule.CurFunction()->GetLabelTab();
354 for (CasePair &casePair : stmt->GetSwitchTable()) {
355 labelTab->caseLabelSet.insert(casePair.second);
356 }
357
358 MapleVector<Cluster> clusters(ownAllocator->Adapter());
359 stmt->SortCasePair(CasePairKeyLessThan);
360 FindClusters(clusters);
361 InitSwitchItems(clusters);
362 BlockNode *blkNode = BuildCodeForSwitchItems(0, static_cast<int>(switchItems.size()) - 1, false, false);
363 if (!jumpToDefaultBlockGenerated) {
364 GotoNode *gotoDft = BuildGotoNode(-1);
365 if (gotoDft != nullptr) {
366 blkNode->AddStatement(gotoDft);
367 }
368 }
369 return blkNode;
370 }
371 } /* namespace maplebe */
372