• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "infer_flags.h"
17 
18 #include <llvm/Analysis/LoopInfo.h>
19 #include <llvm/Analysis/ScalarEvolution.h>
20 #include <llvm/Analysis/ScalarEvolutionExpressions.h>
21 #include <llvm/Analysis/MemoryBuiltins.h>
22 #include <llvm/Analysis/ValueTracking.h>
23 #include <llvm/IR/Dominators.h>
24 #include <llvm/IR/Instructions.h>
25 #include <llvm/IR/InstIterator.h>
26 #include <llvm/IR/Operator.h>
27 #include <llvm/Support/KnownBits.h>
28 
29 #include "transforms/transform_utils.h"
30 
31 #define DEBUG_TYPE "infer-flags"
32 
33 namespace {
34 
CanOverflow(const llvm::KnownBits & start,const llvm::KnownBits & step,uint64_t tripCount)35 bool CanOverflow(const llvm::KnownBits &start, const llvm::KnownBits &step, uint64_t tripCount)
36 {
37     ASSERT(start.getBitWidth() == step.getBitWidth());
38     // Check overflow for simple recurrence like:
39     //     i32 v0 = phi i32 [start bb0, v1 bb1]
40     // Where:
41     //     v1 = op v0, step
42 
43     // Map range [stepMin; stepMax) to [stepMin * tripCount; step * tripCount)
44     auto tripByStep = llvm::ConstantRange::fromKnownBits(
45         llvm::KnownBits::mul(step, llvm::KnownBits::makeConstant(llvm::APInt {step.getBitWidth(), tripCount})), true);
46 
47     // Get signed ranges for step, and start
48     auto stepRange = llvm::ConstantRange::fromKnownBits(step, true);
49     auto startRange = llvm::ConstantRange::fromKnownBits(start, true);
50     // Check actual overflow
51     bool overflow = startRange.signedAddMayOverflow(tripByStep) != llvm::ConstantRange::OverflowResult::NeverOverflows;
52     LLVM_DEBUG(llvm::dbgs() << "stepRange = " << stepRange << ", startRange = " << startRange
53                             << ", tripByStep = " << tripByStep << ", tripCount = " << tripCount
54                             << ", canOverflow = " << llvm::toStringRef(overflow) << "\n");
55     return overflow;
56 }
57 }  // namespace
58 
59 namespace ark::llvmbackend::passes {
60 
ShouldInsert(const ark::llvmbackend::LLVMCompilerOptions * options)61 bool InferFlags::ShouldInsert([[maybe_unused]] const ark::llvmbackend::LLVMCompilerOptions *options)
62 {
63     return true;
64 }
65 
run(llvm::Function & function,llvm::FunctionAnalysisManager & analysisManager)66 llvm::PreservedAnalyses InferFlags::run(llvm::Function &function, llvm::FunctionAnalysisManager &analysisManager)
67 {
68     LLVM_DEBUG(llvm::dbgs() << "Running on '" << function.getName() << "'\n");
69 
70     bool changed = false;
71 
72     auto &scalarEvolution = analysisManager.getResult<llvm::ScalarEvolutionAnalysis>(function);
73     auto &loopAnalysis = analysisManager.getResult<llvm::LoopAnalysis>(function);
74 
75     for (auto &loop : loopAnalysis) {
76         changed |= RunOnLoop(loop, &scalarEvolution);
77     }
78 
79     return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
80 }
81 
RunOnLoop(llvm::Loop * loop,llvm::ScalarEvolution * scalarEvolution)82 bool InferFlags::RunOnLoop(llvm::Loop *loop, llvm::ScalarEvolution *scalarEvolution)
83 {
84     bool changed = false;
85     for (auto basicBlock : loop->blocks()) {
86         changed |= RunOnBasicBlock(loop, basicBlock, scalarEvolution);
87     }
88     return changed;
89 }
90 
RunOnBasicBlock(llvm::Loop * loop,llvm::BasicBlock * basicBlock,llvm::ScalarEvolution * scalarEvolution)91 bool InferFlags::RunOnBasicBlock(llvm::Loop *loop, llvm::BasicBlock *basicBlock, llvm::ScalarEvolution *scalarEvolution)
92 {
93     bool changed = false;
94 
95     for (auto &phi : basicBlock->phis()) {
96         if (!scalarEvolution->isSCEVable(phi.getType())) {
97             continue;
98         }
99         llvm::BinaryOperator *binaryOperator;
100         llvm::Value *step;
101         llvm::Value *start;
102 
103         LLVM_DEBUG(llvm::dbgs() << "Trying to match simple  recurrence for phi '" << phi << "'\n");
104         if (!llvm::matchSimpleRecurrence(&phi, binaryOperator, start, step)) {
105             continue;
106         }
107         LLVM_DEBUG(llvm::dbgs() << "Matched simple recurrence '" << *binaryOperator << "'\n");
108         if (!llvm::isa<llvm::OverflowingBinaryOperator>(binaryOperator)) {
109             continue;
110         }
111         // Support only add because sub is untested
112         if (binaryOperator->getOpcode() != llvm::Instruction::Add) {
113             continue;
114         }
115 
116         auto tripCount = scalarEvolution->getSmallConstantMaxTripCount(loop);
117         LLVM_DEBUG(llvm::dbgs() << "tripCount = '" << tripCount << "'\n");
118         if (tripCount == 0) {
119             continue;
120         }
121         // Now infer range
122         auto dataLayout = basicBlock->getModule()->getDataLayout();
123         auto knownStart = llvm::computeKnownBits(start, dataLayout);
124         auto knownStep = llvm::computeKnownBits(step, dataLayout);
125         if (knownStart.isUnknown() || knownStep.isUnknown()) {
126             LLVM_DEBUG(llvm::dbgs() << "Start or step is unknown\n");
127             continue;
128         }
129         ASSERT(!phi.getType()->isPointerTy());
130         if (!binaryOperator->hasNoSignedWrap() && !CanOverflow(knownStart, knownStep, tripCount)) {
131             LLVM_DEBUG(llvm::dbgs() << "Set nsw to '" << *binaryOperator << "'\n");
132             binaryOperator->setHasNoSignedWrap(true);
133             changed = true;
134         }
135     }
136     return changed;
137 }
138 
139 }  // namespace ark::llvmbackend::passes
140