1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "fix_storage_class.h"
16
17 #include <set>
18
19 #include "source/opt/instruction.h"
20 #include "source/opt/ir_context.h"
21
22 namespace spvtools {
23 namespace opt {
24
Process()25 Pass::Status FixStorageClass::Process() {
26 bool modified = false;
27
28 get_module()->ForEachInst([this, &modified](Instruction* inst) {
29 if (inst->opcode() == spv::Op::OpVariable) {
30 std::set<uint32_t> seen;
31 std::vector<std::pair<Instruction*, uint32_t>> uses;
32 get_def_use_mgr()->ForEachUse(inst,
33 [&uses](Instruction* use, uint32_t op_idx) {
34 uses.push_back({use, op_idx});
35 });
36
37 for (auto& use : uses) {
38 modified |= PropagateStorageClass(
39 use.first,
40 static_cast<spv::StorageClass>(inst->GetSingleWordInOperand(0)),
41 &seen);
42 assert(seen.empty() && "Seen was not properly reset.");
43 modified |=
44 PropagateType(use.first, inst->type_id(), use.second, &seen);
45 assert(seen.empty() && "Seen was not properly reset.");
46 }
47 }
48 });
49 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
50 }
51
PropagateStorageClass(Instruction * inst,spv::StorageClass storage_class,std::set<uint32_t> * seen)52 bool FixStorageClass::PropagateStorageClass(Instruction* inst,
53 spv::StorageClass storage_class,
54 std::set<uint32_t>* seen) {
55 if (!IsPointerResultType(inst)) {
56 return false;
57 }
58
59 if (IsPointerToStorageClass(inst, storage_class)) {
60 if (inst->opcode() == spv::Op::OpPhi) {
61 if (!seen->insert(inst->result_id()).second) {
62 return false;
63 }
64 }
65
66 bool modified = false;
67 std::vector<Instruction*> uses;
68 get_def_use_mgr()->ForEachUser(
69 inst, [&uses](Instruction* use) { uses.push_back(use); });
70 for (Instruction* use : uses) {
71 modified |= PropagateStorageClass(use, storage_class, seen);
72 }
73
74 if (inst->opcode() == spv::Op::OpPhi) {
75 seen->erase(inst->result_id());
76 }
77 return modified;
78 }
79
80 switch (inst->opcode()) {
81 case spv::Op::OpAccessChain:
82 case spv::Op::OpPtrAccessChain:
83 case spv::Op::OpInBoundsAccessChain:
84 case spv::Op::OpCopyObject:
85 case spv::Op::OpPhi:
86 case spv::Op::OpSelect:
87 FixInstructionStorageClass(inst, storage_class, seen);
88 return true;
89 case spv::Op::OpFunctionCall:
90 // We cannot be sure of the actual connection between the storage class
91 // of the parameter and the storage class of the result, so we should not
92 // do anything. If the result type needs to be fixed, the function call
93 // should be inlined.
94 return false;
95 case spv::Op::OpImageTexelPointer:
96 case spv::Op::OpLoad:
97 case spv::Op::OpStore:
98 case spv::Op::OpCopyMemory:
99 case spv::Op::OpCopyMemorySized:
100 case spv::Op::OpVariable:
101 case spv::Op::OpBitcast:
102 // Nothing to change for these opcode. The result type is the same
103 // regardless of the storage class of the operand.
104 return false;
105 default:
106 assert(false &&
107 "Not expecting instruction to have a pointer result type.");
108 return false;
109 }
110 }
111
FixInstructionStorageClass(Instruction * inst,spv::StorageClass storage_class,std::set<uint32_t> * seen)112 void FixStorageClass::FixInstructionStorageClass(
113 Instruction* inst, spv::StorageClass storage_class,
114 std::set<uint32_t>* seen) {
115 assert(IsPointerResultType(inst) &&
116 "The result type of the instruction must be a pointer.");
117
118 ChangeResultStorageClass(inst, storage_class);
119
120 std::vector<Instruction*> uses;
121 get_def_use_mgr()->ForEachUser(
122 inst, [&uses](Instruction* use) { uses.push_back(use); });
123 for (Instruction* use : uses) {
124 PropagateStorageClass(use, storage_class, seen);
125 }
126 }
127
ChangeResultStorageClass(Instruction * inst,spv::StorageClass storage_class) const128 void FixStorageClass::ChangeResultStorageClass(
129 Instruction* inst, spv::StorageClass storage_class) const {
130 analysis::TypeManager* type_mgr = context()->get_type_mgr();
131 Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
132 assert(result_type_inst->opcode() == spv::Op::OpTypePointer);
133 uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
134 uint32_t new_result_type_id =
135 type_mgr->FindPointerToType(pointee_type_id, storage_class);
136 inst->SetResultType(new_result_type_id);
137 context()->UpdateDefUse(inst);
138 }
139
IsPointerResultType(Instruction * inst)140 bool FixStorageClass::IsPointerResultType(Instruction* inst) {
141 if (inst->type_id() == 0) {
142 return false;
143 }
144 const analysis::Type* ret_type =
145 context()->get_type_mgr()->GetType(inst->type_id());
146 return ret_type->AsPointer() != nullptr;
147 }
148
IsPointerToStorageClass(Instruction * inst,spv::StorageClass storage_class)149 bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
150 spv::StorageClass storage_class) {
151 analysis::TypeManager* type_mgr = context()->get_type_mgr();
152 analysis::Type* pType = type_mgr->GetType(inst->type_id());
153 const analysis::Pointer* result_type = pType->AsPointer();
154
155 if (result_type == nullptr) {
156 return false;
157 }
158
159 return (result_type->storage_class() == storage_class);
160 }
161
ChangeResultType(Instruction * inst,uint32_t new_type_id)162 bool FixStorageClass::ChangeResultType(Instruction* inst,
163 uint32_t new_type_id) {
164 if (inst->type_id() == new_type_id) {
165 return false;
166 }
167
168 context()->ForgetUses(inst);
169 inst->SetResultType(new_type_id);
170 context()->AnalyzeUses(inst);
171 return true;
172 }
173
PropagateType(Instruction * inst,uint32_t type_id,uint32_t op_idx,std::set<uint32_t> * seen)174 bool FixStorageClass::PropagateType(Instruction* inst, uint32_t type_id,
175 uint32_t op_idx, std::set<uint32_t>* seen) {
176 assert(type_id != 0 && "Not given a valid type in PropagateType");
177 bool modified = false;
178
179 // If the type of operand |op_idx| forces the result type of |inst| to a
180 // particular type, then we want find that type.
181 uint32_t new_type_id = 0;
182 switch (inst->opcode()) {
183 case spv::Op::OpAccessChain:
184 case spv::Op::OpPtrAccessChain:
185 case spv::Op::OpInBoundsAccessChain:
186 case spv::Op::OpInBoundsPtrAccessChain:
187 if (op_idx == 2) {
188 new_type_id = WalkAccessChainType(inst, type_id);
189 }
190 break;
191 case spv::Op::OpCopyObject:
192 new_type_id = type_id;
193 break;
194 case spv::Op::OpPhi:
195 if (seen->insert(inst->result_id()).second) {
196 new_type_id = type_id;
197 }
198 break;
199 case spv::Op::OpSelect:
200 if (op_idx > 2) {
201 new_type_id = type_id;
202 }
203 break;
204 case spv::Op::OpFunctionCall:
205 // We cannot be sure of the actual connection between the type
206 // of the parameter and the type of the result, so we should not
207 // do anything. If the result type needs to be fixed, the function call
208 // should be inlined.
209 return false;
210 case spv::Op::OpLoad: {
211 Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
212 new_type_id = type_inst->GetSingleWordInOperand(1);
213 break;
214 }
215 case spv::Op::OpStore: {
216 uint32_t obj_id = inst->GetSingleWordInOperand(1);
217 Instruction* obj_inst = get_def_use_mgr()->GetDef(obj_id);
218 uint32_t obj_type_id = obj_inst->type_id();
219
220 uint32_t ptr_id = inst->GetSingleWordInOperand(0);
221 Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id);
222 uint32_t pointee_type_id = GetPointeeTypeId(ptr_inst);
223
224 if (obj_type_id != pointee_type_id) {
225 if (context()->get_type_mgr()->GetType(obj_type_id)->AsImage() &&
226 context()->get_type_mgr()->GetType(pointee_type_id)->AsImage()) {
227 // When storing an image, allow the type mismatch
228 // and let the later legalization passes eliminate the OpStore.
229 // This is to support assigning an image to a variable,
230 // where the assigned image does not have a pre-defined
231 // image format.
232 return false;
233 }
234
235 uint32_t copy_id = GenerateCopy(obj_inst, pointee_type_id, inst);
236 inst->SetInOperand(1, {copy_id});
237 context()->UpdateDefUse(inst);
238 }
239 } break;
240 case spv::Op::OpCopyMemory:
241 case spv::Op::OpCopyMemorySized:
242 // TODO: May need to expand the copy as we do with the stores.
243 break;
244 case spv::Op::OpCompositeConstruct:
245 case spv::Op::OpCompositeExtract:
246 case spv::Op::OpCompositeInsert:
247 // TODO: DXC does not seem to generate code that will require changes to
248 // these opcode. The can be implemented when they come up.
249 break;
250 case spv::Op::OpImageTexelPointer:
251 case spv::Op::OpBitcast:
252 // Nothing to change for these opcode. The result type is the same
253 // regardless of the type of the operand.
254 return false;
255 default:
256 // I expect the remaining instructions to act on types that are guaranteed
257 // to be unique, so no change will be necessary.
258 break;
259 }
260
261 // If the operand forces the result type, then make sure the result type
262 // matches, and update the uses of |inst|. We do not have to check the uses
263 // of |inst| in the result type is not forced because we are only looking for
264 // issue that come from mismatches between function formal and actual
265 // parameters after the function has been inlined. These parameters are
266 // pointers. Once the type no longer depends on the type of the parameter,
267 // then the types should have be correct.
268 if (new_type_id != 0) {
269 modified = ChangeResultType(inst, new_type_id);
270
271 std::vector<std::pair<Instruction*, uint32_t>> uses;
272 get_def_use_mgr()->ForEachUse(inst,
273 [&uses](Instruction* use, uint32_t idx) {
274 uses.push_back({use, idx});
275 });
276
277 for (auto& use : uses) {
278 PropagateType(use.first, new_type_id, use.second, seen);
279 }
280
281 if (inst->opcode() == spv::Op::OpPhi) {
282 seen->erase(inst->result_id());
283 }
284 }
285 return modified;
286 }
287
WalkAccessChainType(Instruction * inst,uint32_t id)288 uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
289 uint32_t start_idx = 0;
290 switch (inst->opcode()) {
291 case spv::Op::OpAccessChain:
292 case spv::Op::OpInBoundsAccessChain:
293 start_idx = 1;
294 break;
295 case spv::Op::OpPtrAccessChain:
296 case spv::Op::OpInBoundsPtrAccessChain:
297 start_idx = 2;
298 break;
299 default:
300 assert(false);
301 break;
302 }
303
304 Instruction* orig_type_inst = get_def_use_mgr()->GetDef(id);
305 assert(orig_type_inst->opcode() == spv::Op::OpTypePointer);
306 id = orig_type_inst->GetSingleWordInOperand(1);
307
308 for (uint32_t i = start_idx; i < inst->NumInOperands(); ++i) {
309 Instruction* type_inst = get_def_use_mgr()->GetDef(id);
310 switch (type_inst->opcode()) {
311 case spv::Op::OpTypeArray:
312 case spv::Op::OpTypeRuntimeArray:
313 case spv::Op::OpTypeMatrix:
314 case spv::Op::OpTypeVector:
315 id = type_inst->GetSingleWordInOperand(0);
316 break;
317 case spv::Op::OpTypeStruct: {
318 const analysis::Constant* index_const =
319 context()->get_constant_mgr()->FindDeclaredConstant(
320 inst->GetSingleWordInOperand(i));
321 // It is highly unlikely that any type would have more fields than could
322 // be indexed by a 32-bit integer, and GetSingleWordInOperand only takes
323 // a 32-bit value, so we would not be able to handle it anyway. But the
324 // specification does allow any scalar integer type, treated as signed,
325 // so we simply downcast the index to 32-bits.
326 uint32_t index =
327 static_cast<uint32_t>(index_const->GetSignExtendedValue());
328 id = type_inst->GetSingleWordInOperand(index);
329 break;
330 }
331 default:
332 break;
333 }
334 assert(id != 0 &&
335 "Tried to extract from an object where it cannot be done.");
336 }
337
338 return context()->get_type_mgr()->FindPointerToType(
339 id, static_cast<spv::StorageClass>(
340 orig_type_inst->GetSingleWordInOperand(0)));
341 }
342
343 // namespace opt
344
345 } // namespace opt
346 } // namespace spvtools
347