1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/strength_reduction_pass.h"
16
17 #include <algorithm>
18 #include <cstdio>
19 #include <cstring>
20 #include <memory>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25
26 #include "source/opt/def_use_manager.h"
27 #include "source/opt/ir_context.h"
28 #include "source/opt/log.h"
29 #include "source/opt/reflect.h"
30
31 namespace spvtools {
32 namespace opt {
33 namespace {
34 // Count the number of trailing zeros in the binary representation of
35 // |constVal|.
CountTrailingZeros(uint32_t constVal)36 uint32_t CountTrailingZeros(uint32_t constVal) {
37 // Faster if we use the hardware count trailing zeros instruction.
38 // If not available, we could create a table.
39 uint32_t shiftAmount = 0;
40 while ((constVal & 1) == 0) {
41 ++shiftAmount;
42 constVal = (constVal >> 1);
43 }
44 return shiftAmount;
45 }
46
47 // Return true if |val| is a power of 2.
IsPowerOf2(uint32_t val)48 bool IsPowerOf2(uint32_t val) {
49 // The idea is that the & will clear out the least
50 // significant 1 bit. If it is a power of 2, then
51 // there is exactly 1 bit set, and the value becomes 0.
52 if (val == 0) return false;
53 return ((val - 1) & val) == 0;
54 }
55
56 } // namespace
57
Process()58 Pass::Status StrengthReductionPass::Process() {
59 // Initialize the member variables on a per module basis.
60 bool modified = false;
61 int32_type_id_ = 0;
62 uint32_type_id_ = 0;
63 std::memset(constant_ids_, 0, sizeof(constant_ids_));
64
65 FindIntTypesAndConstants();
66 modified = ScanFunctions();
67 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
68 }
69
ReplaceMultiplyByPowerOf2(BasicBlock::iterator * inst)70 bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
71 BasicBlock::iterator* inst) {
72 assert((*inst)->opcode() == spv::Op::OpIMul &&
73 "Only works for multiplication of integers.");
74 bool modified = false;
75
76 // Currently only works on 32-bit integers.
77 if ((*inst)->type_id() != int32_type_id_ &&
78 (*inst)->type_id() != uint32_type_id_) {
79 return modified;
80 }
81
82 // Check the operands for a constant that is a power of 2.
83 for (int i = 0; i < 2; i++) {
84 uint32_t opId = (*inst)->GetSingleWordInOperand(i);
85 Instruction* opInst = get_def_use_mgr()->GetDef(opId);
86 if (opInst->opcode() == spv::Op::OpConstant) {
87 // We found a constant operand.
88 uint32_t constVal = opInst->GetSingleWordOperand(2);
89
90 if (IsPowerOf2(constVal)) {
91 modified = true;
92 uint32_t shiftAmount = CountTrailingZeros(constVal);
93 uint32_t shiftConstResultId = GetConstantId(shiftAmount);
94
95 // Create the new instruction.
96 uint32_t newResultId = TakeNextId();
97 std::vector<Operand> newOperands;
98 newOperands.push_back((*inst)->GetInOperand(1 - i));
99 Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
100 {shiftConstResultId});
101 newOperands.push_back(shiftOperand);
102 std::unique_ptr<Instruction> newInstruction(
103 new Instruction(context(), spv::Op::OpShiftLeftLogical,
104 (*inst)->type_id(), newResultId, newOperands));
105
106 // Insert the new instruction and update the data structures.
107 (*inst) = (*inst).InsertBefore(std::move(newInstruction));
108 get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst));
109 ++(*inst);
110 context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId);
111
112 // Remove the old instruction.
113 Instruction* inst_to_delete = &*(*inst);
114 --(*inst);
115 context()->KillInst(inst_to_delete);
116
117 // We do not want to replace the instruction twice if both operands
118 // are constants that are a power of 2. So we break here.
119 break;
120 }
121 }
122 }
123
124 return modified;
125 }
126
FindIntTypesAndConstants()127 void StrengthReductionPass::FindIntTypesAndConstants() {
128 analysis::Integer int32(32, true);
129 int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
130 analysis::Integer uint32(32, false);
131 uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
132 for (auto iter = get_module()->types_values_begin();
133 iter != get_module()->types_values_end(); ++iter) {
134 switch (iter->opcode()) {
135 case spv::Op::OpConstant:
136 if (iter->type_id() == uint32_type_id_) {
137 uint32_t value = iter->GetSingleWordOperand(2);
138 if (value <= 32) constant_ids_[value] = iter->result_id();
139 }
140 break;
141 default:
142 break;
143 }
144 }
145 }
146
GetConstantId(uint32_t val)147 uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
148 assert(val <= 32 &&
149 "This function does not handle constants larger than 32.");
150
151 if (constant_ids_[val] == 0) {
152 if (uint32_type_id_ == 0) {
153 analysis::Integer uint(32, false);
154 uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
155 }
156
157 // Construct the constant.
158 uint32_t resultId = TakeNextId();
159 Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
160 {val});
161 std::unique_ptr<Instruction> newConstant(new Instruction(
162 context(), spv::Op::OpConstant, uint32_type_id_, resultId, {constant}));
163 get_module()->AddGlobalValue(std::move(newConstant));
164
165 // Notify the DefUseManager about this constant.
166 auto constantIter = --get_module()->types_values_end();
167 get_def_use_mgr()->AnalyzeInstDef(&*constantIter);
168
169 // Store the result id for next time.
170 constant_ids_[val] = resultId;
171 }
172
173 return constant_ids_[val];
174 }
175
ScanFunctions()176 bool StrengthReductionPass::ScanFunctions() {
177 // I did not use |ForEachInst| in the module because the function that acts on
178 // the instruction gets a pointer to the instruction. We cannot use that to
179 // insert a new instruction. I want an iterator.
180 bool modified = false;
181 for (auto& func : *get_module()) {
182 for (auto& bb : func) {
183 for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
184 switch (inst->opcode()) {
185 case spv::Op::OpIMul:
186 if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
187 break;
188 default:
189 break;
190 }
191 }
192 }
193 }
194 return modified;
195 }
196
197 } // namespace opt
198 } // namespace spvtools
199