• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 /*
17  * This module analyzes the tag distribution in a switch statement and decides
18  * the best strategy in terms of runtime performance to generate code for it.
19  * The generated code makes use of 3 code generation techniques:
20  *
21  * 1. cascade of if-then-else based on equality test
22  * 2. rangegoto
23  * 3. binary search
24  *
25  * 1 is applied only if the number of possibilities is <= 6.
26  * 2 corresponds to indexed jump, but it requires allocating an array
27  * initialized with the jump targets.  Since it causes memory usage overhead,
28  * rangegoto is used only if the density is higher than 0.7.
29  * If neither 1 nor 2 is applicable, 3 is applied in the form of a decision
30  * tree.  In this case, each test would split the tags into 2 halves.  For
31  * each half, the above algorithm is then applied recursively until the
32  * algorithm terminates.
33  *
34  * But we don't want to apply 3 right from the beginning if both 1 and 2 do not
35  * apply, because there may be regions that have density > 0.7.  Thus, the
36  * switch lowerer begins by finding clusters.  A cluster is defined to be a
37  * maximal range of tags whose density is > 0.7.
38  *
39  * In finding clusters, the original switch table is sorted and then each dense
40  * region is condensed into 1 switch item; in the switch_items table, each item // either corresponds to an original
41  * entry in the original switch table (pair's // second is 0), or to a dense region (pair's second gives the upper limit
42  * of the dense range).  The output code is generated based on the switch_items. See BuildCodeForSwitchItems() which is
43  * recursive.
44  */
45 #include "switch_lowerer.h"
46 #include "mir_nodes.h"
47 #include "mir_builder.h"
48 #include "mir_lower.h" /* "../../../maple_ir/include/mir_lower.h" */
49 
50 namespace maplebe {
51 using namespace maple;
52 
CasePairKeyLessThan(const CasePair & left,const CasePair & right)53 static bool CasePairKeyLessThan(const CasePair &left, const CasePair &right)
54 {
55     return left.first < right.first;
56 }
57 
FindClusters(MapleVector<Cluster> & clusters) const58 void SwitchLowerer::FindClusters(MapleVector<Cluster> &clusters) const
59 {
60     int32 length = static_cast<int>(stmt->GetSwitchTable().size());
61     int32 i = 0;
62     while (i < length - kClusterSwitchCutoff) {
63         for (int32 j = length - 1; j > i; --j) {
64             float tmp1 = static_cast<float>(j - i);
65             float tmp2 = static_cast<float>(stmt->GetCasePair(static_cast<size_t>(static_cast<uint32>(j))).first) -
66                          static_cast<float>(stmt->GetCasePair(static_cast<size_t>(static_cast<uint32>(i))).first);
67             float currDensity = tmp1 / tmp2;
68             if (((j - i) >= kClusterSwitchCutoff) &&
69                 ((currDensity >= kClusterSwitchDensityHigh) ||
70                  ((currDensity >= kClusterSwitchDensityLow) && (tmp2 < kMaxRangeGotoTableSize)))) {
71                 clusters.emplace_back(Cluster(i, j));
72                 i = j;
73                 break;
74             }
75         }
76         ++i;
77     }
78 }
79 
InitSwitchItems(MapleVector<Cluster> & clusters)80 void SwitchLowerer::InitSwitchItems(MapleVector<Cluster> &clusters)
81 {
82     if (clusters.empty()) {
83         for (int32 i = 0; i < static_cast<int>(stmt->GetSwitchTable().size()); ++i) {
84             switchItems.emplace_back(SwitchItem(i, 0));
85         }
86     } else {
87         int32 j = 0;
88         Cluster front = clusters[j];
89         for (int32 i = 0; i < static_cast<int>(stmt->GetSwitchTable().size()); ++i) {
90             if (i == front.first) {
91                 switchItems.emplace_back(SwitchItem(i, front.second));
92                 i = front.second;
93                 ++j;
94                 if (static_cast<int>(clusters.size()) > j) {
95                     front = clusters[j];
96                 }
97             } else {
98                 switchItems.emplace_back(SwitchItem(i, 0));
99             }
100         }
101     }
102 }
103 
BuildRangeGotoNode(int32 startIdx,int32 endIdx)104 RangeGotoNode *SwitchLowerer::BuildRangeGotoNode(int32 startIdx, int32 endIdx)
105 {
106     RangeGotoNode *node = mirModule.CurFuncCodeMemPool()->New<RangeGotoNode>(mirModule);
107     node->SetOpnd(stmt->GetSwitchOpnd(), 0);
108 
109     node->SetRangeGotoTable(SmallCaseVector(mirModule.CurFuncCodeMemPoolAllocator()->Adapter()));
110     node->SetTagOffset(static_cast<int32>(stmt->GetCasePair(static_cast<size_t>(startIdx)).first));
111     uint32 curTag = 0;
112     node->AddRangeGoto(curTag, stmt->GetCasePair(startIdx).second);
113     int64 lastCaseTag = stmt->GetSwitchTable().at(startIdx).first;
114     for (int32 i = startIdx + 1; i <= endIdx; ++i) {
115         /*
116          * The second condition is to solve the problem that compilation falls into a dead loop,
117          * because in some cases the two will fall into a dead loop if they are equal.
118          */
119         while ((stmt->GetCasePair(i).first != (lastCaseTag + 1)) && (stmt->GetCasePair(i).first != lastCaseTag)) {
120             /* fill in a gap in the case tags */
121             curTag = static_cast<uint32>((++lastCaseTag) - node->GetTagOffset());
122             if (stmt->GetDefaultLabel() != 0) {
123                 node->AddRangeGoto(curTag, stmt->GetDefaultLabel());
124             }
125         }
126         curTag = static_cast<uint32>(stmt->GetCasePair(static_cast<size_t>(i)).first - node->GetTagOffset());
127         node->AddRangeGoto(curTag, stmt->GetCasePair(i).second);
128         lastCaseTag = stmt->GetCasePair(i).first;
129     }
130     /* If the density is high enough, the range is allowed to be large */
131     DEBUG_ASSERT(node->GetNumOpnds() == 1, "RangeGotoNode is a UnaryOpnd and numOpnds must be 1");
132     return node;
133 }
134 
BuildCmpNode(Opcode opCode,uint32 idx)135 CompareNode *SwitchLowerer::BuildCmpNode(Opcode opCode, uint32 idx)
136 {
137     CompareNode *binaryExpr = mirModule.CurFuncCodeMemPool()->New<CompareNode>(opCode);
138     binaryExpr->SetPrimType(PTY_u32);
139     binaryExpr->SetOpndType(stmt->GetSwitchOpnd()->GetPrimType());
140 
141     MIRType &type = *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(stmt->GetSwitchOpnd()->GetPrimType()));
142     MIRConst *constVal = GlobalTables::GetIntConstTable().GetOrCreateIntConst(stmt->GetCasePair(idx).first, type);
143     ConstvalNode *exprConst = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>();
144     exprConst->SetPrimType(stmt->GetSwitchOpnd()->GetPrimType());
145     exprConst->SetConstVal(constVal);
146 
147     binaryExpr->SetBOpnd(stmt->GetSwitchOpnd(), 0);
148     binaryExpr->SetBOpnd(exprConst, 1);
149     return binaryExpr;
150 }
151 
BuildGotoNode(int32 idx)152 GotoNode *SwitchLowerer::BuildGotoNode(int32 idx)
153 {
154     if (idx == -1 && stmt->GetDefaultLabel() == 0) {
155         return nullptr;
156     }
157     GotoNode *gotoStmt = mirModule.CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
158     if (idx == -1) {
159         gotoStmt->SetOffset(stmt->GetDefaultLabel());
160     } else {
161         gotoStmt->SetOffset(stmt->GetCasePair(idx).second);
162     }
163     return gotoStmt;
164 }
165 
BuildCondGotoNode(int32 idx,Opcode opCode,BaseNode & cond)166 CondGotoNode *SwitchLowerer::BuildCondGotoNode(int32 idx, Opcode opCode, BaseNode &cond)
167 {
168     if (idx == -1 && stmt->GetDefaultLabel() == 0) {
169         return nullptr;
170     }
171     CondGotoNode *cGotoStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(opCode);
172     cGotoStmt->SetOpnd(&cond, 0);
173     if (idx == -1) {
174         cGotoStmt->SetOffset(stmt->GetDefaultLabel());
175     } else {
176         cGotoStmt->SetOffset(stmt->GetCasePair(idx).second);
177     }
178     return cGotoStmt;
179 }
180 
181 /* start and end is with respect to switchItems */
BuildCodeForSwitchItems(int32 start,int32 end,bool lowBlockNodeChecked,bool highBlockNodeChecked)182 BlockNode *SwitchLowerer::BuildCodeForSwitchItems(int32 start, int32 end, bool lowBlockNodeChecked,
183                                                   bool highBlockNodeChecked)
184 {
185     DEBUG_ASSERT(start >= 0, "invalid args start");
186     DEBUG_ASSERT(end >= 0, "invalid args end");
187     BlockNode *localBlk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
188     if (start > end) {
189         return localBlk;
190     }
191     CondGotoNode *cGoto = nullptr;
192     RangeGotoNode *rangeGoto = nullptr;
193     IfStmtNode *ifStmt = nullptr;
194     CompareNode *cmpNode = nullptr;
195     MIRLower mirLowerer(mirModule, mirModule.CurFunction());
196     mirLowerer.Init();
197     /* if low side starts with a dense item, handle it first */
198     while ((start <= end) && (switchItems[start].second != 0)) {
199         if (!lowBlockNodeChecked) {
200             lowBlockNodeChecked = true;
201             if (!(IsUnsignedInteger(stmt->GetSwitchOpnd()->GetPrimType()) &&
202                   (stmt->GetCasePair(static_cast<size_t>(switchItems[static_cast<uint64>(start)].first)).first == 0))) {
203                 cGoto = BuildCondGotoNode(-1, OP_brtrue, *BuildCmpNode(OP_lt, switchItems[start].first));
204                 if (cGoto != nullptr) {
205                     localBlk->AddStatement(cGoto);
206                 }
207             }
208         }
209         rangeGoto = BuildRangeGotoNode(switchItems[start].first, switchItems[start].second);
210         if (stmt->GetDefaultLabel() == 0) {
211             localBlk->AddStatement(rangeGoto);
212         } else {
213             cmpNode = BuildCmpNode(OP_le, switchItems[start].second);
214             ifStmt = static_cast<IfStmtNode *>(mirModule.GetMIRBuilder()->CreateStmtIf(cmpNode));
215             ifStmt->GetThenPart()->AddStatement(rangeGoto);
216             localBlk->AppendStatementsFromBlock(*mirLowerer.LowerIfStmt(*ifStmt, false));
217         }
218         if (start < end) {
219             lowBlockNodeChecked = (stmt->GetCasePair(switchItems[start].second).first + 1 ==
220                                    stmt->GetCasePair(switchItems[start + 1].first).first);
221         }
222         ++start;
223     }
224     /* if high side starts with a dense item, handle it also */
225     while ((start <= end) && (switchItems[end].second != 0)) {
226         if (!highBlockNodeChecked) {
227             cGoto = BuildCondGotoNode(-1, OP_brtrue, *BuildCmpNode(OP_gt, switchItems[end].second));
228             if (cGoto != nullptr) {
229                 localBlk->AddStatement(cGoto);
230             }
231             highBlockNodeChecked = true;
232         }
233         rangeGoto = BuildRangeGotoNode(switchItems[end].first, switchItems[end].second);
234         if (stmt->GetDefaultLabel() == 0) {
235             localBlk->AddStatement(rangeGoto);
236         } else {
237             cmpNode = BuildCmpNode(OP_ge, switchItems[end].first);
238             ifStmt = static_cast<IfStmtNode *>(mirModule.GetMIRBuilder()->CreateStmtIf(cmpNode));
239             ifStmt->GetThenPart()->AddStatement(rangeGoto);
240             localBlk->AppendStatementsFromBlock(*mirLowerer.LowerIfStmt(*ifStmt, false));
241         }
242         if (start < end) {
243             highBlockNodeChecked = (stmt->GetCasePair(switchItems[end].first).first - 1 ==
244                                     stmt->GetCasePair(switchItems[end - 1].first).first) ||
245                                    (stmt->GetCasePair(switchItems[end].first).first - 1 ==
246                                     stmt->GetCasePair(switchItems[end - 1].second).first);
247         }
248         --end;
249     }
250     if (start > end) {
251         if (!lowBlockNodeChecked || !highBlockNodeChecked) {
252             GotoNode *gotoDft = BuildGotoNode(-1);
253             if (gotoDft != nullptr) {
254                 localBlk->AddStatement(gotoDft);
255                 jumpToDefaultBlockGenerated = true;
256             }
257         }
258         return localBlk;
259     }
260     if ((start == end) && lowBlockNodeChecked && highBlockNodeChecked) {
261         /* only 1 case with 1 tag remains */
262         auto *gotoStmt = BuildGotoNode(switchItems[static_cast<size_t>(start)].first);
263         if (gotoStmt != nullptr) {
264             localBlk->AddStatement(gotoStmt);
265         }
266         return localBlk;
267     }
268     if (end < (start + kClusterSwitchCutoff)) {
269         /* generate equality checks for what remains */
270         while ((start <= end) && (switchItems[start].second == 0)) {
271             if ((start == end) && lowBlockNodeChecked && highBlockNodeChecked) {
272                 cGoto = reinterpret_cast<CondGotoNode *>(
273                     BuildGotoNode(switchItems[start].first)); /* can omit the condition */
274             } else {
275                 cGoto = BuildCondGotoNode(switchItems[start].first, OP_brtrue,
276                                           *BuildCmpNode(OP_eq, switchItems[start].first));
277             }
278             if (cGoto != nullptr) {
279                 localBlk->AddStatement(cGoto);
280             }
281             if (lowBlockNodeChecked && (start < end)) {
282                 lowBlockNodeChecked = (stmt->GetCasePair(switchItems[start].first).first + 1 ==
283                                        stmt->GetCasePair(switchItems[start + 1].first).first);
284             }
285             ++start;
286         }
287         if (start <= end) { /* recursive call */
288             BlockNode *tmp = BuildCodeForSwitchItems(start, end, lowBlockNodeChecked, highBlockNodeChecked);
289             CHECK_FATAL(tmp != nullptr, "tmp should not be nullptr");
290             localBlk->AppendStatementsFromBlock(*tmp);
291         } else if (!lowBlockNodeChecked || !highBlockNodeChecked) {
292             GotoNode *gotoDft = BuildGotoNode(-1);
293             if (gotoDft != nullptr) {
294                 localBlk->AddStatement(gotoDft);
295                 jumpToDefaultBlockGenerated = true;
296             }
297         }
298         return localBlk;
299     }
300 
301     int64 lowestTag = stmt->GetCasePair(switchItems[start].first).first;
302     int64 highestTag = stmt->GetCasePair(switchItems[end].first).first;
303 
304     /*
305      * if lowestTag and higesttag have the same sign, use difference
306      * if lowestTag and higesttag have the diefferent sign, use sum
307      * 1LL << 63 judge lowestTag ^ highestTag operate result highest
308      * bit is 1 or not, the result highest bit is 1 express lowestTag
309      * and highestTag have same sign , otherwise diefferent sign.highestTag
310      * add or subtract lowestTag divide 2 to get middle tag.
311      */
312     int64 middleTag = ((((static_cast<uint64>(lowestTag)) ^ (static_cast<uint64>(highestTag))) & (1ULL << 63)) == 0)
313                           ? (highestTag - lowestTag) / 2 + lowestTag
314                           : (highestTag + lowestTag) / 2;
315     /* find the mid-point in switch_items between start and end */
316     int32 mid = start;
317     while (stmt->GetCasePair(switchItems[mid].first).first < middleTag) {
318         ++mid;
319     }
320     DEBUG_ASSERT(mid >= start, "switch lowering logic mid should greater than or equal start");
321     DEBUG_ASSERT(mid <= end, "switch lowering logic mid should less than or equal end");
322     /* generate test for binary search */
323     if (stmt->GetDefaultLabel() != 0) {
324         cmpNode = BuildCmpNode(OP_lt, static_cast<uint32>(switchItems[static_cast<uint64>(mid)].first));
325         ifStmt = static_cast<IfStmtNode *>(mirModule.GetMIRBuilder()->CreateStmtIf(cmpNode));
326         bool leftHighBNdChecked = (stmt->GetCasePair(switchItems.at(mid - 1).first).first + 1 ==
327                                    stmt->GetCasePair(switchItems.at(mid).first).first) ||
328                                   (stmt->GetCasePair(switchItems.at(mid - 1).second).first + 1 ==
329                                    stmt->GetCasePair(switchItems.at(mid).first).first);
330         ifStmt->SetThenPart(BuildCodeForSwitchItems(start, mid - 1, lowBlockNodeChecked, leftHighBNdChecked));
331         ifStmt->SetElsePart(BuildCodeForSwitchItems(mid, end, true, highBlockNodeChecked));
332         if (ifStmt->GetElsePart()) {
333             ifStmt->SetNumOpnds(kOperandNumTernary);
334         }
335         localBlk->AppendStatementsFromBlock(*mirLowerer.LowerIfStmt(*ifStmt, false));
336     }
337     return localBlk;
338 }
339 
LowerSwitch()340 BlockNode *SwitchLowerer::LowerSwitch()
341 {
342     if (stmt->GetSwitchTable().empty()) { /* change to goto */
343         BlockNode *localBlk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
344         GotoNode *gotoDft = BuildGotoNode(-1);
345         if (gotoDft != nullptr) {
346             localBlk->AddStatement(gotoDft);
347         }
348         return localBlk;
349     }
350 
351     // add case labels to label table's caseLabelSet
352     DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "curFunction should not be nullptr");
353     MIRLabelTable *labelTab = mirModule.CurFunction()->GetLabelTab();
354     for (CasePair &casePair : stmt->GetSwitchTable()) {
355         labelTab->caseLabelSet.insert(casePair.second);
356     }
357 
358     MapleVector<Cluster> clusters(ownAllocator->Adapter());
359     stmt->SortCasePair(CasePairKeyLessThan);
360     FindClusters(clusters);
361     InitSwitchItems(clusters);
362     BlockNode *blkNode = BuildCodeForSwitchItems(0, static_cast<int>(switchItems.size()) - 1, false, false);
363     if (!jumpToDefaultBlockGenerated) {
364         GotoNode *gotoDft = BuildGotoNode(-1);
365         if (gotoDft != nullptr) {
366             blkNode->AddStatement(gotoDft);
367         }
368     }
369     return blkNode;
370 }
371 } /* namespace maplebe */
372