• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/desc_sroa.h"
16 
17 #include "source/util/string_utils.h"
18 
19 namespace spvtools {
20 namespace opt {
21 
Process()22 Pass::Status DescriptorScalarReplacement::Process() {
23   bool modified = false;
24 
25   std::vector<Instruction*> vars_to_kill;
26 
27   for (Instruction& var : context()->types_values()) {
28     if (IsCandidate(&var)) {
29       modified = true;
30       if (!ReplaceCandidate(&var)) {
31         return Status::Failure;
32       }
33       vars_to_kill.push_back(&var);
34     }
35   }
36 
37   for (Instruction* var : vars_to_kill) {
38     context()->KillInst(var);
39   }
40 
41   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
42 }
43 
IsCandidate(Instruction * var)44 bool DescriptorScalarReplacement::IsCandidate(Instruction* var) {
45   if (var->opcode() != SpvOpVariable) {
46     return false;
47   }
48 
49   uint32_t ptr_type_id = var->type_id();
50   Instruction* ptr_type_inst =
51       context()->get_def_use_mgr()->GetDef(ptr_type_id);
52   if (ptr_type_inst->opcode() != SpvOpTypePointer) {
53     return false;
54   }
55 
56   uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
57   Instruction* var_type_inst =
58       context()->get_def_use_mgr()->GetDef(var_type_id);
59   if (var_type_inst->opcode() != SpvOpTypeArray &&
60       var_type_inst->opcode() != SpvOpTypeStruct) {
61     return false;
62   }
63 
64   // All structures with descriptor assignments must be replaced by variables,
65   // one for each of their members - with the exceptions of buffers.
66   if (IsTypeOfStructuredBuffer(var_type_inst)) {
67     return false;
68   }
69 
70   bool has_desc_set_decoration = false;
71   context()->get_decoration_mgr()->ForEachDecoration(
72       var->result_id(), SpvDecorationDescriptorSet,
73       [&has_desc_set_decoration](const Instruction&) {
74         has_desc_set_decoration = true;
75       });
76   if (!has_desc_set_decoration) {
77     return false;
78   }
79 
80   bool has_binding_decoration = false;
81   context()->get_decoration_mgr()->ForEachDecoration(
82       var->result_id(), SpvDecorationBinding,
83       [&has_binding_decoration](const Instruction&) {
84         has_binding_decoration = true;
85       });
86   if (!has_binding_decoration) {
87     return false;
88   }
89 
90   return true;
91 }
92 
IsTypeOfStructuredBuffer(const Instruction * type) const93 bool DescriptorScalarReplacement::IsTypeOfStructuredBuffer(
94     const Instruction* type) const {
95   if (type->opcode() != SpvOpTypeStruct) {
96     return false;
97   }
98 
99   // All buffers have offset decorations for members of their structure types.
100   // This is how we distinguish it from a structure of descriptors.
101   bool has_offset_decoration = false;
102   context()->get_decoration_mgr()->ForEachDecoration(
103       type->result_id(), SpvDecorationOffset,
104       [&has_offset_decoration](const Instruction&) {
105         has_offset_decoration = true;
106       });
107   return has_offset_decoration;
108 }
109 
ReplaceCandidate(Instruction * var)110 bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
111   std::vector<Instruction*> access_chain_work_list;
112   std::vector<Instruction*> load_work_list;
113   bool failed = !get_def_use_mgr()->WhileEachUser(
114       var->result_id(),
115       [this, &access_chain_work_list, &load_work_list](Instruction* use) {
116         if (use->opcode() == SpvOpName) {
117           return true;
118         }
119 
120         if (use->IsDecoration()) {
121           return true;
122         }
123 
124         switch (use->opcode()) {
125           case SpvOpAccessChain:
126           case SpvOpInBoundsAccessChain:
127             access_chain_work_list.push_back(use);
128             return true;
129           case SpvOpLoad:
130             load_work_list.push_back(use);
131             return true;
132           default:
133             context()->EmitErrorMessage(
134                 "Variable cannot be replaced: invalid instruction", use);
135             return false;
136         }
137         return true;
138       });
139 
140   if (failed) {
141     return false;
142   }
143 
144   for (Instruction* use : access_chain_work_list) {
145     if (!ReplaceAccessChain(var, use)) {
146       return false;
147     }
148   }
149   for (Instruction* use : load_work_list) {
150     if (!ReplaceLoadedValue(var, use)) {
151       return false;
152     }
153   }
154   return true;
155 }
156 
ReplaceAccessChain(Instruction * var,Instruction * use)157 bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
158                                                      Instruction* use) {
159   if (use->NumInOperands() <= 1) {
160     context()->EmitErrorMessage(
161         "Variable cannot be replaced: invalid instruction", use);
162     return false;
163   }
164 
165   uint32_t idx_id = use->GetSingleWordInOperand(1);
166   const analysis::Constant* idx_const =
167       context()->get_constant_mgr()->FindDeclaredConstant(idx_id);
168   if (idx_const == nullptr) {
169     context()->EmitErrorMessage("Variable cannot be replaced: invalid index",
170                                 use);
171     return false;
172   }
173 
174   uint32_t idx = idx_const->GetU32();
175   uint32_t replacement_var = GetReplacementVariable(var, idx);
176 
177   if (use->NumInOperands() == 2) {
178     // We are not indexing into the replacement variable.  We can replaces the
179     // access chain with the replacement varibale itself.
180     context()->ReplaceAllUsesWith(use->result_id(), replacement_var);
181     context()->KillInst(use);
182     return true;
183   }
184 
185   // We need to build a new access chain with the replacement variable as the
186   // base address.
187   Instruction::OperandList new_operands;
188 
189   // Same result id and result type.
190   new_operands.emplace_back(use->GetOperand(0));
191   new_operands.emplace_back(use->GetOperand(1));
192 
193   // Use the replacement variable as the base address.
194   new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var}});
195 
196   // Drop the first index because it is consumed by the replacment, and copy the
197   // rest.
198   for (uint32_t i = 4; i < use->NumOperands(); i++) {
199     new_operands.emplace_back(use->GetOperand(i));
200   }
201 
202   use->ReplaceOperands(new_operands);
203   context()->UpdateDefUse(use);
204   return true;
205 }
206 
GetReplacementVariable(Instruction * var,uint32_t idx)207 uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var,
208                                                              uint32_t idx) {
209   auto replacement_vars = replacement_variables_.find(var);
210   if (replacement_vars == replacement_variables_.end()) {
211     uint32_t ptr_type_id = var->type_id();
212     Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
213     assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
214            "Variable should be a pointer to an array or structure.");
215     uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
216     Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(pointee_type_id);
217     const bool is_array = pointee_type_inst->opcode() == SpvOpTypeArray;
218     const bool is_struct = pointee_type_inst->opcode() == SpvOpTypeStruct;
219     assert((is_array || is_struct) &&
220            "Variable should be a pointer to an array or structure.");
221 
222     // For arrays, each array element should be replaced with a new replacement
223     // variable
224     if (is_array) {
225       uint32_t array_len_id = pointee_type_inst->GetSingleWordInOperand(1);
226       const analysis::Constant* array_len_const =
227           context()->get_constant_mgr()->FindDeclaredConstant(array_len_id);
228       assert(array_len_const != nullptr && "Array length must be a constant.");
229       uint32_t array_len = array_len_const->GetU32();
230 
231       replacement_vars = replacement_variables_
232                              .insert({var, std::vector<uint32_t>(array_len, 0)})
233                              .first;
234     }
235     // For structures, each member should be replaced with a new replacement
236     // variable
237     if (is_struct) {
238       const uint32_t num_members = pointee_type_inst->NumInOperands();
239       replacement_vars =
240           replacement_variables_
241               .insert({var, std::vector<uint32_t>(num_members, 0)})
242               .first;
243     }
244   }
245 
246   if (replacement_vars->second[idx] == 0) {
247     replacement_vars->second[idx] = CreateReplacementVariable(var, idx);
248   }
249 
250   return replacement_vars->second[idx];
251 }
252 
CreateReplacementVariable(Instruction * var,uint32_t idx)253 uint32_t DescriptorScalarReplacement::CreateReplacementVariable(
254     Instruction* var, uint32_t idx) {
255   // The storage class for the new variable is the same as the original.
256   SpvStorageClass storage_class =
257       static_cast<SpvStorageClass>(var->GetSingleWordInOperand(0));
258 
259   // The type for the new variable will be a pointer to type of the elements of
260   // the array.
261   uint32_t ptr_type_id = var->type_id();
262   Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
263   assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
264          "Variable should be a pointer to an array or structure.");
265   uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
266   Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(pointee_type_id);
267   const bool is_array = pointee_type_inst->opcode() == SpvOpTypeArray;
268   const bool is_struct = pointee_type_inst->opcode() == SpvOpTypeStruct;
269   assert((is_array || is_struct) &&
270          "Variable should be a pointer to an array or structure.");
271 
272   uint32_t element_type_id =
273       is_array ? pointee_type_inst->GetSingleWordInOperand(0)
274                : pointee_type_inst->GetSingleWordInOperand(idx);
275 
276   uint32_t ptr_element_type_id = context()->get_type_mgr()->FindPointerToType(
277       element_type_id, storage_class);
278 
279   // Create the variable.
280   uint32_t id = TakeNextId();
281   std::unique_ptr<Instruction> variable(
282       new Instruction(context(), SpvOpVariable, ptr_element_type_id, id,
283                       std::initializer_list<Operand>{
284                           {SPV_OPERAND_TYPE_STORAGE_CLASS,
285                            {static_cast<uint32_t>(storage_class)}}}));
286   context()->AddGlobalValue(std::move(variable));
287 
288   // Copy all of the decorations to the new variable.  The only difference is
289   // the Binding decoration needs to be adjusted.
290   for (auto old_decoration :
291        get_decoration_mgr()->GetDecorationsFor(var->result_id(), true)) {
292     assert(old_decoration->opcode() == SpvOpDecorate);
293     std::unique_ptr<Instruction> new_decoration(
294         old_decoration->Clone(context()));
295     new_decoration->SetInOperand(0, {id});
296 
297     uint32_t decoration = new_decoration->GetSingleWordInOperand(1u);
298     if (decoration == SpvDecorationBinding) {
299       uint32_t new_binding = new_decoration->GetSingleWordInOperand(2);
300       if (is_array) {
301         new_binding += idx * GetNumBindingsUsedByType(ptr_element_type_id);
302       }
303       if (is_struct) {
304         // The binding offset that should be added is the sum of binding numbers
305         // used by previous members of the current struct.
306         for (uint32_t i = 0; i < idx; ++i) {
307           new_binding += GetNumBindingsUsedByType(
308               pointee_type_inst->GetSingleWordInOperand(i));
309         }
310       }
311       new_decoration->SetInOperand(2, {new_binding});
312     }
313     context()->AddAnnotationInst(std::move(new_decoration));
314   }
315 
316   // Create a new OpName for the replacement variable.
317   std::vector<std::unique_ptr<Instruction>> names_to_add;
318   for (auto p : context()->GetNames(var->result_id())) {
319     Instruction* name_inst = p.second;
320     std::string name_str = utils::MakeString(name_inst->GetOperand(1).words);
321     if (is_array) {
322       name_str += "[" + utils::ToString(idx) + "]";
323     }
324     if (is_struct) {
325       Instruction* member_name_inst =
326           context()->GetMemberName(pointee_type_inst->result_id(), idx);
327       name_str += ".";
328       if (member_name_inst)
329         name_str += utils::MakeString(member_name_inst->GetOperand(2).words);
330       else
331         // In case the member does not have a name assigned to it, use the
332         // member index.
333         name_str += utils::ToString(idx);
334     }
335 
336     std::unique_ptr<Instruction> new_name(new Instruction(
337         context(), SpvOpName, 0, 0,
338         std::initializer_list<Operand>{
339             {SPV_OPERAND_TYPE_ID, {id}},
340             {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}}));
341     Instruction* new_name_inst = new_name.get();
342     get_def_use_mgr()->AnalyzeInstDefUse(new_name_inst);
343     names_to_add.push_back(std::move(new_name));
344   }
345 
346   // We shouldn't add the new names when we are iterating over name ranges
347   // above. We can add all the new names now.
348   for (auto& new_name : names_to_add)
349     context()->AddDebug2Inst(std::move(new_name));
350 
351   return id;
352 }
353 
GetNumBindingsUsedByType(uint32_t type_id)354 uint32_t DescriptorScalarReplacement::GetNumBindingsUsedByType(
355     uint32_t type_id) {
356   Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
357 
358   // If it's a pointer, look at the underlying type.
359   if (type_inst->opcode() == SpvOpTypePointer) {
360     type_id = type_inst->GetSingleWordInOperand(1);
361     type_inst = get_def_use_mgr()->GetDef(type_id);
362   }
363 
364   // Arrays consume N*M binding numbers where N is the array length, and M is
365   // the number of bindings used by each array element.
366   if (type_inst->opcode() == SpvOpTypeArray) {
367     uint32_t element_type_id = type_inst->GetSingleWordInOperand(0);
368     uint32_t length_id = type_inst->GetSingleWordInOperand(1);
369     const analysis::Constant* length_const =
370         context()->get_constant_mgr()->FindDeclaredConstant(length_id);
371     // OpTypeArray's length must always be a constant
372     assert(length_const != nullptr);
373     uint32_t num_elems = length_const->GetU32();
374     return num_elems * GetNumBindingsUsedByType(element_type_id);
375   }
376 
377   // The number of bindings consumed by a structure is the sum of the bindings
378   // used by its members.
379   if (type_inst->opcode() == SpvOpTypeStruct &&
380       !IsTypeOfStructuredBuffer(type_inst)) {
381     uint32_t sum = 0;
382     for (uint32_t i = 0; i < type_inst->NumInOperands(); i++)
383       sum += GetNumBindingsUsedByType(type_inst->GetSingleWordInOperand(i));
384     return sum;
385   }
386 
387   // All other types are considered to take up 1 binding number.
388   return 1;
389 }
390 
ReplaceLoadedValue(Instruction * var,Instruction * value)391 bool DescriptorScalarReplacement::ReplaceLoadedValue(Instruction* var,
392                                                      Instruction* value) {
393   // |var| is the global variable that has to be eliminated (OpVariable).
394   // |value| is the OpLoad instruction that has loaded |var|.
395   // The function expects all users of |value| to be OpCompositeExtract
396   // instructions. Otherwise the function returns false with an error message.
397   assert(value->opcode() == SpvOpLoad);
398   assert(value->GetSingleWordInOperand(0) == var->result_id());
399   std::vector<Instruction*> work_list;
400   bool failed = !get_def_use_mgr()->WhileEachUser(
401       value->result_id(), [this, &work_list](Instruction* use) {
402         if (use->opcode() != SpvOpCompositeExtract) {
403           context()->EmitErrorMessage(
404               "Variable cannot be replaced: invalid instruction", use);
405           return false;
406         }
407         work_list.push_back(use);
408         return true;
409       });
410 
411   if (failed) {
412     return false;
413   }
414 
415   for (Instruction* use : work_list) {
416     if (!ReplaceCompositeExtract(var, use)) {
417       return false;
418     }
419   }
420 
421   // All usages of the loaded value have been killed. We can kill the OpLoad.
422   context()->KillInst(value);
423   return true;
424 }
425 
ReplaceCompositeExtract(Instruction * var,Instruction * extract)426 bool DescriptorScalarReplacement::ReplaceCompositeExtract(
427     Instruction* var, Instruction* extract) {
428   assert(extract->opcode() == SpvOpCompositeExtract);
429   // We're currently only supporting extractions of one index at a time. If we
430   // need to, we can handle cases with multiple indexes in the future.
431   if (extract->NumInOperands() != 2) {
432     context()->EmitErrorMessage(
433         "Variable cannot be replaced: invalid instruction", extract);
434     return false;
435   }
436 
437   uint32_t replacement_var =
438       GetReplacementVariable(var, extract->GetSingleWordInOperand(1));
439 
440   // The result type of the OpLoad is the same as the result type of the
441   // OpCompositeExtract.
442   uint32_t load_id = TakeNextId();
443   std::unique_ptr<Instruction> load(
444       new Instruction(context(), SpvOpLoad, extract->type_id(), load_id,
445                       std::initializer_list<Operand>{
446                           {SPV_OPERAND_TYPE_ID, {replacement_var}}}));
447   Instruction* load_instr = load.get();
448   get_def_use_mgr()->AnalyzeInstDefUse(load_instr);
449   context()->set_instr_block(load_instr, context()->get_instr_block(extract));
450   extract->InsertBefore(std::move(load));
451   context()->ReplaceAllUsesWith(extract->result_id(), load_id);
452   context()->KillInst(extract);
453   return true;
454 }
455 
456 }  // namespace opt
457 }  // namespace spvtools
458