• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "optimizer/ir/basicblock.h"
17 #include "optimizer/ir/graph.h"
18 #include "optimizer/ir/datatype.h"
19 #include "optimizer/optimizations/const_folding.h"
20 #include "utils/math_helpers.h"
21 
22 #include <cmath>
23 #include <map>
24 
25 namespace ark::compiler {
26 template <class T>
ConvertIntToInt(T value,DataType::Type targetType)27 uint64_t ConvertIntToInt(T value, DataType::Type targetType)
28 {
29     switch (targetType) {
30         case DataType::BOOL:
31             return static_cast<uint64_t>(static_cast<bool>(value));
32         case DataType::UINT8:
33             return static_cast<uint64_t>(static_cast<uint8_t>(value));
34         case DataType::INT8:
35             return static_cast<uint64_t>(static_cast<int8_t>(value));
36         case DataType::UINT16:
37             return static_cast<uint64_t>(static_cast<uint16_t>(value));
38         case DataType::INT16:
39             return static_cast<uint64_t>(static_cast<int16_t>(value));
40         case DataType::UINT32:
41             return static_cast<uint64_t>(static_cast<uint32_t>(value));
42         case DataType::INT32:
43             return static_cast<uint64_t>(static_cast<int32_t>(value));
44         case DataType::UINT64:
45             return static_cast<uint64_t>(value);
46         case DataType::INT64:
47             return static_cast<uint64_t>(static_cast<int64_t>(value));
48         default:
49             UNREACHABLE();
50     }
51 }
52 
53 template <class T>
ConvertIntToFloat(uint64_t value,DataType::Type sourceType)54 T ConvertIntToFloat(uint64_t value, DataType::Type sourceType)
55 {
56     switch (sourceType) {
57         case DataType::BOOL:
58             return static_cast<T>(static_cast<bool>(value));
59         case DataType::UINT8:
60             return static_cast<T>(static_cast<uint8_t>(value));
61         case DataType::INT8:
62             return static_cast<T>(static_cast<int8_t>(value));
63         case DataType::UINT16:
64             return static_cast<T>(static_cast<uint16_t>(value));
65         case DataType::INT16:
66             return static_cast<T>(static_cast<int16_t>(value));
67         case DataType::UINT32:
68             return static_cast<T>(static_cast<uint32_t>(value));
69         case DataType::INT32:
70             return static_cast<T>(static_cast<int32_t>(value));
71         case DataType::UINT64:
72             return static_cast<T>(value);
73         case DataType::INT64:
74             return static_cast<T>(static_cast<int64_t>(value));
75         default:
76             UNREACHABLE();
77     }
78 }
79 
80 template <class To, class From>
ConvertFloatToInt(From value)81 To ConvertFloatToInt(From value)
82 {
83     To res;
84 
85     constexpr To MIN_INT = std::numeric_limits<To>::min();
86     constexpr To MAX_INT = std::numeric_limits<To>::max();
87     const auto floatMinInt = static_cast<From>(MIN_INT);
88     const auto floatMaxInt = static_cast<From>(MAX_INT);
89 
90     if (value > floatMinInt) {
91         if (value < floatMaxInt) {
92             res = static_cast<To>(value);
93         } else {
94             res = MAX_INT;
95         }
96     } else if (std::isnan(value)) {
97         res = 0;
98     } else {
99         res = MIN_INT;
100     }
101 
102     return static_cast<To>(res);
103 }
104 
105 template <class From>
ConvertFloatToInt(From value,DataType::Type targetType)106 uint64_t ConvertFloatToInt(From value, DataType::Type targetType)
107 {
108     ASSERT(DataType::GetCommonType(targetType) == DataType::INT64);
109     switch (targetType) {
110         case DataType::BOOL:
111             return static_cast<uint64_t>(ConvertFloatToInt<bool>(value));
112         case DataType::UINT8:
113             return static_cast<uint64_t>(ConvertFloatToInt<uint8_t>(value));
114         case DataType::INT8:
115             return static_cast<uint64_t>(ConvertFloatToInt<int8_t>(value));
116         case DataType::UINT16:
117             return static_cast<uint64_t>(ConvertFloatToInt<uint16_t>(value));
118         case DataType::INT16:
119             return static_cast<uint64_t>(ConvertFloatToInt<int16_t>(value));
120         case DataType::UINT32:
121             return static_cast<uint64_t>(ConvertFloatToInt<uint32_t>(value));
122         case DataType::INT32:
123             return static_cast<uint64_t>(ConvertFloatToInt<int32_t>(value));
124         case DataType::UINT64:
125             return ConvertFloatToInt<uint64_t>(value);
126         case DataType::INT64:
127             return static_cast<uint64_t>(ConvertFloatToInt<int64_t>(value));
128         default:
129             UNREACHABLE();
130     }
131 }
132 
133 template <class From>
ConvertFloatToIntDyn(From value,RuntimeInterface * runtime,size_t bits)134 uint64_t ConvertFloatToIntDyn(From value, RuntimeInterface *runtime, size_t bits)
135 {
136     return runtime->DynamicCastDoubleToInt(static_cast<double>(value), bits);
137 }
138 
ConstFoldingCreateIntConst(Inst * inst,uint64_t value,bool isLiteralData)139 ConstantInst *ConstFoldingCreateIntConst(Inst *inst, uint64_t value, bool isLiteralData)
140 {
141     auto graph = inst->GetBasicBlock()->GetGraph();
142     if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType()) && !isLiteralData) {
143         return graph->FindOrCreateConstant<uint32_t>(value);
144     }
145     return graph->FindOrCreateConstant(value);
146 }
147 
148 template <typename T>
ConstFoldingCreateConst(Inst * inst,ConstantInst * cnst,bool isLiteralData=false)149 ConstantInst *ConstFoldingCreateConst(Inst *inst, ConstantInst *cnst, bool isLiteralData = false)
150 {
151     return ConstFoldingCreateIntConst(inst, ConvertIntToInt(static_cast<T>(cnst->GetIntValue()), inst->GetType()),
152                                       isLiteralData);
153 }
154 
ConstFoldingCastInt2Int(Inst * inst,ConstantInst * cnst)155 ConstantInst *ConstFoldingCastInt2Int(Inst *inst, ConstantInst *cnst)
156 {
157     switch (inst->GetInputType(0)) {
158         case DataType::BOOL:
159             return ConstFoldingCreateConst<bool>(inst, cnst);
160         case DataType::UINT8:
161             return ConstFoldingCreateConst<uint8_t>(inst, cnst);
162         case DataType::INT8:
163             return ConstFoldingCreateConst<int8_t>(inst, cnst);
164         case DataType::UINT16:
165             return ConstFoldingCreateConst<uint16_t>(inst, cnst);
166         case DataType::INT16:
167             return ConstFoldingCreateConst<int16_t>(inst, cnst);
168         case DataType::UINT32:
169             return ConstFoldingCreateConst<uint32_t>(inst, cnst);
170         case DataType::INT32:
171             return ConstFoldingCreateConst<int32_t>(inst, cnst);
172         case DataType::UINT64:
173             return ConstFoldingCreateConst<uint64_t>(inst, cnst);
174         case DataType::INT64:
175             return ConstFoldingCreateConst<int64_t>(inst, cnst);
176         default:
177             return nullptr;
178     }
179 }
180 
ConstFoldingCastIntConst(Graph * graph,Inst * inst,ConstantInst * cnst,bool isLiteralData=false)181 ConstantInst *ConstFoldingCastIntConst(Graph *graph, Inst *inst, ConstantInst *cnst, bool isLiteralData = false)
182 {
183     auto instType = DataType::GetCommonType(inst->GetType());
184     if (instType == DataType::INT64) {
185         // INT -> INT
186         return ConstFoldingCastInt2Int(inst, cnst);
187     }
188     if (instType == DataType::FLOAT32) {
189         // INT -> FLOAT
190         if (graph->IsBytecodeOptimizer() && !isLiteralData) {
191             return nullptr;
192         }
193         return graph->FindOrCreateConstant(ConvertIntToFloat<float>(cnst->GetIntValue(), inst->GetInputType(0)));
194     }
195     if (instType == DataType::FLOAT64) {
196         // INT -> DOUBLE
197         return graph->FindOrCreateConstant(ConvertIntToFloat<double>(cnst->GetIntValue(), inst->GetInputType(0)));
198     }
199     return nullptr;
200 }
201 
ConstFoldingCastConst(Inst * inst,Inst * input,bool isLiteralData)202 ConstantInst *ConstFoldingCastConst(Inst *inst, Inst *input, bool isLiteralData)
203 {
204     auto graph = inst->GetBasicBlock()->GetGraph();
205     auto cnst = static_cast<ConstantInst *>(input);
206     auto instType = DataType::GetCommonType(inst->GetType());
207     if (cnst->GetType() == DataType::INT32 || cnst->GetType() == DataType::INT64) {
208         return ConstFoldingCastIntConst(graph, inst, cnst);
209     }
210     if (cnst->GetType() == DataType::FLOAT32) {
211         if (graph->IsBytecodeOptimizer() && !isLiteralData) {
212             return nullptr;
213         }
214         if (instType == DataType::INT64) {
215             // FLOAT->INT
216             return graph->FindOrCreateConstant(ConvertFloatToInt(cnst->GetFloatValue(), inst->GetType()));
217         }
218         if (instType == DataType::FLOAT32) {
219             // FLOAT -> FLOAT
220             return cnst;
221         }
222         if (instType == DataType::FLOAT64) {
223             // FLOAT -> DOUBLE
224             return graph->FindOrCreateConstant(static_cast<double>(cnst->GetFloatValue()));
225         }
226     } else if (cnst->GetType() == DataType::FLOAT64) {
227         if (instType == DataType::INT64) {
228             // DOUBLE->INT/LONG
229             uint64_t val = graph->IsDynamicMethod()
230                                ? ConvertFloatToIntDyn(cnst->GetDoubleValue(), graph->GetRuntime(),
231                                                       DataType::GetTypeSize(inst->GetType(), graph->GetArch()))
232                                : ConvertFloatToInt(cnst->GetDoubleValue(), inst->GetType());
233             return ConstFoldingCreateIntConst(inst, val, isLiteralData);
234         }
235         if (instType == DataType::FLOAT32) {
236             // DOUBLE -> FLOAT
237             if (graph->IsBytecodeOptimizer() && !isLiteralData) {
238                 return nullptr;
239             }
240             return graph->FindOrCreateConstant(static_cast<float>(cnst->GetDoubleValue()));
241         }
242         if (instType == DataType::FLOAT64) {
243             // DOUBLE -> DOUBLE
244             return cnst;
245         }
246     }
247     return nullptr;
248 }
249 
ConstFoldingCast(Inst * inst)250 bool ConstFoldingCast(Inst *inst)
251 {
252     ASSERT(inst->GetOpcode() == Opcode::Cast);
253     auto input = inst->GetInput(0).GetInst();
254     if (input->IsConst()) {
255         ConstantInst *nwCnst = ConstFoldingCastConst(inst, input);
256         if (nwCnst != nullptr) {
257             inst->ReplaceUsers(nwCnst);
258             return true;
259         }
260     }
261     return false;
262 }
263 
ConstFoldingNeg(Inst * inst)264 bool ConstFoldingNeg(Inst *inst)
265 {
266     ASSERT(inst->GetOpcode() == Opcode::Neg);
267     auto input = inst->GetInput(0);
268     auto graph = inst->GetBasicBlock()->GetGraph();
269     if (input.GetInst()->IsConst()) {
270         auto cnst = static_cast<ConstantInst *>(input.GetInst());
271         ConstantInst *newCnst = nullptr;
272         switch (DataType::GetCommonType(inst->GetType())) {
273             case DataType::INT64:
274                 newCnst = ConstFoldingCreateIntConst(inst, ConvertIntToInt(-cnst->GetIntValue(), inst->GetType()));
275                 break;
276             case DataType::FLOAT32:
277                 newCnst = graph->FindOrCreateConstant(-cnst->GetFloatValue());
278                 break;
279             case DataType::FLOAT64:
280                 newCnst = graph->FindOrCreateConstant(-cnst->GetDoubleValue());
281                 break;
282             default:
283                 UNREACHABLE();
284         }
285         inst->ReplaceUsers(newCnst);
286         return true;
287     }
288     return false;
289 }
290 
ConstFoldingAbs(Inst * inst)291 bool ConstFoldingAbs(Inst *inst)
292 {
293     ASSERT(inst->GetOpcode() == Opcode::Abs);
294     auto input = inst->GetInput(0);
295     auto graph = inst->GetBasicBlock()->GetGraph();
296     if (input.GetInst()->IsConst()) {
297         auto cnst = static_cast<ConstantInst *>(input.GetInst());
298         ConstantInst *newCnst = nullptr;
299         switch (DataType::GetCommonType(inst->GetType())) {
300             case DataType::INT64: {
301                 ASSERT(DataType::IsTypeSigned(inst->GetType()));
302                 auto value = static_cast<int64_t>(cnst->GetIntValue());
303                 if (value == INT64_MIN) {
304                     newCnst = cnst;
305                     break;
306                 }
307                 auto uvalue = static_cast<uint64_t>((value < 0) ? -value : value);
308                 newCnst = ConstFoldingCreateIntConst(inst, ConvertIntToInt(uvalue, inst->GetType()));
309                 break;
310             }
311             case DataType::FLOAT32:
312                 newCnst = graph->FindOrCreateConstant(std::abs(cnst->GetFloatValue()));
313                 break;
314             case DataType::FLOAT64:
315                 newCnst = graph->FindOrCreateConstant(std::abs(cnst->GetDoubleValue()));
316                 break;
317             default:
318                 UNREACHABLE();
319         }
320         inst->ReplaceUsers(newCnst);
321         return true;
322     }
323     return false;
324 }
325 
ConstFoldingNot(Inst * inst)326 bool ConstFoldingNot(Inst *inst)
327 {
328     ASSERT(inst->GetOpcode() == Opcode::Not);
329     auto input = inst->GetInput(0);
330     ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
331     if (input.GetInst()->IsConst()) {
332         auto cnst = static_cast<ConstantInst *>(input.GetInst());
333         auto newCnst = ConstFoldingCreateIntConst(inst, ConvertIntToInt(~cnst->GetIntValue(), inst->GetType()));
334         inst->ReplaceUsers(newCnst);
335         return true;
336     }
337     return false;
338 }
339 
ConstFoldingAdd(Inst * inst)340 bool ConstFoldingAdd(Inst *inst)
341 {
342     ASSERT(inst->GetOpcode() == Opcode::Add);
343     auto input0 = inst->GetInput(0);
344     auto input1 = inst->GetInput(1);
345     auto graph = inst->GetBasicBlock()->GetGraph();
346     if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
347         auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
348         auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
349         ConstantInst *newCnst = nullptr;
350         switch (DataType::GetCommonType(inst->GetType())) {
351             case DataType::INT64:
352                 newCnst = ConstFoldingCreateIntConst(
353                     inst, ConvertIntToInt(cnst0->GetIntValue() + cnst1->GetIntValue(), inst->GetType()));
354                 break;
355             case DataType::FLOAT32:
356                 newCnst = graph->FindOrCreateConstant(cnst0->GetFloatValue() + cnst1->GetFloatValue());
357                 break;
358             case DataType::FLOAT64:
359                 newCnst = graph->FindOrCreateConstant(cnst0->GetDoubleValue() + cnst1->GetDoubleValue());
360                 break;
361             default:
362                 UNREACHABLE();
363         }
364         inst->ReplaceUsers(newCnst);
365         return true;
366     }
367     return ConstFoldingBinaryMathWithNan(inst);
368 }
369 
ConstFoldingSub(Inst * inst)370 bool ConstFoldingSub(Inst *inst)
371 {
372     ASSERT(inst->GetOpcode() == Opcode::Sub);
373     auto input0 = inst->GetInput(0);
374     auto input1 = inst->GetInput(1);
375     auto graph = inst->GetBasicBlock()->GetGraph();
376     if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
377         ConstantInst *newCnst = nullptr;
378         auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
379         auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
380         switch (DataType::GetCommonType(inst->GetType())) {
381             case DataType::INT64:
382                 newCnst = ConstFoldingCreateIntConst(
383                     inst, ConvertIntToInt(cnst0->GetIntValue() - cnst1->GetIntValue(), inst->GetType()));
384                 break;
385             case DataType::FLOAT32:
386                 newCnst = graph->FindOrCreateConstant(cnst0->GetFloatValue() - cnst1->GetFloatValue());
387                 break;
388             case DataType::FLOAT64:
389                 newCnst = graph->FindOrCreateConstant(cnst0->GetDoubleValue() - cnst1->GetDoubleValue());
390                 break;
391             default:
392                 UNREACHABLE();
393         }
394         inst->ReplaceUsers(newCnst);
395         return true;
396     }
397     if (input0.GetInst() == input1.GetInst() && DataType::GetCommonType(inst->GetType()) == DataType::INT64) {
398         // for floating point values 'x-x -> 0' optimization is not applicable because of NaN/Infinity values
399         auto newCnst = ConstFoldingCreateIntConst(inst, 0);
400         inst->ReplaceUsers(newCnst);
401         return true;
402     }
403     return ConstFoldingBinaryMathWithNan(inst);
404 }
405 
ConstFoldingMul(Inst * inst)406 bool ConstFoldingMul(Inst *inst)
407 {
408     ASSERT(inst->GetOpcode() == Opcode::Mul);
409     auto input0 = inst->GetInput(0).GetInst();
410     auto input1 = inst->GetInput(1).GetInst();
411     auto graph = inst->GetBasicBlock()->GetGraph();
412     ConstantInst *newCnst = nullptr;
413     if (input0->IsConst() && input1->IsConst()) {
414         auto cnst0 = static_cast<ConstantInst *>(input0);
415         auto cnst1 = static_cast<ConstantInst *>(input1);
416         switch (DataType::GetCommonType(inst->GetType())) {
417             case DataType::INT64:
418                 newCnst = ConstFoldingCreateIntConst(
419                     inst, ConvertIntToInt(cnst0->GetIntValue() * cnst1->GetIntValue(), inst->GetType()));
420                 break;
421             case DataType::FLOAT32:
422                 newCnst = graph->FindOrCreateConstant(cnst0->GetFloatValue() * cnst1->GetFloatValue());
423                 break;
424             case DataType::FLOAT64:
425                 newCnst = graph->FindOrCreateConstant(cnst0->GetDoubleValue() * cnst1->GetDoubleValue());
426                 break;
427             default:
428                 UNREACHABLE();
429         }
430         inst->ReplaceUsers(newCnst);
431         return true;
432     }
433     if (ConstFoldingBinaryMathWithNan(inst)) {
434         return true;
435     }
436     // Const is always in input1
437     if (input0->IsConst()) {
438         std::swap(input0, input1);
439     }
440     if (input1->IsConst() && input1->CastToConstant()->IsEqualConst(0, graph->IsBytecodeOptimizer())) {
441         inst->ReplaceUsers(input1);
442         return true;
443     }
444     return false;
445 }
446 
ConstFoldingBinaryMathWithNan(Inst * inst)447 bool ConstFoldingBinaryMathWithNan(Inst *inst)
448 {
449     ASSERT(inst->GetInputsCount() == 2U);
450     auto input0 = inst->GetInput(0).GetInst();
451     auto input1 = inst->GetInput(1).GetInst();
452     ASSERT(!input0->IsConst() || !input1->IsConst());
453     if (!DataType::IsFloatType(inst->GetType())) {
454         return false;
455     }
456     if (input0->IsConst()) {
457         std::swap(input0, input1);
458     }
459     if (!input1->IsConst()) {
460         return false;
461     }
462     if (!input1->CastToConstant()->IsNaNConst()) {
463         return false;
464     }
465     inst->ReplaceUsers(input1);
466     return true;
467 }
468 
ConstFoldingDivInt2Int(Inst * inst,Graph * graph,ConstantInst * cnst0,ConstantInst * cnst1)469 ConstantInst *ConstFoldingDivInt2Int(Inst *inst, Graph *graph, ConstantInst *cnst0, ConstantInst *cnst1)
470 {
471     if (cnst1->GetIntValue() == 0) {
472         return nullptr;
473     }
474     if (DataType::IsTypeSigned(inst->GetType())) {
475         if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
476             if (static_cast<int32_t>(cnst0->GetIntValue()) == INT32_MIN &&
477                 static_cast<int32_t>(cnst1->GetIntValue()) == -1) {
478                 return graph->FindOrCreateConstant<uint32_t>(INT32_MIN);
479             }
480             return graph->FindOrCreateConstant<uint32_t>(
481                 ConvertIntToInt(static_cast<int32_t>(cnst0->GetIntValue()) / static_cast<int32_t>(cnst1->GetIntValue()),
482                                 inst->GetType()));
483         }
484         if (static_cast<int64_t>(cnst0->GetIntValue()) == INT64_MIN &&
485             static_cast<int64_t>(cnst1->GetIntValue()) == -1) {
486             return graph->FindOrCreateConstant<uint64_t>(INT64_MIN);
487         }
488         return graph->FindOrCreateConstant(ConvertIntToInt(
489             static_cast<int64_t>(cnst0->GetIntValue()) / static_cast<int64_t>(cnst1->GetIntValue()), inst->GetType()));
490     }
491 
492     return ConstFoldingCreateIntConst(inst, ConvertIntToInt(cnst0->GetIntValue(), inst->GetType()) /
493                                                 ConvertIntToInt(cnst1->GetIntValue(), inst->GetType()));
494 }
495 
ConstFoldingDiv(Inst * inst)496 bool ConstFoldingDiv(Inst *inst)
497 {
498     ASSERT(inst->GetOpcode() == Opcode::Div);
499     auto input0 = inst->GetDataFlowInput(0);
500     auto input1 = inst->GetDataFlowInput(1);
501     auto graph = inst->GetBasicBlock()->GetGraph();
502     if (!input0->IsConst() || !input1->IsConst()) {
503         return ConstFoldingBinaryMathWithNan(inst);
504     }
505     auto cnst0 = input0->CastToConstant();
506     auto cnst1 = input1->CastToConstant();
507     ConstantInst *newCnst = nullptr;
508     switch (DataType::GetCommonType(inst->GetType())) {
509         case DataType::INT64:
510             newCnst = ConstFoldingDivInt2Int(inst, graph, cnst0, cnst1);
511             if (newCnst == nullptr) {
512                 return false;
513             }
514             break;
515         case DataType::FLOAT32:
516             newCnst = graph->FindOrCreateConstant(cnst0->GetFloatValue() / cnst1->GetFloatValue());
517             break;
518         case DataType::FLOAT64:
519             newCnst = graph->FindOrCreateConstant(cnst0->GetDoubleValue() / cnst1->GetDoubleValue());
520             break;
521         default:
522             UNREACHABLE();
523     }
524     inst->ReplaceUsers(newCnst);
525     return true;
526 }
527 
ConstFoldingMinInt(Inst * inst,Graph * graph,ConstantInst * cnst0,ConstantInst * cnst1)528 ConstantInst *ConstFoldingMinInt(Inst *inst, Graph *graph, ConstantInst *cnst0, ConstantInst *cnst1)
529 {
530     if (DataType::IsTypeSigned(inst->GetType())) {
531         if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
532             return graph->FindOrCreateConstant<uint32_t>(ConvertIntToInt(
533                 std::min(static_cast<int32_t>(cnst0->GetIntValue()), static_cast<int32_t>(cnst1->GetIntValue())),
534                 inst->GetType()));
535         }
536         return graph->FindOrCreateConstant(ConvertIntToInt(
537             std::min(static_cast<int64_t>(cnst0->GetIntValue()), static_cast<int64_t>(cnst1->GetIntValue())),
538             inst->GetType()));
539     }
540     return ConstFoldingCreateIntConst(
541         inst, ConvertIntToInt(std::min(cnst0->GetIntValue(), cnst1->GetIntValue()), inst->GetType()));
542 }
543 
ConstFoldingMin(Inst * inst)544 bool ConstFoldingMin(Inst *inst)
545 {
546     ASSERT(inst->GetOpcode() == Opcode::Min);
547     auto input0 = inst->GetInput(0);
548     auto input1 = inst->GetInput(1);
549     auto graph = inst->GetBasicBlock()->GetGraph();
550     if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
551         auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
552         auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
553         ConstantInst *newCnst = nullptr;
554         switch (DataType::GetCommonType(inst->GetType())) {
555             case DataType::INT64:
556                 newCnst = ConstFoldingMinInt(inst, graph, cnst0, cnst1);
557                 ASSERT(newCnst != nullptr);
558                 break;
559             case DataType::FLOAT32:
560                 newCnst = graph->FindOrCreateConstant(
561                     ark::helpers::math::Min(cnst0->GetFloatValue(), cnst1->GetFloatValue()));
562                 break;
563             case DataType::FLOAT64:
564                 newCnst = graph->FindOrCreateConstant(
565                     ark::helpers::math::Min(cnst0->GetDoubleValue(), cnst1->GetDoubleValue()));
566                 break;
567             default:
568                 UNREACHABLE();
569         }
570         inst->ReplaceUsers(newCnst);
571         return true;
572     }
573     return ConstFoldingBinaryMathWithNan(inst);
574 }
575 
ConstFoldingMaxInt(Inst * inst,Graph * graph,ConstantInst * cnst0,ConstantInst * cnst1)576 ConstantInst *ConstFoldingMaxInt(Inst *inst, Graph *graph, ConstantInst *cnst0, ConstantInst *cnst1)
577 {
578     if (DataType::IsTypeSigned(inst->GetType())) {
579         if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
580             return graph->FindOrCreateConstant<uint32_t>(ConvertIntToInt(
581                 std::max(static_cast<int32_t>(cnst0->GetIntValue()), static_cast<int32_t>(cnst1->GetIntValue())),
582                 inst->GetType()));
583         }
584         return graph->FindOrCreateConstant(ConvertIntToInt(
585             std::max(static_cast<int64_t>(cnst0->GetIntValue()), static_cast<int64_t>(cnst1->GetIntValue())),
586             inst->GetType()));
587     }
588     return ConstFoldingCreateIntConst(
589         inst, ConvertIntToInt(std::max(cnst0->GetIntValue(), cnst1->GetIntValue()), inst->GetType()));
590 }
591 
ConstFoldingMax(Inst * inst)592 bool ConstFoldingMax(Inst *inst)
593 {
594     ASSERT(inst->GetOpcode() == Opcode::Max);
595     auto input0 = inst->GetInput(0);
596     auto input1 = inst->GetInput(1);
597     auto graph = inst->GetBasicBlock()->GetGraph();
598     if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
599         auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
600         auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
601         ConstantInst *newCnst = nullptr;
602         switch (DataType::GetCommonType(inst->GetType())) {
603             case DataType::INT64:
604                 newCnst = ConstFoldingMaxInt(inst, graph, cnst0, cnst1);
605                 ASSERT(newCnst != nullptr);
606                 break;
607             case DataType::FLOAT32:
608                 newCnst = graph->FindOrCreateConstant(
609                     ark::helpers::math::Max(cnst0->GetFloatValue(), cnst1->GetFloatValue()));
610                 break;
611             case DataType::FLOAT64:
612                 newCnst = graph->FindOrCreateConstant(
613                     ark::helpers::math::Max(cnst0->GetDoubleValue(), cnst1->GetDoubleValue()));
614                 break;
615             default:
616                 UNREACHABLE();
617         }
618         inst->ReplaceUsers(newCnst);
619         return true;
620     }
621     return ConstFoldingBinaryMathWithNan(inst);
622 }
623 
ConstFoldingModIntConst(Graph * graph,Inst * inst,ConstantInst * cnst0,ConstantInst * cnst1)624 ConstantInst *ConstFoldingModIntConst(Graph *graph, Inst *inst, ConstantInst *cnst0, ConstantInst *cnst1)
625 {
626     if (DataType::IsTypeSigned(inst->GetType())) {
627         if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
628             if (static_cast<int32_t>(cnst0->GetIntValue()) == INT32_MIN &&
629                 static_cast<int32_t>(cnst1->GetIntValue()) == -1) {
630                 return graph->FindOrCreateConstant<uint32_t>(0);
631             }
632             return graph->FindOrCreateConstant<uint32_t>(
633                 ConvertIntToInt(static_cast<int32_t>(cnst0->GetIntValue()) % static_cast<int32_t>(cnst1->GetIntValue()),
634                                 inst->GetType()));
635         }
636         if (static_cast<int64_t>(cnst0->GetIntValue()) == INT64_MIN &&
637             static_cast<int64_t>(cnst1->GetIntValue()) == -1) {
638             return graph->FindOrCreateConstant<uint64_t>(0);
639         }
640         return graph->FindOrCreateConstant(ConvertIntToInt(
641             static_cast<int64_t>(cnst0->GetIntValue()) % static_cast<int64_t>(cnst1->GetIntValue()), inst->GetType()));
642     }
643     return ConstFoldingCreateIntConst(inst, ConvertIntToInt(cnst0->GetIntValue(), inst->GetType()) %
644                                                 ConvertIntToInt(cnst1->GetIntValue(), inst->GetType()));
645 }
646 
ConstFoldingMod(Inst * inst)647 bool ConstFoldingMod(Inst *inst)
648 {
649     ASSERT(inst->GetOpcode() == Opcode::Mod);
650     auto input0 = inst->GetDataFlowInput(0);
651     auto input1 = inst->GetDataFlowInput(1);
652     auto graph = inst->GetBasicBlock()->GetGraph();
653     if (input1->IsConst() && !DataType::IsFloatType(inst->GetType()) && input1->CastToConstant()->GetIntValue() == 1) {
654         ConstantInst *cnst = ConstFoldingCreateIntConst(inst, 0);
655         inst->ReplaceUsers(cnst);
656         return true;
657     }
658     if (!input0->IsConst() || !input1->IsConst()) {
659         return ConstFoldingBinaryMathWithNan(inst);
660     }
661     ConstantInst *newCnst = nullptr;
662     auto cnst0 = input0->CastToConstant();
663     auto cnst1 = input1->CastToConstant();
664     if (DataType::GetCommonType(inst->GetType()) == DataType::INT64) {
665         if (cnst1->GetIntValue() == 0) {
666             return false;
667         }
668         newCnst = ConstFoldingModIntConst(graph, inst, cnst0, cnst1);
669     } else if (inst->GetType() == DataType::FLOAT32) {
670         if (cnst1->GetFloatValue() == 0) {
671             return false;
672         }
673         newCnst =
674             graph->FindOrCreateConstant(static_cast<float>(fmodf(cnst0->GetFloatValue(), cnst1->GetFloatValue())));
675     } else if (inst->GetType() == DataType::FLOAT64) {
676         if (cnst1->GetDoubleValue() == 0) {
677             return false;
678         }
679         newCnst = graph->FindOrCreateConstant(fmod(cnst0->GetDoubleValue(), cnst1->GetDoubleValue()));
680     }
681     inst->ReplaceUsers(newCnst);
682     return true;
683 }
684 
ConstFoldingShl(Inst * inst)685 bool ConstFoldingShl(Inst *inst)
686 {
687     ASSERT(inst->GetOpcode() == Opcode::Shl);
688     auto input0 = inst->GetInput(0);
689     auto input1 = inst->GetInput(1);
690     auto graph = inst->GetBasicBlock()->GetGraph();
691     ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
692     if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
693         auto cnst0 = input0.GetInst()->CastToConstant()->GetIntValue();
694         auto cnst1 = input1.GetInst()->CastToConstant()->GetIntValue();
695         ConstantInst *newCnst = nullptr;
696         uint64_t sizeMask = DataType::GetTypeSize(inst->GetType(), graph->GetArch()) - 1;
697         if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
698             newCnst = graph->FindOrCreateConstant<uint32_t>(ConvertIntToInt(
699                 static_cast<uint32_t>(cnst0) << (static_cast<uint32_t>(cnst1) & static_cast<uint32_t>(sizeMask)),
700                 inst->GetType()));
701         } else {
702             newCnst = graph->FindOrCreateConstant(ConvertIntToInt(cnst0 << (cnst1 & sizeMask), inst->GetType()));
703         }
704         inst->ReplaceUsers(newCnst);
705         return true;
706     }
707     return false;
708 }
709 
ConstFoldingShr(Inst * inst)710 bool ConstFoldingShr(Inst *inst)
711 {
712     ASSERT(inst->GetOpcode() == Opcode::Shr);
713     auto input0 = inst->GetInput(0);
714     auto input1 = inst->GetInput(1);
715     auto graph = inst->GetBasicBlock()->GetGraph();
716     ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
717     if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
718         auto cnst0 = input0.GetInst()->CastToConstant()->GetIntValue();
719         auto cnst1 = input1.GetInst()->CastToConstant()->GetIntValue();
720         uint64_t sizeMask = DataType::GetTypeSize(inst->GetType(), graph->GetArch()) - 1;
721         // zerod high part of the constant
722         if (sizeMask < DataType::GetTypeSize(DataType::INT32, graph->GetArch())) {
723             // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
724             uint64_t typeMask = (1ULL << (sizeMask + 1)) - 1;
725             cnst0 = cnst0 & typeMask;
726         }
727         ConstantInst *newCnst = nullptr;
728         if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
729             newCnst = graph->FindOrCreateConstant<uint32_t>(ConvertIntToInt(
730                 static_cast<uint32_t>(cnst0) >> (static_cast<uint32_t>(cnst1) & static_cast<uint32_t>(sizeMask)),
731                 inst->GetType()));
732         } else {
733             newCnst = graph->FindOrCreateConstant(ConvertIntToInt(cnst0 >> (cnst1 & sizeMask), inst->GetType()));
734         }
735         inst->ReplaceUsers(newCnst);
736         return true;
737     }
738     return false;
739 }
740 
ConstFoldingAShr(Inst * inst)741 bool ConstFoldingAShr(Inst *inst)
742 {
743     ASSERT(inst->GetOpcode() == Opcode::AShr);
744     auto input0 = inst->GetInput(0);
745     auto input1 = inst->GetInput(1);
746     auto graph = inst->GetBasicBlock()->GetGraph();
747     ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
748     if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
749         int64_t cnst0 = input0.GetInst()->CastToConstant()->GetIntValue();
750         auto cnst1 = input1.GetInst()->CastToConstant()->GetIntValue();
751         uint64_t sizeMask = DataType::GetTypeSize(inst->GetType(), graph->GetArch()) - 1;
752         ConstantInst *newCnst = nullptr;
753         if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
754             newCnst = graph->FindOrCreateConstant<uint32_t>(
755                 // NOLINTNEXTLINE(hicpp-signed-bitwise)
756                 ConvertIntToInt(static_cast<int32_t>(cnst0) >>
757                                     (static_cast<uint32_t>(cnst1) & static_cast<uint32_t>(sizeMask)),
758                                 inst->GetType()));
759         } else {
760             newCnst = graph->FindOrCreateConstant(
761                 // NOLINTNEXTLINE(hicpp-signed-bitwise)
762                 ConvertIntToInt(cnst0 >> (cnst1 & sizeMask), inst->GetType()));
763         }
764         inst->ReplaceUsers(newCnst);
765         return true;
766     }
767     return false;
768 }
769 
ConstFoldingAnd(Inst * inst)770 bool ConstFoldingAnd(Inst *inst)
771 {
772     ASSERT(inst->GetOpcode() == Opcode::And);
773     auto input0 = inst->GetInput(0);
774     auto input1 = inst->GetInput(1);
775     ConstantInst *newCnst = nullptr;
776     ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
777     if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
778         auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
779         auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
780         newCnst = ConstFoldingCreateIntConst(
781             inst, ConvertIntToInt(cnst0->GetIntValue() & cnst1->GetIntValue(), inst->GetType()));
782         inst->ReplaceUsers(newCnst);
783         return true;
784     }
785     if (input0.GetInst()->IsConst()) {
786         newCnst = static_cast<ConstantInst *>(input0.GetInst());
787     } else if (input1.GetInst()->IsConst()) {
788         newCnst = static_cast<ConstantInst *>(input1.GetInst());
789     }
790     if (newCnst != nullptr && newCnst->GetIntValue() == 0) {
791         inst->ReplaceUsers(newCnst);
792         return true;
793     }
794     return false;
795 }
796 
ConstFoldingOr(Inst * inst)797 bool ConstFoldingOr(Inst *inst)
798 {
799     ASSERT(inst->GetOpcode() == Opcode::Or);
800     auto input0 = inst->GetInput(0);
801     auto input1 = inst->GetInput(1);
802     ConstantInst *newCnst = nullptr;
803     ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
804     if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
805         auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
806         auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
807         newCnst = ConstFoldingCreateIntConst(
808             inst, ConvertIntToInt(cnst0->GetIntValue() | cnst1->GetIntValue(), inst->GetType()));
809         inst->ReplaceUsers(newCnst);
810         return true;
811     }
812     if (input0.GetInst()->IsConst()) {
813         newCnst = static_cast<ConstantInst *>(input0.GetInst());
814     } else if (input1.GetInst()->IsConst()) {
815         newCnst = static_cast<ConstantInst *>(input1.GetInst());
816     }
817     if (newCnst != nullptr && newCnst->GetIntValue() == static_cast<uint64_t>(-1)) {
818         inst->ReplaceUsers(newCnst);
819         return true;
820     }
821     return false;
822 }
823 
ConstFoldingXor(Inst * inst)824 bool ConstFoldingXor(Inst *inst)
825 {
826     ASSERT(inst->GetOpcode() == Opcode::Xor);
827     auto input0 = inst->GetInput(0).GetInst();
828     auto input1 = inst->GetInput(1).GetInst();
829     ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
830     if (input0->IsConst() && input1->IsConst()) {
831         auto cnst0 = static_cast<ConstantInst *>(input0);
832         auto cnst1 = static_cast<ConstantInst *>(input1);
833         ConstantInst *newCnst = nullptr;
834         newCnst = ConstFoldingCreateIntConst(
835             inst, ConvertIntToInt(cnst0->GetIntValue() ^ cnst1->GetIntValue(), inst->GetType()));
836         inst->ReplaceUsers(newCnst);
837         return true;
838     }
839     // A xor A = 0
840     if (input0 == input1) {
841         inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 0));
842         return true;
843     }
844     return false;
845 }
846 
847 template <class T>
GetResult(T l,T r,const CmpInst * cmp)848 int64_t GetResult(T l, T r, [[maybe_unused]] const CmpInst *cmp)
849 {
850     // NOLINTNEXTLINE(readability-braces-around-statements, bugprone-suspicious-semicolon)
851     if constexpr (std::is_same<T, float>() || std::is_same<T, double>()) {
852         ASSERT(DataType::IsFloatType(cmp->GetInputType(0)));
853         if (std::isnan(l) || std::isnan(r)) {
854             if (cmp->IsFcmpg()) {
855                 return 1;
856             }
857             return -1;
858         }
859     }
860     if (l > r) {
861         return 1;
862     }
863     if (l < r) {
864         return -1;
865     }
866     return 0;
867 }
868 
GetIntResult(ConstantInst * cnst0,ConstantInst * cnst1,DataType::Type inputType,const CmpInst * cmp)869 int64_t GetIntResult(ConstantInst *cnst0, ConstantInst *cnst1, DataType::Type inputType, const CmpInst *cmp)
870 {
871     auto l = ConvertIntToInt(cnst0->GetIntValue(), inputType);
872     auto r = ConvertIntToInt(cnst1->GetIntValue(), inputType);
873     auto graph = cnst0->GetBasicBlock()->GetGraph();
874     if (DataType::IsTypeSigned(inputType)) {
875         if (graph->IsBytecodeOptimizer() && IsInt32Bit(inputType)) {
876             return GetResult(static_cast<int32_t>(l), static_cast<int32_t>(r), cmp);
877         }
878         return GetResult(static_cast<int64_t>(l), static_cast<int64_t>(r), cmp);
879     }
880     if (graph->IsBytecodeOptimizer() && IsInt32Bit(inputType)) {
881         return GetResult(static_cast<uint32_t>(l), static_cast<uint32_t>(r), cmp);
882     }
883     return GetResult(l, r, cmp);
884 }
885 
ConstFoldingCmpFloatNan(Inst * inst)886 bool ConstFoldingCmpFloatNan(Inst *inst)
887 {
888     ASSERT(inst->GetOpcode() == Opcode::Cmp);
889     auto input0 = inst->GetInput(0).GetInst();
890     auto input1 = inst->GetInput(1).GetInst();
891     if (!input0->IsConst() && !input1->IsConst()) {
892         return false;
893     }
894 
895     if (!DataType::IsFloatType(inst->CastToCmp()->GetOperandsType())) {
896         return false;
897     }
898 
899     // One of the constant always will be in input1
900     if (input0->IsConst()) {
901         std::swap(input0, input1);
902     }
903 
904     // For Float constant is applied only NaN cases
905     if (!input1->CastToConstant()->IsNaNConst()) {
906         return false;
907     }
908     // Result related with Fcmpg as wrote in spec
909     int64_t res {-1};
910     if (inst->CastToCmp()->IsFcmpg()) {
911         res = 1;
912     }
913     inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, res));
914     return true;
915 }
916 
ConstFoldingCmp(Inst * inst)917 bool ConstFoldingCmp(Inst *inst)
918 {
919     ASSERT(inst->GetOpcode() == Opcode::Cmp);
920     auto input0 = inst->GetInput(0).GetInst();
921     auto input1 = inst->GetInput(1).GetInst();
922     auto cmp = inst->CastToCmp();
923     auto inputType = cmp->GetInputType(0);
924     if (ConstFoldingCmpFloatNan(inst)) {
925         return true;
926     }
927     if (input0->IsConst() && input1->IsConst()) {
928         auto cnst0 = static_cast<ConstantInst *>(input0);
929         auto cnst1 = static_cast<ConstantInst *>(input1);
930         int64_t result = 0;
931         switch (DataType::GetCommonType(inputType)) {
932             case DataType::INT64: {
933                 result = GetIntResult(cnst0, cnst1, inputType, cmp);
934                 break;
935             }
936             case DataType::FLOAT32:
937                 result = GetResult(cnst0->GetFloatValue(), cnst1->GetFloatValue(), cmp);
938                 break;
939             case DataType::FLOAT64:
940                 result = GetResult(cnst0->GetDoubleValue(), cnst1->GetDoubleValue(), cmp);
941                 break;
942             default:
943                 break;
944         }
945         auto newCnst = inst->GetBasicBlock()->GetGraph()->FindOrCreateConstant(result);
946         inst->ReplaceUsers(newCnst);
947         return true;
948     }
949     if (input0 == input1 && DataType::GetCommonType(inputType) == DataType::INT64) {
950         // for floating point values result may be non-zero if x is NaN
951         inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 0));
952         return true;
953     }
954     return false;
955 }
956 
ConstFoldingCompareCreateConst(Inst * inst,bool value)957 ConstantInst *ConstFoldingCompareCreateConst(Inst *inst, bool value)
958 {
959     auto graph = inst->GetBasicBlock()->GetGraph();
960     if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
961         return graph->FindOrCreateConstant(static_cast<uint32_t>(value));
962     }
963     return graph->FindOrCreateConstant(static_cast<uint64_t>(value));
964 }
965 
ConstFoldingCompareCreateNewConst(Inst * inst,uint64_t cnstVal0,uint64_t cnstVal1)966 ConstantInst *ConstFoldingCompareCreateNewConst(Inst *inst, uint64_t cnstVal0, uint64_t cnstVal1)
967 {
968     switch (inst->CastToCompare()->GetCc()) {
969         case ConditionCode::CC_EQ:
970             return ConstFoldingCompareCreateConst(inst, (cnstVal0 == cnstVal1));
971         case ConditionCode::CC_NE:
972             return ConstFoldingCompareCreateConst(inst, (cnstVal0 != cnstVal1));
973         case ConditionCode::CC_LT:
974             return ConstFoldingCompareCreateConst(inst,
975                                                   (static_cast<int64_t>(cnstVal0) < static_cast<int64_t>(cnstVal1)));
976         case ConditionCode::CC_B:
977             return ConstFoldingCompareCreateConst(inst, (cnstVal0 < cnstVal1));
978         case ConditionCode::CC_LE:
979             return ConstFoldingCompareCreateConst(inst,
980                                                   (static_cast<int64_t>(cnstVal0) <= static_cast<int64_t>(cnstVal1)));
981         case ConditionCode::CC_BE:
982             return ConstFoldingCompareCreateConst(inst, (cnstVal0 <= cnstVal1));
983         case ConditionCode::CC_GT:
984             return ConstFoldingCompareCreateConst(inst,
985                                                   (static_cast<int64_t>(cnstVal0) > static_cast<int64_t>(cnstVal1)));
986         case ConditionCode::CC_A:
987             return ConstFoldingCompareCreateConst(inst, (cnstVal0 > cnstVal1));
988         case ConditionCode::CC_GE:
989             return ConstFoldingCompareCreateConst(inst,
990                                                   (static_cast<int64_t>(cnstVal0) >= static_cast<int64_t>(cnstVal1)));
991         case ConditionCode::CC_AE:
992             return ConstFoldingCompareCreateConst(inst, (cnstVal0 >= cnstVal1));
993         case ConditionCode::CC_TST_EQ:
994             return ConstFoldingCompareCreateConst(inst, ((cnstVal0 & cnstVal1) == 0));
995         case ConditionCode::CC_TST_NE:
996             return ConstFoldingCompareCreateConst(inst, ((cnstVal0 & cnstVal1) != 0));
997         default:
998             UNREACHABLE();
999     }
1000 }
1001 
ConstFoldingCompareEqualInputs(Inst * inst,Inst * input0,Inst * input1)1002 bool ConstFoldingCompareEqualInputs(Inst *inst, Inst *input0, Inst *input1)
1003 {
1004     if (input0 != input1) {
1005         return false;
1006     }
1007     auto cmpInst = inst->CastToCompare();
1008     auto commonType = DataType::GetCommonType(input0->GetType());
1009     switch (cmpInst->GetCc()) {
1010         case ConditionCode::CC_EQ:
1011         case ConditionCode::CC_LE:
1012         case ConditionCode::CC_GE:
1013         case ConditionCode::CC_BE:
1014         case ConditionCode::CC_AE:
1015             // for floating point values result may be non-zero if x is NaN
1016             if (commonType == DataType::INT64 || commonType == DataType::POINTER) {
1017                 inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 1));
1018                 return true;
1019             }
1020             break;
1021         case ConditionCode::CC_NE:
1022             if (commonType == DataType::INT64 || commonType == DataType::POINTER) {
1023                 inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 0));
1024                 return true;
1025             }
1026             break;
1027         case ConditionCode::CC_LT:
1028         case ConditionCode::CC_GT:
1029         case ConditionCode::CC_B:
1030         case ConditionCode::CC_A:
1031             // x<x is false even for x=NaN
1032             inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 0));
1033             return true;
1034         default:
1035             return false;
1036     }
1037     return false;
1038 }
1039 
IsUniqueRef(Inst * inst)1040 static bool IsUniqueRef(Inst *inst)
1041 {
1042     return inst->IsAllocation() || inst->GetOpcode() == Opcode::NullPtr ||
1043            inst->GetOpcode() == Opcode::LoadUniqueObject;
1044 }
1045 
ConstFoldingCompareFloatNan(Inst * inst)1046 bool ConstFoldingCompareFloatNan(Inst *inst)
1047 {
1048     ASSERT(DataType::IsFloatType(inst->CastToCompare()->GetOperandsType()));
1049     auto input0 = inst->GetInput(0).GetInst();
1050     auto input1 = inst->GetInput(1).GetInst();
1051     if (!input0->IsConst() && !input1->IsConst()) {
1052         return false;
1053     }
1054 
1055     // One of the constant always will be in input1
1056     if (input0->IsConst()) {
1057         std::swap(input0, input1);
1058     }
1059 
1060     // For Float constant is applied only NaN cases
1061     if (!input1->CastToConstant()->IsNaNConst()) {
1062         return false;
1063     }
1064 
1065     // If both operands is NaN constant - it is OK, all optimization will be applied anyway
1066     bool resultConst {};
1067     // We shouldn't reverse ConditionCode, because the results is not related to order of inputs
1068     switch (inst->CastToCompare()->GetCc()) {
1069         case ConditionCode::CC_NE:
1070             // NaN != number is true
1071             resultConst = true;
1072             break;
1073         case ConditionCode::CC_EQ:  // ==
1074         case ConditionCode::CC_LT:  // <
1075         case ConditionCode::CC_LE:  // <=
1076         case ConditionCode::CC_GT:  // >
1077         case ConditionCode::CC_GE:  // >=
1078             // All these CC with NaN give false
1079             resultConst = false;
1080             break;
1081         default:
1082             UNREACHABLE();
1083     }
1084     inst->ReplaceUsers(ConstFoldingCompareCreateConst(inst, resultConst));
1085     return true;
1086 }
1087 
ConstFoldingCompareIntConstant(Inst * inst)1088 bool ConstFoldingCompareIntConstant(Inst *inst)
1089 {
1090     ASSERT(!DataType::IsFloatType(inst->CastToCompare()->GetOperandsType()));
1091     auto input0 = inst->GetInput(0).GetInst();
1092     auto input1 = inst->GetInput(1).GetInst();
1093     if (!input0->IsConst() || !input1->IsConst()) {
1094         return false;
1095     }
1096 
1097     auto cnst0 = input0->CastToConstant();
1098     auto cnst1 = input1->CastToConstant();
1099     ConstantInst *newCnst = nullptr;
1100     auto type = inst->GetInputType(0);
1101     if (DataType::GetCommonType(type) == DataType::INT64) {
1102         uint64_t cnstVal0 = ConvertIntToInt(cnst0->GetIntValue(), type);
1103         uint64_t cnstVal1 = ConvertIntToInt(cnst1->GetIntValue(), type);
1104         newCnst = ConstFoldingCompareCreateNewConst(inst, cnstVal0, cnstVal1);
1105         inst->ReplaceUsers(newCnst);
1106         return true;
1107     }
1108     return false;
1109 }
1110 
ConstFoldingCompare(Inst * inst)1111 bool ConstFoldingCompare(Inst *inst)
1112 {
1113     ASSERT(inst->GetOpcode() == Opcode::Compare);
1114     auto input0 = inst->GetInput(0).GetInst();
1115     auto input1 = inst->GetInput(1).GetInst();
1116 
1117     if (DataType::IsFloatType(inst->CastToCompare()->GetOperandsType())) {
1118         if (ConstFoldingCompareFloatNan(inst)) {
1119             return true;
1120         }
1121     } else {
1122         if (ConstFoldingCompareIntConstant(inst)) {
1123             return true;
1124         }
1125     }
1126     if (input0->GetOpcode() == Opcode::LoadImmediate && input1->GetOpcode() == Opcode::LoadImmediate) {
1127         auto class0 = input0->CastToLoadImmediate()->GetObject();
1128         auto class1 = input1->CastToLoadImmediate()->GetObject();
1129         auto cc = inst->CastToCompare()->GetCc();
1130         ASSERT(cc == CC_NE || cc == CC_EQ);
1131         bool res {(class0 == class1) == (cc == CC_EQ)};
1132         inst->ReplaceUsers(ConstFoldingCompareCreateConst(inst, res));
1133         return true;
1134     }
1135     if (IsZeroConstantOrNullPtr(input0) && IsZeroConstantOrNullPtr(input1)) {
1136         auto cc = inst->CastToCompare()->GetCc();
1137         ASSERT(cc == CC_NE || cc == CC_EQ);
1138         inst->ReplaceUsers(ConstFoldingCompareCreateConst(inst, cc == CC_EQ));
1139         return true;
1140     }
1141     if (ConstFoldingCompareEqualInputs(inst, input0, input1)) {
1142         return true;
1143     }
1144     if (inst->GetInputType(0) == DataType::REFERENCE) {
1145         ASSERT(input0 != input1);
1146         if (IsUniqueRef(input0) && IsUniqueRef(input1)) {
1147             auto cc = inst->CastToCompare()->GetCc();
1148             inst->ReplaceUsers(ConstFoldingCompareCreateConst(inst, cc == CC_NE));
1149             return true;
1150         }
1151     }
1152     return false;
1153 }
1154 
ConstFoldingSqrt(Inst * inst)1155 bool ConstFoldingSqrt(Inst *inst)
1156 {
1157     ASSERT(inst->GetOpcode() == Opcode::Sqrt);
1158     auto input = inst->GetInput(0).GetInst();
1159     if (input->IsConst()) {
1160         auto cnst = input->CastToConstant();
1161         Inst *newCnst = nullptr;
1162         if (cnst->GetType() == DataType::FLOAT32) {
1163             newCnst = inst->GetBasicBlock()->GetGraph()->FindOrCreateConstant(std::sqrt(cnst->GetFloatValue()));
1164         } else {
1165             ASSERT(cnst->GetType() == DataType::FLOAT64);
1166             newCnst = inst->GetBasicBlock()->GetGraph()->FindOrCreateConstant(std::sqrt(cnst->GetDoubleValue()));
1167         }
1168         inst->ReplaceUsers(newCnst);
1169         return true;
1170     }
1171     return false;
1172 }
1173 
ConstFoldingLoadStatic(Inst * inst)1174 bool ConstFoldingLoadStatic(Inst *inst)
1175 {
1176     auto field = inst->CastToLoadStatic()->GetObjField();
1177     ASSERT(field != nullptr);
1178     auto isFloatType = DataType::IsFloatType(inst->GetType());
1179     if (DataType::GetCommonType(inst->GetType()) != DataType::INT64 && !isFloatType) {
1180         return false;
1181     }
1182     auto graph = inst->GetBasicBlock()->GetGraph();
1183     auto runtime = graph->GetRuntime();
1184     auto klass = runtime->GetClassForField(field);
1185     if (runtime->IsFieldReadonly(field) && runtime->IsClassInitialized(reinterpret_cast<uintptr_t>(klass))) {
1186         if (isFloatType) {
1187             auto value = runtime->GetStaticFieldFloatValue(field);
1188             if (inst->GetType() == DataType::FLOAT32) {
1189                 inst->ReplaceUsers(graph->FindOrCreateConstant(static_cast<float>(value)));
1190             } else {
1191                 ASSERT(inst->GetType() == DataType::FLOAT64);
1192                 inst->ReplaceUsers(graph->FindOrCreateConstant(value));
1193             }
1194         } else {
1195             auto value = runtime->GetStaticFieldIntegerValue(field);
1196             inst->ReplaceUsers(graph->FindOrCreateConstant(ConvertIntToInt(value, inst->GetType())));
1197         }
1198         return true;
1199     }
1200     return false;
1201 }
1202 
1203 }  // namespace ark::compiler
1204