1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "infer_flags.h"
17
18 #include <llvm/Analysis/LoopInfo.h>
19 #include <llvm/Analysis/ScalarEvolution.h>
20 #include <llvm/Analysis/ScalarEvolutionExpressions.h>
21 #include <llvm/Analysis/MemoryBuiltins.h>
22 #include <llvm/Analysis/ValueTracking.h>
23 #include <llvm/IR/Dominators.h>
24 #include <llvm/IR/Instructions.h>
25 #include <llvm/IR/InstIterator.h>
26 #include <llvm/IR/Operator.h>
27 #include <llvm/Support/KnownBits.h>
28
29 #include "transforms/transform_utils.h"
30
31 #define DEBUG_TYPE "infer-flags"
32
33 namespace {
34
CanOverflow(const llvm::KnownBits & start,const llvm::KnownBits & step,uint64_t tripCount)35 bool CanOverflow(const llvm::KnownBits &start, const llvm::KnownBits &step, uint64_t tripCount)
36 {
37 ASSERT(start.getBitWidth() == step.getBitWidth());
38 // Check overflow for simple recurrence like:
39 // i32 v0 = phi i32 [start bb0, v1 bb1]
40 // Where:
41 // v1 = op v0, step
42
43 // Map range [stepMin; stepMax) to [stepMin * tripCount; step * tripCount)
44 auto tripByStep = llvm::ConstantRange::fromKnownBits(
45 llvm::KnownBits::mul(step, llvm::KnownBits::makeConstant(llvm::APInt {step.getBitWidth(), tripCount})), true);
46
47 // Get signed ranges for step, and start
48 auto stepRange = llvm::ConstantRange::fromKnownBits(step, true);
49 auto startRange = llvm::ConstantRange::fromKnownBits(start, true);
50 // Check actual overflow
51 bool overflow = startRange.signedAddMayOverflow(tripByStep) != llvm::ConstantRange::OverflowResult::NeverOverflows;
52 LLVM_DEBUG(llvm::dbgs() << "stepRange = " << stepRange << ", startRange = " << startRange
53 << ", tripByStep = " << tripByStep << ", tripCount = " << tripCount
54 << ", canOverflow = " << llvm::toStringRef(overflow) << "\n");
55 return overflow;
56 }
57 } // namespace
58
59 namespace ark::llvmbackend::passes {
60
ShouldInsert(const ark::llvmbackend::LLVMCompilerOptions * options)61 bool InferFlags::ShouldInsert([[maybe_unused]] const ark::llvmbackend::LLVMCompilerOptions *options)
62 {
63 return true;
64 }
65
run(llvm::Function & function,llvm::FunctionAnalysisManager & analysisManager)66 llvm::PreservedAnalyses InferFlags::run(llvm::Function &function, llvm::FunctionAnalysisManager &analysisManager)
67 {
68 LLVM_DEBUG(llvm::dbgs() << "Running on '" << function.getName() << "'\n");
69
70 bool changed = false;
71
72 auto &scalarEvolution = analysisManager.getResult<llvm::ScalarEvolutionAnalysis>(function);
73 auto &loopAnalysis = analysisManager.getResult<llvm::LoopAnalysis>(function);
74
75 for (auto &loop : loopAnalysis) {
76 changed |= RunOnLoop(loop, &scalarEvolution);
77 }
78
79 return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
80 }
81
RunOnLoop(llvm::Loop * loop,llvm::ScalarEvolution * scalarEvolution)82 bool InferFlags::RunOnLoop(llvm::Loop *loop, llvm::ScalarEvolution *scalarEvolution)
83 {
84 bool changed = false;
85 for (auto basicBlock : loop->blocks()) {
86 changed |= RunOnBasicBlock(loop, basicBlock, scalarEvolution);
87 }
88 return changed;
89 }
90
RunOnBasicBlock(llvm::Loop * loop,llvm::BasicBlock * basicBlock,llvm::ScalarEvolution * scalarEvolution)91 bool InferFlags::RunOnBasicBlock(llvm::Loop *loop, llvm::BasicBlock *basicBlock, llvm::ScalarEvolution *scalarEvolution)
92 {
93 bool changed = false;
94
95 for (auto &phi : basicBlock->phis()) {
96 if (!scalarEvolution->isSCEVable(phi.getType())) {
97 continue;
98 }
99 llvm::BinaryOperator *binaryOperator;
100 llvm::Value *step;
101 llvm::Value *start;
102
103 LLVM_DEBUG(llvm::dbgs() << "Trying to match simple recurrence for phi '" << phi << "'\n");
104 if (!llvm::matchSimpleRecurrence(&phi, binaryOperator, start, step)) {
105 continue;
106 }
107 LLVM_DEBUG(llvm::dbgs() << "Matched simple recurrence '" << *binaryOperator << "'\n");
108 if (!llvm::isa<llvm::OverflowingBinaryOperator>(binaryOperator)) {
109 continue;
110 }
111 // Support only add because sub is untested
112 if (binaryOperator->getOpcode() != llvm::Instruction::Add) {
113 continue;
114 }
115
116 auto tripCount = scalarEvolution->getSmallConstantMaxTripCount(loop);
117 LLVM_DEBUG(llvm::dbgs() << "tripCount = '" << tripCount << "'\n");
118 if (tripCount == 0) {
119 continue;
120 }
121 // Now infer range
122 auto dataLayout = basicBlock->getModule()->getDataLayout();
123 auto knownStart = llvm::computeKnownBits(start, dataLayout);
124 auto knownStep = llvm::computeKnownBits(step, dataLayout);
125 if (knownStart.isUnknown() || knownStep.isUnknown()) {
126 LLVM_DEBUG(llvm::dbgs() << "Start or step is unknown\n");
127 continue;
128 }
129 ASSERT(!phi.getType()->isPointerTy());
130 if (!binaryOperator->hasNoSignedWrap() && !CanOverflow(knownStart, knownStep, tripCount)) {
131 LLVM_DEBUG(llvm::dbgs() << "Set nsw to '" << *binaryOperator << "'\n");
132 binaryOperator->setHasNoSignedWrap(true);
133 changed = true;
134 }
135 }
136 return changed;
137 }
138
139 } // namespace ark::llvmbackend::passes
140