1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "cast_opt.h"
17 #include "irmap.h"
18 #include "mir_builder.h"
19 #include "constantfold.h"
20
21 namespace maple {
22 // This is not safe because handlefunction may ignore necessary implicit intTrunc, so we skip it now.
23 static constexpr bool mapleSimplifyZextU32 = false;
24
GetCastKindByTwoType(PrimType fromType,PrimType toType)25 CastKind GetCastKindByTwoType(PrimType fromType, PrimType toType)
26 {
27 // This is a workaround, we don't optimize `cvt u1 xx <expr>` because it will be converted to
28 // `ne u1 xx (<expr>, constval xx 0)`. There is no `cvt u1 xx <expr>` in the future.
29 if (toType == PTY_u1 && fromType != PTY_u1) {
30 return CAST_unknown;
31 }
32 const uint32 fromTypeBitSize = GetPrimTypeActualBitSize(fromType);
33 const uint32 toTypeBitSize = GetPrimTypeActualBitSize(toType);
34 // Both integer, ptr/ref/a64/u64... are also integer
35 if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
36 if (toTypeBitSize == fromTypeBitSize) {
37 return CAST_retype;
38 } else if (toTypeBitSize < fromTypeBitSize) {
39 return CAST_intTrunc;
40 } else {
41 return IsSignedInteger(fromType) ? CAST_sext : CAST_zext;
42 }
43 }
44 // Both fp
45 if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
46 if (toTypeBitSize == fromTypeBitSize) {
47 return CAST_retype;
48 } else if (toTypeBitSize < fromTypeBitSize) {
49 return CAST_fpTrunc;
50 } else {
51 return CAST_fpExt;
52 }
53 }
54 // int2fp
55 if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
56 return CAST_int2fp;
57 }
58 // fp2int
59 if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
60 return CAST_fp2int;
61 }
62 return CAST_unknown;
63 }
64
CreateMapleExprByCastKind(MIRBuilder & mirBuilder,CastKind castKind,PrimType srcType,PrimType dstType,BaseNode * opnd,TyIdx dstTyIdx)65 BaseNode *MapleCastOpt::CreateMapleExprByCastKind(MIRBuilder &mirBuilder, CastKind castKind, PrimType srcType,
66 PrimType dstType, BaseNode *opnd, TyIdx dstTyIdx)
67 {
68 if (castKind == CAST_zext) {
69 return mirBuilder.CreateExprExtractbits(OP_zext, dstType, 0, GetPrimTypeActualBitSize(srcType), opnd);
70 } else if (castKind == CAST_sext) {
71 return mirBuilder.CreateExprExtractbits(OP_sext, dstType, 0, GetPrimTypeActualBitSize(srcType), opnd);
72 } else if (castKind == CAST_retype && srcType == opnd->GetPrimType()) {
73 // If srcType is different from opnd->primType, we should create cvt instead of retype.
74 // Because CGFunc::SelectRetype always use opnd->primType as srcType.
75 CHECK_FATAL(dstTyIdx != 0u, "must specify valid tyIdx for retype");
76 MIRType *dstMIRType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(dstTyIdx);
77 return mirBuilder.CreateExprRetype(*dstMIRType, srcType, opnd);
78 } else {
79 return mirBuilder.CreateExprTypeCvt(OP_cvt, dstType, srcType, *opnd);
80 }
81 }
82
83 // This interface is conservative, which means that some op are explicit type cast but
84 // the interface returns false.
IsExplicitCastOp(Opcode op)85 bool CastOpt::IsExplicitCastOp(Opcode op)
86 {
87 if (op == OP_retype || op == OP_cvt || op == OP_zext || op == OP_sext) {
88 return true;
89 }
90 return false;
91 }
92
93 // This interface is conservative, which means that some op are implicit type cast but
94 // the interface returns false.
IsImplicitCastOp(Opcode op)95 bool CastOpt::IsImplicitCastOp(Opcode op)
96 {
97 if (op == OP_iread || op == OP_regread || op == OP_dread) {
98 return true;
99 }
100 return false;
101 }
102
IsCompareOp(Opcode op)103 bool CastOpt::IsCompareOp(Opcode op)
104 {
105 if (op == OP_eq || op == OP_ne || op == OP_ge || op == OP_gt || op == OP_le || op == OP_lt) {
106 return true;
107 }
108 return false;
109 }
110
111 // If the computed castInfo.kind is CAST_unknown, the computed castInfo is invalid
112 // castInfo.expr should be valid
113 // input: castInfo.expr
114 // output: castInfo.kind, castInfo.srcType, castInfo.dstType
115 template <typename T>
DoComputeCastInfo(CastInfo<T> & castInfo,bool isMeExpr)116 void CastOpt::DoComputeCastInfo(CastInfo<T> &castInfo, bool isMeExpr)
117 {
118 Opcode op = castInfo.GetOp();
119 PrimType dstType = castInfo.GetPrimType();
120 PrimType srcType = PTY_begin;
121 CastKind castKind = CAST_unknown;
122 switch (op) {
123 case OP_zext:
124 case OP_sext: {
125 size_t sizeBit = castInfo.GetBitsSize();
126 // The code can be improved
127 // exclude: sext ixx 1 <expr> because there is no i1 type
128 if (sizeBit == 1 && op == OP_sext) {
129 break;
130 }
131 if (sizeBit == k1BitSize || sizeBit == k8BitSize || sizeBit == k16BitSize || sizeBit == k32BitSize ||
132 sizeBit == k64BitSize) {
133 srcType = GetIntegerPrimTypeBySizeAndSign(sizeBit, op == OP_sext);
134 if (srcType == PTY_begin) {
135 break; // invalid integer type
136 }
137 castKind = (op == OP_sext ? CAST_sext : CAST_zext);
138 }
139 break;
140 }
141 case OP_retype: {
142 // retype's opndType is invalid, we use opnd's primType
143 srcType = castInfo.GetOpndType();
144 if (GetPrimTypeActualBitSize(dstType) != GetPrimTypeActualBitSize(srcType)) {
145 // Example: retype u8 <u8> (iread i32 <* i8> 0 ...)
146 // In the above example, dstType is u8, but we get srcType i32 from iread.
147 // We won't optimize such retype unless we can get real opnd type for all kinds of opnds.
148 // We will improve it if possible.
149 break;
150 }
151 castKind = CAST_retype;
152 break;
153 }
154 case OP_cvt: {
155 srcType = castInfo.GetOpndType();
156 if (srcType == PTY_u1 && dstType != PTY_u1) {
157 srcType = PTY_u8; // From the codegen view, `cvt xx u1` is always same as `cvt xx u8`
158 }
159 castKind = GetCastKindByTwoType(srcType, dstType);
160 break;
161 }
162 case OP_iread:
163 case OP_dread:
164 case OP_regread: {
165 if (op == OP_dread && isMeExpr) {
166 break;
167 }
168 srcType = castInfo.GetOpndType();
169 // Only consider regread/iread/dread with implicit integer extension
170 if (IsPrimitiveInteger(srcType) && IsPrimitiveInteger(dstType) &&
171 GetPrimTypeActualBitSize(srcType) < GetPrimTypeActualBitSize(dstType)) {
172 castKind = (IsSignedInteger(srcType) ? CAST_sext : CAST_zext);
173 }
174 break;
175 }
176 default:
177 break;
178 }
179 castInfo.kind = castKind;
180 castInfo.srcType = srcType;
181 castInfo.dstType = dstType;
182 }
183
184 static constexpr auto kNumCastKinds = static_cast<uint32>(CAST_unknown);
185 static const uint8 castMatrix[kNumCastKinds][kNumCastKinds] = {
186 // i i f f r -+
187 // t n p t f e |
188 // r z s t 2 r p t +- secondCastKind
189 // u e e 2 i u e y |
190 // n x x f n n x p |
191 // c t t p t c t e -+
192 {1, 0, 0, 0, 99, 99, 99, 3}, // intTrunc -+
193 {8, 9, 9, 10, 99, 99, 99, 3}, // zext |
194 {8, 0, 9, 0, 99, 99, 99, 3}, // sext |
195 {99, 99, 99, 99, 0, 0, 0, 4}, // int2fp |
196 {0, 0, 0, 0, 99, 99, 99, 0}, // fp2int +- firstCastKind
197 {99, 99, 99, 99, 0, 0, 0, 4}, // fpTrunc |
198 {99, 99, 99, 99, 2, 8, 2, 4}, // fpExt |
199 {5, 7, 7, 11, 6, 6, 6, 1}, // retype -+
200 };
201
202 // This function determines whether to eliminate a cast pair according to castMatrix
203 // Input is a cast pair like this:
204 // secondCastKind dstType midType2 (firstCastKind midType1 srcType)
205 // If the function returns a valid resultCastKind, the cast pair can be optimized to:
206 // resultCastKind dstType srcType
207 // If the cast pair can NOT be eliminated, -1 will be returned.
208 // ATTENTION: This function may modify srcType
IsEliminableCastPair(CastKind firstCastKind,CastKind secondCastKind,PrimType dstType,PrimType midType2,PrimType midType1,PrimType & srcType)209 int CastOpt::IsEliminableCastPair(CastKind firstCastKind, CastKind secondCastKind, PrimType dstType, PrimType midType2,
210 PrimType midType1, PrimType &srcType)
211 {
212 int castCase = castMatrix[firstCastKind][secondCastKind];
213 uint32 srcSize = GetPrimTypeActualBitSize(srcType);
214 uint32 midSize1 = GetPrimTypeActualBitSize(midType1);
215 uint32 midSize2 = GetPrimTypeActualBitSize(midType2);
216 uint32 dstSize = GetPrimTypeActualBitSize(dstType);
217
218 switch (castCase) {
219 case 0: {
220 // Not allowed
221 return -1;
222 }
223 case 1: { // 1 st case in castMatrix, see the comments above
224 // first intTrunc, then intTrunc
225 // Example: cvt u16 u32 (cvt u32 u64) ==> cvt u16 u64
226 // first retype, then retype
227 // Example: retype i64 u64 (retype u64 ptr) ==> retype i64 ptr
228 return firstCastKind;
229 }
230 case 2: { // 2 nd case in castMatrix, see the comments above
231 // first fpExt, then fpExt
232 // Example: cvt f128 f64 (cvt f64 f32) ==> cvt f128 f32
233 // first fpExt, then fp2int
234 // Example: cvt i64 f64 (cvt f64 f32) ==> cvt i64 f32
235 return secondCastKind;
236 }
237 case 3: { // 3 rd case in castMatrix, see the comments above
238 if (IsPrimitiveInteger(dstType)) {
239 return firstCastKind;
240 }
241 return -1;
242 }
243 case 4: { // 4 th case in castMatrix, see the comments above
244 if (IsPrimitiveFloat(dstType)) {
245 return firstCastKind;
246 }
247 return -1;
248 }
249 case 5: { // 5 th case in castMatrix, see the comments above
250 if (IsPrimitiveInteger(srcType)) {
251 return secondCastKind;
252 }
253 return -1;
254 }
255 case 6: { // 6 th case in castMatrix, see the comments above
256 if (IsPrimitiveFloat(srcType)) {
257 return secondCastKind;
258 }
259 return -1;
260 }
261 case 7: { // 7 th case in castMatrix, see the comments above
262 // first integer retype, then sext/zext
263 if (IsPrimitiveInteger(srcType) && dstSize >= midSize1) {
264 CHECK_FATAL(srcSize == midSize1, "must be");
265 if (midSize2 >= srcSize) {
266 return secondCastKind;
267 }
268 // Example: zext u64 8 (retype u32 i32) ==> zext u64 8
269 srcType = midType2;
270 return secondCastKind;
271 }
272 return -1;
273 }
274 case 8: { // 8 th case in castMatrix, see the comments above
275 if (srcSize == dstSize) {
276 return CAST_retype;
277 } else if (srcSize < dstSize) {
278 return firstCastKind;
279 } else {
280 return secondCastKind;
281 }
282 }
283 // For integer extension pair
284 case 9: { // 9 th case in castMatrix, see the comments above
285 // first zext, then sext
286 // Extreme example: sext i32 16 (zext u64 8) ==> zext i32 8
287 if (firstCastKind != secondCastKind && midSize2 <= midSize1) {
288 if (midSize2 > srcSize) {
289 // The first extension works. After the first zext, the most significant bit must be 0, so the
290 // second sext is actually a zext. Example: sext i64 16 (zext u32 8) ==> zext i64 8
291 return firstCastKind;
292 }
293 // midSize2 <= srcSize
294 // The first extension didn't work
295 // Example: sext i64 8 (zext u32 16) ==> sext i64 8
296 // Example: sext i16 8 (zext u32 16) ==> sext i16 8
297 srcType = midType2;
298 return secondCastKind;
299 }
300
301 // first zext, then zext
302 // first sext, then sext
303 // Example: sext i32 8 (sext i32 8) ==> sext i32 8
304 // Example: zext u16 1 (zext u32 8) ==> zext u16 1 it's ok
305 // midSize2 < srcSize:
306 // Example: zext u64 8 (zext u32 16) ==> zext u64 8
307 // Example: sext i64 8 (sext i32 16) ==> sext i64 8
308 // Example: zext i32 1 (zext u32 8) ==> zext i32 1
309 // Wrong example (midSize2 > midSize1): zext u64 32 (zext u16 8) =[x]=> zext u64 8
310 if (firstCastKind == secondCastKind && midSize2 <= midSize1) {
311 if (midSize2 < srcSize) {
312 srcType = midType2;
313 }
314 return secondCastKind;
315 }
316 return -1;
317 }
318 case 10: { // 10 th case in castMatrix, see the comments above
319 // first zext, then int2fp
320 if (IsSignedInteger(midType2)) {
321 return secondCastKind;
322 }
323 // To improved: consider unsigned
324 return -1;
325 }
326 case 11: { // 11 st case in castMatrix, see the comments above
327 // first retype, then int2fp
328 if (IsPrimitiveInteger(srcType)) {
329 if (IsSignedInteger(srcType) != IsSignedInteger(midType1)) {
330 // If sign diffs, use toType of retype
331 // Example: cvt f64 i64 (retype i64 u64) ==> cvt f64 i64
332 srcType = midType1;
333 }
334 return secondCastKind;
335 }
336 return -1;
337 }
338 case 99: { // 99 is this last case in castMatrix, see the comments above
339 CHECK_FATAL(false, "invalid cast pair");
340 }
341 default: {
342 CHECK_FATAL(false, "can not be here, is castMatrix wrong?");
343 }
344 }
345 }
346
ComputeCastInfo(BaseNodeCastInfo & castInfo)347 void MapleCastOpt::ComputeCastInfo(BaseNodeCastInfo &castInfo)
348 {
349 DoComputeCastInfo(castInfo, false);
350 }
351
TransformCvtU1ToNe(MIRBuilder & mirBuilder,const TypeCvtNode * cvtExpr)352 BaseNode *MapleCastOpt::TransformCvtU1ToNe(MIRBuilder &mirBuilder, const TypeCvtNode *cvtExpr)
353 {
354 PrimType fromType = cvtExpr->FromType();
355 auto *fromMIRType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(fromType));
356 // We use u8 instead of u1 because codegen can't recognize u1
357 auto *toMIRType = GlobalTables::GetTypeTable().GetUInt8();
358 auto *zero = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, *fromMIRType);
359 auto *converted = mirBuilder.CreateExprCompare(OP_ne, *toMIRType, *fromMIRType, cvtExpr->Opnd(0),
360 mirBuilder.CreateConstval(zero));
361 return converted;
362 }
363
SimplifyCast(MIRBuilder & mirBuilder,BaseNode * expr)364 BaseNode *MapleCastOpt::SimplifyCast(MIRBuilder &mirBuilder, BaseNode *expr)
365 {
366 Opcode op = expr->GetOpCode();
367 if (!IsExplicitCastOp(op)) {
368 return nullptr;
369 }
370 auto *opnd = expr->Opnd(0);
371
372 // Convert `cvt u1 xx <expr>` to `ne u1 xx (<expr>, constval xx 0)`
373 if (op == OP_cvt && expr->GetPrimType() == PTY_u1 &&
374 static_cast<TypeCvtNode *>(expr)->FromType() != PTY_u1) { // No need to convert `cvt u1 u1 <expr>`
375 return TransformCvtU1ToNe(mirBuilder, static_cast<TypeCvtNode *>(expr));
376 }
377
378 // If the opnd is a iread/regread, it's OK because it may be a implicit zext or sext
379 Opcode opndOp = opnd->GetOpCode();
380 if (!IsExplicitCastOp(opndOp) && !IsImplicitCastOp(opndOp)) {
381 // only 1 cast
382 // Exmaple: cvt i32 i64 (add i32) ==> add i32
383 BaseNodeCastInfo castInfo(expr);
384 ComputeCastInfo(castInfo);
385 return SimplifyCastSingle(mirBuilder, castInfo);
386 }
387 BaseNodeCastInfo firstCastInfo(opnd);
388 ComputeCastInfo(firstCastInfo);
389 BaseNodeCastInfo secondCastInfo(expr);
390 ComputeCastInfo(secondCastInfo);
391 auto *simplified1 = SimplifyCastPair(mirBuilder, firstCastInfo, secondCastInfo);
392 BaseNode *simplified2 = nullptr;
393 if (simplified1 != nullptr && IsExplicitCastOp(simplified1->GetOpCode())) {
394 // Simplify cast further
395 secondCastInfo.expr = simplified1;
396 ComputeCastInfo(secondCastInfo);
397 simplified2 = SimplifyCastSingle(mirBuilder, secondCastInfo);
398 }
399 if (simplified2 != nullptr) {
400 return simplified2;
401 }
402 if (simplified1 == nullptr) {
403 BaseNodeCastInfo castInfo(expr);
404 ComputeCastInfo(castInfo);
405 return SimplifyCastSingle(mirBuilder, castInfo);
406 }
407 return simplified1;
408 }
409
SimplifyCastSingle(MIRBuilder & mirBuilder,const BaseNodeCastInfo & castInfo)410 BaseNode *MapleCastOpt::SimplifyCastSingle(MIRBuilder &mirBuilder, const BaseNodeCastInfo &castInfo)
411 {
412 if (castInfo.IsInvalid()) {
413 return nullptr;
414 }
415 auto *castExpr = static_cast<BaseNode *>(castInfo.expr);
416 auto *opnd = castExpr->Opnd(0);
417 Opcode op = castExpr->GetOpCode();
418 Opcode opndOp = opnd->GetOpCode();
419 // cast to integer + compare ==> compare
420 if (IsPrimitiveInteger(castInfo.dstType) && IsCompareOp(opndOp)) {
421 // exclude the following castExpr:
422 // sext xx 1 <expr>
423 bool excluded = (op == OP_sext && static_cast<ExtractbitsNode *>(castExpr)->GetBitsSize() == 1);
424 if (!excluded) {
425 opnd->SetPrimType(castExpr->GetPrimType());
426 return opnd;
427 }
428 }
429 // cast + const ==> const
430 if (castInfo.kind != CAST_retype && opndOp == OP_constval) {
431 ConstantFold cf(*theMIRModule);
432 MIRConst *cst = cf.FoldTypeCvtMIRConst(*static_cast<ConstvalNode *>(opnd)->GetConstVal(), castInfo.srcType,
433 castInfo.dstType);
434 if (cst != nullptr) {
435 return mirBuilder.CreateConstval(cst);
436 }
437 }
438 if (mapleSimplifyZextU32) {
439 // zextTo32 + read ==> read 32
440 if (castInfo.kind == CAST_zext && (opndOp == OP_iread || opndOp == OP_regread || opndOp == OP_dread)) {
441 uint32 dstSize = GetPrimTypeActualBitSize(castInfo.dstType);
442 if (dstSize == k32BitSize && IsUnsignedInteger(castInfo.dstType) &&
443 IsUnsignedInteger(opnd->GetPrimType()) &&
444 GetPrimTypeActualBitSize(castInfo.srcType) == GetPrimTypeActualBitSize(opnd->GetPrimType())) {
445 opnd->SetPrimType(castInfo.dstType);
446 return opnd;
447 }
448 }
449 }
450 if (castInfo.dstType == opnd->GetPrimType() &&
451 GetPrimTypeActualBitSize(castInfo.srcType) >= GetPrimTypeActualBitSize(opnd->GetPrimType())) {
452 return opnd;
453 }
454 return nullptr;
455 }
456
SimplifyCastPair(MIRBuilder & mirBuidler,const BaseNodeCastInfo & firstCastInfo,const BaseNodeCastInfo & secondCastInfo)457 BaseNode *MapleCastOpt::SimplifyCastPair(MIRBuilder &mirBuidler, const BaseNodeCastInfo &firstCastInfo,
458 const BaseNodeCastInfo &secondCastInfo)
459 {
460 if (firstCastInfo.IsInvalid()) {
461 // We can NOT eliminate the first cast, try to simplify the second cast individually
462 return SimplifyCastSingle(mirBuidler, secondCastInfo);
463 }
464 if (secondCastInfo.IsInvalid()) {
465 return nullptr;
466 }
467 PrimType srcType = firstCastInfo.srcType;
468 PrimType origSrcType = srcType;
469 PrimType midType1 = firstCastInfo.dstType;
470 PrimType midType2 = secondCastInfo.srcType;
471 PrimType dstType = secondCastInfo.dstType;
472 int result = IsEliminableCastPair(firstCastInfo.kind, secondCastInfo.kind, dstType, midType2, midType1, srcType);
473 if (result == -1) {
474 return SimplifyCastSingle(mirBuidler, secondCastInfo);
475 }
476 auto resultCastKind = CastKind(result);
477 auto *firstCastExpr = static_cast<BaseNode *>(firstCastInfo.expr);
478 auto *secondCastExpr = static_cast<BaseNode *>(secondCastInfo.expr);
479
480 // To improved: do more powerful optimization for firstCastImplicit
481 bool isFirstCastImplicit = !IsExplicitCastOp(firstCastExpr->GetOpCode());
482 if (isFirstCastImplicit) {
483 // Wrong example: zext u32 u8 (iread u32 <* u16>) =[x]=> iread u32 <* u16>
484 // srcType may be modified, we should use origSrcType
485 if (resultCastKind != CAST_unknown && dstType == midType1 &&
486 GetPrimTypeActualBitSize(midType2) >= GetPrimTypeActualBitSize(origSrcType)) {
487 return firstCastExpr;
488 } else {
489 return nullptr;
490 }
491 }
492
493 auto *toCastExpr = firstCastExpr->Opnd(0);
494 // Example: retype u32 <u32> (dread u32 %x) ==> dread u32 %x
495 // Example: retype ptr <* <$Foo>> (dread ptr %p) ==> dread ptr %p
496 if (resultCastKind == CAST_retype && srcType == dstType) {
497 if (toCastExpr->GetPrimType() != dstType) {
498 // Wrong example: retype i16 i16 (regread i32 %1) =[x]=> regread i32 %1
499 // instead: ==> cvt i16 i32 (regread i32 %1)
500 return mirBuidler.CreateExprTypeCvt(OP_cvt, dstType, toCastExpr->GetPrimType(), *toCastExpr);
501 }
502 return toCastExpr;
503 }
504
505 TyIdx dstTyIdx(0);
506 if (resultCastKind == CAST_retype) {
507 // result retype is generated from `retype t1 t2 (retype t3 t4)`
508 if (secondCastExpr->GetOpCode() == OP_retype) {
509 dstTyIdx = static_cast<RetypeNode *>(secondCastExpr)->GetTyIdx();
510 } else {
511 dstTyIdx = TyIdx(dstType);
512 }
513 }
514 return CreateMapleExprByCastKind(mirBuidler, resultCastKind, srcType, dstType, toCastExpr, dstTyIdx);
515 }
516 } // namespace maple
517