• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/strength_reduction_pass.h"
16 
17 #include <algorithm>
18 #include <cstdio>
19 #include <cstring>
20 #include <memory>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25 
26 #include "source/opt/def_use_manager.h"
27 #include "source/opt/ir_context.h"
28 #include "source/opt/log.h"
29 #include "source/opt/reflect.h"
30 
31 namespace spvtools {
32 namespace opt {
33 namespace {
34 // Count the number of trailing zeros in the binary representation of
35 // |constVal|.
CountTrailingZeros(uint32_t constVal)36 uint32_t CountTrailingZeros(uint32_t constVal) {
37   // Faster if we use the hardware count trailing zeros instruction.
38   // If not available, we could create a table.
39   uint32_t shiftAmount = 0;
40   while ((constVal & 1) == 0) {
41     ++shiftAmount;
42     constVal = (constVal >> 1);
43   }
44   return shiftAmount;
45 }
46 
47 // Return true if |val| is a power of 2.
IsPowerOf2(uint32_t val)48 bool IsPowerOf2(uint32_t val) {
49   // The idea is that the & will clear out the least
50   // significant 1 bit.  If it is a power of 2, then
51   // there is exactly 1 bit set, and the value becomes 0.
52   if (val == 0) return false;
53   return ((val - 1) & val) == 0;
54 }
55 
56 }  // namespace
57 
Process()58 Pass::Status StrengthReductionPass::Process() {
59   // Initialize the member variables on a per module basis.
60   bool modified = false;
61   int32_type_id_ = 0;
62   uint32_type_id_ = 0;
63   std::memset(constant_ids_, 0, sizeof(constant_ids_));
64 
65   FindIntTypesAndConstants();
66   modified = ScanFunctions();
67   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
68 }
69 
ReplaceMultiplyByPowerOf2(BasicBlock::iterator * inst)70 bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
71     BasicBlock::iterator* inst) {
72   assert((*inst)->opcode() == spv::Op::OpIMul &&
73          "Only works for multiplication of integers.");
74   bool modified = false;
75 
76   // Currently only works on 32-bit integers.
77   if ((*inst)->type_id() != int32_type_id_ &&
78       (*inst)->type_id() != uint32_type_id_) {
79     return modified;
80   }
81 
82   // Check the operands for a constant that is a power of 2.
83   for (int i = 0; i < 2; i++) {
84     uint32_t opId = (*inst)->GetSingleWordInOperand(i);
85     Instruction* opInst = get_def_use_mgr()->GetDef(opId);
86     if (opInst->opcode() == spv::Op::OpConstant) {
87       // We found a constant operand.
88       uint32_t constVal = opInst->GetSingleWordOperand(2);
89 
90       if (IsPowerOf2(constVal)) {
91         modified = true;
92         uint32_t shiftAmount = CountTrailingZeros(constVal);
93         uint32_t shiftConstResultId = GetConstantId(shiftAmount);
94 
95         // Create the new instruction.
96         uint32_t newResultId = TakeNextId();
97         std::vector<Operand> newOperands;
98         newOperands.push_back((*inst)->GetInOperand(1 - i));
99         Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
100                              {shiftConstResultId});
101         newOperands.push_back(shiftOperand);
102         std::unique_ptr<Instruction> newInstruction(
103             new Instruction(context(), spv::Op::OpShiftLeftLogical,
104                             (*inst)->type_id(), newResultId, newOperands));
105 
106         // Insert the new instruction and update the data structures.
107         (*inst) = (*inst).InsertBefore(std::move(newInstruction));
108         get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst));
109         ++(*inst);
110         context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId);
111 
112         // Remove the old instruction.
113         Instruction* inst_to_delete = &*(*inst);
114         --(*inst);
115         context()->KillInst(inst_to_delete);
116 
117         // We do not want to replace the instruction twice if both operands
118         // are constants that are a power of 2.  So we break here.
119         break;
120       }
121     }
122   }
123 
124   return modified;
125 }
126 
FindIntTypesAndConstants()127 void StrengthReductionPass::FindIntTypesAndConstants() {
128   analysis::Integer int32(32, true);
129   int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
130   analysis::Integer uint32(32, false);
131   uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
132   for (auto iter = get_module()->types_values_begin();
133        iter != get_module()->types_values_end(); ++iter) {
134     switch (iter->opcode()) {
135       case spv::Op::OpConstant:
136         if (iter->type_id() == uint32_type_id_) {
137           uint32_t value = iter->GetSingleWordOperand(2);
138           if (value <= 32) constant_ids_[value] = iter->result_id();
139         }
140         break;
141       default:
142         break;
143     }
144   }
145 }
146 
GetConstantId(uint32_t val)147 uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
148   assert(val <= 32 &&
149          "This function does not handle constants larger than 32.");
150 
151   if (constant_ids_[val] == 0) {
152     if (uint32_type_id_ == 0) {
153       analysis::Integer uint(32, false);
154       uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
155     }
156 
157     // Construct the constant.
158     uint32_t resultId = TakeNextId();
159     Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
160                      {val});
161     std::unique_ptr<Instruction> newConstant(new Instruction(
162         context(), spv::Op::OpConstant, uint32_type_id_, resultId, {constant}));
163     get_module()->AddGlobalValue(std::move(newConstant));
164 
165     // Notify the DefUseManager about this constant.
166     auto constantIter = --get_module()->types_values_end();
167     get_def_use_mgr()->AnalyzeInstDef(&*constantIter);
168 
169     // Store the result id for next time.
170     constant_ids_[val] = resultId;
171   }
172 
173   return constant_ids_[val];
174 }
175 
ScanFunctions()176 bool StrengthReductionPass::ScanFunctions() {
177   // I did not use |ForEachInst| in the module because the function that acts on
178   // the instruction gets a pointer to the instruction.  We cannot use that to
179   // insert a new instruction.  I want an iterator.
180   bool modified = false;
181   for (auto& func : *get_module()) {
182     for (auto& bb : func) {
183       for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
184         switch (inst->opcode()) {
185           case spv::Op::OpIMul:
186             if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
187             break;
188           default:
189             break;
190         }
191       }
192     }
193   }
194   return modified;
195 }
196 
197 }  // namespace opt
198 }  // namespace spvtools
199