1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/scalar_replacement_pass.h"
16
17 #include <algorithm>
18 #include <queue>
19 #include <tuple>
20 #include <utility>
21
22 #include "source/enum_string_mapping.h"
23 #include "source/extensions.h"
24 #include "source/opt/reflect.h"
25 #include "source/opt/types.h"
26 #include "source/util/make_unique.h"
27
28 static const uint32_t kDebugValueOperandValueIndex = 5;
29 static const uint32_t kDebugValueOperandExpressionIndex = 6;
30 static const uint32_t kDebugDeclareOperandVariableIndex = 5;
31
32 namespace spvtools {
33 namespace opt {
34
Process()35 Pass::Status ScalarReplacementPass::Process() {
36 Status status = Status::SuccessWithoutChange;
37 for (auto& f : *get_module()) {
38 if (f.IsDeclaration()) {
39 continue;
40 }
41
42 Status functionStatus = ProcessFunction(&f);
43 if (functionStatus == Status::Failure)
44 return functionStatus;
45 else if (functionStatus == Status::SuccessWithChange)
46 status = functionStatus;
47 }
48
49 return status;
50 }
51
ProcessFunction(Function * function)52 Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
53 std::queue<Instruction*> worklist;
54 BasicBlock& entry = *function->begin();
55 for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
56 // Function storage class OpVariables must appear as the first instructions
57 // of the entry block.
58 if (iter->opcode() != SpvOpVariable) break;
59
60 Instruction* varInst = &*iter;
61 if (CanReplaceVariable(varInst)) {
62 worklist.push(varInst);
63 }
64 }
65
66 Status status = Status::SuccessWithoutChange;
67 while (!worklist.empty()) {
68 Instruction* varInst = worklist.front();
69 worklist.pop();
70
71 Status var_status = ReplaceVariable(varInst, &worklist);
72 if (var_status == Status::Failure)
73 return var_status;
74 else if (var_status == Status::SuccessWithChange)
75 status = var_status;
76 }
77
78 return status;
79 }
80
ReplaceVariable(Instruction * inst,std::queue<Instruction * > * worklist)81 Pass::Status ScalarReplacementPass::ReplaceVariable(
82 Instruction* inst, std::queue<Instruction*>* worklist) {
83 std::vector<Instruction*> replacements;
84 if (!CreateReplacementVariables(inst, &replacements)) {
85 return Status::Failure;
86 }
87
88 std::vector<Instruction*> dead;
89 bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
90 inst, [this, &replacements, &dead](Instruction* user) {
91 if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
92 if (ReplaceWholeDebugDeclare(user, replacements)) {
93 dead.push_back(user);
94 return true;
95 }
96 return false;
97 }
98 if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
99 if (ReplaceWholeDebugValue(user, replacements)) {
100 dead.push_back(user);
101 return true;
102 }
103 return false;
104 }
105 if (!IsAnnotationInst(user->opcode())) {
106 switch (user->opcode()) {
107 case SpvOpLoad:
108 if (ReplaceWholeLoad(user, replacements)) {
109 dead.push_back(user);
110 } else {
111 return false;
112 }
113 break;
114 case SpvOpStore:
115 if (ReplaceWholeStore(user, replacements)) {
116 dead.push_back(user);
117 } else {
118 return false;
119 }
120 break;
121 case SpvOpAccessChain:
122 case SpvOpInBoundsAccessChain:
123 if (ReplaceAccessChain(user, replacements))
124 dead.push_back(user);
125 else
126 return false;
127 break;
128 case SpvOpName:
129 case SpvOpMemberName:
130 break;
131 default:
132 assert(false && "Unexpected opcode");
133 break;
134 }
135 }
136 return true;
137 });
138
139 if (replaced_all_uses) {
140 dead.push_back(inst);
141 } else {
142 return Status::Failure;
143 }
144
145 // If there are no dead instructions to clean up, return with no changes.
146 if (dead.empty()) return Status::SuccessWithoutChange;
147
148 // Clean up some dead code.
149 while (!dead.empty()) {
150 Instruction* toKill = dead.back();
151 dead.pop_back();
152 context()->KillInst(toKill);
153 }
154
155 // Attempt to further scalarize.
156 for (auto var : replacements) {
157 if (var->opcode() == SpvOpVariable) {
158 if (get_def_use_mgr()->NumUsers(var) == 0) {
159 context()->KillInst(var);
160 } else if (CanReplaceVariable(var)) {
161 worklist->push(var);
162 }
163 }
164 }
165
166 return Status::SuccessWithChange;
167 }
168
ReplaceWholeDebugDeclare(Instruction * dbg_decl,const std::vector<Instruction * > & replacements)169 bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
170 Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
171 // Insert Deref operation to the front of the operation list of |dbg_decl|.
172 Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
173 dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
174 auto* deref_expr =
175 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
176
177 // Add DebugValue instruction with Indexes operand and Deref operation.
178 int32_t idx = 0;
179 for (const auto* var : replacements) {
180 Instruction* insert_before = var->NextNode();
181 while (insert_before->opcode() == SpvOpVariable)
182 insert_before = insert_before->NextNode();
183 assert(insert_before != nullptr && "unexpected end of list");
184 Instruction* added_dbg_value =
185 context()->get_debug_info_mgr()->AddDebugValueForDecl(
186 dbg_decl, /*value_id=*/var->result_id(),
187 /*insert_before=*/insert_before, /*scope_and_line=*/dbg_decl);
188
189 if (added_dbg_value == nullptr) return false;
190 added_dbg_value->AddOperand(
191 {SPV_OPERAND_TYPE_ID,
192 {context()->get_constant_mgr()->GetSIntConst(idx)}});
193 added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
194 {deref_expr->result_id()});
195 if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
196 context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
197 }
198 ++idx;
199 }
200 return true;
201 }
202
ReplaceWholeDebugValue(Instruction * dbg_value,const std::vector<Instruction * > & replacements)203 bool ScalarReplacementPass::ReplaceWholeDebugValue(
204 Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
205 int32_t idx = 0;
206 BasicBlock* block = context()->get_instr_block(dbg_value);
207 for (auto var : replacements) {
208 // Clone the DebugValue.
209 std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
210 uint32_t new_id = TakeNextId();
211 if (new_id == 0) return false;
212 new_dbg_value->SetResultId(new_id);
213 // Update 'Value' operand to the |replacements|.
214 new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
215 // Append 'Indexes' operand.
216 new_dbg_value->AddOperand(
217 {SPV_OPERAND_TYPE_ID,
218 {context()->get_constant_mgr()->GetSIntConst(idx)}});
219 // Insert the new DebugValue to the basic block.
220 auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
221 get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
222 context()->set_instr_block(added_instr, block);
223 ++idx;
224 }
225 return true;
226 }
227
ReplaceWholeLoad(Instruction * load,const std::vector<Instruction * > & replacements)228 bool ScalarReplacementPass::ReplaceWholeLoad(
229 Instruction* load, const std::vector<Instruction*>& replacements) {
230 // Replaces the load of the entire composite with a load from each replacement
231 // variable followed by a composite construction.
232 BasicBlock* block = context()->get_instr_block(load);
233 std::vector<Instruction*> loads;
234 loads.reserve(replacements.size());
235 BasicBlock::iterator where(load);
236 for (auto var : replacements) {
237 // Create a load of each replacement variable.
238 if (var->opcode() != SpvOpVariable) {
239 loads.push_back(var);
240 continue;
241 }
242
243 Instruction* type = GetStorageType(var);
244 uint32_t loadId = TakeNextId();
245 if (loadId == 0) {
246 return false;
247 }
248 std::unique_ptr<Instruction> newLoad(
249 new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
250 std::initializer_list<Operand>{
251 {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
252 // Copy memory access attributes which start at index 1. Index 0 is the
253 // pointer to load.
254 for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
255 Operand copy(load->GetInOperand(i));
256 newLoad->AddOperand(std::move(copy));
257 }
258 where = where.InsertBefore(std::move(newLoad));
259 get_def_use_mgr()->AnalyzeInstDefUse(&*where);
260 context()->set_instr_block(&*where, block);
261 where->UpdateDebugInfoFrom(load);
262 loads.push_back(&*where);
263 }
264
265 // Construct a new composite.
266 uint32_t compositeId = TakeNextId();
267 if (compositeId == 0) {
268 return false;
269 }
270 where = load;
271 std::unique_ptr<Instruction> compositeConstruct(new Instruction(
272 context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
273 for (auto l : loads) {
274 Operand op(SPV_OPERAND_TYPE_ID,
275 std::initializer_list<uint32_t>{l->result_id()});
276 compositeConstruct->AddOperand(std::move(op));
277 }
278 where = where.InsertBefore(std::move(compositeConstruct));
279 get_def_use_mgr()->AnalyzeInstDefUse(&*where);
280 where->UpdateDebugInfoFrom(load);
281 context()->set_instr_block(&*where, block);
282 context()->ReplaceAllUsesWith(load->result_id(), compositeId);
283 return true;
284 }
285
ReplaceWholeStore(Instruction * store,const std::vector<Instruction * > & replacements)286 bool ScalarReplacementPass::ReplaceWholeStore(
287 Instruction* store, const std::vector<Instruction*>& replacements) {
288 // Replaces a store to the whole composite with a series of extract and stores
289 // to each element.
290 uint32_t storeInput = store->GetSingleWordInOperand(1u);
291 BasicBlock* block = context()->get_instr_block(store);
292 BasicBlock::iterator where(store);
293 uint32_t elementIndex = 0;
294 for (auto var : replacements) {
295 // Create the extract.
296 if (var->opcode() != SpvOpVariable) {
297 elementIndex++;
298 continue;
299 }
300
301 Instruction* type = GetStorageType(var);
302 uint32_t extractId = TakeNextId();
303 if (extractId == 0) {
304 return false;
305 }
306 std::unique_ptr<Instruction> extract(new Instruction(
307 context(), SpvOpCompositeExtract, type->result_id(), extractId,
308 std::initializer_list<Operand>{
309 {SPV_OPERAND_TYPE_ID, {storeInput}},
310 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
311 auto iter = where.InsertBefore(std::move(extract));
312 iter->UpdateDebugInfoFrom(store);
313 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
314 context()->set_instr_block(&*iter, block);
315
316 // Create the store.
317 std::unique_ptr<Instruction> newStore(
318 new Instruction(context(), SpvOpStore, 0, 0,
319 std::initializer_list<Operand>{
320 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
321 {SPV_OPERAND_TYPE_ID, {extractId}}}));
322 // Copy memory access attributes which start at index 2. Index 0 is the
323 // pointer and index 1 is the data.
324 for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
325 Operand copy(store->GetInOperand(i));
326 newStore->AddOperand(std::move(copy));
327 }
328 iter = where.InsertBefore(std::move(newStore));
329 iter->UpdateDebugInfoFrom(store);
330 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
331 context()->set_instr_block(&*iter, block);
332 }
333 return true;
334 }
335
ReplaceAccessChain(Instruction * chain,const std::vector<Instruction * > & replacements)336 bool ScalarReplacementPass::ReplaceAccessChain(
337 Instruction* chain, const std::vector<Instruction*>& replacements) {
338 // Replaces the access chain with either another access chain (with one fewer
339 // indexes) or a direct use of the replacement variable.
340 uint32_t indexId = chain->GetSingleWordInOperand(1u);
341 const Instruction* index = get_def_use_mgr()->GetDef(indexId);
342 int64_t indexValue = context()
343 ->get_constant_mgr()
344 ->GetConstantFromInst(index)
345 ->GetSignExtendedValue();
346 if (indexValue < 0 ||
347 indexValue >= static_cast<int64_t>(replacements.size())) {
348 // Out of bounds access, this is illegal IR. Notice that OpAccessChain
349 // indexing is 0-based, so we should also reject index == size-of-array.
350 return false;
351 } else {
352 const Instruction* var = replacements[static_cast<size_t>(indexValue)];
353 if (chain->NumInOperands() > 2) {
354 // Replace input access chain with another access chain.
355 BasicBlock::iterator chainIter(chain);
356 uint32_t replacementId = TakeNextId();
357 if (replacementId == 0) {
358 return false;
359 }
360 std::unique_ptr<Instruction> replacementChain(new Instruction(
361 context(), chain->opcode(), chain->type_id(), replacementId,
362 std::initializer_list<Operand>{
363 {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
364 // Add the remaining indexes.
365 for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
366 Operand copy(chain->GetInOperand(i));
367 replacementChain->AddOperand(std::move(copy));
368 }
369 replacementChain->UpdateDebugInfoFrom(chain);
370 auto iter = chainIter.InsertBefore(std::move(replacementChain));
371 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
372 context()->set_instr_block(&*iter, context()->get_instr_block(chain));
373 context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
374 } else {
375 // Replace with a use of the variable.
376 context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
377 }
378 }
379
380 return true;
381 }
382
CreateReplacementVariables(Instruction * inst,std::vector<Instruction * > * replacements)383 bool ScalarReplacementPass::CreateReplacementVariables(
384 Instruction* inst, std::vector<Instruction*>* replacements) {
385 Instruction* type = GetStorageType(inst);
386
387 std::unique_ptr<std::unordered_set<int64_t>> components_used =
388 GetUsedComponents(inst);
389
390 uint32_t elem = 0;
391 switch (type->opcode()) {
392 case SpvOpTypeStruct:
393 type->ForEachInOperand(
394 [this, inst, &elem, replacements, &components_used](uint32_t* id) {
395 if (!components_used || components_used->count(elem)) {
396 CreateVariable(*id, inst, elem, replacements);
397 } else {
398 replacements->push_back(CreateNullConstant(*id));
399 }
400 elem++;
401 });
402 break;
403 case SpvOpTypeArray:
404 for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
405 if (!components_used || components_used->count(i)) {
406 CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
407 replacements);
408 } else {
409 replacements->push_back(
410 CreateNullConstant(type->GetSingleWordInOperand(0u)));
411 }
412 }
413 break;
414
415 case SpvOpTypeMatrix:
416 case SpvOpTypeVector:
417 for (uint32_t i = 0; i != GetNumElements(type); ++i) {
418 CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
419 }
420 break;
421
422 default:
423 assert(false && "Unexpected type.");
424 break;
425 }
426
427 TransferAnnotations(inst, replacements);
428 return std::find(replacements->begin(), replacements->end(), nullptr) ==
429 replacements->end();
430 }
431
TransferAnnotations(const Instruction * source,std::vector<Instruction * > * replacements)432 void ScalarReplacementPass::TransferAnnotations(
433 const Instruction* source, std::vector<Instruction*>* replacements) {
434 // Only transfer invariant and restrict decorations on the variable. There are
435 // no type or member decorations that are necessary to transfer.
436 for (auto inst :
437 get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
438 assert(inst->opcode() == SpvOpDecorate);
439 uint32_t decoration = inst->GetSingleWordInOperand(1u);
440 if (decoration == SpvDecorationInvariant ||
441 decoration == SpvDecorationRestrict) {
442 for (auto var : *replacements) {
443 if (var == nullptr) {
444 continue;
445 }
446
447 std::unique_ptr<Instruction> annotation(
448 new Instruction(context(), SpvOpDecorate, 0, 0,
449 std::initializer_list<Operand>{
450 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
451 {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
452 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
453 Operand copy(inst->GetInOperand(i));
454 annotation->AddOperand(std::move(copy));
455 }
456 context()->AddAnnotationInst(std::move(annotation));
457 get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
458 }
459 }
460 }
461 }
462
CreateVariable(uint32_t typeId,Instruction * varInst,uint32_t index,std::vector<Instruction * > * replacements)463 void ScalarReplacementPass::CreateVariable(
464 uint32_t typeId, Instruction* varInst, uint32_t index,
465 std::vector<Instruction*>* replacements) {
466 uint32_t ptrId = GetOrCreatePointerType(typeId);
467 uint32_t id = TakeNextId();
468
469 if (id == 0) {
470 replacements->push_back(nullptr);
471 }
472
473 std::unique_ptr<Instruction> variable(new Instruction(
474 context(), SpvOpVariable, ptrId, id,
475 std::initializer_list<Operand>{
476 {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
477
478 BasicBlock* block = context()->get_instr_block(varInst);
479 block->begin().InsertBefore(std::move(variable));
480 Instruction* inst = &*block->begin();
481
482 // If varInst was initialized, make sure to initialize its replacement.
483 GetOrCreateInitialValue(varInst, index, inst);
484 get_def_use_mgr()->AnalyzeInstDefUse(inst);
485 context()->set_instr_block(inst, block);
486
487 // Copy decorations from the member to the new variable.
488 Instruction* typeInst = GetStorageType(varInst);
489 for (auto dec_inst :
490 get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
491 uint32_t decoration;
492 if (dec_inst->opcode() != SpvOpMemberDecorate) {
493 continue;
494 }
495
496 if (dec_inst->GetSingleWordInOperand(1) != index) {
497 continue;
498 }
499
500 decoration = dec_inst->GetSingleWordInOperand(2u);
501 switch (decoration) {
502 case SpvDecorationRelaxedPrecision: {
503 std::unique_ptr<Instruction> new_dec_inst(
504 new Instruction(context(), SpvOpDecorate, 0, 0, {}));
505 new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
506 for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
507 new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
508 }
509 context()->AddAnnotationInst(std::move(new_dec_inst));
510 } break;
511 default:
512 break;
513 }
514 }
515
516 // Update the DebugInfo debug information.
517 inst->UpdateDebugInfoFrom(varInst);
518
519 replacements->push_back(inst);
520 }
521
GetOrCreatePointerType(uint32_t id)522 uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
523 auto iter = pointee_to_pointer_.find(id);
524 if (iter != pointee_to_pointer_.end()) return iter->second;
525
526 analysis::Type* pointeeTy;
527 std::unique_ptr<analysis::Pointer> pointerTy;
528 std::tie(pointeeTy, pointerTy) =
529 context()->get_type_mgr()->GetTypeAndPointerType(id,
530 SpvStorageClassFunction);
531 uint32_t ptrId = 0;
532 if (pointeeTy->IsUniqueType()) {
533 // Non-ambiguous type, just ask the type manager for an id.
534 ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
535 pointee_to_pointer_[id] = ptrId;
536 return ptrId;
537 }
538
539 // Ambiguous type. We must perform a linear search to try and find the right
540 // type.
541 for (auto global : context()->types_values()) {
542 if (global.opcode() == SpvOpTypePointer &&
543 global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
544 global.GetSingleWordInOperand(1u) == id) {
545 if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
546 // Only reuse a decoration-less pointer of the correct type.
547 ptrId = global.result_id();
548 break;
549 }
550 }
551 }
552
553 if (ptrId != 0) {
554 pointee_to_pointer_[id] = ptrId;
555 return ptrId;
556 }
557
558 ptrId = TakeNextId();
559 context()->AddType(MakeUnique<Instruction>(
560 context(), SpvOpTypePointer, 0, ptrId,
561 std::initializer_list<Operand>{
562 {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
563 {SPV_OPERAND_TYPE_ID, {id}}}));
564 Instruction* ptr = &*--context()->types_values_end();
565 get_def_use_mgr()->AnalyzeInstDefUse(ptr);
566 pointee_to_pointer_[id] = ptrId;
567 // Register with the type manager if necessary.
568 context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
569
570 return ptrId;
571 }
572
GetOrCreateInitialValue(Instruction * source,uint32_t index,Instruction * newVar)573 void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
574 uint32_t index,
575 Instruction* newVar) {
576 assert(source->opcode() == SpvOpVariable);
577 if (source->NumInOperands() < 2) return;
578
579 uint32_t initId = source->GetSingleWordInOperand(1u);
580 uint32_t storageId = GetStorageType(newVar)->result_id();
581 Instruction* init = get_def_use_mgr()->GetDef(initId);
582 uint32_t newInitId = 0;
583 // TODO(dnovillo): Refactor this with constant propagation.
584 if (init->opcode() == SpvOpConstantNull) {
585 // Initialize to appropriate NULL.
586 auto iter = type_to_null_.find(storageId);
587 if (iter == type_to_null_.end()) {
588 newInitId = TakeNextId();
589 type_to_null_[storageId] = newInitId;
590 context()->AddGlobalValue(
591 MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
592 newInitId, std::initializer_list<Operand>{}));
593 Instruction* newNull = &*--context()->types_values_end();
594 get_def_use_mgr()->AnalyzeInstDefUse(newNull);
595 } else {
596 newInitId = iter->second;
597 }
598 } else if (IsSpecConstantInst(init->opcode())) {
599 // Create a new constant extract.
600 newInitId = TakeNextId();
601 context()->AddGlobalValue(MakeUnique<Instruction>(
602 context(), SpvOpSpecConstantOp, storageId, newInitId,
603 std::initializer_list<Operand>{
604 {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
605 {SPV_OPERAND_TYPE_ID, {init->result_id()}},
606 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
607 Instruction* newSpecConst = &*--context()->types_values_end();
608 get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
609 } else if (init->opcode() == SpvOpConstantComposite) {
610 // Get the appropriate index constant.
611 newInitId = init->GetSingleWordInOperand(index);
612 Instruction* element = get_def_use_mgr()->GetDef(newInitId);
613 if (element->opcode() == SpvOpUndef) {
614 // Undef is not a valid initializer for a variable.
615 newInitId = 0;
616 }
617 } else {
618 assert(false);
619 }
620
621 if (newInitId != 0) {
622 newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
623 }
624 }
625
GetArrayLength(const Instruction * arrayType) const626 uint64_t ScalarReplacementPass::GetArrayLength(
627 const Instruction* arrayType) const {
628 assert(arrayType->opcode() == SpvOpTypeArray);
629 const Instruction* length =
630 get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
631 return context()
632 ->get_constant_mgr()
633 ->GetConstantFromInst(length)
634 ->GetZeroExtendedValue();
635 }
636
GetNumElements(const Instruction * type) const637 uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
638 assert(type->opcode() == SpvOpTypeVector ||
639 type->opcode() == SpvOpTypeMatrix);
640 const Operand& op = type->GetInOperand(1u);
641 assert(op.words.size() <= 2);
642 uint64_t len = 0;
643 for (size_t i = 0; i != op.words.size(); ++i) {
644 len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
645 }
646 return len;
647 }
648
IsSpecConstant(uint32_t id) const649 bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
650 const Instruction* inst = get_def_use_mgr()->GetDef(id);
651 assert(inst);
652 return spvOpcodeIsSpecConstant(inst->opcode());
653 }
654
GetStorageType(const Instruction * inst) const655 Instruction* ScalarReplacementPass::GetStorageType(
656 const Instruction* inst) const {
657 assert(inst->opcode() == SpvOpVariable);
658
659 uint32_t ptrTypeId = inst->type_id();
660 uint32_t typeId =
661 get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
662 return get_def_use_mgr()->GetDef(typeId);
663 }
664
CanReplaceVariable(const Instruction * varInst) const665 bool ScalarReplacementPass::CanReplaceVariable(
666 const Instruction* varInst) const {
667 assert(varInst->opcode() == SpvOpVariable);
668
669 // Can only replace function scope variables.
670 if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
671 return false;
672 }
673
674 if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
675 return false;
676 }
677
678 const Instruction* typeInst = GetStorageType(varInst);
679 if (!CheckType(typeInst)) {
680 return false;
681 }
682
683 if (!CheckAnnotations(varInst)) {
684 return false;
685 }
686
687 if (!CheckUses(varInst)) {
688 return false;
689 }
690
691 return true;
692 }
693
CheckType(const Instruction * typeInst) const694 bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
695 if (!CheckTypeAnnotations(typeInst)) {
696 return false;
697 }
698
699 switch (typeInst->opcode()) {
700 case SpvOpTypeStruct:
701 // Don't bother with empty structs or very large structs.
702 if (typeInst->NumInOperands() == 0 ||
703 IsLargerThanSizeLimit(typeInst->NumInOperands())) {
704 return false;
705 }
706 return true;
707 case SpvOpTypeArray:
708 if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
709 return false;
710 }
711 if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
712 return false;
713 }
714 return true;
715 // TODO(alanbaker): Develop some heuristics for when this should be
716 // re-enabled.
717 //// Specifically including matrix and vector in an attempt to reduce the
718 //// number of vector registers required.
719 // case SpvOpTypeMatrix:
720 // case SpvOpTypeVector:
721 // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
722 // return true;
723
724 case SpvOpTypeRuntimeArray:
725 default:
726 return false;
727 }
728 }
729
CheckTypeAnnotations(const Instruction * typeInst) const730 bool ScalarReplacementPass::CheckTypeAnnotations(
731 const Instruction* typeInst) const {
732 for (auto inst :
733 get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
734 uint32_t decoration;
735 if (inst->opcode() == SpvOpDecorate) {
736 decoration = inst->GetSingleWordInOperand(1u);
737 } else {
738 assert(inst->opcode() == SpvOpMemberDecorate);
739 decoration = inst->GetSingleWordInOperand(2u);
740 }
741
742 switch (decoration) {
743 case SpvDecorationRowMajor:
744 case SpvDecorationColMajor:
745 case SpvDecorationArrayStride:
746 case SpvDecorationMatrixStride:
747 case SpvDecorationCPacked:
748 case SpvDecorationInvariant:
749 case SpvDecorationRestrict:
750 case SpvDecorationOffset:
751 case SpvDecorationAlignment:
752 case SpvDecorationAlignmentId:
753 case SpvDecorationMaxByteOffset:
754 case SpvDecorationRelaxedPrecision:
755 break;
756 default:
757 return false;
758 }
759 }
760
761 return true;
762 }
763
CheckAnnotations(const Instruction * varInst) const764 bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
765 for (auto inst :
766 get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
767 assert(inst->opcode() == SpvOpDecorate);
768 uint32_t decoration = inst->GetSingleWordInOperand(1u);
769 switch (decoration) {
770 case SpvDecorationInvariant:
771 case SpvDecorationRestrict:
772 case SpvDecorationAlignment:
773 case SpvDecorationAlignmentId:
774 case SpvDecorationMaxByteOffset:
775 break;
776 default:
777 return false;
778 }
779 }
780
781 return true;
782 }
783
CheckUses(const Instruction * inst) const784 bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
785 VariableStats stats = {0, 0};
786 bool ok = CheckUses(inst, &stats);
787
788 // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
789 // SRoA is costly, such as when the structure has many (unaccessed?)
790 // members.
791
792 return ok;
793 }
794
CheckUses(const Instruction * inst,VariableStats * stats) const795 bool ScalarReplacementPass::CheckUses(const Instruction* inst,
796 VariableStats* stats) const {
797 uint64_t max_legal_index = GetMaxLegalIndex(inst);
798
799 bool ok = true;
800 get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
801 const Instruction* user,
802 uint32_t index) {
803 if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare ||
804 user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
805 // TODO: include num_partial_accesses if it uses Fragment operation or
806 // DebugValue has Indexes operand.
807 stats->num_full_accesses++;
808 return;
809 }
810
811 // Annotations are check as a group separately.
812 if (!IsAnnotationInst(user->opcode())) {
813 switch (user->opcode()) {
814 case SpvOpAccessChain:
815 case SpvOpInBoundsAccessChain:
816 if (index == 2u && user->NumInOperands() > 1) {
817 uint32_t id = user->GetSingleWordInOperand(1u);
818 const Instruction* opInst = get_def_use_mgr()->GetDef(id);
819 const auto* constant =
820 context()->get_constant_mgr()->GetConstantFromInst(opInst);
821 if (!constant) {
822 ok = false;
823 } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
824 ok = false;
825 } else {
826 if (!CheckUsesRelaxed(user)) ok = false;
827 }
828 stats->num_partial_accesses++;
829 } else {
830 ok = false;
831 }
832 break;
833 case SpvOpLoad:
834 if (!CheckLoad(user, index)) ok = false;
835 stats->num_full_accesses++;
836 break;
837 case SpvOpStore:
838 if (!CheckStore(user, index)) ok = false;
839 stats->num_full_accesses++;
840 break;
841 case SpvOpName:
842 case SpvOpMemberName:
843 break;
844 default:
845 ok = false;
846 break;
847 }
848 }
849 });
850
851 return ok;
852 }
853
CheckUsesRelaxed(const Instruction * inst) const854 bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
855 bool ok = true;
856 get_def_use_mgr()->ForEachUse(
857 inst, [this, &ok](const Instruction* user, uint32_t index) {
858 switch (user->opcode()) {
859 case SpvOpAccessChain:
860 case SpvOpInBoundsAccessChain:
861 if (index != 2u) {
862 ok = false;
863 } else {
864 if (!CheckUsesRelaxed(user)) ok = false;
865 }
866 break;
867 case SpvOpLoad:
868 if (!CheckLoad(user, index)) ok = false;
869 break;
870 case SpvOpStore:
871 if (!CheckStore(user, index)) ok = false;
872 break;
873 case SpvOpImageTexelPointer:
874 if (!CheckImageTexelPointer(index)) ok = false;
875 break;
876 case SpvOpExtInst:
877 if (user->GetCommonDebugOpcode() != CommonDebugInfoDebugDeclare ||
878 !CheckDebugDeclare(index))
879 ok = false;
880 break;
881 default:
882 ok = false;
883 break;
884 }
885 });
886
887 return ok;
888 }
889
CheckImageTexelPointer(uint32_t index) const890 bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
891 return index == 2u;
892 }
893
CheckLoad(const Instruction * inst,uint32_t index) const894 bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
895 uint32_t index) const {
896 if (index != 2u) return false;
897 if (inst->NumInOperands() >= 2 &&
898 inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
899 return false;
900 return true;
901 }
902
CheckStore(const Instruction * inst,uint32_t index) const903 bool ScalarReplacementPass::CheckStore(const Instruction* inst,
904 uint32_t index) const {
905 if (index != 0u) return false;
906 if (inst->NumInOperands() >= 3 &&
907 inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
908 return false;
909 return true;
910 }
911
CheckDebugDeclare(uint32_t index) const912 bool ScalarReplacementPass::CheckDebugDeclare(uint32_t index) const {
913 if (index != kDebugDeclareOperandVariableIndex) return false;
914 return true;
915 }
916
IsLargerThanSizeLimit(uint64_t length) const917 bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
918 if (max_num_elements_ == 0) {
919 return false;
920 }
921 return length > max_num_elements_;
922 }
923
924 std::unique_ptr<std::unordered_set<int64_t>>
GetUsedComponents(Instruction * inst)925 ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
926 std::unique_ptr<std::unordered_set<int64_t>> result(
927 new std::unordered_set<int64_t>());
928
929 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
930
931 def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
932 this](Instruction* use) {
933 switch (use->opcode()) {
934 case SpvOpLoad: {
935 // Look for extract from the load.
936 std::vector<uint32_t> t;
937 if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
938 if (use2->opcode() != SpvOpCompositeExtract ||
939 use2->NumInOperands() <= 1) {
940 return false;
941 }
942 t.push_back(use2->GetSingleWordInOperand(1));
943 return true;
944 })) {
945 result->insert(t.begin(), t.end());
946 return true;
947 } else {
948 result.reset(nullptr);
949 return false;
950 }
951 }
952 case SpvOpName:
953 case SpvOpMemberName:
954 case SpvOpStore:
955 // No components are used.
956 return true;
957 case SpvOpAccessChain:
958 case SpvOpInBoundsAccessChain: {
959 // Add the first index it if is a constant.
960 // TODO: Could be improved by checking if the address is used in a load.
961 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
962 uint32_t index_id = use->GetSingleWordInOperand(1);
963 const analysis::Constant* index_const =
964 const_mgr->FindDeclaredConstant(index_id);
965 if (index_const) {
966 result->insert(index_const->GetSignExtendedValue());
967 return true;
968 } else {
969 // Could be any element. Assuming all are used.
970 result.reset(nullptr);
971 return false;
972 }
973 }
974 default:
975 // We do not know what is happening. Have to assume the worst.
976 result.reset(nullptr);
977 return false;
978 }
979 });
980
981 return result;
982 }
983
CreateNullConstant(uint32_t type_id)984 Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
985 analysis::TypeManager* type_mgr = context()->get_type_mgr();
986 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
987
988 const analysis::Type* type = type_mgr->GetType(type_id);
989 const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
990 Instruction* null_inst =
991 const_mgr->GetDefiningInstruction(null_const, type_id);
992 if (null_inst != nullptr) {
993 context()->UpdateDefUse(null_inst);
994 }
995 return null_inst;
996 }
997
GetMaxLegalIndex(const Instruction * var_inst) const998 uint64_t ScalarReplacementPass::GetMaxLegalIndex(
999 const Instruction* var_inst) const {
1000 assert(var_inst->opcode() == SpvOpVariable &&
1001 "|var_inst| must be a variable instruction.");
1002 Instruction* type = GetStorageType(var_inst);
1003 switch (type->opcode()) {
1004 case SpvOpTypeStruct:
1005 return type->NumInOperands();
1006 case SpvOpTypeArray:
1007 return GetArrayLength(type);
1008 case SpvOpTypeMatrix:
1009 case SpvOpTypeVector:
1010 return GetNumElements(type);
1011 default:
1012 return 0;
1013 }
1014 return 0;
1015 }
1016
1017 } // namespace opt
1018 } // namespace spvtools
1019