1 /*
2 * Copyright (c) 2021-2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "optimizer/ir/basicblock.h"
17 #include "optimizer/ir/graph.h"
18 #include "optimizer/ir/datatype.h"
19 #include "optimizer/optimizations/const_folding.h"
20 #include "utils/math_helpers.h"
21
22 #include <cmath>
23 #include <map>
24
25 namespace ark::compiler {
26 template <class T>
ConvertIntToInt(T value,DataType::Type targetType)27 uint64_t ConvertIntToInt(T value, DataType::Type targetType)
28 {
29 switch (targetType) {
30 case DataType::BOOL:
31 return static_cast<uint64_t>(static_cast<bool>(value));
32 case DataType::UINT8:
33 return static_cast<uint64_t>(static_cast<uint8_t>(value));
34 case DataType::INT8:
35 return static_cast<uint64_t>(static_cast<int8_t>(value));
36 case DataType::UINT16:
37 return static_cast<uint64_t>(static_cast<uint16_t>(value));
38 case DataType::INT16:
39 return static_cast<uint64_t>(static_cast<int16_t>(value));
40 case DataType::UINT32:
41 return static_cast<uint64_t>(static_cast<uint32_t>(value));
42 case DataType::INT32:
43 return static_cast<uint64_t>(static_cast<int32_t>(value));
44 case DataType::UINT64:
45 return static_cast<uint64_t>(value);
46 case DataType::INT64:
47 return static_cast<uint64_t>(static_cast<int64_t>(value));
48 default:
49 UNREACHABLE();
50 }
51 }
52
53 template <class T>
ConvertIntToFloat(uint64_t value,DataType::Type sourceType)54 T ConvertIntToFloat(uint64_t value, DataType::Type sourceType)
55 {
56 switch (sourceType) {
57 case DataType::BOOL:
58 return static_cast<T>(static_cast<bool>(value));
59 case DataType::UINT8:
60 return static_cast<T>(static_cast<uint8_t>(value));
61 case DataType::INT8:
62 return static_cast<T>(static_cast<int8_t>(value));
63 case DataType::UINT16:
64 return static_cast<T>(static_cast<uint16_t>(value));
65 case DataType::INT16:
66 return static_cast<T>(static_cast<int16_t>(value));
67 case DataType::UINT32:
68 return static_cast<T>(static_cast<uint32_t>(value));
69 case DataType::INT32:
70 return static_cast<T>(static_cast<int32_t>(value));
71 case DataType::UINT64:
72 return static_cast<T>(value);
73 case DataType::INT64:
74 return static_cast<T>(static_cast<int64_t>(value));
75 default:
76 UNREACHABLE();
77 }
78 }
79
80 template <class To, class From>
ConvertFloatToInt(From value)81 To ConvertFloatToInt(From value)
82 {
83 To res;
84
85 constexpr To MIN_INT = std::numeric_limits<To>::min();
86 constexpr To MAX_INT = std::numeric_limits<To>::max();
87 const auto floatMinInt = static_cast<From>(MIN_INT);
88 const auto floatMaxInt = static_cast<From>(MAX_INT);
89
90 if (value > floatMinInt) {
91 if (value < floatMaxInt) {
92 res = static_cast<To>(value);
93 } else {
94 res = MAX_INT;
95 }
96 } else if (std::isnan(value)) {
97 res = 0;
98 } else {
99 res = MIN_INT;
100 }
101
102 return static_cast<To>(res);
103 }
104
105 template <class From>
ConvertFloatToInt(From value,DataType::Type targetType)106 uint64_t ConvertFloatToInt(From value, DataType::Type targetType)
107 {
108 ASSERT(DataType::GetCommonType(targetType) == DataType::INT64);
109 switch (targetType) {
110 case DataType::BOOL:
111 return static_cast<uint64_t>(ConvertFloatToInt<bool>(value));
112 case DataType::UINT8:
113 return static_cast<uint64_t>(ConvertFloatToInt<uint8_t>(value));
114 case DataType::INT8:
115 return static_cast<uint64_t>(ConvertFloatToInt<int8_t>(value));
116 case DataType::UINT16:
117 return static_cast<uint64_t>(ConvertFloatToInt<uint16_t>(value));
118 case DataType::INT16:
119 return static_cast<uint64_t>(ConvertFloatToInt<int16_t>(value));
120 case DataType::UINT32:
121 return static_cast<uint64_t>(ConvertFloatToInt<uint32_t>(value));
122 case DataType::INT32:
123 return static_cast<uint64_t>(ConvertFloatToInt<int32_t>(value));
124 case DataType::UINT64:
125 return ConvertFloatToInt<uint64_t>(value);
126 case DataType::INT64:
127 return static_cast<uint64_t>(ConvertFloatToInt<int64_t>(value));
128 default:
129 UNREACHABLE();
130 }
131 }
132
133 template <class From>
ConvertFloatToIntDyn(From value,RuntimeInterface * runtime,size_t bits)134 uint64_t ConvertFloatToIntDyn(From value, RuntimeInterface *runtime, size_t bits)
135 {
136 return runtime->DynamicCastDoubleToInt(static_cast<double>(value), bits);
137 }
138
ConstFoldingCreateIntConst(Inst * inst,uint64_t value,bool isLiteralData)139 ConstantInst *ConstFoldingCreateIntConst(Inst *inst, uint64_t value, bool isLiteralData)
140 {
141 auto graph = inst->GetBasicBlock()->GetGraph();
142 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType()) && !isLiteralData) {
143 return graph->FindOrCreateConstant<uint32_t>(value);
144 }
145 return graph->FindOrCreateConstant(value);
146 }
147
148 template <typename T>
ConstFoldingCreateConst(Inst * inst,ConstantInst * cnst,bool isLiteralData=false)149 ConstantInst *ConstFoldingCreateConst(Inst *inst, ConstantInst *cnst, bool isLiteralData = false)
150 {
151 return ConstFoldingCreateIntConst(inst, ConvertIntToInt(static_cast<T>(cnst->GetIntValue()), inst->GetType()),
152 isLiteralData);
153 }
154
ConstFoldingCastInt2Int(Inst * inst,ConstantInst * cnst)155 ConstantInst *ConstFoldingCastInt2Int(Inst *inst, ConstantInst *cnst)
156 {
157 switch (inst->GetInputType(0)) {
158 case DataType::BOOL:
159 return ConstFoldingCreateConst<bool>(inst, cnst);
160 case DataType::UINT8:
161 return ConstFoldingCreateConst<uint8_t>(inst, cnst);
162 case DataType::INT8:
163 return ConstFoldingCreateConst<int8_t>(inst, cnst);
164 case DataType::UINT16:
165 return ConstFoldingCreateConst<uint16_t>(inst, cnst);
166 case DataType::INT16:
167 return ConstFoldingCreateConst<int16_t>(inst, cnst);
168 case DataType::UINT32:
169 return ConstFoldingCreateConst<uint32_t>(inst, cnst);
170 case DataType::INT32:
171 return ConstFoldingCreateConst<int32_t>(inst, cnst);
172 case DataType::UINT64:
173 return ConstFoldingCreateConst<uint64_t>(inst, cnst);
174 case DataType::INT64:
175 return ConstFoldingCreateConst<int64_t>(inst, cnst);
176 default:
177 return nullptr;
178 }
179 }
180
ConstFoldingCastIntConst(Graph * graph,Inst * inst,ConstantInst * cnst,bool isLiteralData=false)181 ConstantInst *ConstFoldingCastIntConst(Graph *graph, Inst *inst, ConstantInst *cnst, bool isLiteralData = false)
182 {
183 auto instType = DataType::GetCommonType(inst->GetType());
184 if (instType == DataType::INT64) {
185 // INT -> INT
186 return ConstFoldingCastInt2Int(inst, cnst);
187 }
188 if (instType == DataType::FLOAT32) {
189 // INT -> FLOAT
190 if (graph->IsBytecodeOptimizer() && !isLiteralData) {
191 return nullptr;
192 }
193 return graph->FindOrCreateConstant(ConvertIntToFloat<float>(cnst->GetIntValue(), inst->GetInputType(0)));
194 }
195 if (instType == DataType::FLOAT64) {
196 // INT -> DOUBLE
197 return graph->FindOrCreateConstant(ConvertIntToFloat<double>(cnst->GetIntValue(), inst->GetInputType(0)));
198 }
199 return nullptr;
200 }
201
ConstFoldingCastConst(Inst * inst,Inst * input,bool isLiteralData)202 ConstantInst *ConstFoldingCastConst(Inst *inst, Inst *input, bool isLiteralData)
203 {
204 auto graph = inst->GetBasicBlock()->GetGraph();
205 auto cnst = static_cast<ConstantInst *>(input);
206 auto instType = DataType::GetCommonType(inst->GetType());
207 if (cnst->GetType() == DataType::INT32 || cnst->GetType() == DataType::INT64) {
208 return ConstFoldingCastIntConst(graph, inst, cnst);
209 }
210 if (cnst->GetType() == DataType::FLOAT32) {
211 if (graph->IsBytecodeOptimizer() && !isLiteralData) {
212 return nullptr;
213 }
214 if (instType == DataType::INT64) {
215 // FLOAT->INT
216 return graph->FindOrCreateConstant(ConvertFloatToInt(cnst->GetFloatValue(), inst->GetType()));
217 }
218 if (instType == DataType::FLOAT32) {
219 // FLOAT -> FLOAT
220 return cnst;
221 }
222 if (instType == DataType::FLOAT64) {
223 // FLOAT -> DOUBLE
224 return graph->FindOrCreateConstant(static_cast<double>(cnst->GetFloatValue()));
225 }
226 } else if (cnst->GetType() == DataType::FLOAT64) {
227 if (instType == DataType::INT64) {
228 // DOUBLE->INT/LONG
229 uint64_t val = graph->IsDynamicMethod()
230 ? ConvertFloatToIntDyn(cnst->GetDoubleValue(), graph->GetRuntime(),
231 DataType::GetTypeSize(inst->GetType(), graph->GetArch()))
232 : ConvertFloatToInt(cnst->GetDoubleValue(), inst->GetType());
233 return ConstFoldingCreateIntConst(inst, val, isLiteralData);
234 }
235 if (instType == DataType::FLOAT32) {
236 // DOUBLE -> FLOAT
237 if (graph->IsBytecodeOptimizer() && !isLiteralData) {
238 return nullptr;
239 }
240 return graph->FindOrCreateConstant(static_cast<float>(cnst->GetDoubleValue()));
241 }
242 if (instType == DataType::FLOAT64) {
243 // DOUBLE -> DOUBLE
244 return cnst;
245 }
246 }
247 return nullptr;
248 }
249
ConstFoldingCast(Inst * inst)250 bool ConstFoldingCast(Inst *inst)
251 {
252 ASSERT(inst->GetOpcode() == Opcode::Cast);
253 auto input = inst->GetInput(0).GetInst();
254 if (input->IsConst()) {
255 ConstantInst *nwCnst = ConstFoldingCastConst(inst, input);
256 if (nwCnst != nullptr) {
257 inst->ReplaceUsers(nwCnst);
258 return true;
259 }
260 }
261 return false;
262 }
263
ConstFoldingNeg(Inst * inst)264 bool ConstFoldingNeg(Inst *inst)
265 {
266 ASSERT(inst->GetOpcode() == Opcode::Neg);
267 auto input = inst->GetInput(0);
268 auto graph = inst->GetBasicBlock()->GetGraph();
269 if (input.GetInst()->IsConst()) {
270 auto cnst = static_cast<ConstantInst *>(input.GetInst());
271 ConstantInst *newCnst = nullptr;
272 switch (DataType::GetCommonType(inst->GetType())) {
273 case DataType::INT64:
274 newCnst = ConstFoldingCreateIntConst(inst, ConvertIntToInt(-cnst->GetIntValue(), inst->GetType()));
275 break;
276 case DataType::FLOAT32:
277 newCnst = graph->FindOrCreateConstant(-cnst->GetFloatValue());
278 break;
279 case DataType::FLOAT64:
280 newCnst = graph->FindOrCreateConstant(-cnst->GetDoubleValue());
281 break;
282 default:
283 UNREACHABLE();
284 }
285 inst->ReplaceUsers(newCnst);
286 return true;
287 }
288 return false;
289 }
290
ConstFoldingAbs(Inst * inst)291 bool ConstFoldingAbs(Inst *inst)
292 {
293 ASSERT(inst->GetOpcode() == Opcode::Abs);
294 auto input = inst->GetInput(0);
295 auto graph = inst->GetBasicBlock()->GetGraph();
296 if (input.GetInst()->IsConst()) {
297 auto cnst = static_cast<ConstantInst *>(input.GetInst());
298 ConstantInst *newCnst = nullptr;
299 switch (DataType::GetCommonType(inst->GetType())) {
300 case DataType::INT64: {
301 ASSERT(DataType::IsTypeSigned(inst->GetType()));
302 auto value = static_cast<int64_t>(cnst->GetIntValue());
303 if (value == INT64_MIN) {
304 newCnst = cnst;
305 break;
306 }
307 auto uvalue = static_cast<uint64_t>((value < 0) ? -value : value);
308 newCnst = ConstFoldingCreateIntConst(inst, ConvertIntToInt(uvalue, inst->GetType()));
309 break;
310 }
311 case DataType::FLOAT32:
312 newCnst = graph->FindOrCreateConstant(std::abs(cnst->GetFloatValue()));
313 break;
314 case DataType::FLOAT64:
315 newCnst = graph->FindOrCreateConstant(std::abs(cnst->GetDoubleValue()));
316 break;
317 default:
318 UNREACHABLE();
319 }
320 inst->ReplaceUsers(newCnst);
321 return true;
322 }
323 return false;
324 }
325
ConstFoldingNot(Inst * inst)326 bool ConstFoldingNot(Inst *inst)
327 {
328 ASSERT(inst->GetOpcode() == Opcode::Not);
329 auto input = inst->GetInput(0);
330 ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
331 if (input.GetInst()->IsConst()) {
332 auto cnst = static_cast<ConstantInst *>(input.GetInst());
333 auto newCnst = ConstFoldingCreateIntConst(inst, ConvertIntToInt(~cnst->GetIntValue(), inst->GetType()));
334 inst->ReplaceUsers(newCnst);
335 return true;
336 }
337 return false;
338 }
339
ConstFoldingAdd(Inst * inst)340 bool ConstFoldingAdd(Inst *inst)
341 {
342 ASSERT(inst->GetOpcode() == Opcode::Add);
343 auto input0 = inst->GetInput(0);
344 auto input1 = inst->GetInput(1);
345 auto graph = inst->GetBasicBlock()->GetGraph();
346 if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
347 auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
348 auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
349 ConstantInst *newCnst = nullptr;
350 switch (DataType::GetCommonType(inst->GetType())) {
351 case DataType::INT64:
352 newCnst = ConstFoldingCreateIntConst(
353 inst, ConvertIntToInt(cnst0->GetIntValue() + cnst1->GetIntValue(), inst->GetType()));
354 break;
355 case DataType::FLOAT32:
356 newCnst = graph->FindOrCreateConstant(cnst0->GetFloatValue() + cnst1->GetFloatValue());
357 break;
358 case DataType::FLOAT64:
359 newCnst = graph->FindOrCreateConstant(cnst0->GetDoubleValue() + cnst1->GetDoubleValue());
360 break;
361 default:
362 UNREACHABLE();
363 }
364 inst->ReplaceUsers(newCnst);
365 return true;
366 }
367 return ConstFoldingBinaryMathWithNan(inst);
368 }
369
ConstFoldingSub(Inst * inst)370 bool ConstFoldingSub(Inst *inst)
371 {
372 ASSERT(inst->GetOpcode() == Opcode::Sub);
373 auto input0 = inst->GetInput(0);
374 auto input1 = inst->GetInput(1);
375 auto graph = inst->GetBasicBlock()->GetGraph();
376 if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
377 ConstantInst *newCnst = nullptr;
378 auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
379 auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
380 switch (DataType::GetCommonType(inst->GetType())) {
381 case DataType::INT64:
382 newCnst = ConstFoldingCreateIntConst(
383 inst, ConvertIntToInt(cnst0->GetIntValue() - cnst1->GetIntValue(), inst->GetType()));
384 break;
385 case DataType::FLOAT32:
386 newCnst = graph->FindOrCreateConstant(cnst0->GetFloatValue() - cnst1->GetFloatValue());
387 break;
388 case DataType::FLOAT64:
389 newCnst = graph->FindOrCreateConstant(cnst0->GetDoubleValue() - cnst1->GetDoubleValue());
390 break;
391 default:
392 UNREACHABLE();
393 }
394 inst->ReplaceUsers(newCnst);
395 return true;
396 }
397 if (input0.GetInst() == input1.GetInst() && DataType::GetCommonType(inst->GetType()) == DataType::INT64) {
398 // for floating point values 'x-x -> 0' optimization is not applicable because of NaN/Infinity values
399 auto newCnst = ConstFoldingCreateIntConst(inst, 0);
400 inst->ReplaceUsers(newCnst);
401 return true;
402 }
403 return ConstFoldingBinaryMathWithNan(inst);
404 }
405
ConstFoldingMul(Inst * inst)406 bool ConstFoldingMul(Inst *inst)
407 {
408 ASSERT(inst->GetOpcode() == Opcode::Mul);
409 auto input0 = inst->GetInput(0).GetInst();
410 auto input1 = inst->GetInput(1).GetInst();
411 auto graph = inst->GetBasicBlock()->GetGraph();
412 ConstantInst *newCnst = nullptr;
413 if (input0->IsConst() && input1->IsConst()) {
414 auto cnst0 = static_cast<ConstantInst *>(input0);
415 auto cnst1 = static_cast<ConstantInst *>(input1);
416 switch (DataType::GetCommonType(inst->GetType())) {
417 case DataType::INT64:
418 newCnst = ConstFoldingCreateIntConst(
419 inst, ConvertIntToInt(cnst0->GetIntValue() * cnst1->GetIntValue(), inst->GetType()));
420 break;
421 case DataType::FLOAT32:
422 newCnst = graph->FindOrCreateConstant(cnst0->GetFloatValue() * cnst1->GetFloatValue());
423 break;
424 case DataType::FLOAT64:
425 newCnst = graph->FindOrCreateConstant(cnst0->GetDoubleValue() * cnst1->GetDoubleValue());
426 break;
427 default:
428 UNREACHABLE();
429 }
430 inst->ReplaceUsers(newCnst);
431 return true;
432 }
433 if (ConstFoldingBinaryMathWithNan(inst)) {
434 return true;
435 }
436 // Const is always in input1
437 if (input0->IsConst()) {
438 std::swap(input0, input1);
439 }
440 if (input1->IsConst() && input1->CastToConstant()->IsEqualConst(0, graph->IsBytecodeOptimizer())) {
441 inst->ReplaceUsers(input1);
442 return true;
443 }
444 return false;
445 }
446
ConstFoldingBinaryMathWithNan(Inst * inst)447 bool ConstFoldingBinaryMathWithNan(Inst *inst)
448 {
449 ASSERT(inst->GetInputsCount() == 2U);
450 auto input0 = inst->GetInput(0).GetInst();
451 auto input1 = inst->GetInput(1).GetInst();
452 ASSERT(!input0->IsConst() || !input1->IsConst());
453 if (!DataType::IsFloatType(inst->GetType())) {
454 return false;
455 }
456 if (input0->IsConst()) {
457 std::swap(input0, input1);
458 }
459 if (!input1->IsConst()) {
460 return false;
461 }
462 if (!input1->CastToConstant()->IsNaNConst()) {
463 return false;
464 }
465 inst->ReplaceUsers(input1);
466 return true;
467 }
468
ConstFoldingDivInt2Int(Inst * inst,Graph * graph,ConstantInst * cnst0,ConstantInst * cnst1)469 ConstantInst *ConstFoldingDivInt2Int(Inst *inst, Graph *graph, ConstantInst *cnst0, ConstantInst *cnst1)
470 {
471 if (cnst1->GetIntValue() == 0) {
472 return nullptr;
473 }
474 if (DataType::IsTypeSigned(inst->GetType())) {
475 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
476 if (static_cast<int32_t>(cnst0->GetIntValue()) == INT32_MIN &&
477 static_cast<int32_t>(cnst1->GetIntValue()) == -1) {
478 return graph->FindOrCreateConstant<uint32_t>(INT32_MIN);
479 }
480 return graph->FindOrCreateConstant<uint32_t>(
481 ConvertIntToInt(static_cast<int32_t>(cnst0->GetIntValue()) / static_cast<int32_t>(cnst1->GetIntValue()),
482 inst->GetType()));
483 }
484 if (static_cast<int64_t>(cnst0->GetIntValue()) == INT64_MIN &&
485 static_cast<int64_t>(cnst1->GetIntValue()) == -1) {
486 return graph->FindOrCreateConstant<uint64_t>(INT64_MIN);
487 }
488 return graph->FindOrCreateConstant(ConvertIntToInt(
489 static_cast<int64_t>(cnst0->GetIntValue()) / static_cast<int64_t>(cnst1->GetIntValue()), inst->GetType()));
490 }
491
492 return ConstFoldingCreateIntConst(inst, ConvertIntToInt(cnst0->GetIntValue(), inst->GetType()) /
493 ConvertIntToInt(cnst1->GetIntValue(), inst->GetType()));
494 }
495
ConstFoldingDiv(Inst * inst)496 bool ConstFoldingDiv(Inst *inst)
497 {
498 ASSERT(inst->GetOpcode() == Opcode::Div);
499 auto input0 = inst->GetDataFlowInput(0);
500 auto input1 = inst->GetDataFlowInput(1);
501 auto graph = inst->GetBasicBlock()->GetGraph();
502 if (!input0->IsConst() || !input1->IsConst()) {
503 return ConstFoldingBinaryMathWithNan(inst);
504 }
505 auto cnst0 = input0->CastToConstant();
506 auto cnst1 = input1->CastToConstant();
507 ConstantInst *newCnst = nullptr;
508 switch (DataType::GetCommonType(inst->GetType())) {
509 case DataType::INT64:
510 newCnst = ConstFoldingDivInt2Int(inst, graph, cnst0, cnst1);
511 if (newCnst == nullptr) {
512 return false;
513 }
514 break;
515 case DataType::FLOAT32:
516 newCnst = graph->FindOrCreateConstant(cnst0->GetFloatValue() / cnst1->GetFloatValue());
517 break;
518 case DataType::FLOAT64:
519 newCnst = graph->FindOrCreateConstant(cnst0->GetDoubleValue() / cnst1->GetDoubleValue());
520 break;
521 default:
522 UNREACHABLE();
523 }
524 inst->ReplaceUsers(newCnst);
525 return true;
526 }
527
ConstFoldingMinInt(Inst * inst,Graph * graph,ConstantInst * cnst0,ConstantInst * cnst1)528 ConstantInst *ConstFoldingMinInt(Inst *inst, Graph *graph, ConstantInst *cnst0, ConstantInst *cnst1)
529 {
530 if (DataType::IsTypeSigned(inst->GetType())) {
531 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
532 return graph->FindOrCreateConstant<uint32_t>(ConvertIntToInt(
533 std::min(static_cast<int32_t>(cnst0->GetIntValue()), static_cast<int32_t>(cnst1->GetIntValue())),
534 inst->GetType()));
535 }
536 return graph->FindOrCreateConstant(ConvertIntToInt(
537 std::min(static_cast<int64_t>(cnst0->GetIntValue()), static_cast<int64_t>(cnst1->GetIntValue())),
538 inst->GetType()));
539 }
540 return ConstFoldingCreateIntConst(
541 inst, ConvertIntToInt(std::min(cnst0->GetIntValue(), cnst1->GetIntValue()), inst->GetType()));
542 }
543
ConstFoldingMin(Inst * inst)544 bool ConstFoldingMin(Inst *inst)
545 {
546 ASSERT(inst->GetOpcode() == Opcode::Min);
547 auto input0 = inst->GetInput(0);
548 auto input1 = inst->GetInput(1);
549 auto graph = inst->GetBasicBlock()->GetGraph();
550 if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
551 auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
552 auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
553 ConstantInst *newCnst = nullptr;
554 switch (DataType::GetCommonType(inst->GetType())) {
555 case DataType::INT64:
556 newCnst = ConstFoldingMinInt(inst, graph, cnst0, cnst1);
557 ASSERT(newCnst != nullptr);
558 break;
559 case DataType::FLOAT32:
560 newCnst = graph->FindOrCreateConstant(
561 ark::helpers::math::Min(cnst0->GetFloatValue(), cnst1->GetFloatValue()));
562 break;
563 case DataType::FLOAT64:
564 newCnst = graph->FindOrCreateConstant(
565 ark::helpers::math::Min(cnst0->GetDoubleValue(), cnst1->GetDoubleValue()));
566 break;
567 default:
568 UNREACHABLE();
569 }
570 inst->ReplaceUsers(newCnst);
571 return true;
572 }
573 return ConstFoldingBinaryMathWithNan(inst);
574 }
575
ConstFoldingMaxInt(Inst * inst,Graph * graph,ConstantInst * cnst0,ConstantInst * cnst1)576 ConstantInst *ConstFoldingMaxInt(Inst *inst, Graph *graph, ConstantInst *cnst0, ConstantInst *cnst1)
577 {
578 if (DataType::IsTypeSigned(inst->GetType())) {
579 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
580 return graph->FindOrCreateConstant<uint32_t>(ConvertIntToInt(
581 std::max(static_cast<int32_t>(cnst0->GetIntValue()), static_cast<int32_t>(cnst1->GetIntValue())),
582 inst->GetType()));
583 }
584 return graph->FindOrCreateConstant(ConvertIntToInt(
585 std::max(static_cast<int64_t>(cnst0->GetIntValue()), static_cast<int64_t>(cnst1->GetIntValue())),
586 inst->GetType()));
587 }
588 return ConstFoldingCreateIntConst(
589 inst, ConvertIntToInt(std::max(cnst0->GetIntValue(), cnst1->GetIntValue()), inst->GetType()));
590 }
591
ConstFoldingMax(Inst * inst)592 bool ConstFoldingMax(Inst *inst)
593 {
594 ASSERT(inst->GetOpcode() == Opcode::Max);
595 auto input0 = inst->GetInput(0);
596 auto input1 = inst->GetInput(1);
597 auto graph = inst->GetBasicBlock()->GetGraph();
598 if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
599 auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
600 auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
601 ConstantInst *newCnst = nullptr;
602 switch (DataType::GetCommonType(inst->GetType())) {
603 case DataType::INT64:
604 newCnst = ConstFoldingMaxInt(inst, graph, cnst0, cnst1);
605 ASSERT(newCnst != nullptr);
606 break;
607 case DataType::FLOAT32:
608 newCnst = graph->FindOrCreateConstant(
609 ark::helpers::math::Max(cnst0->GetFloatValue(), cnst1->GetFloatValue()));
610 break;
611 case DataType::FLOAT64:
612 newCnst = graph->FindOrCreateConstant(
613 ark::helpers::math::Max(cnst0->GetDoubleValue(), cnst1->GetDoubleValue()));
614 break;
615 default:
616 UNREACHABLE();
617 }
618 inst->ReplaceUsers(newCnst);
619 return true;
620 }
621 return ConstFoldingBinaryMathWithNan(inst);
622 }
623
ConstFoldingModIntConst(Graph * graph,Inst * inst,ConstantInst * cnst0,ConstantInst * cnst1)624 ConstantInst *ConstFoldingModIntConst(Graph *graph, Inst *inst, ConstantInst *cnst0, ConstantInst *cnst1)
625 {
626 if (DataType::IsTypeSigned(inst->GetType())) {
627 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
628 if (static_cast<int32_t>(cnst0->GetIntValue()) == INT32_MIN &&
629 static_cast<int32_t>(cnst1->GetIntValue()) == -1) {
630 return graph->FindOrCreateConstant<uint32_t>(0);
631 }
632 return graph->FindOrCreateConstant<uint32_t>(
633 ConvertIntToInt(static_cast<int32_t>(cnst0->GetIntValue()) % static_cast<int32_t>(cnst1->GetIntValue()),
634 inst->GetType()));
635 }
636 if (static_cast<int64_t>(cnst0->GetIntValue()) == INT64_MIN &&
637 static_cast<int64_t>(cnst1->GetIntValue()) == -1) {
638 return graph->FindOrCreateConstant<uint64_t>(0);
639 }
640 return graph->FindOrCreateConstant(ConvertIntToInt(
641 static_cast<int64_t>(cnst0->GetIntValue()) % static_cast<int64_t>(cnst1->GetIntValue()), inst->GetType()));
642 }
643 return ConstFoldingCreateIntConst(inst, ConvertIntToInt(cnst0->GetIntValue(), inst->GetType()) %
644 ConvertIntToInt(cnst1->GetIntValue(), inst->GetType()));
645 }
646
ConstFoldingMod(Inst * inst)647 bool ConstFoldingMod(Inst *inst)
648 {
649 ASSERT(inst->GetOpcode() == Opcode::Mod);
650 auto input0 = inst->GetDataFlowInput(0);
651 auto input1 = inst->GetDataFlowInput(1);
652 auto graph = inst->GetBasicBlock()->GetGraph();
653 if (input1->IsConst() && !DataType::IsFloatType(inst->GetType()) && input1->CastToConstant()->GetIntValue() == 1) {
654 ConstantInst *cnst = ConstFoldingCreateIntConst(inst, 0);
655 inst->ReplaceUsers(cnst);
656 return true;
657 }
658 if (!input0->IsConst() || !input1->IsConst()) {
659 return ConstFoldingBinaryMathWithNan(inst);
660 }
661 ConstantInst *newCnst = nullptr;
662 auto cnst0 = input0->CastToConstant();
663 auto cnst1 = input1->CastToConstant();
664 if (DataType::GetCommonType(inst->GetType()) == DataType::INT64) {
665 if (cnst1->GetIntValue() == 0) {
666 return false;
667 }
668 newCnst = ConstFoldingModIntConst(graph, inst, cnst0, cnst1);
669 } else if (inst->GetType() == DataType::FLOAT32) {
670 if (cnst1->GetFloatValue() == 0) {
671 return false;
672 }
673 newCnst =
674 graph->FindOrCreateConstant(static_cast<float>(fmodf(cnst0->GetFloatValue(), cnst1->GetFloatValue())));
675 } else if (inst->GetType() == DataType::FLOAT64) {
676 if (cnst1->GetDoubleValue() == 0) {
677 return false;
678 }
679 newCnst = graph->FindOrCreateConstant(fmod(cnst0->GetDoubleValue(), cnst1->GetDoubleValue()));
680 }
681 inst->ReplaceUsers(newCnst);
682 return true;
683 }
684
ConstFoldingShl(Inst * inst)685 bool ConstFoldingShl(Inst *inst)
686 {
687 ASSERT(inst->GetOpcode() == Opcode::Shl);
688 auto input0 = inst->GetInput(0);
689 auto input1 = inst->GetInput(1);
690 auto graph = inst->GetBasicBlock()->GetGraph();
691 ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
692 if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
693 auto cnst0 = input0.GetInst()->CastToConstant()->GetIntValue();
694 auto cnst1 = input1.GetInst()->CastToConstant()->GetIntValue();
695 ConstantInst *newCnst = nullptr;
696 uint64_t sizeMask = DataType::GetTypeSize(inst->GetType(), graph->GetArch()) - 1;
697 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
698 newCnst = graph->FindOrCreateConstant<uint32_t>(ConvertIntToInt(
699 static_cast<uint32_t>(cnst0) << (static_cast<uint32_t>(cnst1) & static_cast<uint32_t>(sizeMask)),
700 inst->GetType()));
701 } else {
702 newCnst = graph->FindOrCreateConstant(ConvertIntToInt(cnst0 << (cnst1 & sizeMask), inst->GetType()));
703 }
704 inst->ReplaceUsers(newCnst);
705 return true;
706 }
707 return false;
708 }
709
ConstFoldingShr(Inst * inst)710 bool ConstFoldingShr(Inst *inst)
711 {
712 ASSERT(inst->GetOpcode() == Opcode::Shr);
713 auto input0 = inst->GetInput(0);
714 auto input1 = inst->GetInput(1);
715 auto graph = inst->GetBasicBlock()->GetGraph();
716 ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
717 if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
718 auto cnst0 = input0.GetInst()->CastToConstant()->GetIntValue();
719 auto cnst1 = input1.GetInst()->CastToConstant()->GetIntValue();
720 uint64_t sizeMask = DataType::GetTypeSize(inst->GetType(), graph->GetArch()) - 1;
721 // zerod high part of the constant
722 if (sizeMask < DataType::GetTypeSize(DataType::INT32, graph->GetArch())) {
723 // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
724 uint64_t typeMask = (1ULL << (sizeMask + 1)) - 1;
725 cnst0 = cnst0 & typeMask;
726 }
727 ConstantInst *newCnst = nullptr;
728 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
729 newCnst = graph->FindOrCreateConstant<uint32_t>(ConvertIntToInt(
730 static_cast<uint32_t>(cnst0) >> (static_cast<uint32_t>(cnst1) & static_cast<uint32_t>(sizeMask)),
731 inst->GetType()));
732 } else {
733 newCnst = graph->FindOrCreateConstant(ConvertIntToInt(cnst0 >> (cnst1 & sizeMask), inst->GetType()));
734 }
735 inst->ReplaceUsers(newCnst);
736 return true;
737 }
738 return false;
739 }
740
ConstFoldingAShr(Inst * inst)741 bool ConstFoldingAShr(Inst *inst)
742 {
743 ASSERT(inst->GetOpcode() == Opcode::AShr);
744 auto input0 = inst->GetInput(0);
745 auto input1 = inst->GetInput(1);
746 auto graph = inst->GetBasicBlock()->GetGraph();
747 ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
748 if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
749 int64_t cnst0 = input0.GetInst()->CastToConstant()->GetIntValue();
750 auto cnst1 = input1.GetInst()->CastToConstant()->GetIntValue();
751 uint64_t sizeMask = DataType::GetTypeSize(inst->GetType(), graph->GetArch()) - 1;
752 ConstantInst *newCnst = nullptr;
753 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
754 newCnst = graph->FindOrCreateConstant<uint32_t>(
755 // NOLINTNEXTLINE(hicpp-signed-bitwise)
756 ConvertIntToInt(static_cast<int32_t>(cnst0) >>
757 (static_cast<uint32_t>(cnst1) & static_cast<uint32_t>(sizeMask)),
758 inst->GetType()));
759 } else {
760 newCnst = graph->FindOrCreateConstant(
761 // NOLINTNEXTLINE(hicpp-signed-bitwise)
762 ConvertIntToInt(cnst0 >> (cnst1 & sizeMask), inst->GetType()));
763 }
764 inst->ReplaceUsers(newCnst);
765 return true;
766 }
767 return false;
768 }
769
ConstFoldingAnd(Inst * inst)770 bool ConstFoldingAnd(Inst *inst)
771 {
772 ASSERT(inst->GetOpcode() == Opcode::And);
773 auto input0 = inst->GetInput(0);
774 auto input1 = inst->GetInput(1);
775 ConstantInst *newCnst = nullptr;
776 ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
777 if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
778 auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
779 auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
780 newCnst = ConstFoldingCreateIntConst(
781 inst, ConvertIntToInt(cnst0->GetIntValue() & cnst1->GetIntValue(), inst->GetType()));
782 inst->ReplaceUsers(newCnst);
783 return true;
784 }
785 if (input0.GetInst()->IsConst()) {
786 newCnst = static_cast<ConstantInst *>(input0.GetInst());
787 } else if (input1.GetInst()->IsConst()) {
788 newCnst = static_cast<ConstantInst *>(input1.GetInst());
789 }
790 if (newCnst != nullptr && newCnst->GetIntValue() == 0) {
791 inst->ReplaceUsers(newCnst);
792 return true;
793 }
794 return false;
795 }
796
ConstFoldingOr(Inst * inst)797 bool ConstFoldingOr(Inst *inst)
798 {
799 ASSERT(inst->GetOpcode() == Opcode::Or);
800 auto input0 = inst->GetInput(0);
801 auto input1 = inst->GetInput(1);
802 ConstantInst *newCnst = nullptr;
803 ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
804 if (input0.GetInst()->IsConst() && input1.GetInst()->IsConst()) {
805 auto cnst0 = static_cast<ConstantInst *>(input0.GetInst());
806 auto cnst1 = static_cast<ConstantInst *>(input1.GetInst());
807 newCnst = ConstFoldingCreateIntConst(
808 inst, ConvertIntToInt(cnst0->GetIntValue() | cnst1->GetIntValue(), inst->GetType()));
809 inst->ReplaceUsers(newCnst);
810 return true;
811 }
812 if (input0.GetInst()->IsConst()) {
813 newCnst = static_cast<ConstantInst *>(input0.GetInst());
814 } else if (input1.GetInst()->IsConst()) {
815 newCnst = static_cast<ConstantInst *>(input1.GetInst());
816 }
817 if (newCnst != nullptr && newCnst->GetIntValue() == static_cast<uint64_t>(-1)) {
818 inst->ReplaceUsers(newCnst);
819 return true;
820 }
821 return false;
822 }
823
ConstFoldingXor(Inst * inst)824 bool ConstFoldingXor(Inst *inst)
825 {
826 ASSERT(inst->GetOpcode() == Opcode::Xor);
827 auto input0 = inst->GetInput(0).GetInst();
828 auto input1 = inst->GetInput(1).GetInst();
829 ASSERT(DataType::GetCommonType(inst->GetType()) == DataType::INT64);
830 if (input0->IsConst() && input1->IsConst()) {
831 auto cnst0 = static_cast<ConstantInst *>(input0);
832 auto cnst1 = static_cast<ConstantInst *>(input1);
833 ConstantInst *newCnst = nullptr;
834 newCnst = ConstFoldingCreateIntConst(
835 inst, ConvertIntToInt(cnst0->GetIntValue() ^ cnst1->GetIntValue(), inst->GetType()));
836 inst->ReplaceUsers(newCnst);
837 return true;
838 }
839 // A xor A = 0
840 if (input0 == input1) {
841 inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 0));
842 return true;
843 }
844 return false;
845 }
846
847 template <class T>
GetResult(T l,T r,const CmpInst * cmp)848 int64_t GetResult(T l, T r, [[maybe_unused]] const CmpInst *cmp)
849 {
850 // NOLINTNEXTLINE(readability-braces-around-statements, bugprone-suspicious-semicolon)
851 if constexpr (std::is_same<T, float>() || std::is_same<T, double>()) {
852 ASSERT(DataType::IsFloatType(cmp->GetInputType(0)));
853 if (std::isnan(l) || std::isnan(r)) {
854 if (cmp->IsFcmpg()) {
855 return 1;
856 }
857 return -1;
858 }
859 }
860 if (l > r) {
861 return 1;
862 }
863 if (l < r) {
864 return -1;
865 }
866 return 0;
867 }
868
GetIntResult(ConstantInst * cnst0,ConstantInst * cnst1,DataType::Type inputType,const CmpInst * cmp)869 int64_t GetIntResult(ConstantInst *cnst0, ConstantInst *cnst1, DataType::Type inputType, const CmpInst *cmp)
870 {
871 auto l = ConvertIntToInt(cnst0->GetIntValue(), inputType);
872 auto r = ConvertIntToInt(cnst1->GetIntValue(), inputType);
873 auto graph = cnst0->GetBasicBlock()->GetGraph();
874 if (DataType::IsTypeSigned(inputType)) {
875 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inputType)) {
876 return GetResult(static_cast<int32_t>(l), static_cast<int32_t>(r), cmp);
877 }
878 return GetResult(static_cast<int64_t>(l), static_cast<int64_t>(r), cmp);
879 }
880 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inputType)) {
881 return GetResult(static_cast<uint32_t>(l), static_cast<uint32_t>(r), cmp);
882 }
883 return GetResult(l, r, cmp);
884 }
885
ConstFoldingCmpFloatNan(Inst * inst)886 bool ConstFoldingCmpFloatNan(Inst *inst)
887 {
888 ASSERT(inst->GetOpcode() == Opcode::Cmp);
889 auto input0 = inst->GetInput(0).GetInst();
890 auto input1 = inst->GetInput(1).GetInst();
891 if (!input0->IsConst() && !input1->IsConst()) {
892 return false;
893 }
894
895 if (!DataType::IsFloatType(inst->CastToCmp()->GetOperandsType())) {
896 return false;
897 }
898
899 // One of the constant always will be in input1
900 if (input0->IsConst()) {
901 std::swap(input0, input1);
902 }
903
904 // For Float constant is applied only NaN cases
905 if (!input1->CastToConstant()->IsNaNConst()) {
906 return false;
907 }
908 // Result related with Fcmpg as wrote in spec
909 int64_t res {-1};
910 if (inst->CastToCmp()->IsFcmpg()) {
911 res = 1;
912 }
913 inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, res));
914 return true;
915 }
916
ConstFoldingCmp(Inst * inst)917 bool ConstFoldingCmp(Inst *inst)
918 {
919 ASSERT(inst->GetOpcode() == Opcode::Cmp);
920 auto input0 = inst->GetInput(0).GetInst();
921 auto input1 = inst->GetInput(1).GetInst();
922 auto cmp = inst->CastToCmp();
923 auto inputType = cmp->GetInputType(0);
924 if (ConstFoldingCmpFloatNan(inst)) {
925 return true;
926 }
927 if (input0->IsConst() && input1->IsConst()) {
928 auto cnst0 = static_cast<ConstantInst *>(input0);
929 auto cnst1 = static_cast<ConstantInst *>(input1);
930 int64_t result = 0;
931 switch (DataType::GetCommonType(inputType)) {
932 case DataType::INT64: {
933 result = GetIntResult(cnst0, cnst1, inputType, cmp);
934 break;
935 }
936 case DataType::FLOAT32:
937 result = GetResult(cnst0->GetFloatValue(), cnst1->GetFloatValue(), cmp);
938 break;
939 case DataType::FLOAT64:
940 result = GetResult(cnst0->GetDoubleValue(), cnst1->GetDoubleValue(), cmp);
941 break;
942 default:
943 break;
944 }
945 auto newCnst = inst->GetBasicBlock()->GetGraph()->FindOrCreateConstant(result);
946 inst->ReplaceUsers(newCnst);
947 return true;
948 }
949 if (input0 == input1 && DataType::GetCommonType(inputType) == DataType::INT64) {
950 // for floating point values result may be non-zero if x is NaN
951 inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 0));
952 return true;
953 }
954 return false;
955 }
956
ConstFoldingCompareCreateConst(Inst * inst,bool value)957 ConstantInst *ConstFoldingCompareCreateConst(Inst *inst, bool value)
958 {
959 auto graph = inst->GetBasicBlock()->GetGraph();
960 if (graph->IsBytecodeOptimizer() && IsInt32Bit(inst->GetType())) {
961 return graph->FindOrCreateConstant(static_cast<uint32_t>(value));
962 }
963 return graph->FindOrCreateConstant(static_cast<uint64_t>(value));
964 }
965
ConstFoldingCompareCreateNewConst(Inst * inst,uint64_t cnstVal0,uint64_t cnstVal1)966 ConstantInst *ConstFoldingCompareCreateNewConst(Inst *inst, uint64_t cnstVal0, uint64_t cnstVal1)
967 {
968 switch (inst->CastToCompare()->GetCc()) {
969 case ConditionCode::CC_EQ:
970 return ConstFoldingCompareCreateConst(inst, (cnstVal0 == cnstVal1));
971 case ConditionCode::CC_NE:
972 return ConstFoldingCompareCreateConst(inst, (cnstVal0 != cnstVal1));
973 case ConditionCode::CC_LT:
974 return ConstFoldingCompareCreateConst(inst,
975 (static_cast<int64_t>(cnstVal0) < static_cast<int64_t>(cnstVal1)));
976 case ConditionCode::CC_B:
977 return ConstFoldingCompareCreateConst(inst, (cnstVal0 < cnstVal1));
978 case ConditionCode::CC_LE:
979 return ConstFoldingCompareCreateConst(inst,
980 (static_cast<int64_t>(cnstVal0) <= static_cast<int64_t>(cnstVal1)));
981 case ConditionCode::CC_BE:
982 return ConstFoldingCompareCreateConst(inst, (cnstVal0 <= cnstVal1));
983 case ConditionCode::CC_GT:
984 return ConstFoldingCompareCreateConst(inst,
985 (static_cast<int64_t>(cnstVal0) > static_cast<int64_t>(cnstVal1)));
986 case ConditionCode::CC_A:
987 return ConstFoldingCompareCreateConst(inst, (cnstVal0 > cnstVal1));
988 case ConditionCode::CC_GE:
989 return ConstFoldingCompareCreateConst(inst,
990 (static_cast<int64_t>(cnstVal0) >= static_cast<int64_t>(cnstVal1)));
991 case ConditionCode::CC_AE:
992 return ConstFoldingCompareCreateConst(inst, (cnstVal0 >= cnstVal1));
993 case ConditionCode::CC_TST_EQ:
994 return ConstFoldingCompareCreateConst(inst, ((cnstVal0 & cnstVal1) == 0));
995 case ConditionCode::CC_TST_NE:
996 return ConstFoldingCompareCreateConst(inst, ((cnstVal0 & cnstVal1) != 0));
997 default:
998 UNREACHABLE();
999 }
1000 }
1001
ConstFoldingCompareEqualInputs(Inst * inst,Inst * input0,Inst * input1)1002 bool ConstFoldingCompareEqualInputs(Inst *inst, Inst *input0, Inst *input1)
1003 {
1004 if (input0 != input1) {
1005 return false;
1006 }
1007 auto cmpInst = inst->CastToCompare();
1008 auto commonType = DataType::GetCommonType(input0->GetType());
1009 switch (cmpInst->GetCc()) {
1010 case ConditionCode::CC_EQ:
1011 case ConditionCode::CC_LE:
1012 case ConditionCode::CC_GE:
1013 case ConditionCode::CC_BE:
1014 case ConditionCode::CC_AE:
1015 // for floating point values result may be non-zero if x is NaN
1016 if (commonType == DataType::INT64 || commonType == DataType::POINTER) {
1017 inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 1));
1018 return true;
1019 }
1020 break;
1021 case ConditionCode::CC_NE:
1022 if (commonType == DataType::INT64 || commonType == DataType::POINTER) {
1023 inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 0));
1024 return true;
1025 }
1026 break;
1027 case ConditionCode::CC_LT:
1028 case ConditionCode::CC_GT:
1029 case ConditionCode::CC_B:
1030 case ConditionCode::CC_A:
1031 // x<x is false even for x=NaN
1032 inst->ReplaceUsers(ConstFoldingCreateIntConst(inst, 0));
1033 return true;
1034 default:
1035 return false;
1036 }
1037 return false;
1038 }
1039
IsUniqueRef(Inst * inst)1040 static bool IsUniqueRef(Inst *inst)
1041 {
1042 return inst->IsAllocation() || inst->GetOpcode() == Opcode::NullPtr ||
1043 inst->GetOpcode() == Opcode::LoadUniqueObject;
1044 }
1045
ConstFoldingCompareFloatNan(Inst * inst)1046 bool ConstFoldingCompareFloatNan(Inst *inst)
1047 {
1048 ASSERT(DataType::IsFloatType(inst->CastToCompare()->GetOperandsType()));
1049 auto input0 = inst->GetInput(0).GetInst();
1050 auto input1 = inst->GetInput(1).GetInst();
1051 if (!input0->IsConst() && !input1->IsConst()) {
1052 return false;
1053 }
1054
1055 // One of the constant always will be in input1
1056 if (input0->IsConst()) {
1057 std::swap(input0, input1);
1058 }
1059
1060 // For Float constant is applied only NaN cases
1061 if (!input1->CastToConstant()->IsNaNConst()) {
1062 return false;
1063 }
1064
1065 // If both operands is NaN constant - it is OK, all optimization will be applied anyway
1066 bool resultConst {};
1067 // We shouldn't reverse ConditionCode, because the results is not related to order of inputs
1068 switch (inst->CastToCompare()->GetCc()) {
1069 case ConditionCode::CC_NE:
1070 // NaN != number is true
1071 resultConst = true;
1072 break;
1073 case ConditionCode::CC_EQ: // ==
1074 case ConditionCode::CC_LT: // <
1075 case ConditionCode::CC_LE: // <=
1076 case ConditionCode::CC_GT: // >
1077 case ConditionCode::CC_GE: // >=
1078 // All these CC with NaN give false
1079 resultConst = false;
1080 break;
1081 default:
1082 UNREACHABLE();
1083 }
1084 inst->ReplaceUsers(ConstFoldingCompareCreateConst(inst, resultConst));
1085 return true;
1086 }
1087
ConstFoldingCompareIntConstant(Inst * inst)1088 bool ConstFoldingCompareIntConstant(Inst *inst)
1089 {
1090 ASSERT(!DataType::IsFloatType(inst->CastToCompare()->GetOperandsType()));
1091 auto input0 = inst->GetInput(0).GetInst();
1092 auto input1 = inst->GetInput(1).GetInst();
1093 if (!input0->IsConst() || !input1->IsConst()) {
1094 return false;
1095 }
1096
1097 auto cnst0 = input0->CastToConstant();
1098 auto cnst1 = input1->CastToConstant();
1099 ConstantInst *newCnst = nullptr;
1100 auto type = inst->GetInputType(0);
1101 if (DataType::GetCommonType(type) == DataType::INT64) {
1102 uint64_t cnstVal0 = ConvertIntToInt(cnst0->GetIntValue(), type);
1103 uint64_t cnstVal1 = ConvertIntToInt(cnst1->GetIntValue(), type);
1104 newCnst = ConstFoldingCompareCreateNewConst(inst, cnstVal0, cnstVal1);
1105 inst->ReplaceUsers(newCnst);
1106 return true;
1107 }
1108 return false;
1109 }
1110
ConstFoldingCompare(Inst * inst)1111 bool ConstFoldingCompare(Inst *inst)
1112 {
1113 ASSERT(inst->GetOpcode() == Opcode::Compare);
1114 auto input0 = inst->GetInput(0).GetInst();
1115 auto input1 = inst->GetInput(1).GetInst();
1116
1117 if (DataType::IsFloatType(inst->CastToCompare()->GetOperandsType())) {
1118 if (ConstFoldingCompareFloatNan(inst)) {
1119 return true;
1120 }
1121 } else {
1122 if (ConstFoldingCompareIntConstant(inst)) {
1123 return true;
1124 }
1125 }
1126 if (input0->GetOpcode() == Opcode::LoadImmediate && input1->GetOpcode() == Opcode::LoadImmediate) {
1127 auto class0 = input0->CastToLoadImmediate()->GetObject();
1128 auto class1 = input1->CastToLoadImmediate()->GetObject();
1129 auto cc = inst->CastToCompare()->GetCc();
1130 ASSERT(cc == CC_NE || cc == CC_EQ);
1131 bool res {(class0 == class1) == (cc == CC_EQ)};
1132 inst->ReplaceUsers(ConstFoldingCompareCreateConst(inst, res));
1133 return true;
1134 }
1135 if (IsZeroConstantOrNullPtr(input0) && IsZeroConstantOrNullPtr(input1)) {
1136 auto cc = inst->CastToCompare()->GetCc();
1137 ASSERT(cc == CC_NE || cc == CC_EQ);
1138 inst->ReplaceUsers(ConstFoldingCompareCreateConst(inst, cc == CC_EQ));
1139 return true;
1140 }
1141 if (ConstFoldingCompareEqualInputs(inst, input0, input1)) {
1142 return true;
1143 }
1144 if (inst->GetInputType(0) == DataType::REFERENCE) {
1145 ASSERT(input0 != input1);
1146 if (IsUniqueRef(input0) && IsUniqueRef(input1)) {
1147 auto cc = inst->CastToCompare()->GetCc();
1148 inst->ReplaceUsers(ConstFoldingCompareCreateConst(inst, cc == CC_NE));
1149 return true;
1150 }
1151 }
1152 return false;
1153 }
1154
ConstFoldingSqrt(Inst * inst)1155 bool ConstFoldingSqrt(Inst *inst)
1156 {
1157 ASSERT(inst->GetOpcode() == Opcode::Sqrt);
1158 auto input = inst->GetInput(0).GetInst();
1159 if (input->IsConst()) {
1160 auto cnst = input->CastToConstant();
1161 Inst *newCnst = nullptr;
1162 if (cnst->GetType() == DataType::FLOAT32) {
1163 newCnst = inst->GetBasicBlock()->GetGraph()->FindOrCreateConstant(std::sqrt(cnst->GetFloatValue()));
1164 } else {
1165 ASSERT(cnst->GetType() == DataType::FLOAT64);
1166 newCnst = inst->GetBasicBlock()->GetGraph()->FindOrCreateConstant(std::sqrt(cnst->GetDoubleValue()));
1167 }
1168 inst->ReplaceUsers(newCnst);
1169 return true;
1170 }
1171 return false;
1172 }
1173
ConstFoldingLoadStatic(Inst * inst)1174 bool ConstFoldingLoadStatic(Inst *inst)
1175 {
1176 auto field = inst->CastToLoadStatic()->GetObjField();
1177 ASSERT(field != nullptr);
1178 auto isFloatType = DataType::IsFloatType(inst->GetType());
1179 if (DataType::GetCommonType(inst->GetType()) != DataType::INT64 && !isFloatType) {
1180 return false;
1181 }
1182 auto graph = inst->GetBasicBlock()->GetGraph();
1183 auto runtime = graph->GetRuntime();
1184 auto klass = runtime->GetClassForField(field);
1185 if (runtime->IsFieldReadonly(field) && runtime->IsClassInitialized(reinterpret_cast<uintptr_t>(klass))) {
1186 if (isFloatType) {
1187 auto value = runtime->GetStaticFieldFloatValue(field);
1188 if (inst->GetType() == DataType::FLOAT32) {
1189 inst->ReplaceUsers(graph->FindOrCreateConstant(static_cast<float>(value)));
1190 } else {
1191 ASSERT(inst->GetType() == DataType::FLOAT64);
1192 inst->ReplaceUsers(graph->FindOrCreateConstant(value));
1193 }
1194 } else {
1195 auto value = runtime->GetStaticFieldIntegerValue(field);
1196 inst->ReplaceUsers(graph->FindOrCreateConstant(ConvertIntToInt(value, inst->GetType())));
1197 }
1198 return true;
1199 }
1200 return false;
1201 }
1202
1203 } // namespace ark::compiler
1204