• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "fix_storage_class.h"
16 
17 #include <set>
18 
19 #include "source/opt/instruction.h"
20 #include "source/opt/ir_context.h"
21 
22 namespace spvtools {
23 namespace opt {
24 
Process()25 Pass::Status FixStorageClass::Process() {
26   bool modified = false;
27 
28   get_module()->ForEachInst([this, &modified](Instruction* inst) {
29     if (inst->opcode() == spv::Op::OpVariable) {
30       std::set<uint32_t> seen;
31       std::vector<std::pair<Instruction*, uint32_t>> uses;
32       get_def_use_mgr()->ForEachUse(inst,
33                                     [&uses](Instruction* use, uint32_t op_idx) {
34                                       uses.push_back({use, op_idx});
35                                     });
36 
37       for (auto& use : uses) {
38         modified |= PropagateStorageClass(
39             use.first,
40             static_cast<spv::StorageClass>(inst->GetSingleWordInOperand(0)),
41             &seen);
42         assert(seen.empty() && "Seen was not properly reset.");
43         modified |=
44             PropagateType(use.first, inst->type_id(), use.second, &seen);
45         assert(seen.empty() && "Seen was not properly reset.");
46       }
47     }
48   });
49   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
50 }
51 
PropagateStorageClass(Instruction * inst,spv::StorageClass storage_class,std::set<uint32_t> * seen)52 bool FixStorageClass::PropagateStorageClass(Instruction* inst,
53                                             spv::StorageClass storage_class,
54                                             std::set<uint32_t>* seen) {
55   if (!IsPointerResultType(inst)) {
56     return false;
57   }
58 
59   if (IsPointerToStorageClass(inst, storage_class)) {
60     if (inst->opcode() == spv::Op::OpPhi) {
61       if (!seen->insert(inst->result_id()).second) {
62         return false;
63       }
64     }
65 
66     bool modified = false;
67     std::vector<Instruction*> uses;
68     get_def_use_mgr()->ForEachUser(
69         inst, [&uses](Instruction* use) { uses.push_back(use); });
70     for (Instruction* use : uses) {
71       modified |= PropagateStorageClass(use, storage_class, seen);
72     }
73 
74     if (inst->opcode() == spv::Op::OpPhi) {
75       seen->erase(inst->result_id());
76     }
77     return modified;
78   }
79 
80   switch (inst->opcode()) {
81     case spv::Op::OpAccessChain:
82     case spv::Op::OpPtrAccessChain:
83     case spv::Op::OpInBoundsAccessChain:
84     case spv::Op::OpCopyObject:
85     case spv::Op::OpPhi:
86     case spv::Op::OpSelect:
87       FixInstructionStorageClass(inst, storage_class, seen);
88       return true;
89     case spv::Op::OpFunctionCall:
90       // We cannot be sure of the actual connection between the storage class
91       // of the parameter and the storage class of the result, so we should not
92       // do anything.  If the result type needs to be fixed, the function call
93       // should be inlined.
94       return false;
95     case spv::Op::OpImageTexelPointer:
96     case spv::Op::OpLoad:
97     case spv::Op::OpStore:
98     case spv::Op::OpCopyMemory:
99     case spv::Op::OpCopyMemorySized:
100     case spv::Op::OpVariable:
101     case spv::Op::OpBitcast:
102       // Nothing to change for these opcode.  The result type is the same
103       // regardless of the storage class of the operand.
104       return false;
105     default:
106       assert(false &&
107              "Not expecting instruction to have a pointer result type.");
108       return false;
109   }
110 }
111 
FixInstructionStorageClass(Instruction * inst,spv::StorageClass storage_class,std::set<uint32_t> * seen)112 void FixStorageClass::FixInstructionStorageClass(
113     Instruction* inst, spv::StorageClass storage_class,
114     std::set<uint32_t>* seen) {
115   assert(IsPointerResultType(inst) &&
116          "The result type of the instruction must be a pointer.");
117 
118   ChangeResultStorageClass(inst, storage_class);
119 
120   std::vector<Instruction*> uses;
121   get_def_use_mgr()->ForEachUser(
122       inst, [&uses](Instruction* use) { uses.push_back(use); });
123   for (Instruction* use : uses) {
124     PropagateStorageClass(use, storage_class, seen);
125   }
126 }
127 
ChangeResultStorageClass(Instruction * inst,spv::StorageClass storage_class) const128 void FixStorageClass::ChangeResultStorageClass(
129     Instruction* inst, spv::StorageClass storage_class) const {
130   analysis::TypeManager* type_mgr = context()->get_type_mgr();
131   Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
132   assert(result_type_inst->opcode() == spv::Op::OpTypePointer);
133   uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
134   uint32_t new_result_type_id =
135       type_mgr->FindPointerToType(pointee_type_id, storage_class);
136   inst->SetResultType(new_result_type_id);
137   context()->UpdateDefUse(inst);
138 }
139 
IsPointerResultType(Instruction * inst)140 bool FixStorageClass::IsPointerResultType(Instruction* inst) {
141   if (inst->type_id() == 0) {
142     return false;
143   }
144   const analysis::Type* ret_type =
145       context()->get_type_mgr()->GetType(inst->type_id());
146   return ret_type->AsPointer() != nullptr;
147 }
148 
IsPointerToStorageClass(Instruction * inst,spv::StorageClass storage_class)149 bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
150                                               spv::StorageClass storage_class) {
151   analysis::TypeManager* type_mgr = context()->get_type_mgr();
152   analysis::Type* pType = type_mgr->GetType(inst->type_id());
153   const analysis::Pointer* result_type = pType->AsPointer();
154 
155   if (result_type == nullptr) {
156     return false;
157   }
158 
159   return (result_type->storage_class() == storage_class);
160 }
161 
ChangeResultType(Instruction * inst,uint32_t new_type_id)162 bool FixStorageClass::ChangeResultType(Instruction* inst,
163                                        uint32_t new_type_id) {
164   if (inst->type_id() == new_type_id) {
165     return false;
166   }
167 
168   context()->ForgetUses(inst);
169   inst->SetResultType(new_type_id);
170   context()->AnalyzeUses(inst);
171   return true;
172 }
173 
PropagateType(Instruction * inst,uint32_t type_id,uint32_t op_idx,std::set<uint32_t> * seen)174 bool FixStorageClass::PropagateType(Instruction* inst, uint32_t type_id,
175                                     uint32_t op_idx, std::set<uint32_t>* seen) {
176   assert(type_id != 0 && "Not given a valid type in PropagateType");
177   bool modified = false;
178 
179   // If the type of operand |op_idx| forces the result type of |inst| to a
180   // particular type, then we want find that type.
181   uint32_t new_type_id = 0;
182   switch (inst->opcode()) {
183     case spv::Op::OpAccessChain:
184     case spv::Op::OpPtrAccessChain:
185     case spv::Op::OpInBoundsAccessChain:
186     case spv::Op::OpInBoundsPtrAccessChain:
187       if (op_idx == 2) {
188         new_type_id = WalkAccessChainType(inst, type_id);
189       }
190       break;
191     case spv::Op::OpCopyObject:
192       new_type_id = type_id;
193       break;
194     case spv::Op::OpPhi:
195       if (seen->insert(inst->result_id()).second) {
196         new_type_id = type_id;
197       }
198       break;
199     case spv::Op::OpSelect:
200       if (op_idx > 2) {
201         new_type_id = type_id;
202       }
203       break;
204     case spv::Op::OpFunctionCall:
205       // We cannot be sure of the actual connection between the type
206       // of the parameter and the type of the result, so we should not
207       // do anything.  If the result type needs to be fixed, the function call
208       // should be inlined.
209       return false;
210     case spv::Op::OpLoad: {
211       Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
212       new_type_id = type_inst->GetSingleWordInOperand(1);
213       break;
214     }
215     case spv::Op::OpStore: {
216       uint32_t obj_id = inst->GetSingleWordInOperand(1);
217       Instruction* obj_inst = get_def_use_mgr()->GetDef(obj_id);
218       uint32_t obj_type_id = obj_inst->type_id();
219 
220       uint32_t ptr_id = inst->GetSingleWordInOperand(0);
221       Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id);
222       uint32_t pointee_type_id = GetPointeeTypeId(ptr_inst);
223 
224       if (obj_type_id != pointee_type_id) {
225         if (context()->get_type_mgr()->GetType(obj_type_id)->AsImage() &&
226             context()->get_type_mgr()->GetType(pointee_type_id)->AsImage()) {
227           // When storing an image, allow the type mismatch
228           // and let the later legalization passes eliminate the OpStore.
229           // This is to support assigning an image to a variable,
230           // where the assigned image does not have a pre-defined
231           // image format.
232           return false;
233         }
234 
235         uint32_t copy_id = GenerateCopy(obj_inst, pointee_type_id, inst);
236         inst->SetInOperand(1, {copy_id});
237         context()->UpdateDefUse(inst);
238       }
239     } break;
240     case spv::Op::OpCopyMemory:
241     case spv::Op::OpCopyMemorySized:
242       // TODO: May need to expand the copy as we do with the stores.
243       break;
244     case spv::Op::OpCompositeConstruct:
245     case spv::Op::OpCompositeExtract:
246     case spv::Op::OpCompositeInsert:
247       // TODO: DXC does not seem to generate code that will require changes to
248       // these opcode.  The can be implemented when they come up.
249       break;
250     case spv::Op::OpImageTexelPointer:
251     case spv::Op::OpBitcast:
252       // Nothing to change for these opcode.  The result type is the same
253       // regardless of the type of the operand.
254       return false;
255     default:
256       // I expect the remaining instructions to act on types that are guaranteed
257       // to be unique, so no change will be necessary.
258       break;
259   }
260 
261   // If the operand forces the result type, then make sure the result type
262   // matches, and update the uses of |inst|.  We do not have to check the uses
263   // of |inst| in the result type is not forced because we are only looking for
264   // issue that come from mismatches between function formal and actual
265   // parameters after the function has been inlined.  These parameters are
266   // pointers. Once the type no longer depends on the type of the parameter,
267   // then the types should have be correct.
268   if (new_type_id != 0) {
269     modified = ChangeResultType(inst, new_type_id);
270 
271     std::vector<std::pair<Instruction*, uint32_t>> uses;
272     get_def_use_mgr()->ForEachUse(inst,
273                                   [&uses](Instruction* use, uint32_t idx) {
274                                     uses.push_back({use, idx});
275                                   });
276 
277     for (auto& use : uses) {
278       PropagateType(use.first, new_type_id, use.second, seen);
279     }
280 
281     if (inst->opcode() == spv::Op::OpPhi) {
282       seen->erase(inst->result_id());
283     }
284   }
285   return modified;
286 }
287 
WalkAccessChainType(Instruction * inst,uint32_t id)288 uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
289   uint32_t start_idx = 0;
290   switch (inst->opcode()) {
291     case spv::Op::OpAccessChain:
292     case spv::Op::OpInBoundsAccessChain:
293       start_idx = 1;
294       break;
295     case spv::Op::OpPtrAccessChain:
296     case spv::Op::OpInBoundsPtrAccessChain:
297       start_idx = 2;
298       break;
299     default:
300       assert(false);
301       break;
302   }
303 
304   Instruction* orig_type_inst = get_def_use_mgr()->GetDef(id);
305   assert(orig_type_inst->opcode() == spv::Op::OpTypePointer);
306   id = orig_type_inst->GetSingleWordInOperand(1);
307 
308   for (uint32_t i = start_idx; i < inst->NumInOperands(); ++i) {
309     Instruction* type_inst = get_def_use_mgr()->GetDef(id);
310     switch (type_inst->opcode()) {
311       case spv::Op::OpTypeArray:
312       case spv::Op::OpTypeRuntimeArray:
313       case spv::Op::OpTypeMatrix:
314       case spv::Op::OpTypeVector:
315         id = type_inst->GetSingleWordInOperand(0);
316         break;
317       case spv::Op::OpTypeStruct: {
318         const analysis::Constant* index_const =
319             context()->get_constant_mgr()->FindDeclaredConstant(
320                 inst->GetSingleWordInOperand(i));
321         // It is highly unlikely that any type would have more fields than could
322         // be indexed by a 32-bit integer, and GetSingleWordInOperand only takes
323         // a 32-bit value, so we would not be able to handle it anyway. But the
324         // specification does allow any scalar integer type, treated as signed,
325         // so we simply downcast the index to 32-bits.
326         uint32_t index =
327             static_cast<uint32_t>(index_const->GetSignExtendedValue());
328         id = type_inst->GetSingleWordInOperand(index);
329         break;
330       }
331       default:
332         break;
333     }
334     assert(id != 0 &&
335            "Tried to extract from an object where it cannot be done.");
336   }
337 
338   return context()->get_type_mgr()->FindPointerToType(
339       id, static_cast<spv::StorageClass>(
340               orig_type_inst->GetSingleWordInOperand(0)));
341 }
342 
343 // namespace opt
344 
345 }  // namespace opt
346 }  // namespace spvtools
347