1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/scalar_replacement_pass.h"
16
17 #include <algorithm>
18 #include <queue>
19 #include <tuple>
20 #include <utility>
21
22 #include "source/extensions.h"
23 #include "source/opt/reflect.h"
24 #include "source/opt/types.h"
25 #include "source/util/make_unique.h"
26
27 namespace spvtools {
28 namespace opt {
29 namespace {
30 constexpr uint32_t kDebugValueOperandValueIndex = 5;
31 constexpr uint32_t kDebugValueOperandExpressionIndex = 6;
32 constexpr uint32_t kDebugDeclareOperandVariableIndex = 5;
33 } // namespace
34
Process()35 Pass::Status ScalarReplacementPass::Process() {
36 Status status = Status::SuccessWithoutChange;
37 for (auto& f : *get_module()) {
38 if (f.IsDeclaration()) {
39 continue;
40 }
41
42 Status functionStatus = ProcessFunction(&f);
43 if (functionStatus == Status::Failure)
44 return functionStatus;
45 else if (functionStatus == Status::SuccessWithChange)
46 status = functionStatus;
47 }
48
49 return status;
50 }
51
ProcessFunction(Function * function)52 Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
53 std::queue<Instruction*> worklist;
54 BasicBlock& entry = *function->begin();
55 for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
56 // Function storage class OpVariables must appear as the first instructions
57 // of the entry block.
58 if (iter->opcode() != spv::Op::OpVariable) break;
59
60 Instruction* varInst = &*iter;
61 if (CanReplaceVariable(varInst)) {
62 worklist.push(varInst);
63 }
64 }
65
66 Status status = Status::SuccessWithoutChange;
67 while (!worklist.empty()) {
68 Instruction* varInst = worklist.front();
69 worklist.pop();
70
71 Status var_status = ReplaceVariable(varInst, &worklist);
72 if (var_status == Status::Failure)
73 return var_status;
74 else if (var_status == Status::SuccessWithChange)
75 status = var_status;
76 }
77
78 return status;
79 }
80
ReplaceVariable(Instruction * inst,std::queue<Instruction * > * worklist)81 Pass::Status ScalarReplacementPass::ReplaceVariable(
82 Instruction* inst, std::queue<Instruction*>* worklist) {
83 std::vector<Instruction*> replacements;
84 if (!CreateReplacementVariables(inst, &replacements)) {
85 return Status::Failure;
86 }
87
88 std::vector<Instruction*> dead;
89 bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
90 inst, [this, &replacements, &dead](Instruction* user) {
91 if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
92 if (ReplaceWholeDebugDeclare(user, replacements)) {
93 dead.push_back(user);
94 return true;
95 }
96 return false;
97 }
98 if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
99 if (ReplaceWholeDebugValue(user, replacements)) {
100 dead.push_back(user);
101 return true;
102 }
103 return false;
104 }
105 if (!IsAnnotationInst(user->opcode())) {
106 switch (user->opcode()) {
107 case spv::Op::OpLoad:
108 if (ReplaceWholeLoad(user, replacements)) {
109 dead.push_back(user);
110 } else {
111 return false;
112 }
113 break;
114 case spv::Op::OpStore:
115 if (ReplaceWholeStore(user, replacements)) {
116 dead.push_back(user);
117 } else {
118 return false;
119 }
120 break;
121 case spv::Op::OpAccessChain:
122 case spv::Op::OpInBoundsAccessChain:
123 if (ReplaceAccessChain(user, replacements))
124 dead.push_back(user);
125 else
126 return false;
127 break;
128 case spv::Op::OpName:
129 case spv::Op::OpMemberName:
130 break;
131 default:
132 assert(false && "Unexpected opcode");
133 break;
134 }
135 }
136 return true;
137 });
138
139 if (replaced_all_uses) {
140 dead.push_back(inst);
141 } else {
142 return Status::Failure;
143 }
144
145 // If there are no dead instructions to clean up, return with no changes.
146 if (dead.empty()) return Status::SuccessWithoutChange;
147
148 // Clean up some dead code.
149 while (!dead.empty()) {
150 Instruction* toKill = dead.back();
151 dead.pop_back();
152 context()->KillInst(toKill);
153 }
154
155 // Attempt to further scalarize.
156 for (auto var : replacements) {
157 if (var->opcode() == spv::Op::OpVariable) {
158 if (get_def_use_mgr()->NumUsers(var) == 0) {
159 context()->KillInst(var);
160 } else if (CanReplaceVariable(var)) {
161 worklist->push(var);
162 }
163 }
164 }
165
166 return Status::SuccessWithChange;
167 }
168
ReplaceWholeDebugDeclare(Instruction * dbg_decl,const std::vector<Instruction * > & replacements)169 bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
170 Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
171 // Insert Deref operation to the front of the operation list of |dbg_decl|.
172 Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
173 dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
174 auto* deref_expr =
175 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
176
177 // Add DebugValue instruction with Indexes operand and Deref operation.
178 int32_t idx = 0;
179 for (const auto* var : replacements) {
180 Instruction* insert_before = var->NextNode();
181 while (insert_before->opcode() == spv::Op::OpVariable)
182 insert_before = insert_before->NextNode();
183 assert(insert_before != nullptr && "unexpected end of list");
184 Instruction* added_dbg_value =
185 context()->get_debug_info_mgr()->AddDebugValueForDecl(
186 dbg_decl, /*value_id=*/var->result_id(),
187 /*insert_before=*/insert_before, /*scope_and_line=*/dbg_decl);
188
189 if (added_dbg_value == nullptr) return false;
190 added_dbg_value->AddOperand(
191 {SPV_OPERAND_TYPE_ID,
192 {context()->get_constant_mgr()->GetSIntConstId(idx)}});
193 added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
194 {deref_expr->result_id()});
195 if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
196 context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
197 }
198 ++idx;
199 }
200 return true;
201 }
202
ReplaceWholeDebugValue(Instruction * dbg_value,const std::vector<Instruction * > & replacements)203 bool ScalarReplacementPass::ReplaceWholeDebugValue(
204 Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
205 int32_t idx = 0;
206 BasicBlock* block = context()->get_instr_block(dbg_value);
207 for (auto var : replacements) {
208 // Clone the DebugValue.
209 std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
210 uint32_t new_id = TakeNextId();
211 if (new_id == 0) return false;
212 new_dbg_value->SetResultId(new_id);
213 // Update 'Value' operand to the |replacements|.
214 new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
215 // Append 'Indexes' operand.
216 new_dbg_value->AddOperand(
217 {SPV_OPERAND_TYPE_ID,
218 {context()->get_constant_mgr()->GetSIntConstId(idx)}});
219 // Insert the new DebugValue to the basic block.
220 auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
221 get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
222 context()->set_instr_block(added_instr, block);
223 ++idx;
224 }
225 return true;
226 }
227
ReplaceWholeLoad(Instruction * load,const std::vector<Instruction * > & replacements)228 bool ScalarReplacementPass::ReplaceWholeLoad(
229 Instruction* load, const std::vector<Instruction*>& replacements) {
230 // Replaces the load of the entire composite with a load from each replacement
231 // variable followed by a composite construction.
232 BasicBlock* block = context()->get_instr_block(load);
233 std::vector<Instruction*> loads;
234 loads.reserve(replacements.size());
235 BasicBlock::iterator where(load);
236 for (auto var : replacements) {
237 // Create a load of each replacement variable.
238 if (var->opcode() != spv::Op::OpVariable) {
239 loads.push_back(var);
240 continue;
241 }
242
243 Instruction* type = GetStorageType(var);
244 uint32_t loadId = TakeNextId();
245 if (loadId == 0) {
246 return false;
247 }
248 std::unique_ptr<Instruction> newLoad(
249 new Instruction(context(), spv::Op::OpLoad, type->result_id(), loadId,
250 std::initializer_list<Operand>{
251 {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
252 // Copy memory access attributes which start at index 1. Index 0 is the
253 // pointer to load.
254 for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
255 Operand copy(load->GetInOperand(i));
256 newLoad->AddOperand(std::move(copy));
257 }
258 where = where.InsertBefore(std::move(newLoad));
259 get_def_use_mgr()->AnalyzeInstDefUse(&*where);
260 context()->set_instr_block(&*where, block);
261 where->UpdateDebugInfoFrom(load);
262 loads.push_back(&*where);
263 }
264
265 // Construct a new composite.
266 uint32_t compositeId = TakeNextId();
267 if (compositeId == 0) {
268 return false;
269 }
270 where = load;
271 std::unique_ptr<Instruction> compositeConstruct(
272 new Instruction(context(), spv::Op::OpCompositeConstruct, load->type_id(),
273 compositeId, {}));
274 for (auto l : loads) {
275 Operand op(SPV_OPERAND_TYPE_ID,
276 std::initializer_list<uint32_t>{l->result_id()});
277 compositeConstruct->AddOperand(std::move(op));
278 }
279 where = where.InsertBefore(std::move(compositeConstruct));
280 get_def_use_mgr()->AnalyzeInstDefUse(&*where);
281 where->UpdateDebugInfoFrom(load);
282 context()->set_instr_block(&*where, block);
283 context()->ReplaceAllUsesWith(load->result_id(), compositeId);
284 return true;
285 }
286
ReplaceWholeStore(Instruction * store,const std::vector<Instruction * > & replacements)287 bool ScalarReplacementPass::ReplaceWholeStore(
288 Instruction* store, const std::vector<Instruction*>& replacements) {
289 // Replaces a store to the whole composite with a series of extract and stores
290 // to each element.
291 uint32_t storeInput = store->GetSingleWordInOperand(1u);
292 BasicBlock* block = context()->get_instr_block(store);
293 BasicBlock::iterator where(store);
294 uint32_t elementIndex = 0;
295 for (auto var : replacements) {
296 // Create the extract.
297 if (var->opcode() != spv::Op::OpVariable) {
298 elementIndex++;
299 continue;
300 }
301
302 Instruction* type = GetStorageType(var);
303 uint32_t extractId = TakeNextId();
304 if (extractId == 0) {
305 return false;
306 }
307 std::unique_ptr<Instruction> extract(new Instruction(
308 context(), spv::Op::OpCompositeExtract, type->result_id(), extractId,
309 std::initializer_list<Operand>{
310 {SPV_OPERAND_TYPE_ID, {storeInput}},
311 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
312 auto iter = where.InsertBefore(std::move(extract));
313 iter->UpdateDebugInfoFrom(store);
314 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
315 context()->set_instr_block(&*iter, block);
316
317 // Create the store.
318 std::unique_ptr<Instruction> newStore(
319 new Instruction(context(), spv::Op::OpStore, 0, 0,
320 std::initializer_list<Operand>{
321 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
322 {SPV_OPERAND_TYPE_ID, {extractId}}}));
323 // Copy memory access attributes which start at index 2. Index 0 is the
324 // pointer and index 1 is the data.
325 for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
326 Operand copy(store->GetInOperand(i));
327 newStore->AddOperand(std::move(copy));
328 }
329 iter = where.InsertBefore(std::move(newStore));
330 iter->UpdateDebugInfoFrom(store);
331 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
332 context()->set_instr_block(&*iter, block);
333 }
334 return true;
335 }
336
ReplaceAccessChain(Instruction * chain,const std::vector<Instruction * > & replacements)337 bool ScalarReplacementPass::ReplaceAccessChain(
338 Instruction* chain, const std::vector<Instruction*>& replacements) {
339 // Replaces the access chain with either another access chain (with one fewer
340 // indexes) or a direct use of the replacement variable.
341 uint32_t indexId = chain->GetSingleWordInOperand(1u);
342 const Instruction* index = get_def_use_mgr()->GetDef(indexId);
343 int64_t indexValue = context()
344 ->get_constant_mgr()
345 ->GetConstantFromInst(index)
346 ->GetSignExtendedValue();
347 if (indexValue < 0 ||
348 indexValue >= static_cast<int64_t>(replacements.size())) {
349 // Out of bounds access, this is illegal IR. Notice that OpAccessChain
350 // indexing is 0-based, so we should also reject index == size-of-array.
351 return false;
352 } else {
353 const Instruction* var = replacements[static_cast<size_t>(indexValue)];
354 if (chain->NumInOperands() > 2) {
355 // Replace input access chain with another access chain.
356 BasicBlock::iterator chainIter(chain);
357 uint32_t replacementId = TakeNextId();
358 if (replacementId == 0) {
359 return false;
360 }
361 std::unique_ptr<Instruction> replacementChain(new Instruction(
362 context(), chain->opcode(), chain->type_id(), replacementId,
363 std::initializer_list<Operand>{
364 {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
365 // Add the remaining indexes.
366 for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
367 Operand copy(chain->GetInOperand(i));
368 replacementChain->AddOperand(std::move(copy));
369 }
370 replacementChain->UpdateDebugInfoFrom(chain);
371 auto iter = chainIter.InsertBefore(std::move(replacementChain));
372 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
373 context()->set_instr_block(&*iter, context()->get_instr_block(chain));
374 context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
375 } else {
376 // Replace with a use of the variable.
377 context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
378 }
379 }
380
381 return true;
382 }
383
CreateReplacementVariables(Instruction * inst,std::vector<Instruction * > * replacements)384 bool ScalarReplacementPass::CreateReplacementVariables(
385 Instruction* inst, std::vector<Instruction*>* replacements) {
386 Instruction* type = GetStorageType(inst);
387
388 std::unique_ptr<std::unordered_set<int64_t>> components_used =
389 GetUsedComponents(inst);
390
391 uint32_t elem = 0;
392 switch (type->opcode()) {
393 case spv::Op::OpTypeStruct:
394 type->ForEachInOperand(
395 [this, inst, &elem, replacements, &components_used](uint32_t* id) {
396 if (!components_used || components_used->count(elem)) {
397 CreateVariable(*id, inst, elem, replacements);
398 } else {
399 replacements->push_back(GetUndef(*id));
400 }
401 elem++;
402 });
403 break;
404 case spv::Op::OpTypeArray:
405 for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
406 if (!components_used || components_used->count(i)) {
407 CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
408 replacements);
409 } else {
410 uint32_t element_type_id = type->GetSingleWordInOperand(0);
411 replacements->push_back(GetUndef(element_type_id));
412 }
413 }
414 break;
415
416 case spv::Op::OpTypeMatrix:
417 case spv::Op::OpTypeVector:
418 for (uint32_t i = 0; i != GetNumElements(type); ++i) {
419 CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
420 }
421 break;
422
423 default:
424 assert(false && "Unexpected type.");
425 break;
426 }
427
428 TransferAnnotations(inst, replacements);
429 return std::find(replacements->begin(), replacements->end(), nullptr) ==
430 replacements->end();
431 }
432
GetUndef(uint32_t type_id)433 Instruction* ScalarReplacementPass::GetUndef(uint32_t type_id) {
434 return get_def_use_mgr()->GetDef(Type2Undef(type_id));
435 }
436
TransferAnnotations(const Instruction * source,std::vector<Instruction * > * replacements)437 void ScalarReplacementPass::TransferAnnotations(
438 const Instruction* source, std::vector<Instruction*>* replacements) {
439 // Only transfer invariant and restrict decorations on the variable. There are
440 // no type or member decorations that are necessary to transfer.
441 for (auto inst :
442 get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
443 assert(inst->opcode() == spv::Op::OpDecorate);
444 auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u));
445 if (decoration == spv::Decoration::Invariant ||
446 decoration == spv::Decoration::Restrict) {
447 for (auto var : *replacements) {
448 if (var == nullptr) {
449 continue;
450 }
451
452 std::unique_ptr<Instruction> annotation(new Instruction(
453 context(), spv::Op::OpDecorate, 0, 0,
454 std::initializer_list<Operand>{
455 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
456 {SPV_OPERAND_TYPE_DECORATION, {uint32_t(decoration)}}}));
457 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
458 Operand copy(inst->GetInOperand(i));
459 annotation->AddOperand(std::move(copy));
460 }
461 context()->AddAnnotationInst(std::move(annotation));
462 get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
463 }
464 }
465 }
466 }
467
CreateVariable(uint32_t type_id,Instruction * var_inst,uint32_t index,std::vector<Instruction * > * replacements)468 void ScalarReplacementPass::CreateVariable(
469 uint32_t type_id, Instruction* var_inst, uint32_t index,
470 std::vector<Instruction*>* replacements) {
471 uint32_t ptr_id = GetOrCreatePointerType(type_id);
472 uint32_t id = TakeNextId();
473
474 if (id == 0) {
475 replacements->push_back(nullptr);
476 }
477
478 std::unique_ptr<Instruction> variable(
479 new Instruction(context(), spv::Op::OpVariable, ptr_id, id,
480 std::initializer_list<Operand>{
481 {SPV_OPERAND_TYPE_STORAGE_CLASS,
482 {uint32_t(spv::StorageClass::Function)}}}));
483
484 BasicBlock* block = context()->get_instr_block(var_inst);
485 block->begin().InsertBefore(std::move(variable));
486 Instruction* inst = &*block->begin();
487
488 // If varInst was initialized, make sure to initialize its replacement.
489 GetOrCreateInitialValue(var_inst, index, inst);
490 get_def_use_mgr()->AnalyzeInstDefUse(inst);
491 context()->set_instr_block(inst, block);
492
493 CopyDecorationsToVariable(var_inst, inst, index);
494 inst->UpdateDebugInfoFrom(var_inst);
495
496 replacements->push_back(inst);
497 }
498
GetOrCreatePointerType(uint32_t id)499 uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
500 auto iter = pointee_to_pointer_.find(id);
501 if (iter != pointee_to_pointer_.end()) return iter->second;
502
503 analysis::TypeManager* type_mgr = context()->get_type_mgr();
504 uint32_t ptr_type_id =
505 type_mgr->FindPointerToType(id, spv::StorageClass::Function);
506 pointee_to_pointer_[id] = ptr_type_id;
507 return ptr_type_id;
508 }
509
GetOrCreateInitialValue(Instruction * source,uint32_t index,Instruction * newVar)510 void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
511 uint32_t index,
512 Instruction* newVar) {
513 assert(source->opcode() == spv::Op::OpVariable);
514 if (source->NumInOperands() < 2) return;
515
516 uint32_t initId = source->GetSingleWordInOperand(1u);
517 uint32_t storageId = GetStorageType(newVar)->result_id();
518 Instruction* init = get_def_use_mgr()->GetDef(initId);
519 uint32_t newInitId = 0;
520 // TODO(dnovillo): Refactor this with constant propagation.
521 if (init->opcode() == spv::Op::OpConstantNull) {
522 // Initialize to appropriate NULL.
523 auto iter = type_to_null_.find(storageId);
524 if (iter == type_to_null_.end()) {
525 newInitId = TakeNextId();
526 type_to_null_[storageId] = newInitId;
527 context()->AddGlobalValue(
528 MakeUnique<Instruction>(context(), spv::Op::OpConstantNull, storageId,
529 newInitId, std::initializer_list<Operand>{}));
530 Instruction* newNull = &*--context()->types_values_end();
531 get_def_use_mgr()->AnalyzeInstDefUse(newNull);
532 } else {
533 newInitId = iter->second;
534 }
535 } else if (IsSpecConstantInst(init->opcode())) {
536 // Create a new constant extract.
537 newInitId = TakeNextId();
538 context()->AddGlobalValue(MakeUnique<Instruction>(
539 context(), spv::Op::OpSpecConstantOp, storageId, newInitId,
540 std::initializer_list<Operand>{
541 {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER,
542 {uint32_t(spv::Op::OpCompositeExtract)}},
543 {SPV_OPERAND_TYPE_ID, {init->result_id()}},
544 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
545 Instruction* newSpecConst = &*--context()->types_values_end();
546 get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
547 } else if (init->opcode() == spv::Op::OpConstantComposite) {
548 // Get the appropriate index constant.
549 newInitId = init->GetSingleWordInOperand(index);
550 Instruction* element = get_def_use_mgr()->GetDef(newInitId);
551 if (element->opcode() == spv::Op::OpUndef) {
552 // Undef is not a valid initializer for a variable.
553 newInitId = 0;
554 }
555 } else {
556 assert(false);
557 }
558
559 if (newInitId != 0) {
560 newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
561 }
562 }
563
GetArrayLength(const Instruction * arrayType) const564 uint64_t ScalarReplacementPass::GetArrayLength(
565 const Instruction* arrayType) const {
566 assert(arrayType->opcode() == spv::Op::OpTypeArray);
567 const Instruction* length =
568 get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
569 return context()
570 ->get_constant_mgr()
571 ->GetConstantFromInst(length)
572 ->GetZeroExtendedValue();
573 }
574
GetNumElements(const Instruction * type) const575 uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
576 assert(type->opcode() == spv::Op::OpTypeVector ||
577 type->opcode() == spv::Op::OpTypeMatrix);
578 const Operand& op = type->GetInOperand(1u);
579 assert(op.words.size() <= 2);
580 uint64_t len = 0;
581 for (size_t i = 0; i != op.words.size(); ++i) {
582 len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
583 }
584 return len;
585 }
586
IsSpecConstant(uint32_t id) const587 bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
588 const Instruction* inst = get_def_use_mgr()->GetDef(id);
589 assert(inst);
590 return spvOpcodeIsSpecConstant(inst->opcode());
591 }
592
GetStorageType(const Instruction * inst) const593 Instruction* ScalarReplacementPass::GetStorageType(
594 const Instruction* inst) const {
595 assert(inst->opcode() == spv::Op::OpVariable);
596
597 uint32_t ptrTypeId = inst->type_id();
598 uint32_t typeId =
599 get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
600 return get_def_use_mgr()->GetDef(typeId);
601 }
602
CanReplaceVariable(const Instruction * varInst) const603 bool ScalarReplacementPass::CanReplaceVariable(
604 const Instruction* varInst) const {
605 assert(varInst->opcode() == spv::Op::OpVariable);
606
607 // Can only replace function scope variables.
608 if (spv::StorageClass(varInst->GetSingleWordInOperand(0u)) !=
609 spv::StorageClass::Function) {
610 return false;
611 }
612
613 if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
614 return false;
615 }
616
617 const Instruction* typeInst = GetStorageType(varInst);
618 if (!CheckType(typeInst)) {
619 return false;
620 }
621
622 if (!CheckAnnotations(varInst)) {
623 return false;
624 }
625
626 if (!CheckUses(varInst)) {
627 return false;
628 }
629
630 return true;
631 }
632
CheckType(const Instruction * typeInst) const633 bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
634 if (!CheckTypeAnnotations(typeInst)) {
635 return false;
636 }
637
638 switch (typeInst->opcode()) {
639 case spv::Op::OpTypeStruct:
640 // Don't bother with empty structs or very large structs.
641 if (typeInst->NumInOperands() == 0 ||
642 IsLargerThanSizeLimit(typeInst->NumInOperands())) {
643 return false;
644 }
645 return true;
646 case spv::Op::OpTypeArray:
647 if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
648 return false;
649 }
650 if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
651 return false;
652 }
653 return true;
654 // TODO(alanbaker): Develop some heuristics for when this should be
655 // re-enabled.
656 //// Specifically including matrix and vector in an attempt to reduce the
657 //// number of vector registers required.
658 // case spv::Op::OpTypeMatrix:
659 // case spv::Op::OpTypeVector:
660 // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
661 // return true;
662
663 case spv::Op::OpTypeRuntimeArray:
664 default:
665 return false;
666 }
667 }
668
CheckTypeAnnotations(const Instruction * typeInst) const669 bool ScalarReplacementPass::CheckTypeAnnotations(
670 const Instruction* typeInst) const {
671 for (auto inst :
672 get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
673 uint32_t decoration;
674 if (inst->opcode() == spv::Op::OpDecorate) {
675 decoration = inst->GetSingleWordInOperand(1u);
676 } else {
677 assert(inst->opcode() == spv::Op::OpMemberDecorate);
678 decoration = inst->GetSingleWordInOperand(2u);
679 }
680
681 switch (spv::Decoration(decoration)) {
682 case spv::Decoration::RowMajor:
683 case spv::Decoration::ColMajor:
684 case spv::Decoration::ArrayStride:
685 case spv::Decoration::MatrixStride:
686 case spv::Decoration::CPacked:
687 case spv::Decoration::Invariant:
688 case spv::Decoration::Restrict:
689 case spv::Decoration::Offset:
690 case spv::Decoration::Alignment:
691 case spv::Decoration::AlignmentId:
692 case spv::Decoration::MaxByteOffset:
693 case spv::Decoration::RelaxedPrecision:
694 case spv::Decoration::AliasedPointer:
695 case spv::Decoration::RestrictPointer:
696 break;
697 default:
698 return false;
699 }
700 }
701
702 return true;
703 }
704
CheckAnnotations(const Instruction * varInst) const705 bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
706 for (auto inst :
707 get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
708 assert(inst->opcode() == spv::Op::OpDecorate);
709 auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u));
710 switch (decoration) {
711 case spv::Decoration::Invariant:
712 case spv::Decoration::Restrict:
713 case spv::Decoration::Alignment:
714 case spv::Decoration::AlignmentId:
715 case spv::Decoration::MaxByteOffset:
716 case spv::Decoration::AliasedPointer:
717 case spv::Decoration::RestrictPointer:
718 break;
719 default:
720 return false;
721 }
722 }
723
724 return true;
725 }
726
CheckUses(const Instruction * inst) const727 bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
728 VariableStats stats = {0, 0};
729 bool ok = CheckUses(inst, &stats);
730
731 // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
732 // SRoA is costly, such as when the structure has many (unaccessed?)
733 // members.
734
735 return ok;
736 }
737
CheckUses(const Instruction * inst,VariableStats * stats) const738 bool ScalarReplacementPass::CheckUses(const Instruction* inst,
739 VariableStats* stats) const {
740 uint64_t max_legal_index = GetMaxLegalIndex(inst);
741
742 bool ok = true;
743 get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
744 const Instruction* user,
745 uint32_t index) {
746 if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare ||
747 user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
748 // TODO: include num_partial_accesses if it uses Fragment operation or
749 // DebugValue has Indexes operand.
750 stats->num_full_accesses++;
751 return;
752 }
753
754 // Annotations are check as a group separately.
755 if (!IsAnnotationInst(user->opcode())) {
756 switch (user->opcode()) {
757 case spv::Op::OpAccessChain:
758 case spv::Op::OpInBoundsAccessChain:
759 if (index == 2u && user->NumInOperands() > 1) {
760 uint32_t id = user->GetSingleWordInOperand(1u);
761 const Instruction* opInst = get_def_use_mgr()->GetDef(id);
762 const auto* constant =
763 context()->get_constant_mgr()->GetConstantFromInst(opInst);
764 if (!constant) {
765 ok = false;
766 } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
767 ok = false;
768 } else {
769 if (!CheckUsesRelaxed(user)) ok = false;
770 }
771 stats->num_partial_accesses++;
772 } else {
773 ok = false;
774 }
775 break;
776 case spv::Op::OpLoad:
777 if (!CheckLoad(user, index)) ok = false;
778 stats->num_full_accesses++;
779 break;
780 case spv::Op::OpStore:
781 if (!CheckStore(user, index)) ok = false;
782 stats->num_full_accesses++;
783 break;
784 case spv::Op::OpName:
785 case spv::Op::OpMemberName:
786 break;
787 default:
788 ok = false;
789 break;
790 }
791 }
792 });
793
794 return ok;
795 }
796
CheckUsesRelaxed(const Instruction * inst) const797 bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
798 bool ok = true;
799 get_def_use_mgr()->ForEachUse(
800 inst, [this, &ok](const Instruction* user, uint32_t index) {
801 switch (user->opcode()) {
802 case spv::Op::OpAccessChain:
803 case spv::Op::OpInBoundsAccessChain:
804 if (index != 2u) {
805 ok = false;
806 } else {
807 if (!CheckUsesRelaxed(user)) ok = false;
808 }
809 break;
810 case spv::Op::OpLoad:
811 if (!CheckLoad(user, index)) ok = false;
812 break;
813 case spv::Op::OpStore:
814 if (!CheckStore(user, index)) ok = false;
815 break;
816 case spv::Op::OpImageTexelPointer:
817 if (!CheckImageTexelPointer(index)) ok = false;
818 break;
819 case spv::Op::OpExtInst:
820 if (user->GetCommonDebugOpcode() != CommonDebugInfoDebugDeclare ||
821 !CheckDebugDeclare(index))
822 ok = false;
823 break;
824 default:
825 ok = false;
826 break;
827 }
828 });
829
830 return ok;
831 }
832
CheckImageTexelPointer(uint32_t index) const833 bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
834 return index == 2u;
835 }
836
CheckLoad(const Instruction * inst,uint32_t index) const837 bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
838 uint32_t index) const {
839 if (index != 2u) return false;
840 if (inst->NumInOperands() >= 2 &&
841 inst->GetSingleWordInOperand(1u) &
842 uint32_t(spv::MemoryAccessMask::Volatile))
843 return false;
844 return true;
845 }
846
CheckStore(const Instruction * inst,uint32_t index) const847 bool ScalarReplacementPass::CheckStore(const Instruction* inst,
848 uint32_t index) const {
849 if (index != 0u) return false;
850 if (inst->NumInOperands() >= 3 &&
851 inst->GetSingleWordInOperand(2u) &
852 uint32_t(spv::MemoryAccessMask::Volatile))
853 return false;
854 return true;
855 }
856
CheckDebugDeclare(uint32_t index) const857 bool ScalarReplacementPass::CheckDebugDeclare(uint32_t index) const {
858 if (index != kDebugDeclareOperandVariableIndex) return false;
859 return true;
860 }
861
IsLargerThanSizeLimit(uint64_t length) const862 bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
863 if (max_num_elements_ == 0) {
864 return false;
865 }
866 return length > max_num_elements_;
867 }
868
869 std::unique_ptr<std::unordered_set<int64_t>>
GetUsedComponents(Instruction * inst)870 ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
871 std::unique_ptr<std::unordered_set<int64_t>> result(
872 new std::unordered_set<int64_t>());
873
874 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
875
876 def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
877 this](Instruction* use) {
878 switch (use->opcode()) {
879 case spv::Op::OpLoad: {
880 // Look for extract from the load.
881 std::vector<uint32_t> t;
882 if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
883 if (use2->opcode() != spv::Op::OpCompositeExtract ||
884 use2->NumInOperands() <= 1) {
885 return false;
886 }
887 t.push_back(use2->GetSingleWordInOperand(1));
888 return true;
889 })) {
890 result->insert(t.begin(), t.end());
891 return true;
892 } else {
893 result.reset(nullptr);
894 return false;
895 }
896 }
897 case spv::Op::OpName:
898 case spv::Op::OpMemberName:
899 case spv::Op::OpStore:
900 // No components are used.
901 return true;
902 case spv::Op::OpAccessChain:
903 case spv::Op::OpInBoundsAccessChain: {
904 // Add the first index it if is a constant.
905 // TODO: Could be improved by checking if the address is used in a load.
906 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
907 uint32_t index_id = use->GetSingleWordInOperand(1);
908 const analysis::Constant* index_const =
909 const_mgr->FindDeclaredConstant(index_id);
910 if (index_const) {
911 result->insert(index_const->GetSignExtendedValue());
912 return true;
913 } else {
914 // Could be any element. Assuming all are used.
915 result.reset(nullptr);
916 return false;
917 }
918 }
919 default:
920 // We do not know what is happening. Have to assume the worst.
921 result.reset(nullptr);
922 return false;
923 }
924 });
925
926 return result;
927 }
928
GetMaxLegalIndex(const Instruction * var_inst) const929 uint64_t ScalarReplacementPass::GetMaxLegalIndex(
930 const Instruction* var_inst) const {
931 assert(var_inst->opcode() == spv::Op::OpVariable &&
932 "|var_inst| must be a variable instruction.");
933 Instruction* type = GetStorageType(var_inst);
934 switch (type->opcode()) {
935 case spv::Op::OpTypeStruct:
936 return type->NumInOperands();
937 case spv::Op::OpTypeArray:
938 return GetArrayLength(type);
939 case spv::Op::OpTypeMatrix:
940 case spv::Op::OpTypeVector:
941 return GetNumElements(type);
942 default:
943 return 0;
944 }
945 return 0;
946 }
947
CopyDecorationsToVariable(Instruction * from,Instruction * to,uint32_t member_index)948 void ScalarReplacementPass::CopyDecorationsToVariable(Instruction* from,
949 Instruction* to,
950 uint32_t member_index) {
951 CopyPointerDecorationsToVariable(from, to);
952 CopyNecessaryMemberDecorationsToVariable(from, to, member_index);
953 }
954
CopyPointerDecorationsToVariable(Instruction * from,Instruction * to)955 void ScalarReplacementPass::CopyPointerDecorationsToVariable(Instruction* from,
956 Instruction* to) {
957 // The RestrictPointer and AliasedPointer decorations are copied to all
958 // members even if the new variable does not contain a pointer. It does
959 // not hurt to do so.
960 for (auto dec_inst :
961 get_decoration_mgr()->GetDecorationsFor(from->result_id(), false)) {
962 uint32_t decoration;
963 decoration = dec_inst->GetSingleWordInOperand(1u);
964 switch (spv::Decoration(decoration)) {
965 case spv::Decoration::AliasedPointer:
966 case spv::Decoration::RestrictPointer: {
967 std::unique_ptr<Instruction> new_dec_inst(dec_inst->Clone(context()));
968 new_dec_inst->SetInOperand(0, {to->result_id()});
969 context()->AddAnnotationInst(std::move(new_dec_inst));
970 } break;
971 default:
972 break;
973 }
974 }
975 }
976
CopyNecessaryMemberDecorationsToVariable(Instruction * from,Instruction * to,uint32_t member_index)977 void ScalarReplacementPass::CopyNecessaryMemberDecorationsToVariable(
978 Instruction* from, Instruction* to, uint32_t member_index) {
979 Instruction* type_inst = GetStorageType(from);
980 for (auto dec_inst :
981 get_decoration_mgr()->GetDecorationsFor(type_inst->result_id(), false)) {
982 uint32_t decoration;
983 if (dec_inst->opcode() == spv::Op::OpMemberDecorate) {
984 if (dec_inst->GetSingleWordInOperand(1) != member_index) {
985 continue;
986 }
987
988 decoration = dec_inst->GetSingleWordInOperand(2u);
989 switch (spv::Decoration(decoration)) {
990 case spv::Decoration::ArrayStride:
991 case spv::Decoration::Alignment:
992 case spv::Decoration::AlignmentId:
993 case spv::Decoration::MaxByteOffset:
994 case spv::Decoration::MaxByteOffsetId:
995 case spv::Decoration::RelaxedPrecision: {
996 std::unique_ptr<Instruction> new_dec_inst(
997 new Instruction(context(), spv::Op::OpDecorate, 0, 0, {}));
998 new_dec_inst->AddOperand(
999 Operand(SPV_OPERAND_TYPE_ID, {to->result_id()}));
1000 for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
1001 new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
1002 }
1003 context()->AddAnnotationInst(std::move(new_dec_inst));
1004 } break;
1005 default:
1006 break;
1007 }
1008 }
1009 }
1010 }
1011
1012 } // namespace opt
1013 } // namespace spvtools
1014