• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/scalar_replacement_pass.h"
16 
17 #include <algorithm>
18 #include <queue>
19 #include <tuple>
20 #include <utility>
21 
22 #include "source/enum_string_mapping.h"
23 #include "source/extensions.h"
24 #include "source/opt/reflect.h"
25 #include "source/opt/types.h"
26 #include "source/util/make_unique.h"
27 
28 namespace spvtools {
29 namespace opt {
30 
Process()31 Pass::Status ScalarReplacementPass::Process() {
32   Status status = Status::SuccessWithoutChange;
33   for (auto& f : *get_module()) {
34     Status functionStatus = ProcessFunction(&f);
35     if (functionStatus == Status::Failure)
36       return functionStatus;
37     else if (functionStatus == Status::SuccessWithChange)
38       status = functionStatus;
39   }
40 
41   return status;
42 }
43 
ProcessFunction(Function * function)44 Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
45   std::queue<Instruction*> worklist;
46   BasicBlock& entry = *function->begin();
47   for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
48     // Function storage class OpVariables must appear as the first instructions
49     // of the entry block.
50     if (iter->opcode() != SpvOpVariable) break;
51 
52     Instruction* varInst = &*iter;
53     if (CanReplaceVariable(varInst)) {
54       worklist.push(varInst);
55     }
56   }
57 
58   Status status = Status::SuccessWithoutChange;
59   while (!worklist.empty()) {
60     Instruction* varInst = worklist.front();
61     worklist.pop();
62 
63     Status var_status = ReplaceVariable(varInst, &worklist);
64     if (var_status == Status::Failure)
65       return var_status;
66     else if (var_status == Status::SuccessWithChange)
67       status = var_status;
68   }
69 
70   return status;
71 }
72 
ReplaceVariable(Instruction * inst,std::queue<Instruction * > * worklist)73 Pass::Status ScalarReplacementPass::ReplaceVariable(
74     Instruction* inst, std::queue<Instruction*>* worklist) {
75   std::vector<Instruction*> replacements;
76   if (!CreateReplacementVariables(inst, &replacements)) {
77     return Status::Failure;
78   }
79 
80   std::vector<Instruction*> dead;
81   bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
82       inst, [this, &replacements, &dead](Instruction* user) {
83         if (!IsAnnotationInst(user->opcode())) {
84           switch (user->opcode()) {
85             case SpvOpLoad:
86               if (ReplaceWholeLoad(user, replacements)) {
87                 dead.push_back(user);
88               } else {
89                 return false;
90               }
91               break;
92             case SpvOpStore:
93               if (ReplaceWholeStore(user, replacements)) {
94                 dead.push_back(user);
95               } else {
96                 return false;
97               }
98               break;
99             case SpvOpAccessChain:
100             case SpvOpInBoundsAccessChain:
101               if (ReplaceAccessChain(user, replacements))
102                 dead.push_back(user);
103               else
104                 return false;
105               break;
106             case SpvOpName:
107             case SpvOpMemberName:
108               break;
109             default:
110               assert(false && "Unexpected opcode");
111               break;
112           }
113         }
114         return true;
115       });
116 
117   if (replaced_all_uses) {
118     dead.push_back(inst);
119   } else {
120     return Status::Failure;
121   }
122 
123   // If there are no dead instructions to clean up, return with no changes.
124   if (dead.empty()) return Status::SuccessWithoutChange;
125 
126   // Clean up some dead code.
127   while (!dead.empty()) {
128     Instruction* toKill = dead.back();
129     dead.pop_back();
130     context()->KillInst(toKill);
131   }
132 
133   // Attempt to further scalarize.
134   for (auto var : replacements) {
135     if (var->opcode() == SpvOpVariable) {
136       if (get_def_use_mgr()->NumUsers(var) == 0) {
137         context()->KillInst(var);
138       } else if (CanReplaceVariable(var)) {
139         worklist->push(var);
140       }
141     }
142   }
143 
144   return Status::SuccessWithChange;
145 }
146 
ReplaceWholeLoad(Instruction * load,const std::vector<Instruction * > & replacements)147 bool ScalarReplacementPass::ReplaceWholeLoad(
148     Instruction* load, const std::vector<Instruction*>& replacements) {
149   // Replaces the load of the entire composite with a load from each replacement
150   // variable followed by a composite construction.
151   BasicBlock* block = context()->get_instr_block(load);
152   std::vector<Instruction*> loads;
153   loads.reserve(replacements.size());
154   BasicBlock::iterator where(load);
155   for (auto var : replacements) {
156     // Create a load of each replacement variable.
157     if (var->opcode() != SpvOpVariable) {
158       loads.push_back(var);
159       continue;
160     }
161 
162     Instruction* type = GetStorageType(var);
163     uint32_t loadId = TakeNextId();
164     if (loadId == 0) {
165       return false;
166     }
167     std::unique_ptr<Instruction> newLoad(
168         new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
169                         std::initializer_list<Operand>{
170                             {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
171     // Copy memory access attributes which start at index 1. Index 0 is the
172     // pointer to load.
173     for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
174       Operand copy(load->GetInOperand(i));
175       newLoad->AddOperand(std::move(copy));
176     }
177     where = where.InsertBefore(std::move(newLoad));
178     get_def_use_mgr()->AnalyzeInstDefUse(&*where);
179     context()->set_instr_block(&*where, block);
180     loads.push_back(&*where);
181   }
182 
183   // Construct a new composite.
184   uint32_t compositeId = TakeNextId();
185   if (compositeId == 0) {
186     return false;
187   }
188   where = load;
189   std::unique_ptr<Instruction> compositeConstruct(new Instruction(
190       context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
191   for (auto l : loads) {
192     Operand op(SPV_OPERAND_TYPE_ID,
193                std::initializer_list<uint32_t>{l->result_id()});
194     compositeConstruct->AddOperand(std::move(op));
195   }
196   where = where.InsertBefore(std::move(compositeConstruct));
197   get_def_use_mgr()->AnalyzeInstDefUse(&*where);
198   context()->set_instr_block(&*where, block);
199   context()->ReplaceAllUsesWith(load->result_id(), compositeId);
200   return true;
201 }
202 
ReplaceWholeStore(Instruction * store,const std::vector<Instruction * > & replacements)203 bool ScalarReplacementPass::ReplaceWholeStore(
204     Instruction* store, const std::vector<Instruction*>& replacements) {
205   // Replaces a store to the whole composite with a series of extract and stores
206   // to each element.
207   uint32_t storeInput = store->GetSingleWordInOperand(1u);
208   BasicBlock* block = context()->get_instr_block(store);
209   BasicBlock::iterator where(store);
210   uint32_t elementIndex = 0;
211   for (auto var : replacements) {
212     // Create the extract.
213     if (var->opcode() != SpvOpVariable) {
214       elementIndex++;
215       continue;
216     }
217 
218     Instruction* type = GetStorageType(var);
219     uint32_t extractId = TakeNextId();
220     if (extractId == 0) {
221       return false;
222     }
223     std::unique_ptr<Instruction> extract(new Instruction(
224         context(), SpvOpCompositeExtract, type->result_id(), extractId,
225         std::initializer_list<Operand>{
226             {SPV_OPERAND_TYPE_ID, {storeInput}},
227             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
228     auto iter = where.InsertBefore(std::move(extract));
229     get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
230     context()->set_instr_block(&*iter, block);
231 
232     // Create the store.
233     std::unique_ptr<Instruction> newStore(
234         new Instruction(context(), SpvOpStore, 0, 0,
235                         std::initializer_list<Operand>{
236                             {SPV_OPERAND_TYPE_ID, {var->result_id()}},
237                             {SPV_OPERAND_TYPE_ID, {extractId}}}));
238     // Copy memory access attributes which start at index 2. Index 0 is the
239     // pointer and index 1 is the data.
240     for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
241       Operand copy(store->GetInOperand(i));
242       newStore->AddOperand(std::move(copy));
243     }
244     iter = where.InsertBefore(std::move(newStore));
245     get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
246     context()->set_instr_block(&*iter, block);
247   }
248   return true;
249 }
250 
ReplaceAccessChain(Instruction * chain,const std::vector<Instruction * > & replacements)251 bool ScalarReplacementPass::ReplaceAccessChain(
252     Instruction* chain, const std::vector<Instruction*>& replacements) {
253   // Replaces the access chain with either another access chain (with one fewer
254   // indexes) or a direct use of the replacement variable.
255   uint32_t indexId = chain->GetSingleWordInOperand(1u);
256   const Instruction* index = get_def_use_mgr()->GetDef(indexId);
257   int64_t indexValue = context()
258                            ->get_constant_mgr()
259                            ->GetConstantFromInst(index)
260                            ->GetSignExtendedValue();
261   if (indexValue < 0 ||
262       indexValue >= static_cast<int64_t>(replacements.size())) {
263     // Out of bounds access, this is illegal IR.  Notice that OpAccessChain
264     // indexing is 0-based, so we should also reject index == size-of-array.
265     return false;
266   } else {
267     const Instruction* var = replacements[static_cast<size_t>(indexValue)];
268     if (chain->NumInOperands() > 2) {
269       // Replace input access chain with another access chain.
270       BasicBlock::iterator chainIter(chain);
271       uint32_t replacementId = TakeNextId();
272       if (replacementId == 0) {
273         return false;
274       }
275       std::unique_ptr<Instruction> replacementChain(new Instruction(
276           context(), chain->opcode(), chain->type_id(), replacementId,
277           std::initializer_list<Operand>{
278               {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
279       // Add the remaining indexes.
280       for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
281         Operand copy(chain->GetInOperand(i));
282         replacementChain->AddOperand(std::move(copy));
283       }
284       auto iter = chainIter.InsertBefore(std::move(replacementChain));
285       get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
286       context()->set_instr_block(&*iter, context()->get_instr_block(chain));
287       context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
288     } else {
289       // Replace with a use of the variable.
290       context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
291     }
292   }
293 
294   return true;
295 }
296 
CreateReplacementVariables(Instruction * inst,std::vector<Instruction * > * replacements)297 bool ScalarReplacementPass::CreateReplacementVariables(
298     Instruction* inst, std::vector<Instruction*>* replacements) {
299   Instruction* type = GetStorageType(inst);
300 
301   std::unique_ptr<std::unordered_set<int64_t>> components_used =
302       GetUsedComponents(inst);
303 
304   uint32_t elem = 0;
305   switch (type->opcode()) {
306     case SpvOpTypeStruct:
307       type->ForEachInOperand(
308           [this, inst, &elem, replacements, &components_used](uint32_t* id) {
309             if (!components_used || components_used->count(elem)) {
310               CreateVariable(*id, inst, elem, replacements);
311             } else {
312               replacements->push_back(CreateNullConstant(*id));
313             }
314             elem++;
315           });
316       break;
317     case SpvOpTypeArray:
318       for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
319         if (!components_used || components_used->count(i)) {
320           CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
321                          replacements);
322         } else {
323           replacements->push_back(
324               CreateNullConstant(type->GetSingleWordInOperand(0u)));
325         }
326       }
327       break;
328 
329     case SpvOpTypeMatrix:
330     case SpvOpTypeVector:
331       for (uint32_t i = 0; i != GetNumElements(type); ++i) {
332         CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
333       }
334       break;
335 
336     default:
337       assert(false && "Unexpected type.");
338       break;
339   }
340 
341   TransferAnnotations(inst, replacements);
342   return std::find(replacements->begin(), replacements->end(), nullptr) ==
343          replacements->end();
344 }
345 
TransferAnnotations(const Instruction * source,std::vector<Instruction * > * replacements)346 void ScalarReplacementPass::TransferAnnotations(
347     const Instruction* source, std::vector<Instruction*>* replacements) {
348   // Only transfer invariant and restrict decorations on the variable. There are
349   // no type or member decorations that are necessary to transfer.
350   for (auto inst :
351        get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
352     assert(inst->opcode() == SpvOpDecorate);
353     uint32_t decoration = inst->GetSingleWordInOperand(1u);
354     if (decoration == SpvDecorationInvariant ||
355         decoration == SpvDecorationRestrict) {
356       for (auto var : *replacements) {
357         if (var == nullptr) {
358           continue;
359         }
360 
361         std::unique_ptr<Instruction> annotation(
362             new Instruction(context(), SpvOpDecorate, 0, 0,
363                             std::initializer_list<Operand>{
364                                 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
365                                 {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
366         for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
367           Operand copy(inst->GetInOperand(i));
368           annotation->AddOperand(std::move(copy));
369         }
370         context()->AddAnnotationInst(std::move(annotation));
371         get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
372       }
373     }
374   }
375 }
376 
CreateVariable(uint32_t typeId,Instruction * varInst,uint32_t index,std::vector<Instruction * > * replacements)377 void ScalarReplacementPass::CreateVariable(
378     uint32_t typeId, Instruction* varInst, uint32_t index,
379     std::vector<Instruction*>* replacements) {
380   uint32_t ptrId = GetOrCreatePointerType(typeId);
381   uint32_t id = TakeNextId();
382 
383   if (id == 0) {
384     replacements->push_back(nullptr);
385   }
386 
387   std::unique_ptr<Instruction> variable(new Instruction(
388       context(), SpvOpVariable, ptrId, id,
389       std::initializer_list<Operand>{
390           {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
391 
392   BasicBlock* block = context()->get_instr_block(varInst);
393   block->begin().InsertBefore(std::move(variable));
394   Instruction* inst = &*block->begin();
395 
396   // If varInst was initialized, make sure to initialize its replacement.
397   GetOrCreateInitialValue(varInst, index, inst);
398   get_def_use_mgr()->AnalyzeInstDefUse(inst);
399   context()->set_instr_block(inst, block);
400 
401   // Copy decorations from the member to the new variable.
402   Instruction* typeInst = GetStorageType(varInst);
403   for (auto dec_inst :
404        get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
405     uint32_t decoration;
406     if (dec_inst->opcode() != SpvOpMemberDecorate) {
407       continue;
408     }
409 
410     if (dec_inst->GetSingleWordInOperand(1) != index) {
411       continue;
412     }
413 
414     decoration = dec_inst->GetSingleWordInOperand(2u);
415     switch (decoration) {
416       case SpvDecorationRelaxedPrecision: {
417         std::unique_ptr<Instruction> new_dec_inst(
418             new Instruction(context(), SpvOpDecorate, 0, 0, {}));
419         new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
420         for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
421           new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
422         }
423         context()->AddAnnotationInst(std::move(new_dec_inst));
424       } break;
425       default:
426         break;
427     }
428   }
429 
430   replacements->push_back(inst);
431 }
432 
GetOrCreatePointerType(uint32_t id)433 uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
434   auto iter = pointee_to_pointer_.find(id);
435   if (iter != pointee_to_pointer_.end()) return iter->second;
436 
437   analysis::Type* pointeeTy;
438   std::unique_ptr<analysis::Pointer> pointerTy;
439   std::tie(pointeeTy, pointerTy) =
440       context()->get_type_mgr()->GetTypeAndPointerType(id,
441                                                        SpvStorageClassFunction);
442   uint32_t ptrId = 0;
443   if (pointeeTy->IsUniqueType()) {
444     // Non-ambiguous type, just ask the type manager for an id.
445     ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
446     pointee_to_pointer_[id] = ptrId;
447     return ptrId;
448   }
449 
450   // Ambiguous type. We must perform a linear search to try and find the right
451   // type.
452   for (auto global : context()->types_values()) {
453     if (global.opcode() == SpvOpTypePointer &&
454         global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
455         global.GetSingleWordInOperand(1u) == id) {
456       if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
457         // Only reuse a decoration-less pointer of the correct type.
458         ptrId = global.result_id();
459         break;
460       }
461     }
462   }
463 
464   if (ptrId != 0) {
465     pointee_to_pointer_[id] = ptrId;
466     return ptrId;
467   }
468 
469   ptrId = TakeNextId();
470   context()->AddType(MakeUnique<Instruction>(
471       context(), SpvOpTypePointer, 0, ptrId,
472       std::initializer_list<Operand>{
473           {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
474           {SPV_OPERAND_TYPE_ID, {id}}}));
475   Instruction* ptr = &*--context()->types_values_end();
476   get_def_use_mgr()->AnalyzeInstDefUse(ptr);
477   pointee_to_pointer_[id] = ptrId;
478   // Register with the type manager if necessary.
479   context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
480 
481   return ptrId;
482 }
483 
GetOrCreateInitialValue(Instruction * source,uint32_t index,Instruction * newVar)484 void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
485                                                     uint32_t index,
486                                                     Instruction* newVar) {
487   assert(source->opcode() == SpvOpVariable);
488   if (source->NumInOperands() < 2) return;
489 
490   uint32_t initId = source->GetSingleWordInOperand(1u);
491   uint32_t storageId = GetStorageType(newVar)->result_id();
492   Instruction* init = get_def_use_mgr()->GetDef(initId);
493   uint32_t newInitId = 0;
494   // TODO(dnovillo): Refactor this with constant propagation.
495   if (init->opcode() == SpvOpConstantNull) {
496     // Initialize to appropriate NULL.
497     auto iter = type_to_null_.find(storageId);
498     if (iter == type_to_null_.end()) {
499       newInitId = TakeNextId();
500       type_to_null_[storageId] = newInitId;
501       context()->AddGlobalValue(
502           MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
503                                   newInitId, std::initializer_list<Operand>{}));
504       Instruction* newNull = &*--context()->types_values_end();
505       get_def_use_mgr()->AnalyzeInstDefUse(newNull);
506     } else {
507       newInitId = iter->second;
508     }
509   } else if (IsSpecConstantInst(init->opcode())) {
510     // Create a new constant extract.
511     newInitId = TakeNextId();
512     context()->AddGlobalValue(MakeUnique<Instruction>(
513         context(), SpvOpSpecConstantOp, storageId, newInitId,
514         std::initializer_list<Operand>{
515             {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
516             {SPV_OPERAND_TYPE_ID, {init->result_id()}},
517             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
518     Instruction* newSpecConst = &*--context()->types_values_end();
519     get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
520   } else if (init->opcode() == SpvOpConstantComposite) {
521     // Get the appropriate index constant.
522     newInitId = init->GetSingleWordInOperand(index);
523     Instruction* element = get_def_use_mgr()->GetDef(newInitId);
524     if (element->opcode() == SpvOpUndef) {
525       // Undef is not a valid initializer for a variable.
526       newInitId = 0;
527     }
528   } else {
529     assert(false);
530   }
531 
532   if (newInitId != 0) {
533     newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
534   }
535 }
536 
GetArrayLength(const Instruction * arrayType) const537 uint64_t ScalarReplacementPass::GetArrayLength(
538     const Instruction* arrayType) const {
539   assert(arrayType->opcode() == SpvOpTypeArray);
540   const Instruction* length =
541       get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
542   return context()
543       ->get_constant_mgr()
544       ->GetConstantFromInst(length)
545       ->GetZeroExtendedValue();
546 }
547 
GetNumElements(const Instruction * type) const548 uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
549   assert(type->opcode() == SpvOpTypeVector ||
550          type->opcode() == SpvOpTypeMatrix);
551   const Operand& op = type->GetInOperand(1u);
552   assert(op.words.size() <= 2);
553   uint64_t len = 0;
554   for (size_t i = 0; i != op.words.size(); ++i) {
555     len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
556   }
557   return len;
558 }
559 
IsSpecConstant(uint32_t id) const560 bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
561   const Instruction* inst = get_def_use_mgr()->GetDef(id);
562   assert(inst);
563   return spvOpcodeIsSpecConstant(inst->opcode());
564 }
565 
GetStorageType(const Instruction * inst) const566 Instruction* ScalarReplacementPass::GetStorageType(
567     const Instruction* inst) const {
568   assert(inst->opcode() == SpvOpVariable);
569 
570   uint32_t ptrTypeId = inst->type_id();
571   uint32_t typeId =
572       get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
573   return get_def_use_mgr()->GetDef(typeId);
574 }
575 
CanReplaceVariable(const Instruction * varInst) const576 bool ScalarReplacementPass::CanReplaceVariable(
577     const Instruction* varInst) const {
578   assert(varInst->opcode() == SpvOpVariable);
579 
580   // Can only replace function scope variables.
581   if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
582     return false;
583   }
584 
585   if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
586     return false;
587   }
588 
589   const Instruction* typeInst = GetStorageType(varInst);
590   if (!CheckType(typeInst)) {
591     return false;
592   }
593 
594   if (!CheckAnnotations(varInst)) {
595     return false;
596   }
597 
598   if (!CheckUses(varInst)) {
599     return false;
600   }
601 
602   return true;
603 }
604 
CheckType(const Instruction * typeInst) const605 bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
606   if (!CheckTypeAnnotations(typeInst)) {
607     return false;
608   }
609 
610   switch (typeInst->opcode()) {
611     case SpvOpTypeStruct:
612       // Don't bother with empty structs or very large structs.
613       if (typeInst->NumInOperands() == 0 ||
614           IsLargerThanSizeLimit(typeInst->NumInOperands())) {
615         return false;
616       }
617       return true;
618     case SpvOpTypeArray:
619       if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
620         return false;
621       }
622       if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
623         return false;
624       }
625       return true;
626       // TODO(alanbaker): Develop some heuristics for when this should be
627       // re-enabled.
628       //// Specifically including matrix and vector in an attempt to reduce the
629       //// number of vector registers required.
630       // case SpvOpTypeMatrix:
631       // case SpvOpTypeVector:
632       //  if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
633       //  return true;
634 
635     case SpvOpTypeRuntimeArray:
636     default:
637       return false;
638   }
639 }
640 
CheckTypeAnnotations(const Instruction * typeInst) const641 bool ScalarReplacementPass::CheckTypeAnnotations(
642     const Instruction* typeInst) const {
643   for (auto inst :
644        get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
645     uint32_t decoration;
646     if (inst->opcode() == SpvOpDecorate) {
647       decoration = inst->GetSingleWordInOperand(1u);
648     } else {
649       assert(inst->opcode() == SpvOpMemberDecorate);
650       decoration = inst->GetSingleWordInOperand(2u);
651     }
652 
653     switch (decoration) {
654       case SpvDecorationRowMajor:
655       case SpvDecorationColMajor:
656       case SpvDecorationArrayStride:
657       case SpvDecorationMatrixStride:
658       case SpvDecorationCPacked:
659       case SpvDecorationInvariant:
660       case SpvDecorationRestrict:
661       case SpvDecorationOffset:
662       case SpvDecorationAlignment:
663       case SpvDecorationAlignmentId:
664       case SpvDecorationMaxByteOffset:
665       case SpvDecorationRelaxedPrecision:
666         break;
667       default:
668         return false;
669     }
670   }
671 
672   return true;
673 }
674 
CheckAnnotations(const Instruction * varInst) const675 bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
676   for (auto inst :
677        get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
678     assert(inst->opcode() == SpvOpDecorate);
679     uint32_t decoration = inst->GetSingleWordInOperand(1u);
680     switch (decoration) {
681       case SpvDecorationInvariant:
682       case SpvDecorationRestrict:
683       case SpvDecorationAlignment:
684       case SpvDecorationAlignmentId:
685       case SpvDecorationMaxByteOffset:
686         break;
687       default:
688         return false;
689     }
690   }
691 
692   return true;
693 }
694 
CheckUses(const Instruction * inst) const695 bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
696   VariableStats stats = {0, 0};
697   bool ok = CheckUses(inst, &stats);
698 
699   // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
700   // SRoA is costly, such as when the structure has many (unaccessed?)
701   // members.
702 
703   return ok;
704 }
705 
CheckUses(const Instruction * inst,VariableStats * stats) const706 bool ScalarReplacementPass::CheckUses(const Instruction* inst,
707                                       VariableStats* stats) const {
708   uint64_t max_legal_index = GetMaxLegalIndex(inst);
709 
710   bool ok = true;
711   get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
712                                           const Instruction* user,
713                                           uint32_t index) {
714     // Annotations are check as a group separately.
715     if (!IsAnnotationInst(user->opcode())) {
716       switch (user->opcode()) {
717         case SpvOpAccessChain:
718         case SpvOpInBoundsAccessChain:
719           if (index == 2u && user->NumInOperands() > 1) {
720             uint32_t id = user->GetSingleWordInOperand(1u);
721             const Instruction* opInst = get_def_use_mgr()->GetDef(id);
722             const auto* constant =
723                 context()->get_constant_mgr()->GetConstantFromInst(opInst);
724             if (!constant) {
725               ok = false;
726             } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
727               ok = false;
728             } else {
729               if (!CheckUsesRelaxed(user)) ok = false;
730             }
731             stats->num_partial_accesses++;
732           } else {
733             ok = false;
734           }
735           break;
736         case SpvOpLoad:
737           if (!CheckLoad(user, index)) ok = false;
738           stats->num_full_accesses++;
739           break;
740         case SpvOpStore:
741           if (!CheckStore(user, index)) ok = false;
742           stats->num_full_accesses++;
743           break;
744         case SpvOpName:
745         case SpvOpMemberName:
746           break;
747         default:
748           ok = false;
749           break;
750       }
751     }
752   });
753 
754   return ok;
755 }
756 
CheckUsesRelaxed(const Instruction * inst) const757 bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
758   bool ok = true;
759   get_def_use_mgr()->ForEachUse(
760       inst, [this, &ok](const Instruction* user, uint32_t index) {
761         switch (user->opcode()) {
762           case SpvOpAccessChain:
763           case SpvOpInBoundsAccessChain:
764             if (index != 2u) {
765               ok = false;
766             } else {
767               if (!CheckUsesRelaxed(user)) ok = false;
768             }
769             break;
770           case SpvOpLoad:
771             if (!CheckLoad(user, index)) ok = false;
772             break;
773           case SpvOpStore:
774             if (!CheckStore(user, index)) ok = false;
775             break;
776           default:
777             ok = false;
778             break;
779         }
780       });
781 
782   return ok;
783 }
784 
CheckLoad(const Instruction * inst,uint32_t index) const785 bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
786                                       uint32_t index) const {
787   if (index != 2u) return false;
788   if (inst->NumInOperands() >= 2 &&
789       inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
790     return false;
791   return true;
792 }
793 
CheckStore(const Instruction * inst,uint32_t index) const794 bool ScalarReplacementPass::CheckStore(const Instruction* inst,
795                                        uint32_t index) const {
796   if (index != 0u) return false;
797   if (inst->NumInOperands() >= 3 &&
798       inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
799     return false;
800   return true;
801 }
IsLargerThanSizeLimit(uint64_t length) const802 bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
803   if (max_num_elements_ == 0) {
804     return false;
805   }
806   return length > max_num_elements_;
807 }
808 
809 std::unique_ptr<std::unordered_set<int64_t>>
GetUsedComponents(Instruction * inst)810 ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
811   std::unique_ptr<std::unordered_set<int64_t>> result(
812       new std::unordered_set<int64_t>());
813 
814   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
815 
816   def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
817                                     this](Instruction* use) {
818     switch (use->opcode()) {
819       case SpvOpLoad: {
820         // Look for extract from the load.
821         std::vector<uint32_t> t;
822         if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
823               if (use2->opcode() != SpvOpCompositeExtract ||
824                   use2->NumInOperands() <= 1) {
825                 return false;
826               }
827               t.push_back(use2->GetSingleWordInOperand(1));
828               return true;
829             })) {
830           result->insert(t.begin(), t.end());
831           return true;
832         } else {
833           result.reset(nullptr);
834           return false;
835         }
836       }
837       case SpvOpName:
838       case SpvOpMemberName:
839       case SpvOpStore:
840         // No components are used.
841         return true;
842       case SpvOpAccessChain:
843       case SpvOpInBoundsAccessChain: {
844         // Add the first index it if is a constant.
845         // TODO: Could be improved by checking if the address is used in a load.
846         analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
847         uint32_t index_id = use->GetSingleWordInOperand(1);
848         const analysis::Constant* index_const =
849             const_mgr->FindDeclaredConstant(index_id);
850         if (index_const) {
851           result->insert(index_const->GetSignExtendedValue());
852           return true;
853         } else {
854           // Could be any element.  Assuming all are used.
855           result.reset(nullptr);
856           return false;
857         }
858       }
859       default:
860         // We do not know what is happening.  Have to assume the worst.
861         result.reset(nullptr);
862         return false;
863     }
864   });
865 
866   return result;
867 }
868 
CreateNullConstant(uint32_t type_id)869 Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
870   analysis::TypeManager* type_mgr = context()->get_type_mgr();
871   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
872 
873   const analysis::Type* type = type_mgr->GetType(type_id);
874   const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
875   Instruction* null_inst =
876       const_mgr->GetDefiningInstruction(null_const, type_id);
877   if (null_inst != nullptr) {
878     context()->UpdateDefUse(null_inst);
879   }
880   return null_inst;
881 }
882 
GetMaxLegalIndex(const Instruction * var_inst) const883 uint64_t ScalarReplacementPass::GetMaxLegalIndex(
884     const Instruction* var_inst) const {
885   assert(var_inst->opcode() == SpvOpVariable &&
886          "|var_inst| must be a variable instruction.");
887   Instruction* type = GetStorageType(var_inst);
888   switch (type->opcode()) {
889     case SpvOpTypeStruct:
890       return type->NumInOperands();
891     case SpvOpTypeArray:
892       return GetArrayLength(type);
893     case SpvOpTypeMatrix:
894     case SpvOpTypeVector:
895       return GetNumElements(type);
896     default:
897       return 0;
898   }
899   return 0;
900 }
901 
902 }  // namespace opt
903 }  // namespace spvtools
904