1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/private_to_local_pass.h"
16
17 #include <memory>
18 #include <utility>
19 #include <vector>
20
21 #include "source/opt/ir_context.h"
22 #include "source/spirv_constant.h"
23
24 namespace spvtools {
25 namespace opt {
26 namespace {
27 constexpr uint32_t kVariableStorageClassInIdx = 0;
28 constexpr uint32_t kSpvTypePointerTypeIdInIdx = 1;
29 } // namespace
30
Process()31 Pass::Status PrivateToLocalPass::Process() {
32 bool modified = false;
33
34 // Private variables require the shader capability. If this is not a shader,
35 // there is no work to do.
36 if (context()->get_feature_mgr()->HasCapability(spv::Capability::Addresses))
37 return Status::SuccessWithoutChange;
38
39 std::vector<std::pair<Instruction*, Function*>> variables_to_move;
40 std::unordered_set<uint32_t> localized_variables;
41 for (auto& inst : context()->types_values()) {
42 if (inst.opcode() != spv::Op::OpVariable) {
43 continue;
44 }
45
46 if (spv::StorageClass(inst.GetSingleWordInOperand(
47 kVariableStorageClassInIdx)) != spv::StorageClass::Private) {
48 continue;
49 }
50
51 Function* target_function = FindLocalFunction(inst);
52 if (target_function != nullptr) {
53 variables_to_move.push_back({&inst, target_function});
54 }
55 }
56
57 modified = !variables_to_move.empty();
58 for (auto p : variables_to_move) {
59 if (!MoveVariable(p.first, p.second)) {
60 return Status::Failure;
61 }
62 localized_variables.insert(p.first->result_id());
63 }
64
65 if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
66 // In SPIR-V 1.4 and later entry points must list private storage class
67 // variables that are statically used by the entry point. Go through the
68 // entry points and remove any references to variables that were localized.
69 for (auto& entry : get_module()->entry_points()) {
70 std::vector<Operand> new_operands;
71 for (uint32_t i = 0; i < entry.NumInOperands(); ++i) {
72 // Execution model, function id and name are always kept.
73 if (i < 3 ||
74 !localized_variables.count(entry.GetSingleWordInOperand(i))) {
75 new_operands.push_back(entry.GetInOperand(i));
76 }
77 }
78 if (new_operands.size() != entry.NumInOperands()) {
79 entry.SetInOperands(std::move(new_operands));
80 context()->AnalyzeUses(&entry);
81 }
82 }
83 }
84
85 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
86 }
87
FindLocalFunction(const Instruction & inst) const88 Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const {
89 bool found_first_use = false;
90 Function* target_function = nullptr;
91 context()->get_def_use_mgr()->ForEachUser(
92 inst.result_id(),
93 [&target_function, &found_first_use, this](Instruction* use) {
94 BasicBlock* current_block = context()->get_instr_block(use);
95 if (current_block == nullptr) {
96 return;
97 }
98
99 if (!IsValidUse(use)) {
100 found_first_use = true;
101 target_function = nullptr;
102 return;
103 }
104 Function* current_function = current_block->GetParent();
105 if (!found_first_use) {
106 found_first_use = true;
107 target_function = current_function;
108 } else if (target_function != current_function) {
109 target_function = nullptr;
110 }
111 });
112 return target_function;
113 } // namespace opt
114
MoveVariable(Instruction * variable,Function * function)115 bool PrivateToLocalPass::MoveVariable(Instruction* variable,
116 Function* function) {
117 // The variable needs to be removed from the global section, and placed in the
118 // header of the function. First step remove from the global list.
119 variable->RemoveFromList();
120 std::unique_ptr<Instruction> var(variable); // Take ownership.
121 context()->ForgetUses(variable);
122
123 // Update the storage class of the variable.
124 variable->SetInOperand(kVariableStorageClassInIdx,
125 {uint32_t(spv::StorageClass::Function)});
126
127 // Update the type as well.
128 uint32_t new_type_id = GetNewType(variable->type_id());
129 if (new_type_id == 0) {
130 return false;
131 }
132 variable->SetResultType(new_type_id);
133
134 // Place the variable at the start of the first basic block.
135 context()->AnalyzeUses(variable);
136 context()->set_instr_block(variable, &*function->begin());
137 function->begin()->begin()->InsertBefore(std::move(var));
138
139 // Update uses where the type may have changed.
140 return UpdateUses(variable);
141 }
142
GetNewType(uint32_t old_type_id)143 uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) {
144 auto type_mgr = context()->get_type_mgr();
145 Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id);
146 uint32_t pointee_type_id =
147 old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
148 uint32_t new_type_id =
149 type_mgr->FindPointerToType(pointee_type_id, spv::StorageClass::Function);
150 if (new_type_id != 0) {
151 context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id));
152 }
153 return new_type_id;
154 }
155
IsValidUse(const Instruction * inst) const156 bool PrivateToLocalPass::IsValidUse(const Instruction* inst) const {
157 // The cases in this switch have to match the cases in |UpdateUse|.
158 // If we don't know how to update it, it is not valid.
159 if (inst->GetCommonDebugOpcode() == CommonDebugInfoDebugGlobalVariable) {
160 return true;
161 }
162 switch (inst->opcode()) {
163 case spv::Op::OpLoad:
164 case spv::Op::OpStore:
165 case spv::Op::OpImageTexelPointer: // Treat like a load
166 return true;
167 case spv::Op::OpAccessChain:
168 return context()->get_def_use_mgr()->WhileEachUser(
169 inst, [this](const Instruction* user) {
170 if (!IsValidUse(user)) return false;
171 return true;
172 });
173 case spv::Op::OpName:
174 return true;
175 default:
176 return spvOpcodeIsDecoration(inst->opcode());
177 }
178 }
179
UpdateUse(Instruction * inst,Instruction * user)180 bool PrivateToLocalPass::UpdateUse(Instruction* inst, Instruction* user) {
181 // The cases in this switch have to match the cases in |IsValidUse|. If we
182 // don't think it is valid, the optimization will not view the variable as a
183 // candidate, and therefore the use will not be updated.
184 if (inst->GetCommonDebugOpcode() == CommonDebugInfoDebugGlobalVariable) {
185 context()->get_debug_info_mgr()->ConvertDebugGlobalToLocalVariable(inst,
186 user);
187 return true;
188 }
189 switch (inst->opcode()) {
190 case spv::Op::OpLoad:
191 case spv::Op::OpStore:
192 case spv::Op::OpImageTexelPointer: // Treat like a load
193 // The type is fine because it is the type pointed to, and that does not
194 // change.
195 break;
196 case spv::Op::OpAccessChain: {
197 context()->ForgetUses(inst);
198 uint32_t new_type_id = GetNewType(inst->type_id());
199 if (new_type_id == 0) {
200 return false;
201 }
202 inst->SetResultType(new_type_id);
203 context()->AnalyzeUses(inst);
204
205 // Update uses where the type may have changed.
206 if (!UpdateUses(inst)) {
207 return false;
208 }
209 } break;
210 case spv::Op::OpName:
211 case spv::Op::OpEntryPoint: // entry points will be updated separately.
212 break;
213 default:
214 assert(spvOpcodeIsDecoration(inst->opcode()) &&
215 "Do not know how to update the type for this instruction.");
216 break;
217 }
218 return true;
219 }
220
UpdateUses(Instruction * inst)221 bool PrivateToLocalPass::UpdateUses(Instruction* inst) {
222 uint32_t id = inst->result_id();
223 std::vector<Instruction*> uses;
224 context()->get_def_use_mgr()->ForEachUser(
225 id, [&uses](Instruction* use) { uses.push_back(use); });
226
227 for (Instruction* use : uses) {
228 if (!UpdateUse(use, inst)) {
229 return false;
230 }
231 }
232 return true;
233 }
234
235 } // namespace opt
236 } // namespace spvtools
237