1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/private_to_local_pass.h"
16
17 #include <memory>
18 #include <utility>
19 #include <vector>
20
21 #include "source/opt/ir_context.h"
22 #include "source/spirv_constant.h"
23
24 namespace spvtools {
25 namespace opt {
26 namespace {
27
28 const uint32_t kVariableStorageClassInIdx = 0;
29 const uint32_t kSpvTypePointerTypeIdInIdx = 1;
30
31 } // namespace
32
Process()33 Pass::Status PrivateToLocalPass::Process() {
34 bool modified = false;
35
36 // Private variables require the shader capability. If this is not a shader,
37 // there is no work to do.
38 if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses))
39 return Status::SuccessWithoutChange;
40
41 std::vector<std::pair<Instruction*, Function*>> variables_to_move;
42 std::unordered_set<uint32_t> localized_variables;
43 for (auto& inst : context()->types_values()) {
44 if (inst.opcode() != SpvOpVariable) {
45 continue;
46 }
47
48 if (inst.GetSingleWordInOperand(kVariableStorageClassInIdx) !=
49 SpvStorageClassPrivate) {
50 continue;
51 }
52
53 Function* target_function = FindLocalFunction(inst);
54 if (target_function != nullptr) {
55 variables_to_move.push_back({&inst, target_function});
56 }
57 }
58
59 modified = !variables_to_move.empty();
60 for (auto p : variables_to_move) {
61 if (!MoveVariable(p.first, p.second)) {
62 return Status::Failure;
63 }
64 localized_variables.insert(p.first->result_id());
65 }
66
67 if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
68 // In SPIR-V 1.4 and later entry points must list private storage class
69 // variables that are statically used by the entry point. Go through the
70 // entry points and remove any references to variables that were localized.
71 for (auto& entry : get_module()->entry_points()) {
72 std::vector<Operand> new_operands;
73 for (uint32_t i = 0; i < entry.NumInOperands(); ++i) {
74 // Execution model, function id and name are always kept.
75 if (i < 3 ||
76 !localized_variables.count(entry.GetSingleWordInOperand(i))) {
77 new_operands.push_back(entry.GetInOperand(i));
78 }
79 }
80 if (new_operands.size() != entry.NumInOperands()) {
81 entry.SetInOperands(std::move(new_operands));
82 context()->AnalyzeUses(&entry);
83 }
84 }
85 }
86
87 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
88 }
89
FindLocalFunction(const Instruction & inst) const90 Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const {
91 bool found_first_use = false;
92 Function* target_function = nullptr;
93 context()->get_def_use_mgr()->ForEachUser(
94 inst.result_id(),
95 [&target_function, &found_first_use, this](Instruction* use) {
96 BasicBlock* current_block = context()->get_instr_block(use);
97 if (current_block == nullptr) {
98 return;
99 }
100
101 if (!IsValidUse(use)) {
102 found_first_use = true;
103 target_function = nullptr;
104 return;
105 }
106 Function* current_function = current_block->GetParent();
107 if (!found_first_use) {
108 found_first_use = true;
109 target_function = current_function;
110 } else if (target_function != current_function) {
111 target_function = nullptr;
112 }
113 });
114 return target_function;
115 } // namespace opt
116
MoveVariable(Instruction * variable,Function * function)117 bool PrivateToLocalPass::MoveVariable(Instruction* variable,
118 Function* function) {
119 // The variable needs to be removed from the global section, and placed in the
120 // header of the function. First step remove from the global list.
121 variable->RemoveFromList();
122 std::unique_ptr<Instruction> var(variable); // Take ownership.
123 context()->ForgetUses(variable);
124
125 // Update the storage class of the variable.
126 variable->SetInOperand(kVariableStorageClassInIdx, {SpvStorageClassFunction});
127
128 // Update the type as well.
129 uint32_t new_type_id = GetNewType(variable->type_id());
130 if (new_type_id == 0) {
131 return false;
132 }
133 variable->SetResultType(new_type_id);
134
135 // Place the variable at the start of the first basic block.
136 context()->AnalyzeUses(variable);
137 context()->set_instr_block(variable, &*function->begin());
138 function->begin()->begin()->InsertBefore(std::move(var));
139
140 // Update uses where the type may have changed.
141 return UpdateUses(variable);
142 }
143
GetNewType(uint32_t old_type_id)144 uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) {
145 auto type_mgr = context()->get_type_mgr();
146 Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id);
147 uint32_t pointee_type_id =
148 old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
149 uint32_t new_type_id =
150 type_mgr->FindPointerToType(pointee_type_id, SpvStorageClassFunction);
151 if (new_type_id != 0) {
152 context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id));
153 }
154 return new_type_id;
155 }
156
IsValidUse(const Instruction * inst) const157 bool PrivateToLocalPass::IsValidUse(const Instruction* inst) const {
158 // The cases in this switch have to match the cases in |UpdateUse|.
159 // If we don't know how to update it, it is not valid.
160 if (inst->GetCommonDebugOpcode() == CommonDebugInfoDebugGlobalVariable) {
161 return true;
162 }
163 switch (inst->opcode()) {
164 case SpvOpLoad:
165 case SpvOpStore:
166 case SpvOpImageTexelPointer: // Treat like a load
167 return true;
168 case SpvOpAccessChain:
169 return context()->get_def_use_mgr()->WhileEachUser(
170 inst, [this](const Instruction* user) {
171 if (!IsValidUse(user)) return false;
172 return true;
173 });
174 case SpvOpName:
175 return true;
176 default:
177 return spvOpcodeIsDecoration(inst->opcode());
178 }
179 }
180
UpdateUse(Instruction * inst,Instruction * user)181 bool PrivateToLocalPass::UpdateUse(Instruction* inst, Instruction* user) {
182 // The cases in this switch have to match the cases in |IsValidUse|. If we
183 // don't think it is valid, the optimization will not view the variable as a
184 // candidate, and therefore the use will not be updated.
185 if (inst->GetCommonDebugOpcode() == CommonDebugInfoDebugGlobalVariable) {
186 context()->get_debug_info_mgr()->ConvertDebugGlobalToLocalVariable(inst,
187 user);
188 return true;
189 }
190 switch (inst->opcode()) {
191 case SpvOpLoad:
192 case SpvOpStore:
193 case SpvOpImageTexelPointer: // Treat like a load
194 // The type is fine because it is the type pointed to, and that does not
195 // change.
196 break;
197 case SpvOpAccessChain: {
198 context()->ForgetUses(inst);
199 uint32_t new_type_id = GetNewType(inst->type_id());
200 if (new_type_id == 0) {
201 return false;
202 }
203 inst->SetResultType(new_type_id);
204 context()->AnalyzeUses(inst);
205
206 // Update uses where the type may have changed.
207 if (!UpdateUses(inst)) {
208 return false;
209 }
210 } break;
211 case SpvOpName:
212 case SpvOpEntryPoint: // entry points will be updated separately.
213 break;
214 default:
215 assert(spvOpcodeIsDecoration(inst->opcode()) &&
216 "Do not know how to update the type for this instruction.");
217 break;
218 }
219 return true;
220 }
221
UpdateUses(Instruction * inst)222 bool PrivateToLocalPass::UpdateUses(Instruction* inst) {
223 uint32_t id = inst->result_id();
224 std::vector<Instruction*> uses;
225 context()->get_def_use_mgr()->ForEachUser(
226 id, [&uses](Instruction* use) { uses.push_back(use); });
227
228 for (Instruction* use : uses) {
229 if (!UpdateUse(use, inst)) {
230 return false;
231 }
232 }
233 return true;
234 }
235
236 } // namespace opt
237 } // namespace spvtools
238