1 // Copyright (c) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/interface_var_sroa.h"
16
17 #include <iostream>
18
19 #include "source/opt/decoration_manager.h"
20 #include "source/opt/def_use_manager.h"
21 #include "source/opt/function.h"
22 #include "source/opt/log.h"
23 #include "source/opt/type_manager.h"
24 #include "source/util/make_unique.h"
25
26 const static uint32_t kOpDecorateDecorationInOperandIndex = 1;
27 const static uint32_t kOpDecorateLiteralInOperandIndex = 2;
28 const static uint32_t kOpEntryPointInOperandInterface = 3;
29 const static uint32_t kOpVariableStorageClassInOperandIndex = 0;
30 const static uint32_t kOpTypeArrayElemTypeInOperandIndex = 0;
31 const static uint32_t kOpTypeArrayLengthInOperandIndex = 1;
32 const static uint32_t kOpTypeMatrixColCountInOperandIndex = 1;
33 const static uint32_t kOpTypeMatrixColTypeInOperandIndex = 0;
34 const static uint32_t kOpTypePtrTypeInOperandIndex = 1;
35 const static uint32_t kOpConstantValueInOperandIndex = 0;
36
37 namespace spvtools {
38 namespace opt {
39 namespace {
40
41 // Get the length of the OpTypeArray |array_type|.
GetArrayLength(analysis::DefUseManager * def_use_mgr,Instruction * array_type)42 uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr,
43 Instruction* array_type) {
44 assert(array_type->opcode() == SpvOpTypeArray);
45 uint32_t const_int_id =
46 array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex);
47 Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id);
48 assert(array_length_inst->opcode() == SpvOpConstant);
49 return array_length_inst->GetSingleWordInOperand(
50 kOpConstantValueInOperandIndex);
51 }
52
53 // Get the element type instruction of the OpTypeArray |array_type|.
GetArrayElementType(analysis::DefUseManager * def_use_mgr,Instruction * array_type)54 Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr,
55 Instruction* array_type) {
56 assert(array_type->opcode() == SpvOpTypeArray);
57 uint32_t elem_type_id =
58 array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
59 return def_use_mgr->GetDef(elem_type_id);
60 }
61
62 // Get the column type instruction of the OpTypeMatrix |matrix_type|.
GetMatrixColumnType(analysis::DefUseManager * def_use_mgr,Instruction * matrix_type)63 Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr,
64 Instruction* matrix_type) {
65 assert(matrix_type->opcode() == SpvOpTypeMatrix);
66 uint32_t column_type_id =
67 matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
68 return def_use_mgr->GetDef(column_type_id);
69 }
70
71 // Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it
72 // |depth_to_component| times recursively and returns the component type.
73 // |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction.
GetComponentTypeOfArrayMatrix(analysis::DefUseManager * def_use_mgr,uint32_t type_id,uint32_t depth_to_component)74 uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr,
75 uint32_t type_id,
76 uint32_t depth_to_component) {
77 if (depth_to_component == 0) return type_id;
78
79 Instruction* type_inst = def_use_mgr->GetDef(type_id);
80 if (type_inst->opcode() == SpvOpTypeArray) {
81 uint32_t elem_type_id =
82 type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
83 return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id,
84 depth_to_component - 1);
85 }
86
87 assert(type_inst->opcode() == SpvOpTypeMatrix);
88 uint32_t column_type_id =
89 type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
90 return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id,
91 depth_to_component - 1);
92 }
93
94 // Creates an OpDecorate instruction whose Target is |var_id| and Decoration is
95 // |decoration|. Adds |literal| as an extra operand of the instruction.
CreateDecoration(analysis::DecorationManager * decoration_mgr,uint32_t var_id,SpvDecoration decoration,uint32_t literal)96 void CreateDecoration(analysis::DecorationManager* decoration_mgr,
97 uint32_t var_id, SpvDecoration decoration,
98 uint32_t literal) {
99 std::vector<Operand> operands({
100 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
101 {spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION,
102 {static_cast<uint32_t>(decoration)}},
103 {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}},
104 });
105 decoration_mgr->AddDecoration(SpvOpDecorate, std::move(operands));
106 }
107
108 // Replaces load instructions with composite construct instructions in all the
109 // users of the loads. |loads_to_composites| is the mapping from each load to
110 // its corresponding OpCompositeConstruct.
ReplaceLoadWithCompositeConstruct(IRContext * context,const std::unordered_map<Instruction *,Instruction * > & loads_to_composites)111 void ReplaceLoadWithCompositeConstruct(
112 IRContext* context,
113 const std::unordered_map<Instruction*, Instruction*>& loads_to_composites) {
114 for (const auto& load_and_composite : loads_to_composites) {
115 Instruction* load = load_and_composite.first;
116 Instruction* composite_construct = load_and_composite.second;
117
118 std::vector<Instruction*> users;
119 context->get_def_use_mgr()->ForEachUse(
120 load, [&users, composite_construct](Instruction* user, uint32_t index) {
121 user->GetOperand(index).words[0] = composite_construct->result_id();
122 users.push_back(user);
123 });
124
125 for (Instruction* user : users)
126 context->get_def_use_mgr()->AnalyzeInstUse(user);
127 }
128 }
129
130 // Returns the storage class of the instruction |var|.
GetStorageClass(Instruction * var)131 SpvStorageClass GetStorageClass(Instruction* var) {
132 return static_cast<SpvStorageClass>(
133 var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex));
134 }
135
136 } // namespace
137
HasExtraArrayness(Instruction & entry_point,Instruction * var)138 bool InterfaceVariableScalarReplacement::HasExtraArrayness(
139 Instruction& entry_point, Instruction* var) {
140 SpvExecutionModel execution_model =
141 static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
142 if (execution_model != SpvExecutionModelTessellationEvaluation &&
143 execution_model != SpvExecutionModelTessellationControl) {
144 return false;
145 }
146 if (!context()->get_decoration_mgr()->HasDecoration(var->result_id(),
147 SpvDecorationPatch)) {
148 if (execution_model == SpvExecutionModelTessellationControl) return true;
149 return GetStorageClass(var) != SpvStorageClassOutput;
150 }
151 return false;
152 }
153
154 bool InterfaceVariableScalarReplacement::
CheckExtraArraynessConflictBetweenEntries(Instruction * interface_var,bool has_extra_arrayness)155 CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
156 bool has_extra_arrayness) {
157 if (has_extra_arrayness) {
158 return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var);
159 }
160 return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var);
161 }
162
GetVariableLocation(Instruction * var,uint32_t * location)163 bool InterfaceVariableScalarReplacement::GetVariableLocation(
164 Instruction* var, uint32_t* location) {
165 return !context()->get_decoration_mgr()->WhileEachDecoration(
166 var->result_id(), SpvDecorationLocation,
167 [location](const Instruction& inst) {
168 *location =
169 inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
170 return false;
171 });
172 }
173
GetVariableComponent(Instruction * var,uint32_t * component)174 bool InterfaceVariableScalarReplacement::GetVariableComponent(
175 Instruction* var, uint32_t* component) {
176 return !context()->get_decoration_mgr()->WhileEachDecoration(
177 var->result_id(), SpvDecorationComponent,
178 [component](const Instruction& inst) {
179 *component =
180 inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
181 return false;
182 });
183 }
184
185 std::vector<Instruction*>
CollectInterfaceVariables(Instruction & entry_point)186 InterfaceVariableScalarReplacement::CollectInterfaceVariables(
187 Instruction& entry_point) {
188 std::vector<Instruction*> interface_vars;
189 for (uint32_t i = kOpEntryPointInOperandInterface;
190 i < entry_point.NumInOperands(); ++i) {
191 Instruction* interface_var = context()->get_def_use_mgr()->GetDef(
192 entry_point.GetSingleWordInOperand(i));
193 assert(interface_var->opcode() == SpvOpVariable);
194
195 SpvStorageClass storage_class = GetStorageClass(interface_var);
196 if (storage_class != SpvStorageClassInput &&
197 storage_class != SpvStorageClassOutput) {
198 continue;
199 }
200
201 interface_vars.push_back(interface_var);
202 }
203 return interface_vars;
204 }
205
KillInstructionAndUsers(Instruction * inst)206 void InterfaceVariableScalarReplacement::KillInstructionAndUsers(
207 Instruction* inst) {
208 if (inst->opcode() == SpvOpEntryPoint) {
209 return;
210 }
211 if (inst->opcode() != SpvOpAccessChain) {
212 context()->KillInst(inst);
213 return;
214 }
215 std::vector<Instruction*> users;
216 context()->get_def_use_mgr()->ForEachUser(
217 inst, [&users](Instruction* user) { users.push_back(user); });
218 for (auto user : users) {
219 context()->KillInst(user);
220 }
221 context()->KillInst(inst);
222 }
223
KillInstructionsAndUsers(const std::vector<Instruction * > & insts)224 void InterfaceVariableScalarReplacement::KillInstructionsAndUsers(
225 const std::vector<Instruction*>& insts) {
226 for (Instruction* inst : insts) {
227 KillInstructionAndUsers(inst);
228 }
229 }
230
KillLocationAndComponentDecorations(uint32_t var_id)231 void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations(
232 uint32_t var_id) {
233 context()->get_decoration_mgr()->RemoveDecorationsFrom(
234 var_id, [](const Instruction& inst) {
235 uint32_t decoration =
236 inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex);
237 return decoration == SpvDecorationLocation ||
238 decoration == SpvDecorationComponent;
239 });
240 }
241
ReplaceInterfaceVariableWithScalars(Instruction * interface_var,Instruction * interface_var_type,uint32_t location,uint32_t component,uint32_t extra_array_length)242 bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
243 Instruction* interface_var, Instruction* interface_var_type,
244 uint32_t location, uint32_t component, uint32_t extra_array_length) {
245 NestedCompositeComponents scalar_interface_vars =
246 CreateScalarInterfaceVarsForReplacement(interface_var_type,
247 GetStorageClass(interface_var),
248 extra_array_length);
249
250 AddLocationAndComponentDecorations(scalar_interface_vars, &location,
251 component);
252 KillLocationAndComponentDecorations(interface_var->result_id());
253
254 if (!ReplaceInterfaceVarWith(interface_var, extra_array_length,
255 scalar_interface_vars)) {
256 return false;
257 }
258
259 context()->KillInst(interface_var);
260 return true;
261 }
262
ReplaceInterfaceVarWith(Instruction * interface_var,uint32_t extra_array_length,const NestedCompositeComponents & scalar_interface_vars)263 bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
264 Instruction* interface_var, uint32_t extra_array_length,
265 const NestedCompositeComponents& scalar_interface_vars) {
266 std::vector<Instruction*> users;
267 context()->get_def_use_mgr()->ForEachUser(
268 interface_var, [&users](Instruction* user) { users.push_back(user); });
269
270 std::vector<uint32_t> interface_var_component_indices;
271 std::unordered_map<Instruction*, Instruction*> loads_to_composites;
272 std::unordered_map<Instruction*, Instruction*>
273 loads_for_access_chain_to_composites;
274 if (extra_array_length != 0) {
275 // Note that the extra arrayness is the first dimension of the array
276 // interface variable.
277 for (uint32_t index = 0; index < extra_array_length; ++index) {
278 std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
279 if (!ReplaceComponentsOfInterfaceVarWith(
280 interface_var, users, scalar_interface_vars,
281 interface_var_component_indices, &index,
282 &loads_to_component_values,
283 &loads_for_access_chain_to_composites)) {
284 return false;
285 }
286 AddComponentsToCompositesForLoads(loads_to_component_values,
287 &loads_to_composites, 0);
288 }
289 } else if (!ReplaceComponentsOfInterfaceVarWith(
290 interface_var, users, scalar_interface_vars,
291 interface_var_component_indices, nullptr, &loads_to_composites,
292 &loads_for_access_chain_to_composites)) {
293 return false;
294 }
295
296 ReplaceLoadWithCompositeConstruct(context(), loads_to_composites);
297 ReplaceLoadWithCompositeConstruct(context(),
298 loads_for_access_chain_to_composites);
299
300 KillInstructionsAndUsers(users);
301 return true;
302 }
303
AddLocationAndComponentDecorations(const NestedCompositeComponents & vars,uint32_t * location,uint32_t component)304 void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
305 const NestedCompositeComponents& vars, uint32_t* location,
306 uint32_t component) {
307 if (!vars.HasMultipleComponents()) {
308 uint32_t var_id = vars.GetComponentVariable()->result_id();
309 CreateDecoration(context()->get_decoration_mgr(), var_id,
310 SpvDecorationLocation, *location);
311 CreateDecoration(context()->get_decoration_mgr(), var_id,
312 SpvDecorationComponent, component);
313 ++(*location);
314 return;
315 }
316 for (const auto& var : vars.GetComponents()) {
317 AddLocationAndComponentDecorations(var, location, component);
318 }
319 }
320
ReplaceComponentsOfInterfaceVarWith(Instruction * interface_var,const std::vector<Instruction * > & interface_var_users,const NestedCompositeComponents & scalar_interface_vars,std::vector<uint32_t> & interface_var_component_indices,const uint32_t * extra_array_index,std::unordered_map<Instruction *,Instruction * > * loads_to_composites,std::unordered_map<Instruction *,Instruction * > * loads_for_access_chain_to_composites)321 bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
322 Instruction* interface_var,
323 const std::vector<Instruction*>& interface_var_users,
324 const NestedCompositeComponents& scalar_interface_vars,
325 std::vector<uint32_t>& interface_var_component_indices,
326 const uint32_t* extra_array_index,
327 std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
328 std::unordered_map<Instruction*, Instruction*>*
329 loads_for_access_chain_to_composites) {
330 if (!scalar_interface_vars.HasMultipleComponents()) {
331 for (Instruction* interface_var_user : interface_var_users) {
332 if (!ReplaceComponentOfInterfaceVarWith(
333 interface_var, interface_var_user,
334 scalar_interface_vars.GetComponentVariable(),
335 interface_var_component_indices, extra_array_index,
336 loads_to_composites, loads_for_access_chain_to_composites)) {
337 return false;
338 }
339 }
340 return true;
341 }
342 return ReplaceMultipleComponentsOfInterfaceVarWith(
343 interface_var, interface_var_users, scalar_interface_vars.GetComponents(),
344 interface_var_component_indices, extra_array_index, loads_to_composites,
345 loads_for_access_chain_to_composites);
346 }
347
348 bool InterfaceVariableScalarReplacement::
ReplaceMultipleComponentsOfInterfaceVarWith(Instruction * interface_var,const std::vector<Instruction * > & interface_var_users,const std::vector<NestedCompositeComponents> & components,std::vector<uint32_t> & interface_var_component_indices,const uint32_t * extra_array_index,std::unordered_map<Instruction *,Instruction * > * loads_to_composites,std::unordered_map<Instruction *,Instruction * > * loads_for_access_chain_to_composites)349 ReplaceMultipleComponentsOfInterfaceVarWith(
350 Instruction* interface_var,
351 const std::vector<Instruction*>& interface_var_users,
352 const std::vector<NestedCompositeComponents>& components,
353 std::vector<uint32_t>& interface_var_component_indices,
354 const uint32_t* extra_array_index,
355 std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
356 std::unordered_map<Instruction*, Instruction*>*
357 loads_for_access_chain_to_composites) {
358 for (uint32_t i = 0; i < components.size(); ++i) {
359 interface_var_component_indices.push_back(i);
360 std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
361 std::unordered_map<Instruction*, Instruction*>
362 loads_for_access_chain_to_component_values;
363 if (!ReplaceComponentsOfInterfaceVarWith(
364 interface_var, interface_var_users, components[i],
365 interface_var_component_indices, extra_array_index,
366 &loads_to_component_values,
367 &loads_for_access_chain_to_component_values)) {
368 return false;
369 }
370 interface_var_component_indices.pop_back();
371
372 uint32_t depth_to_component =
373 static_cast<uint32_t>(interface_var_component_indices.size());
374 AddComponentsToCompositesForLoads(
375 loads_for_access_chain_to_component_values,
376 loads_for_access_chain_to_composites, depth_to_component);
377 if (extra_array_index) ++depth_to_component;
378 AddComponentsToCompositesForLoads(loads_to_component_values,
379 loads_to_composites, depth_to_component);
380 }
381 return true;
382 }
383
ReplaceComponentOfInterfaceVarWith(Instruction * interface_var,Instruction * interface_var_user,Instruction * scalar_var,const std::vector<uint32_t> & interface_var_component_indices,const uint32_t * extra_array_index,std::unordered_map<Instruction *,Instruction * > * loads_to_component_values,std::unordered_map<Instruction *,Instruction * > * loads_for_access_chain_to_component_values)384 bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
385 Instruction* interface_var, Instruction* interface_var_user,
386 Instruction* scalar_var,
387 const std::vector<uint32_t>& interface_var_component_indices,
388 const uint32_t* extra_array_index,
389 std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
390 std::unordered_map<Instruction*, Instruction*>*
391 loads_for_access_chain_to_component_values) {
392 SpvOp opcode = interface_var_user->opcode();
393 if (opcode == SpvOpStore) {
394 uint32_t value_id = interface_var_user->GetSingleWordInOperand(1);
395 StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices,
396 scalar_var, extra_array_index,
397 interface_var_user);
398 return true;
399 }
400 if (opcode == SpvOpLoad) {
401 Instruction* scalar_load =
402 LoadScalarVar(scalar_var, extra_array_index, interface_var_user);
403 loads_to_component_values->insert({interface_var_user, scalar_load});
404 return true;
405 }
406
407 // Copy OpName and annotation instructions only once. Therefore, we create
408 // them only for the first element of the extra array.
409 if (extra_array_index && *extra_array_index != 0) return true;
410
411 if (opcode == SpvOpDecorateId || opcode == SpvOpDecorateString ||
412 opcode == SpvOpDecorate) {
413 CloneAnnotationForVariable(interface_var_user, scalar_var->result_id());
414 return true;
415 }
416
417 if (opcode == SpvOpName) {
418 std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context()));
419 new_inst->SetInOperand(0, {scalar_var->result_id()});
420 context()->AddDebug2Inst(std::move(new_inst));
421 return true;
422 }
423
424 if (opcode == SpvOpEntryPoint) {
425 return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
426 scalar_var->result_id());
427 }
428
429 if (opcode == SpvOpAccessChain) {
430 ReplaceAccessChainWith(interface_var_user, interface_var_component_indices,
431 scalar_var,
432 loads_for_access_chain_to_component_values);
433 return true;
434 }
435
436 std::string message("Unhandled instruction");
437 message += "\n " + interface_var_user->PrettyPrint(
438 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
439 message +=
440 "\nfor interface variable scalar replacement\n " +
441 interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
442 context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
443 return false;
444 }
445
UseBaseAccessChainForAccessChain(Instruction * access_chain,Instruction * base_access_chain)446 void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain(
447 Instruction* access_chain, Instruction* base_access_chain) {
448 assert(base_access_chain->opcode() == SpvOpAccessChain &&
449 access_chain->opcode() == SpvOpAccessChain &&
450 access_chain->GetSingleWordInOperand(0) ==
451 base_access_chain->result_id());
452 Instruction::OperandList new_operands;
453 for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) {
454 new_operands.emplace_back(base_access_chain->GetInOperand(i));
455 }
456 for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
457 new_operands.emplace_back(access_chain->GetInOperand(i));
458 }
459 access_chain->SetInOperands(std::move(new_operands));
460 }
461
CreateAccessChainToVar(uint32_t var_type_id,Instruction * var,const std::vector<uint32_t> & index_ids,Instruction * insert_before,uint32_t * component_type_id)462 Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar(
463 uint32_t var_type_id, Instruction* var,
464 const std::vector<uint32_t>& index_ids, Instruction* insert_before,
465 uint32_t* component_type_id) {
466 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
467 *component_type_id = GetComponentTypeOfArrayMatrix(
468 def_use_mgr, var_type_id, static_cast<uint32_t>(index_ids.size()));
469
470 uint32_t ptr_type_id =
471 GetPointerType(*component_type_id, GetStorageClass(var));
472
473 std::unique_ptr<Instruction> new_access_chain(
474 new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
475 std::initializer_list<Operand>{
476 {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
477 for (uint32_t index_id : index_ids) {
478 new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}});
479 }
480
481 Instruction* inst = new_access_chain.get();
482 def_use_mgr->AnalyzeInstDefUse(inst);
483 insert_before->InsertBefore(std::move(new_access_chain));
484 return inst;
485 }
486
CreateAccessChainWithIndex(uint32_t component_type_id,Instruction * var,uint32_t index,Instruction * insert_before)487 Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
488 uint32_t component_type_id, Instruction* var, uint32_t index,
489 Instruction* insert_before) {
490 uint32_t ptr_type_id =
491 GetPointerType(component_type_id, GetStorageClass(var));
492 uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index);
493 std::unique_ptr<Instruction> new_access_chain(
494 new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
495 std::initializer_list<Operand>{
496 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
497 {SPV_OPERAND_TYPE_ID, {index_id}},
498 }));
499 Instruction* inst = new_access_chain.get();
500 context()->get_def_use_mgr()->AnalyzeInstDefUse(inst);
501 insert_before->InsertBefore(std::move(new_access_chain));
502 return inst;
503 }
504
ReplaceAccessChainWith(Instruction * access_chain,const std::vector<uint32_t> & interface_var_component_indices,Instruction * scalar_var,std::unordered_map<Instruction *,Instruction * > * loads_to_component_values)505 void InterfaceVariableScalarReplacement::ReplaceAccessChainWith(
506 Instruction* access_chain,
507 const std::vector<uint32_t>& interface_var_component_indices,
508 Instruction* scalar_var,
509 std::unordered_map<Instruction*, Instruction*>* loads_to_component_values) {
510 std::vector<uint32_t> indexes;
511 for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
512 indexes.push_back(access_chain->GetSingleWordInOperand(i));
513 }
514
515 // Note that we have a strong assumption that |access_chain| has only a single
516 // index that is for the extra arrayness.
517 context()->get_def_use_mgr()->ForEachUser(
518 access_chain,
519 [this, access_chain, &indexes, &interface_var_component_indices,
520 scalar_var, loads_to_component_values](Instruction* user) {
521 switch (user->opcode()) {
522 case SpvOpAccessChain: {
523 UseBaseAccessChainForAccessChain(user, access_chain);
524 ReplaceAccessChainWith(user, interface_var_component_indices,
525 scalar_var, loads_to_component_values);
526 return;
527 }
528 case SpvOpStore: {
529 uint32_t value_id = user->GetSingleWordInOperand(1);
530 StoreComponentOfValueToAccessChainToScalarVar(
531 value_id, interface_var_component_indices, scalar_var, indexes,
532 user);
533 return;
534 }
535 case SpvOpLoad: {
536 Instruction* value =
537 LoadAccessChainToVar(scalar_var, indexes, user);
538 loads_to_component_values->insert({user, value});
539 return;
540 }
541 default:
542 break;
543 }
544 });
545 }
546
CloneAnnotationForVariable(Instruction * annotation_inst,uint32_t var_id)547 void InterfaceVariableScalarReplacement::CloneAnnotationForVariable(
548 Instruction* annotation_inst, uint32_t var_id) {
549 assert(annotation_inst->opcode() == SpvOpDecorate ||
550 annotation_inst->opcode() == SpvOpDecorateId ||
551 annotation_inst->opcode() == SpvOpDecorateString);
552 std::unique_ptr<Instruction> new_inst(annotation_inst->Clone(context()));
553 new_inst->SetInOperand(0, {var_id});
554 context()->AddAnnotationInst(std::move(new_inst));
555 }
556
ReplaceInterfaceVarInEntryPoint(Instruction * interface_var,Instruction * entry_point,uint32_t scalar_var_id)557 bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint(
558 Instruction* interface_var, Instruction* entry_point,
559 uint32_t scalar_var_id) {
560 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
561 uint32_t interface_var_id = interface_var->result_id();
562 if (interface_vars_removed_from_entry_point_operands_.find(
563 interface_var_id) !=
564 interface_vars_removed_from_entry_point_operands_.end()) {
565 entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}});
566 def_use_mgr->AnalyzeInstUse(entry_point);
567 return true;
568 }
569
570 bool success = !entry_point->WhileEachInId(
571 [&interface_var_id, &scalar_var_id](uint32_t* id) {
572 if (*id == interface_var_id) {
573 *id = scalar_var_id;
574 return false;
575 }
576 return true;
577 });
578 if (!success) {
579 std::string message(
580 "interface variable is not an operand of the entry point");
581 message += "\n " + interface_var->PrettyPrint(
582 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
583 message += "\n " + entry_point->PrettyPrint(
584 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
585 context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
586 return false;
587 }
588
589 def_use_mgr->AnalyzeInstUse(entry_point);
590 interface_vars_removed_from_entry_point_operands_.insert(interface_var_id);
591 return true;
592 }
593
GetPointeeTypeIdOfVar(Instruction * var)594 uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar(
595 Instruction* var) {
596 assert(var->opcode() == SpvOpVariable);
597
598 uint32_t ptr_type_id = var->type_id();
599 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
600 Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id);
601
602 assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
603 "Variable must have a pointer type.");
604 return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex);
605 }
606
StoreComponentOfValueToScalarVar(uint32_t value_id,const std::vector<uint32_t> & component_indices,Instruction * scalar_var,const uint32_t * extra_array_index,Instruction * insert_before)607 void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar(
608 uint32_t value_id, const std::vector<uint32_t>& component_indices,
609 Instruction* scalar_var, const uint32_t* extra_array_index,
610 Instruction* insert_before) {
611 uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
612 Instruction* ptr = scalar_var;
613 if (extra_array_index) {
614 auto* ty_mgr = context()->get_type_mgr();
615 analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
616 assert(array_type != nullptr);
617 component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
618 ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
619 *extra_array_index, insert_before);
620 }
621
622 StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
623 extra_array_index, insert_before);
624 }
625
LoadScalarVar(Instruction * scalar_var,const uint32_t * extra_array_index,Instruction * insert_before)626 Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
627 Instruction* scalar_var, const uint32_t* extra_array_index,
628 Instruction* insert_before) {
629 uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
630 Instruction* ptr = scalar_var;
631 if (extra_array_index) {
632 auto* ty_mgr = context()->get_type_mgr();
633 analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
634 assert(array_type != nullptr);
635 component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
636 ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
637 *extra_array_index, insert_before);
638 }
639
640 return CreateLoad(component_type_id, ptr, insert_before);
641 }
642
CreateLoad(uint32_t type_id,Instruction * ptr,Instruction * insert_before)643 Instruction* InterfaceVariableScalarReplacement::CreateLoad(
644 uint32_t type_id, Instruction* ptr, Instruction* insert_before) {
645 std::unique_ptr<Instruction> load(
646 new Instruction(context(), SpvOpLoad, type_id, TakeNextId(),
647 std::initializer_list<Operand>{
648 {SPV_OPERAND_TYPE_ID, {ptr->result_id()}}}));
649 Instruction* load_inst = load.get();
650 context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst);
651 insert_before->InsertBefore(std::move(load));
652 return load_inst;
653 }
654
StoreComponentOfValueTo(uint32_t component_type_id,uint32_t value_id,const std::vector<uint32_t> & component_indices,Instruction * ptr,const uint32_t * extra_array_index,Instruction * insert_before)655 void InterfaceVariableScalarReplacement::StoreComponentOfValueTo(
656 uint32_t component_type_id, uint32_t value_id,
657 const std::vector<uint32_t>& component_indices, Instruction* ptr,
658 const uint32_t* extra_array_index, Instruction* insert_before) {
659 std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract(
660 component_type_id, value_id, component_indices, extra_array_index));
661
662 std::unique_ptr<Instruction> new_store(
663 new Instruction(context(), SpvOpStore));
664 new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}});
665 new_store->AddOperand(
666 {SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}});
667
668 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
669 def_use_mgr->AnalyzeInstDefUse(composite_extract.get());
670 def_use_mgr->AnalyzeInstDefUse(new_store.get());
671
672 insert_before->InsertBefore(std::move(composite_extract));
673 insert_before->InsertBefore(std::move(new_store));
674 }
675
CreateCompositeExtract(uint32_t type_id,uint32_t composite_id,const std::vector<uint32_t> & indexes,const uint32_t * extra_first_index)676 Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract(
677 uint32_t type_id, uint32_t composite_id,
678 const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) {
679 uint32_t component_id = TakeNextId();
680 Instruction* composite_extract = new Instruction(
681 context(), SpvOpCompositeExtract, type_id, component_id,
682 std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}});
683 if (extra_first_index) {
684 composite_extract->AddOperand(
685 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}});
686 }
687 for (uint32_t index : indexes) {
688 composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}});
689 }
690 return composite_extract;
691 }
692
693 void InterfaceVariableScalarReplacement::
StoreComponentOfValueToAccessChainToScalarVar(uint32_t value_id,const std::vector<uint32_t> & component_indices,Instruction * scalar_var,const std::vector<uint32_t> & access_chain_indices,Instruction * insert_before)694 StoreComponentOfValueToAccessChainToScalarVar(
695 uint32_t value_id, const std::vector<uint32_t>& component_indices,
696 Instruction* scalar_var,
697 const std::vector<uint32_t>& access_chain_indices,
698 Instruction* insert_before) {
699 uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
700 Instruction* ptr = scalar_var;
701 if (!access_chain_indices.empty()) {
702 ptr = CreateAccessChainToVar(component_type_id, scalar_var,
703 access_chain_indices, insert_before,
704 &component_type_id);
705 }
706
707 StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
708 nullptr, insert_before);
709 }
710
LoadAccessChainToVar(Instruction * var,const std::vector<uint32_t> & indexes,Instruction * insert_before)711 Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar(
712 Instruction* var, const std::vector<uint32_t>& indexes,
713 Instruction* insert_before) {
714 uint32_t component_type_id = GetPointeeTypeIdOfVar(var);
715 Instruction* ptr = var;
716 if (!indexes.empty()) {
717 ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before,
718 &component_type_id);
719 }
720
721 return CreateLoad(component_type_id, ptr, insert_before);
722 }
723
724 Instruction*
CreateCompositeConstructForComponentOfLoad(Instruction * load,uint32_t depth_to_component)725 InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad(
726 Instruction* load, uint32_t depth_to_component) {
727 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
728 uint32_t type_id = load->type_id();
729 if (depth_to_component != 0) {
730 type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(),
731 depth_to_component);
732 }
733 uint32_t new_id = context()->TakeNextId();
734 std::unique_ptr<Instruction> new_composite_construct(
735 new Instruction(context(), SpvOpCompositeConstruct, type_id, new_id, {}));
736 Instruction* composite_construct = new_composite_construct.get();
737 def_use_mgr->AnalyzeInstDefUse(composite_construct);
738
739 // Insert |new_composite_construct| after |load|. When there are multiple
740 // recursive composite construct instructions for a load, we have to place the
741 // composite construct with a lower depth later because it constructs the
742 // composite that contains other composites with lower depths.
743 auto* insert_before = load->NextNode();
744 while (true) {
745 auto itr =
746 composite_ids_to_component_depths.find(insert_before->result_id());
747 if (itr == composite_ids_to_component_depths.end()) break;
748 if (itr->second <= depth_to_component) break;
749 insert_before = insert_before->NextNode();
750 }
751 insert_before->InsertBefore(std::move(new_composite_construct));
752 composite_ids_to_component_depths.insert({new_id, depth_to_component});
753 return composite_construct;
754 }
755
AddComponentsToCompositesForLoads(const std::unordered_map<Instruction *,Instruction * > & loads_to_component_values,std::unordered_map<Instruction *,Instruction * > * loads_to_composites,uint32_t depth_to_component)756 void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads(
757 const std::unordered_map<Instruction*, Instruction*>&
758 loads_to_component_values,
759 std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
760 uint32_t depth_to_component) {
761 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
762 for (auto& load_and_component_vale : loads_to_component_values) {
763 Instruction* load = load_and_component_vale.first;
764 Instruction* component_value = load_and_component_vale.second;
765 Instruction* composite_construct = nullptr;
766 auto itr = loads_to_composites->find(load);
767 if (itr == loads_to_composites->end()) {
768 composite_construct =
769 CreateCompositeConstructForComponentOfLoad(load, depth_to_component);
770 loads_to_composites->insert({load, composite_construct});
771 } else {
772 composite_construct = itr->second;
773 }
774 composite_construct->AddOperand(
775 {SPV_OPERAND_TYPE_ID, {component_value->result_id()}});
776 def_use_mgr->AnalyzeInstDefUse(composite_construct);
777 }
778 }
779
GetArrayType(uint32_t elem_type_id,uint32_t array_length)780 uint32_t InterfaceVariableScalarReplacement::GetArrayType(
781 uint32_t elem_type_id, uint32_t array_length) {
782 analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id);
783 uint32_t array_length_id =
784 context()->get_constant_mgr()->GetUIntConst(array_length);
785 analysis::Array array_type(
786 elem_type,
787 analysis::Array::LengthInfo{array_length_id, {0, array_length}});
788 return context()->get_type_mgr()->GetTypeInstruction(&array_type);
789 }
790
GetPointerType(uint32_t type_id,SpvStorageClass storage_class)791 uint32_t InterfaceVariableScalarReplacement::GetPointerType(
792 uint32_t type_id, SpvStorageClass storage_class) {
793 analysis::Type* type = context()->get_type_mgr()->GetType(type_id);
794 analysis::Pointer ptr_type(type, storage_class);
795 return context()->get_type_mgr()->GetTypeInstruction(&ptr_type);
796 }
797
798 InterfaceVariableScalarReplacement::NestedCompositeComponents
CreateScalarInterfaceVarsForArray(Instruction * interface_var_type,SpvStorageClass storage_class,uint32_t extra_array_length)799 InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
800 Instruction* interface_var_type, SpvStorageClass storage_class,
801 uint32_t extra_array_length) {
802 assert(interface_var_type->opcode() == SpvOpTypeArray);
803
804 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
805 uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type);
806 Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type);
807
808 NestedCompositeComponents scalar_vars;
809 while (array_length > 0) {
810 NestedCompositeComponents scalar_vars_for_element =
811 CreateScalarInterfaceVarsForReplacement(elem_type, storage_class,
812 extra_array_length);
813 scalar_vars.AddComponent(scalar_vars_for_element);
814 --array_length;
815 }
816 return scalar_vars;
817 }
818
819 InterfaceVariableScalarReplacement::NestedCompositeComponents
CreateScalarInterfaceVarsForMatrix(Instruction * interface_var_type,SpvStorageClass storage_class,uint32_t extra_array_length)820 InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
821 Instruction* interface_var_type, SpvStorageClass storage_class,
822 uint32_t extra_array_length) {
823 assert(interface_var_type->opcode() == SpvOpTypeMatrix);
824
825 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
826 uint32_t column_count = interface_var_type->GetSingleWordInOperand(
827 kOpTypeMatrixColCountInOperandIndex);
828 Instruction* column_type =
829 GetMatrixColumnType(def_use_mgr, interface_var_type);
830
831 NestedCompositeComponents scalar_vars;
832 while (column_count > 0) {
833 NestedCompositeComponents scalar_vars_for_column =
834 CreateScalarInterfaceVarsForReplacement(column_type, storage_class,
835 extra_array_length);
836 scalar_vars.AddComponent(scalar_vars_for_column);
837 --column_count;
838 }
839 return scalar_vars;
840 }
841
842 InterfaceVariableScalarReplacement::NestedCompositeComponents
CreateScalarInterfaceVarsForReplacement(Instruction * interface_var_type,SpvStorageClass storage_class,uint32_t extra_array_length)843 InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
844 Instruction* interface_var_type, SpvStorageClass storage_class,
845 uint32_t extra_array_length) {
846 // Handle array case.
847 if (interface_var_type->opcode() == SpvOpTypeArray) {
848 return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class,
849 extra_array_length);
850 }
851
852 // Handle matrix case.
853 if (interface_var_type->opcode() == SpvOpTypeMatrix) {
854 return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class,
855 extra_array_length);
856 }
857
858 // Handle scalar or vector case.
859 NestedCompositeComponents scalar_var;
860 uint32_t type_id = interface_var_type->result_id();
861 if (extra_array_length != 0) {
862 type_id = GetArrayType(type_id, extra_array_length);
863 }
864 uint32_t ptr_type_id =
865 context()->get_type_mgr()->FindPointerToType(type_id, storage_class);
866 uint32_t id = TakeNextId();
867 std::unique_ptr<Instruction> variable(
868 new Instruction(context(), SpvOpVariable, ptr_type_id, id,
869 std::initializer_list<Operand>{
870 {SPV_OPERAND_TYPE_STORAGE_CLASS,
871 {static_cast<uint32_t>(storage_class)}}}));
872 scalar_var.SetSingleComponentVariable(variable.get());
873 context()->AddGlobalValue(std::move(variable));
874 return scalar_var;
875 }
876
GetTypeOfVariable(Instruction * var)877 Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable(
878 Instruction* var) {
879 uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var);
880 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
881 return def_use_mgr->GetDef(pointee_type_id);
882 }
883
Process()884 Pass::Status InterfaceVariableScalarReplacement::Process() {
885 Pass::Status status = Status::SuccessWithoutChange;
886 for (Instruction& entry_point : get_module()->entry_points()) {
887 status =
888 CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point));
889 }
890 return status;
891 }
892
893 bool InterfaceVariableScalarReplacement::
ReportErrorIfHasExtraArraynessForOtherEntry(Instruction * var)894 ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) {
895 if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end())
896 return false;
897
898 std::string message(
899 "A variable is arrayed for an entry point but it is not "
900 "arrayed for another entry point");
901 message +=
902 "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
903 context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
904 return true;
905 }
906
907 bool InterfaceVariableScalarReplacement::
ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction * var)908 ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) {
909 if (vars_without_extra_arrayness.find(var) ==
910 vars_without_extra_arrayness.end())
911 return false;
912
913 std::string message(
914 "A variable is not arrayed for an entry point but it is "
915 "arrayed for another entry point");
916 message +=
917 "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
918 context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
919 return true;
920 }
921
922 Pass::Status
ReplaceInterfaceVarsWithScalars(Instruction & entry_point)923 InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars(
924 Instruction& entry_point) {
925 std::vector<Instruction*> interface_vars =
926 CollectInterfaceVariables(entry_point);
927
928 Pass::Status status = Status::SuccessWithoutChange;
929 for (Instruction* interface_var : interface_vars) {
930 uint32_t location, component;
931 if (!GetVariableLocation(interface_var, &location)) continue;
932 if (!GetVariableComponent(interface_var, &component)) component = 0;
933
934 Instruction* interface_var_type = GetTypeOfVariable(interface_var);
935 uint32_t extra_array_length = 0;
936 if (HasExtraArrayness(entry_point, interface_var)) {
937 extra_array_length =
938 GetArrayLength(context()->get_def_use_mgr(), interface_var_type);
939 interface_var_type =
940 GetArrayElementType(context()->get_def_use_mgr(), interface_var_type);
941 vars_with_extra_arrayness.insert(interface_var);
942 } else {
943 vars_without_extra_arrayness.insert(interface_var);
944 }
945
946 if (!CheckExtraArraynessConflictBetweenEntries(interface_var,
947 extra_array_length != 0)) {
948 return Pass::Status::Failure;
949 }
950
951 if (interface_var_type->opcode() != SpvOpTypeArray &&
952 interface_var_type->opcode() != SpvOpTypeMatrix) {
953 continue;
954 }
955
956 if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type,
957 location, component,
958 extra_array_length)) {
959 return Pass::Status::Failure;
960 }
961 status = Pass::Status::SuccessWithChange;
962 }
963
964 return status;
965 }
966
967 } // namespace opt
968 } // namespace spvtools
969