1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/scalar_replacement_pass.h"
16
17 #include <algorithm>
18 #include <queue>
19 #include <tuple>
20 #include <utility>
21
22 #include "source/enum_string_mapping.h"
23 #include "source/extensions.h"
24 #include "source/opt/reflect.h"
25 #include "source/opt/types.h"
26 #include "source/util/make_unique.h"
27
28 static const uint32_t kDebugValueOperandValueIndex = 5;
29 static const uint32_t kDebugValueOperandExpressionIndex = 6;
30
31 namespace spvtools {
32 namespace opt {
33
Process()34 Pass::Status ScalarReplacementPass::Process() {
35 Status status = Status::SuccessWithoutChange;
36 for (auto& f : *get_module()) {
37 Status functionStatus = ProcessFunction(&f);
38 if (functionStatus == Status::Failure)
39 return functionStatus;
40 else if (functionStatus == Status::SuccessWithChange)
41 status = functionStatus;
42 }
43
44 return status;
45 }
46
ProcessFunction(Function * function)47 Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
48 std::queue<Instruction*> worklist;
49 BasicBlock& entry = *function->begin();
50 for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
51 // Function storage class OpVariables must appear as the first instructions
52 // of the entry block.
53 if (iter->opcode() != SpvOpVariable) break;
54
55 Instruction* varInst = &*iter;
56 if (CanReplaceVariable(varInst)) {
57 worklist.push(varInst);
58 }
59 }
60
61 Status status = Status::SuccessWithoutChange;
62 while (!worklist.empty()) {
63 Instruction* varInst = worklist.front();
64 worklist.pop();
65
66 Status var_status = ReplaceVariable(varInst, &worklist);
67 if (var_status == Status::Failure)
68 return var_status;
69 else if (var_status == Status::SuccessWithChange)
70 status = var_status;
71 }
72
73 return status;
74 }
75
ReplaceVariable(Instruction * inst,std::queue<Instruction * > * worklist)76 Pass::Status ScalarReplacementPass::ReplaceVariable(
77 Instruction* inst, std::queue<Instruction*>* worklist) {
78 std::vector<Instruction*> replacements;
79 if (!CreateReplacementVariables(inst, &replacements)) {
80 return Status::Failure;
81 }
82
83 std::vector<Instruction*> dead;
84 bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
85 inst, [this, &replacements, &dead](Instruction* user) {
86 if (user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugDeclare) {
87 if (ReplaceWholeDebugDeclare(user, replacements)) {
88 dead.push_back(user);
89 return true;
90 }
91 return false;
92 }
93 if (user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugValue) {
94 if (ReplaceWholeDebugValue(user, replacements)) {
95 dead.push_back(user);
96 return true;
97 }
98 return false;
99 }
100 if (!IsAnnotationInst(user->opcode())) {
101 switch (user->opcode()) {
102 case SpvOpLoad:
103 if (ReplaceWholeLoad(user, replacements)) {
104 dead.push_back(user);
105 } else {
106 return false;
107 }
108 break;
109 case SpvOpStore:
110 if (ReplaceWholeStore(user, replacements)) {
111 dead.push_back(user);
112 } else {
113 return false;
114 }
115 break;
116 case SpvOpAccessChain:
117 case SpvOpInBoundsAccessChain:
118 if (ReplaceAccessChain(user, replacements))
119 dead.push_back(user);
120 else
121 return false;
122 break;
123 case SpvOpName:
124 case SpvOpMemberName:
125 break;
126 default:
127 assert(false && "Unexpected opcode");
128 break;
129 }
130 }
131 return true;
132 });
133
134 if (replaced_all_uses) {
135 dead.push_back(inst);
136 } else {
137 return Status::Failure;
138 }
139
140 // If there are no dead instructions to clean up, return with no changes.
141 if (dead.empty()) return Status::SuccessWithoutChange;
142
143 // Clean up some dead code.
144 while (!dead.empty()) {
145 Instruction* toKill = dead.back();
146 dead.pop_back();
147 context()->KillInst(toKill);
148 }
149
150 // Attempt to further scalarize.
151 for (auto var : replacements) {
152 if (var->opcode() == SpvOpVariable) {
153 if (get_def_use_mgr()->NumUsers(var) == 0) {
154 context()->KillInst(var);
155 } else if (CanReplaceVariable(var)) {
156 worklist->push(var);
157 }
158 }
159 }
160
161 return Status::SuccessWithChange;
162 }
163
ReplaceWholeDebugDeclare(Instruction * dbg_decl,const std::vector<Instruction * > & replacements)164 bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
165 Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
166 // Insert Deref operation to the front of the operation list of |dbg_decl|.
167 Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
168 dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
169 auto* deref_expr =
170 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
171
172 // Add DebugValue instruction with Indexes operand and Deref operation.
173 int32_t idx = 0;
174 for (const auto* var : replacements) {
175 Instruction* added_dbg_value =
176 context()->get_debug_info_mgr()->AddDebugValueForDecl(
177 dbg_decl, /*value_id=*/var->result_id(),
178 /*insert_before=*/var->NextNode(), /*scope_and_line=*/dbg_decl);
179
180 if (added_dbg_value == nullptr) return false;
181 added_dbg_value->AddOperand(
182 {SPV_OPERAND_TYPE_ID,
183 {context()->get_constant_mgr()->GetSIntConst(idx)}});
184 added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
185 {deref_expr->result_id()});
186 if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
187 context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
188 }
189 ++idx;
190 }
191 return true;
192 }
193
ReplaceWholeDebugValue(Instruction * dbg_value,const std::vector<Instruction * > & replacements)194 bool ScalarReplacementPass::ReplaceWholeDebugValue(
195 Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
196 int32_t idx = 0;
197 BasicBlock* block = context()->get_instr_block(dbg_value);
198 for (auto var : replacements) {
199 // Clone the DebugValue.
200 std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
201 uint32_t new_id = TakeNextId();
202 if (new_id == 0) return false;
203 new_dbg_value->SetResultId(new_id);
204 // Update 'Value' operand to the |replacements|.
205 new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
206 // Append 'Indexes' operand.
207 new_dbg_value->AddOperand(
208 {SPV_OPERAND_TYPE_ID,
209 {context()->get_constant_mgr()->GetSIntConst(idx)}});
210 // Insert the new DebugValue to the basic block.
211 auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
212 get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
213 context()->set_instr_block(added_instr, block);
214 ++idx;
215 }
216 return true;
217 }
218
ReplaceWholeLoad(Instruction * load,const std::vector<Instruction * > & replacements)219 bool ScalarReplacementPass::ReplaceWholeLoad(
220 Instruction* load, const std::vector<Instruction*>& replacements) {
221 // Replaces the load of the entire composite with a load from each replacement
222 // variable followed by a composite construction.
223 BasicBlock* block = context()->get_instr_block(load);
224 std::vector<Instruction*> loads;
225 loads.reserve(replacements.size());
226 BasicBlock::iterator where(load);
227 for (auto var : replacements) {
228 // Create a load of each replacement variable.
229 if (var->opcode() != SpvOpVariable) {
230 loads.push_back(var);
231 continue;
232 }
233
234 Instruction* type = GetStorageType(var);
235 uint32_t loadId = TakeNextId();
236 if (loadId == 0) {
237 return false;
238 }
239 std::unique_ptr<Instruction> newLoad(
240 new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
241 std::initializer_list<Operand>{
242 {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
243 // Copy memory access attributes which start at index 1. Index 0 is the
244 // pointer to load.
245 for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
246 Operand copy(load->GetInOperand(i));
247 newLoad->AddOperand(std::move(copy));
248 }
249 where = where.InsertBefore(std::move(newLoad));
250 get_def_use_mgr()->AnalyzeInstDefUse(&*where);
251 context()->set_instr_block(&*where, block);
252 where->UpdateDebugInfoFrom(load);
253 loads.push_back(&*where);
254 }
255
256 // Construct a new composite.
257 uint32_t compositeId = TakeNextId();
258 if (compositeId == 0) {
259 return false;
260 }
261 where = load;
262 std::unique_ptr<Instruction> compositeConstruct(new Instruction(
263 context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
264 for (auto l : loads) {
265 Operand op(SPV_OPERAND_TYPE_ID,
266 std::initializer_list<uint32_t>{l->result_id()});
267 compositeConstruct->AddOperand(std::move(op));
268 }
269 where = where.InsertBefore(std::move(compositeConstruct));
270 get_def_use_mgr()->AnalyzeInstDefUse(&*where);
271 where->UpdateDebugInfoFrom(load);
272 context()->set_instr_block(&*where, block);
273 context()->ReplaceAllUsesWith(load->result_id(), compositeId);
274 return true;
275 }
276
ReplaceWholeStore(Instruction * store,const std::vector<Instruction * > & replacements)277 bool ScalarReplacementPass::ReplaceWholeStore(
278 Instruction* store, const std::vector<Instruction*>& replacements) {
279 // Replaces a store to the whole composite with a series of extract and stores
280 // to each element.
281 uint32_t storeInput = store->GetSingleWordInOperand(1u);
282 BasicBlock* block = context()->get_instr_block(store);
283 BasicBlock::iterator where(store);
284 uint32_t elementIndex = 0;
285 for (auto var : replacements) {
286 // Create the extract.
287 if (var->opcode() != SpvOpVariable) {
288 elementIndex++;
289 continue;
290 }
291
292 Instruction* type = GetStorageType(var);
293 uint32_t extractId = TakeNextId();
294 if (extractId == 0) {
295 return false;
296 }
297 std::unique_ptr<Instruction> extract(new Instruction(
298 context(), SpvOpCompositeExtract, type->result_id(), extractId,
299 std::initializer_list<Operand>{
300 {SPV_OPERAND_TYPE_ID, {storeInput}},
301 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
302 auto iter = where.InsertBefore(std::move(extract));
303 iter->UpdateDebugInfoFrom(store);
304 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
305 context()->set_instr_block(&*iter, block);
306
307 // Create the store.
308 std::unique_ptr<Instruction> newStore(
309 new Instruction(context(), SpvOpStore, 0, 0,
310 std::initializer_list<Operand>{
311 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
312 {SPV_OPERAND_TYPE_ID, {extractId}}}));
313 // Copy memory access attributes which start at index 2. Index 0 is the
314 // pointer and index 1 is the data.
315 for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
316 Operand copy(store->GetInOperand(i));
317 newStore->AddOperand(std::move(copy));
318 }
319 iter = where.InsertBefore(std::move(newStore));
320 iter->UpdateDebugInfoFrom(store);
321 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
322 context()->set_instr_block(&*iter, block);
323 }
324 return true;
325 }
326
ReplaceAccessChain(Instruction * chain,const std::vector<Instruction * > & replacements)327 bool ScalarReplacementPass::ReplaceAccessChain(
328 Instruction* chain, const std::vector<Instruction*>& replacements) {
329 // Replaces the access chain with either another access chain (with one fewer
330 // indexes) or a direct use of the replacement variable.
331 uint32_t indexId = chain->GetSingleWordInOperand(1u);
332 const Instruction* index = get_def_use_mgr()->GetDef(indexId);
333 int64_t indexValue = context()
334 ->get_constant_mgr()
335 ->GetConstantFromInst(index)
336 ->GetSignExtendedValue();
337 if (indexValue < 0 ||
338 indexValue >= static_cast<int64_t>(replacements.size())) {
339 // Out of bounds access, this is illegal IR. Notice that OpAccessChain
340 // indexing is 0-based, so we should also reject index == size-of-array.
341 return false;
342 } else {
343 const Instruction* var = replacements[static_cast<size_t>(indexValue)];
344 if (chain->NumInOperands() > 2) {
345 // Replace input access chain with another access chain.
346 BasicBlock::iterator chainIter(chain);
347 uint32_t replacementId = TakeNextId();
348 if (replacementId == 0) {
349 return false;
350 }
351 std::unique_ptr<Instruction> replacementChain(new Instruction(
352 context(), chain->opcode(), chain->type_id(), replacementId,
353 std::initializer_list<Operand>{
354 {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
355 // Add the remaining indexes.
356 for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
357 Operand copy(chain->GetInOperand(i));
358 replacementChain->AddOperand(std::move(copy));
359 }
360 replacementChain->UpdateDebugInfoFrom(chain);
361 auto iter = chainIter.InsertBefore(std::move(replacementChain));
362 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
363 context()->set_instr_block(&*iter, context()->get_instr_block(chain));
364 context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
365 } else {
366 // Replace with a use of the variable.
367 context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
368 }
369 }
370
371 return true;
372 }
373
CreateReplacementVariables(Instruction * inst,std::vector<Instruction * > * replacements)374 bool ScalarReplacementPass::CreateReplacementVariables(
375 Instruction* inst, std::vector<Instruction*>* replacements) {
376 Instruction* type = GetStorageType(inst);
377
378 std::unique_ptr<std::unordered_set<int64_t>> components_used =
379 GetUsedComponents(inst);
380
381 uint32_t elem = 0;
382 switch (type->opcode()) {
383 case SpvOpTypeStruct:
384 type->ForEachInOperand(
385 [this, inst, &elem, replacements, &components_used](uint32_t* id) {
386 if (!components_used || components_used->count(elem)) {
387 CreateVariable(*id, inst, elem, replacements);
388 } else {
389 replacements->push_back(CreateNullConstant(*id));
390 }
391 elem++;
392 });
393 break;
394 case SpvOpTypeArray:
395 for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
396 if (!components_used || components_used->count(i)) {
397 CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
398 replacements);
399 } else {
400 replacements->push_back(
401 CreateNullConstant(type->GetSingleWordInOperand(0u)));
402 }
403 }
404 break;
405
406 case SpvOpTypeMatrix:
407 case SpvOpTypeVector:
408 for (uint32_t i = 0; i != GetNumElements(type); ++i) {
409 CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
410 }
411 break;
412
413 default:
414 assert(false && "Unexpected type.");
415 break;
416 }
417
418 TransferAnnotations(inst, replacements);
419 return std::find(replacements->begin(), replacements->end(), nullptr) ==
420 replacements->end();
421 }
422
TransferAnnotations(const Instruction * source,std::vector<Instruction * > * replacements)423 void ScalarReplacementPass::TransferAnnotations(
424 const Instruction* source, std::vector<Instruction*>* replacements) {
425 // Only transfer invariant and restrict decorations on the variable. There are
426 // no type or member decorations that are necessary to transfer.
427 for (auto inst :
428 get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
429 assert(inst->opcode() == SpvOpDecorate);
430 uint32_t decoration = inst->GetSingleWordInOperand(1u);
431 if (decoration == SpvDecorationInvariant ||
432 decoration == SpvDecorationRestrict) {
433 for (auto var : *replacements) {
434 if (var == nullptr) {
435 continue;
436 }
437
438 std::unique_ptr<Instruction> annotation(
439 new Instruction(context(), SpvOpDecorate, 0, 0,
440 std::initializer_list<Operand>{
441 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
442 {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
443 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
444 Operand copy(inst->GetInOperand(i));
445 annotation->AddOperand(std::move(copy));
446 }
447 context()->AddAnnotationInst(std::move(annotation));
448 get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
449 }
450 }
451 }
452 }
453
CreateVariable(uint32_t typeId,Instruction * varInst,uint32_t index,std::vector<Instruction * > * replacements)454 void ScalarReplacementPass::CreateVariable(
455 uint32_t typeId, Instruction* varInst, uint32_t index,
456 std::vector<Instruction*>* replacements) {
457 uint32_t ptrId = GetOrCreatePointerType(typeId);
458 uint32_t id = TakeNextId();
459
460 if (id == 0) {
461 replacements->push_back(nullptr);
462 }
463
464 std::unique_ptr<Instruction> variable(new Instruction(
465 context(), SpvOpVariable, ptrId, id,
466 std::initializer_list<Operand>{
467 {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
468
469 BasicBlock* block = context()->get_instr_block(varInst);
470 block->begin().InsertBefore(std::move(variable));
471 Instruction* inst = &*block->begin();
472
473 // If varInst was initialized, make sure to initialize its replacement.
474 GetOrCreateInitialValue(varInst, index, inst);
475 get_def_use_mgr()->AnalyzeInstDefUse(inst);
476 context()->set_instr_block(inst, block);
477
478 // Copy decorations from the member to the new variable.
479 Instruction* typeInst = GetStorageType(varInst);
480 for (auto dec_inst :
481 get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
482 uint32_t decoration;
483 if (dec_inst->opcode() != SpvOpMemberDecorate) {
484 continue;
485 }
486
487 if (dec_inst->GetSingleWordInOperand(1) != index) {
488 continue;
489 }
490
491 decoration = dec_inst->GetSingleWordInOperand(2u);
492 switch (decoration) {
493 case SpvDecorationRelaxedPrecision: {
494 std::unique_ptr<Instruction> new_dec_inst(
495 new Instruction(context(), SpvOpDecorate, 0, 0, {}));
496 new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
497 for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
498 new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
499 }
500 context()->AddAnnotationInst(std::move(new_dec_inst));
501 } break;
502 default:
503 break;
504 }
505 }
506
507 // Update the OpenCL.DebugInfo.100 debug information.
508 inst->UpdateDebugInfoFrom(varInst);
509
510 replacements->push_back(inst);
511 }
512
GetOrCreatePointerType(uint32_t id)513 uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
514 auto iter = pointee_to_pointer_.find(id);
515 if (iter != pointee_to_pointer_.end()) return iter->second;
516
517 analysis::Type* pointeeTy;
518 std::unique_ptr<analysis::Pointer> pointerTy;
519 std::tie(pointeeTy, pointerTy) =
520 context()->get_type_mgr()->GetTypeAndPointerType(id,
521 SpvStorageClassFunction);
522 uint32_t ptrId = 0;
523 if (pointeeTy->IsUniqueType()) {
524 // Non-ambiguous type, just ask the type manager for an id.
525 ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
526 pointee_to_pointer_[id] = ptrId;
527 return ptrId;
528 }
529
530 // Ambiguous type. We must perform a linear search to try and find the right
531 // type.
532 for (auto global : context()->types_values()) {
533 if (global.opcode() == SpvOpTypePointer &&
534 global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
535 global.GetSingleWordInOperand(1u) == id) {
536 if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
537 // Only reuse a decoration-less pointer of the correct type.
538 ptrId = global.result_id();
539 break;
540 }
541 }
542 }
543
544 if (ptrId != 0) {
545 pointee_to_pointer_[id] = ptrId;
546 return ptrId;
547 }
548
549 ptrId = TakeNextId();
550 context()->AddType(MakeUnique<Instruction>(
551 context(), SpvOpTypePointer, 0, ptrId,
552 std::initializer_list<Operand>{
553 {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
554 {SPV_OPERAND_TYPE_ID, {id}}}));
555 Instruction* ptr = &*--context()->types_values_end();
556 get_def_use_mgr()->AnalyzeInstDefUse(ptr);
557 pointee_to_pointer_[id] = ptrId;
558 // Register with the type manager if necessary.
559 context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
560
561 return ptrId;
562 }
563
GetOrCreateInitialValue(Instruction * source,uint32_t index,Instruction * newVar)564 void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
565 uint32_t index,
566 Instruction* newVar) {
567 assert(source->opcode() == SpvOpVariable);
568 if (source->NumInOperands() < 2) return;
569
570 uint32_t initId = source->GetSingleWordInOperand(1u);
571 uint32_t storageId = GetStorageType(newVar)->result_id();
572 Instruction* init = get_def_use_mgr()->GetDef(initId);
573 uint32_t newInitId = 0;
574 // TODO(dnovillo): Refactor this with constant propagation.
575 if (init->opcode() == SpvOpConstantNull) {
576 // Initialize to appropriate NULL.
577 auto iter = type_to_null_.find(storageId);
578 if (iter == type_to_null_.end()) {
579 newInitId = TakeNextId();
580 type_to_null_[storageId] = newInitId;
581 context()->AddGlobalValue(
582 MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
583 newInitId, std::initializer_list<Operand>{}));
584 Instruction* newNull = &*--context()->types_values_end();
585 get_def_use_mgr()->AnalyzeInstDefUse(newNull);
586 } else {
587 newInitId = iter->second;
588 }
589 } else if (IsSpecConstantInst(init->opcode())) {
590 // Create a new constant extract.
591 newInitId = TakeNextId();
592 context()->AddGlobalValue(MakeUnique<Instruction>(
593 context(), SpvOpSpecConstantOp, storageId, newInitId,
594 std::initializer_list<Operand>{
595 {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
596 {SPV_OPERAND_TYPE_ID, {init->result_id()}},
597 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
598 Instruction* newSpecConst = &*--context()->types_values_end();
599 get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
600 } else if (init->opcode() == SpvOpConstantComposite) {
601 // Get the appropriate index constant.
602 newInitId = init->GetSingleWordInOperand(index);
603 Instruction* element = get_def_use_mgr()->GetDef(newInitId);
604 if (element->opcode() == SpvOpUndef) {
605 // Undef is not a valid initializer for a variable.
606 newInitId = 0;
607 }
608 } else {
609 assert(false);
610 }
611
612 if (newInitId != 0) {
613 newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
614 }
615 }
616
GetArrayLength(const Instruction * arrayType) const617 uint64_t ScalarReplacementPass::GetArrayLength(
618 const Instruction* arrayType) const {
619 assert(arrayType->opcode() == SpvOpTypeArray);
620 const Instruction* length =
621 get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
622 return context()
623 ->get_constant_mgr()
624 ->GetConstantFromInst(length)
625 ->GetZeroExtendedValue();
626 }
627
GetNumElements(const Instruction * type) const628 uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
629 assert(type->opcode() == SpvOpTypeVector ||
630 type->opcode() == SpvOpTypeMatrix);
631 const Operand& op = type->GetInOperand(1u);
632 assert(op.words.size() <= 2);
633 uint64_t len = 0;
634 for (size_t i = 0; i != op.words.size(); ++i) {
635 len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
636 }
637 return len;
638 }
639
IsSpecConstant(uint32_t id) const640 bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
641 const Instruction* inst = get_def_use_mgr()->GetDef(id);
642 assert(inst);
643 return spvOpcodeIsSpecConstant(inst->opcode());
644 }
645
GetStorageType(const Instruction * inst) const646 Instruction* ScalarReplacementPass::GetStorageType(
647 const Instruction* inst) const {
648 assert(inst->opcode() == SpvOpVariable);
649
650 uint32_t ptrTypeId = inst->type_id();
651 uint32_t typeId =
652 get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
653 return get_def_use_mgr()->GetDef(typeId);
654 }
655
CanReplaceVariable(const Instruction * varInst) const656 bool ScalarReplacementPass::CanReplaceVariable(
657 const Instruction* varInst) const {
658 assert(varInst->opcode() == SpvOpVariable);
659
660 // Can only replace function scope variables.
661 if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
662 return false;
663 }
664
665 if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
666 return false;
667 }
668
669 const Instruction* typeInst = GetStorageType(varInst);
670 if (!CheckType(typeInst)) {
671 return false;
672 }
673
674 if (!CheckAnnotations(varInst)) {
675 return false;
676 }
677
678 if (!CheckUses(varInst)) {
679 return false;
680 }
681
682 return true;
683 }
684
CheckType(const Instruction * typeInst) const685 bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
686 if (!CheckTypeAnnotations(typeInst)) {
687 return false;
688 }
689
690 switch (typeInst->opcode()) {
691 case SpvOpTypeStruct:
692 // Don't bother with empty structs or very large structs.
693 if (typeInst->NumInOperands() == 0 ||
694 IsLargerThanSizeLimit(typeInst->NumInOperands())) {
695 return false;
696 }
697 return true;
698 case SpvOpTypeArray:
699 if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
700 return false;
701 }
702 if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
703 return false;
704 }
705 return true;
706 // TODO(alanbaker): Develop some heuristics for when this should be
707 // re-enabled.
708 //// Specifically including matrix and vector in an attempt to reduce the
709 //// number of vector registers required.
710 // case SpvOpTypeMatrix:
711 // case SpvOpTypeVector:
712 // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
713 // return true;
714
715 case SpvOpTypeRuntimeArray:
716 default:
717 return false;
718 }
719 }
720
CheckTypeAnnotations(const Instruction * typeInst) const721 bool ScalarReplacementPass::CheckTypeAnnotations(
722 const Instruction* typeInst) const {
723 for (auto inst :
724 get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
725 uint32_t decoration;
726 if (inst->opcode() == SpvOpDecorate) {
727 decoration = inst->GetSingleWordInOperand(1u);
728 } else {
729 assert(inst->opcode() == SpvOpMemberDecorate);
730 decoration = inst->GetSingleWordInOperand(2u);
731 }
732
733 switch (decoration) {
734 case SpvDecorationRowMajor:
735 case SpvDecorationColMajor:
736 case SpvDecorationArrayStride:
737 case SpvDecorationMatrixStride:
738 case SpvDecorationCPacked:
739 case SpvDecorationInvariant:
740 case SpvDecorationRestrict:
741 case SpvDecorationOffset:
742 case SpvDecorationAlignment:
743 case SpvDecorationAlignmentId:
744 case SpvDecorationMaxByteOffset:
745 case SpvDecorationRelaxedPrecision:
746 break;
747 default:
748 return false;
749 }
750 }
751
752 return true;
753 }
754
CheckAnnotations(const Instruction * varInst) const755 bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
756 for (auto inst :
757 get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
758 assert(inst->opcode() == SpvOpDecorate);
759 uint32_t decoration = inst->GetSingleWordInOperand(1u);
760 switch (decoration) {
761 case SpvDecorationInvariant:
762 case SpvDecorationRestrict:
763 case SpvDecorationAlignment:
764 case SpvDecorationAlignmentId:
765 case SpvDecorationMaxByteOffset:
766 break;
767 default:
768 return false;
769 }
770 }
771
772 return true;
773 }
774
CheckUses(const Instruction * inst) const775 bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
776 VariableStats stats = {0, 0};
777 bool ok = CheckUses(inst, &stats);
778
779 // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
780 // SRoA is costly, such as when the structure has many (unaccessed?)
781 // members.
782
783 return ok;
784 }
785
CheckUses(const Instruction * inst,VariableStats * stats) const786 bool ScalarReplacementPass::CheckUses(const Instruction* inst,
787 VariableStats* stats) const {
788 uint64_t max_legal_index = GetMaxLegalIndex(inst);
789
790 bool ok = true;
791 get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
792 const Instruction* user,
793 uint32_t index) {
794 if (user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugDeclare ||
795 user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugValue) {
796 // TODO: include num_partial_accesses if it uses Fragment operation or
797 // DebugValue has Indexes operand.
798 stats->num_full_accesses++;
799 return;
800 }
801
802 // Annotations are check as a group separately.
803 if (!IsAnnotationInst(user->opcode())) {
804 switch (user->opcode()) {
805 case SpvOpAccessChain:
806 case SpvOpInBoundsAccessChain:
807 if (index == 2u && user->NumInOperands() > 1) {
808 uint32_t id = user->GetSingleWordInOperand(1u);
809 const Instruction* opInst = get_def_use_mgr()->GetDef(id);
810 const auto* constant =
811 context()->get_constant_mgr()->GetConstantFromInst(opInst);
812 if (!constant) {
813 ok = false;
814 } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
815 ok = false;
816 } else {
817 if (!CheckUsesRelaxed(user)) ok = false;
818 }
819 stats->num_partial_accesses++;
820 } else {
821 ok = false;
822 }
823 break;
824 case SpvOpLoad:
825 if (!CheckLoad(user, index)) ok = false;
826 stats->num_full_accesses++;
827 break;
828 case SpvOpStore:
829 if (!CheckStore(user, index)) ok = false;
830 stats->num_full_accesses++;
831 break;
832 case SpvOpName:
833 case SpvOpMemberName:
834 break;
835 default:
836 ok = false;
837 break;
838 }
839 }
840 });
841
842 return ok;
843 }
844
CheckUsesRelaxed(const Instruction * inst) const845 bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
846 bool ok = true;
847 get_def_use_mgr()->ForEachUse(
848 inst, [this, &ok](const Instruction* user, uint32_t index) {
849 switch (user->opcode()) {
850 case SpvOpAccessChain:
851 case SpvOpInBoundsAccessChain:
852 if (index != 2u) {
853 ok = false;
854 } else {
855 if (!CheckUsesRelaxed(user)) ok = false;
856 }
857 break;
858 case SpvOpLoad:
859 if (!CheckLoad(user, index)) ok = false;
860 break;
861 case SpvOpStore:
862 if (!CheckStore(user, index)) ok = false;
863 break;
864 case SpvOpImageTexelPointer:
865 if (!CheckImageTexelPointer(index)) ok = false;
866 break;
867 default:
868 ok = false;
869 break;
870 }
871 });
872
873 return ok;
874 }
875
CheckImageTexelPointer(uint32_t index) const876 bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
877 return index == 2u;
878 }
879
CheckLoad(const Instruction * inst,uint32_t index) const880 bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
881 uint32_t index) const {
882 if (index != 2u) return false;
883 if (inst->NumInOperands() >= 2 &&
884 inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
885 return false;
886 return true;
887 }
888
CheckStore(const Instruction * inst,uint32_t index) const889 bool ScalarReplacementPass::CheckStore(const Instruction* inst,
890 uint32_t index) const {
891 if (index != 0u) return false;
892 if (inst->NumInOperands() >= 3 &&
893 inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
894 return false;
895 return true;
896 }
IsLargerThanSizeLimit(uint64_t length) const897 bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
898 if (max_num_elements_ == 0) {
899 return false;
900 }
901 return length > max_num_elements_;
902 }
903
904 std::unique_ptr<std::unordered_set<int64_t>>
GetUsedComponents(Instruction * inst)905 ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
906 std::unique_ptr<std::unordered_set<int64_t>> result(
907 new std::unordered_set<int64_t>());
908
909 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
910
911 def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
912 this](Instruction* use) {
913 switch (use->opcode()) {
914 case SpvOpLoad: {
915 // Look for extract from the load.
916 std::vector<uint32_t> t;
917 if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
918 if (use2->opcode() != SpvOpCompositeExtract ||
919 use2->NumInOperands() <= 1) {
920 return false;
921 }
922 t.push_back(use2->GetSingleWordInOperand(1));
923 return true;
924 })) {
925 result->insert(t.begin(), t.end());
926 return true;
927 } else {
928 result.reset(nullptr);
929 return false;
930 }
931 }
932 case SpvOpName:
933 case SpvOpMemberName:
934 case SpvOpStore:
935 // No components are used.
936 return true;
937 case SpvOpAccessChain:
938 case SpvOpInBoundsAccessChain: {
939 // Add the first index it if is a constant.
940 // TODO: Could be improved by checking if the address is used in a load.
941 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
942 uint32_t index_id = use->GetSingleWordInOperand(1);
943 const analysis::Constant* index_const =
944 const_mgr->FindDeclaredConstant(index_id);
945 if (index_const) {
946 result->insert(index_const->GetSignExtendedValue());
947 return true;
948 } else {
949 // Could be any element. Assuming all are used.
950 result.reset(nullptr);
951 return false;
952 }
953 }
954 default:
955 // We do not know what is happening. Have to assume the worst.
956 result.reset(nullptr);
957 return false;
958 }
959 });
960
961 return result;
962 }
963
CreateNullConstant(uint32_t type_id)964 Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
965 analysis::TypeManager* type_mgr = context()->get_type_mgr();
966 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
967
968 const analysis::Type* type = type_mgr->GetType(type_id);
969 const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
970 Instruction* null_inst =
971 const_mgr->GetDefiningInstruction(null_const, type_id);
972 if (null_inst != nullptr) {
973 context()->UpdateDefUse(null_inst);
974 }
975 return null_inst;
976 }
977
GetMaxLegalIndex(const Instruction * var_inst) const978 uint64_t ScalarReplacementPass::GetMaxLegalIndex(
979 const Instruction* var_inst) const {
980 assert(var_inst->opcode() == SpvOpVariable &&
981 "|var_inst| must be a variable instruction.");
982 Instruction* type = GetStorageType(var_inst);
983 switch (type->opcode()) {
984 case SpvOpTypeStruct:
985 return type->NumInOperands();
986 case SpvOpTypeArray:
987 return GetArrayLength(type);
988 case SpvOpTypeMatrix:
989 case SpvOpTypeVector:
990 return GetNumElements(type);
991 default:
992 return 0;
993 }
994 return 0;
995 }
996
997 } // namespace opt
998 } // namespace spvtools
999