• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/scalar_replacement_pass.h"
16 
17 #include <algorithm>
18 #include <queue>
19 #include <tuple>
20 #include <utility>
21 
22 #include "source/enum_string_mapping.h"
23 #include "source/extensions.h"
24 #include "source/opt/reflect.h"
25 #include "source/opt/types.h"
26 #include "source/util/make_unique.h"
27 
28 static const uint32_t kDebugValueOperandValueIndex = 5;
29 static const uint32_t kDebugValueOperandExpressionIndex = 6;
30 static const uint32_t kDebugDeclareOperandVariableIndex = 5;
31 
32 namespace spvtools {
33 namespace opt {
34 
Process()35 Pass::Status ScalarReplacementPass::Process() {
36   Status status = Status::SuccessWithoutChange;
37   for (auto& f : *get_module()) {
38     if (f.IsDeclaration()) {
39       continue;
40     }
41 
42     Status functionStatus = ProcessFunction(&f);
43     if (functionStatus == Status::Failure)
44       return functionStatus;
45     else if (functionStatus == Status::SuccessWithChange)
46       status = functionStatus;
47   }
48 
49   return status;
50 }
51 
ProcessFunction(Function * function)52 Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
53   std::queue<Instruction*> worklist;
54   BasicBlock& entry = *function->begin();
55   for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
56     // Function storage class OpVariables must appear as the first instructions
57     // of the entry block.
58     if (iter->opcode() != SpvOpVariable) break;
59 
60     Instruction* varInst = &*iter;
61     if (CanReplaceVariable(varInst)) {
62       worklist.push(varInst);
63     }
64   }
65 
66   Status status = Status::SuccessWithoutChange;
67   while (!worklist.empty()) {
68     Instruction* varInst = worklist.front();
69     worklist.pop();
70 
71     Status var_status = ReplaceVariable(varInst, &worklist);
72     if (var_status == Status::Failure)
73       return var_status;
74     else if (var_status == Status::SuccessWithChange)
75       status = var_status;
76   }
77 
78   return status;
79 }
80 
ReplaceVariable(Instruction * inst,std::queue<Instruction * > * worklist)81 Pass::Status ScalarReplacementPass::ReplaceVariable(
82     Instruction* inst, std::queue<Instruction*>* worklist) {
83   std::vector<Instruction*> replacements;
84   if (!CreateReplacementVariables(inst, &replacements)) {
85     return Status::Failure;
86   }
87 
88   std::vector<Instruction*> dead;
89   bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
90       inst, [this, &replacements, &dead](Instruction* user) {
91         if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
92           if (ReplaceWholeDebugDeclare(user, replacements)) {
93             dead.push_back(user);
94             return true;
95           }
96           return false;
97         }
98         if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
99           if (ReplaceWholeDebugValue(user, replacements)) {
100             dead.push_back(user);
101             return true;
102           }
103           return false;
104         }
105         if (!IsAnnotationInst(user->opcode())) {
106           switch (user->opcode()) {
107             case SpvOpLoad:
108               if (ReplaceWholeLoad(user, replacements)) {
109                 dead.push_back(user);
110               } else {
111                 return false;
112               }
113               break;
114             case SpvOpStore:
115               if (ReplaceWholeStore(user, replacements)) {
116                 dead.push_back(user);
117               } else {
118                 return false;
119               }
120               break;
121             case SpvOpAccessChain:
122             case SpvOpInBoundsAccessChain:
123               if (ReplaceAccessChain(user, replacements))
124                 dead.push_back(user);
125               else
126                 return false;
127               break;
128             case SpvOpName:
129             case SpvOpMemberName:
130               break;
131             default:
132               assert(false && "Unexpected opcode");
133               break;
134           }
135         }
136         return true;
137       });
138 
139   if (replaced_all_uses) {
140     dead.push_back(inst);
141   } else {
142     return Status::Failure;
143   }
144 
145   // If there are no dead instructions to clean up, return with no changes.
146   if (dead.empty()) return Status::SuccessWithoutChange;
147 
148   // Clean up some dead code.
149   while (!dead.empty()) {
150     Instruction* toKill = dead.back();
151     dead.pop_back();
152     context()->KillInst(toKill);
153   }
154 
155   // Attempt to further scalarize.
156   for (auto var : replacements) {
157     if (var->opcode() == SpvOpVariable) {
158       if (get_def_use_mgr()->NumUsers(var) == 0) {
159         context()->KillInst(var);
160       } else if (CanReplaceVariable(var)) {
161         worklist->push(var);
162       }
163     }
164   }
165 
166   return Status::SuccessWithChange;
167 }
168 
ReplaceWholeDebugDeclare(Instruction * dbg_decl,const std::vector<Instruction * > & replacements)169 bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
170     Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
171   // Insert Deref operation to the front of the operation list of |dbg_decl|.
172   Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
173       dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
174   auto* deref_expr =
175       context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
176 
177   // Add DebugValue instruction with Indexes operand and Deref operation.
178   int32_t idx = 0;
179   for (const auto* var : replacements) {
180     Instruction* insert_before = var->NextNode();
181     while (insert_before->opcode() == SpvOpVariable)
182       insert_before = insert_before->NextNode();
183     assert(insert_before != nullptr && "unexpected end of list");
184     Instruction* added_dbg_value =
185         context()->get_debug_info_mgr()->AddDebugValueForDecl(
186             dbg_decl, /*value_id=*/var->result_id(),
187             /*insert_before=*/insert_before, /*scope_and_line=*/dbg_decl);
188 
189     if (added_dbg_value == nullptr) return false;
190     added_dbg_value->AddOperand(
191         {SPV_OPERAND_TYPE_ID,
192          {context()->get_constant_mgr()->GetSIntConst(idx)}});
193     added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
194                                 {deref_expr->result_id()});
195     if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
196       context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
197     }
198     ++idx;
199   }
200   return true;
201 }
202 
ReplaceWholeDebugValue(Instruction * dbg_value,const std::vector<Instruction * > & replacements)203 bool ScalarReplacementPass::ReplaceWholeDebugValue(
204     Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
205   int32_t idx = 0;
206   BasicBlock* block = context()->get_instr_block(dbg_value);
207   for (auto var : replacements) {
208     // Clone the DebugValue.
209     std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
210     uint32_t new_id = TakeNextId();
211     if (new_id == 0) return false;
212     new_dbg_value->SetResultId(new_id);
213     // Update 'Value' operand to the |replacements|.
214     new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
215     // Append 'Indexes' operand.
216     new_dbg_value->AddOperand(
217         {SPV_OPERAND_TYPE_ID,
218          {context()->get_constant_mgr()->GetSIntConst(idx)}});
219     // Insert the new DebugValue to the basic block.
220     auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
221     get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
222     context()->set_instr_block(added_instr, block);
223     ++idx;
224   }
225   return true;
226 }
227 
ReplaceWholeLoad(Instruction * load,const std::vector<Instruction * > & replacements)228 bool ScalarReplacementPass::ReplaceWholeLoad(
229     Instruction* load, const std::vector<Instruction*>& replacements) {
230   // Replaces the load of the entire composite with a load from each replacement
231   // variable followed by a composite construction.
232   BasicBlock* block = context()->get_instr_block(load);
233   std::vector<Instruction*> loads;
234   loads.reserve(replacements.size());
235   BasicBlock::iterator where(load);
236   for (auto var : replacements) {
237     // Create a load of each replacement variable.
238     if (var->opcode() != SpvOpVariable) {
239       loads.push_back(var);
240       continue;
241     }
242 
243     Instruction* type = GetStorageType(var);
244     uint32_t loadId = TakeNextId();
245     if (loadId == 0) {
246       return false;
247     }
248     std::unique_ptr<Instruction> newLoad(
249         new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
250                         std::initializer_list<Operand>{
251                             {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
252     // Copy memory access attributes which start at index 1. Index 0 is the
253     // pointer to load.
254     for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
255       Operand copy(load->GetInOperand(i));
256       newLoad->AddOperand(std::move(copy));
257     }
258     where = where.InsertBefore(std::move(newLoad));
259     get_def_use_mgr()->AnalyzeInstDefUse(&*where);
260     context()->set_instr_block(&*where, block);
261     where->UpdateDebugInfoFrom(load);
262     loads.push_back(&*where);
263   }
264 
265   // Construct a new composite.
266   uint32_t compositeId = TakeNextId();
267   if (compositeId == 0) {
268     return false;
269   }
270   where = load;
271   std::unique_ptr<Instruction> compositeConstruct(new Instruction(
272       context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
273   for (auto l : loads) {
274     Operand op(SPV_OPERAND_TYPE_ID,
275                std::initializer_list<uint32_t>{l->result_id()});
276     compositeConstruct->AddOperand(std::move(op));
277   }
278   where = where.InsertBefore(std::move(compositeConstruct));
279   get_def_use_mgr()->AnalyzeInstDefUse(&*where);
280   where->UpdateDebugInfoFrom(load);
281   context()->set_instr_block(&*where, block);
282   context()->ReplaceAllUsesWith(load->result_id(), compositeId);
283   return true;
284 }
285 
ReplaceWholeStore(Instruction * store,const std::vector<Instruction * > & replacements)286 bool ScalarReplacementPass::ReplaceWholeStore(
287     Instruction* store, const std::vector<Instruction*>& replacements) {
288   // Replaces a store to the whole composite with a series of extract and stores
289   // to each element.
290   uint32_t storeInput = store->GetSingleWordInOperand(1u);
291   BasicBlock* block = context()->get_instr_block(store);
292   BasicBlock::iterator where(store);
293   uint32_t elementIndex = 0;
294   for (auto var : replacements) {
295     // Create the extract.
296     if (var->opcode() != SpvOpVariable) {
297       elementIndex++;
298       continue;
299     }
300 
301     Instruction* type = GetStorageType(var);
302     uint32_t extractId = TakeNextId();
303     if (extractId == 0) {
304       return false;
305     }
306     std::unique_ptr<Instruction> extract(new Instruction(
307         context(), SpvOpCompositeExtract, type->result_id(), extractId,
308         std::initializer_list<Operand>{
309             {SPV_OPERAND_TYPE_ID, {storeInput}},
310             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
311     auto iter = where.InsertBefore(std::move(extract));
312     iter->UpdateDebugInfoFrom(store);
313     get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
314     context()->set_instr_block(&*iter, block);
315 
316     // Create the store.
317     std::unique_ptr<Instruction> newStore(
318         new Instruction(context(), SpvOpStore, 0, 0,
319                         std::initializer_list<Operand>{
320                             {SPV_OPERAND_TYPE_ID, {var->result_id()}},
321                             {SPV_OPERAND_TYPE_ID, {extractId}}}));
322     // Copy memory access attributes which start at index 2. Index 0 is the
323     // pointer and index 1 is the data.
324     for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
325       Operand copy(store->GetInOperand(i));
326       newStore->AddOperand(std::move(copy));
327     }
328     iter = where.InsertBefore(std::move(newStore));
329     iter->UpdateDebugInfoFrom(store);
330     get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
331     context()->set_instr_block(&*iter, block);
332   }
333   return true;
334 }
335 
ReplaceAccessChain(Instruction * chain,const std::vector<Instruction * > & replacements)336 bool ScalarReplacementPass::ReplaceAccessChain(
337     Instruction* chain, const std::vector<Instruction*>& replacements) {
338   // Replaces the access chain with either another access chain (with one fewer
339   // indexes) or a direct use of the replacement variable.
340   uint32_t indexId = chain->GetSingleWordInOperand(1u);
341   const Instruction* index = get_def_use_mgr()->GetDef(indexId);
342   int64_t indexValue = context()
343                            ->get_constant_mgr()
344                            ->GetConstantFromInst(index)
345                            ->GetSignExtendedValue();
346   if (indexValue < 0 ||
347       indexValue >= static_cast<int64_t>(replacements.size())) {
348     // Out of bounds access, this is illegal IR.  Notice that OpAccessChain
349     // indexing is 0-based, so we should also reject index == size-of-array.
350     return false;
351   } else {
352     const Instruction* var = replacements[static_cast<size_t>(indexValue)];
353     if (chain->NumInOperands() > 2) {
354       // Replace input access chain with another access chain.
355       BasicBlock::iterator chainIter(chain);
356       uint32_t replacementId = TakeNextId();
357       if (replacementId == 0) {
358         return false;
359       }
360       std::unique_ptr<Instruction> replacementChain(new Instruction(
361           context(), chain->opcode(), chain->type_id(), replacementId,
362           std::initializer_list<Operand>{
363               {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
364       // Add the remaining indexes.
365       for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
366         Operand copy(chain->GetInOperand(i));
367         replacementChain->AddOperand(std::move(copy));
368       }
369       replacementChain->UpdateDebugInfoFrom(chain);
370       auto iter = chainIter.InsertBefore(std::move(replacementChain));
371       get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
372       context()->set_instr_block(&*iter, context()->get_instr_block(chain));
373       context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
374     } else {
375       // Replace with a use of the variable.
376       context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
377     }
378   }
379 
380   return true;
381 }
382 
CreateReplacementVariables(Instruction * inst,std::vector<Instruction * > * replacements)383 bool ScalarReplacementPass::CreateReplacementVariables(
384     Instruction* inst, std::vector<Instruction*>* replacements) {
385   Instruction* type = GetStorageType(inst);
386 
387   std::unique_ptr<std::unordered_set<int64_t>> components_used =
388       GetUsedComponents(inst);
389 
390   uint32_t elem = 0;
391   switch (type->opcode()) {
392     case SpvOpTypeStruct:
393       type->ForEachInOperand(
394           [this, inst, &elem, replacements, &components_used](uint32_t* id) {
395             if (!components_used || components_used->count(elem)) {
396               CreateVariable(*id, inst, elem, replacements);
397             } else {
398               replacements->push_back(CreateNullConstant(*id));
399             }
400             elem++;
401           });
402       break;
403     case SpvOpTypeArray:
404       for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
405         if (!components_used || components_used->count(i)) {
406           CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
407                          replacements);
408         } else {
409           replacements->push_back(
410               CreateNullConstant(type->GetSingleWordInOperand(0u)));
411         }
412       }
413       break;
414 
415     case SpvOpTypeMatrix:
416     case SpvOpTypeVector:
417       for (uint32_t i = 0; i != GetNumElements(type); ++i) {
418         CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
419       }
420       break;
421 
422     default:
423       assert(false && "Unexpected type.");
424       break;
425   }
426 
427   TransferAnnotations(inst, replacements);
428   return std::find(replacements->begin(), replacements->end(), nullptr) ==
429          replacements->end();
430 }
431 
TransferAnnotations(const Instruction * source,std::vector<Instruction * > * replacements)432 void ScalarReplacementPass::TransferAnnotations(
433     const Instruction* source, std::vector<Instruction*>* replacements) {
434   // Only transfer invariant and restrict decorations on the variable. There are
435   // no type or member decorations that are necessary to transfer.
436   for (auto inst :
437        get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
438     assert(inst->opcode() == SpvOpDecorate);
439     uint32_t decoration = inst->GetSingleWordInOperand(1u);
440     if (decoration == SpvDecorationInvariant ||
441         decoration == SpvDecorationRestrict) {
442       for (auto var : *replacements) {
443         if (var == nullptr) {
444           continue;
445         }
446 
447         std::unique_ptr<Instruction> annotation(
448             new Instruction(context(), SpvOpDecorate, 0, 0,
449                             std::initializer_list<Operand>{
450                                 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
451                                 {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
452         for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
453           Operand copy(inst->GetInOperand(i));
454           annotation->AddOperand(std::move(copy));
455         }
456         context()->AddAnnotationInst(std::move(annotation));
457         get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
458       }
459     }
460   }
461 }
462 
CreateVariable(uint32_t typeId,Instruction * varInst,uint32_t index,std::vector<Instruction * > * replacements)463 void ScalarReplacementPass::CreateVariable(
464     uint32_t typeId, Instruction* varInst, uint32_t index,
465     std::vector<Instruction*>* replacements) {
466   uint32_t ptrId = GetOrCreatePointerType(typeId);
467   uint32_t id = TakeNextId();
468 
469   if (id == 0) {
470     replacements->push_back(nullptr);
471   }
472 
473   std::unique_ptr<Instruction> variable(new Instruction(
474       context(), SpvOpVariable, ptrId, id,
475       std::initializer_list<Operand>{
476           {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
477 
478   BasicBlock* block = context()->get_instr_block(varInst);
479   block->begin().InsertBefore(std::move(variable));
480   Instruction* inst = &*block->begin();
481 
482   // If varInst was initialized, make sure to initialize its replacement.
483   GetOrCreateInitialValue(varInst, index, inst);
484   get_def_use_mgr()->AnalyzeInstDefUse(inst);
485   context()->set_instr_block(inst, block);
486 
487   // Copy decorations from the member to the new variable.
488   Instruction* typeInst = GetStorageType(varInst);
489   for (auto dec_inst :
490        get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
491     uint32_t decoration;
492     if (dec_inst->opcode() != SpvOpMemberDecorate) {
493       continue;
494     }
495 
496     if (dec_inst->GetSingleWordInOperand(1) != index) {
497       continue;
498     }
499 
500     decoration = dec_inst->GetSingleWordInOperand(2u);
501     switch (decoration) {
502       case SpvDecorationRelaxedPrecision: {
503         std::unique_ptr<Instruction> new_dec_inst(
504             new Instruction(context(), SpvOpDecorate, 0, 0, {}));
505         new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
506         for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
507           new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
508         }
509         context()->AddAnnotationInst(std::move(new_dec_inst));
510       } break;
511       default:
512         break;
513     }
514   }
515 
516   // Update the DebugInfo debug information.
517   inst->UpdateDebugInfoFrom(varInst);
518 
519   replacements->push_back(inst);
520 }
521 
GetOrCreatePointerType(uint32_t id)522 uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
523   auto iter = pointee_to_pointer_.find(id);
524   if (iter != pointee_to_pointer_.end()) return iter->second;
525 
526   analysis::Type* pointeeTy;
527   std::unique_ptr<analysis::Pointer> pointerTy;
528   std::tie(pointeeTy, pointerTy) =
529       context()->get_type_mgr()->GetTypeAndPointerType(id,
530                                                        SpvStorageClassFunction);
531   uint32_t ptrId = 0;
532   if (pointeeTy->IsUniqueType()) {
533     // Non-ambiguous type, just ask the type manager for an id.
534     ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
535     pointee_to_pointer_[id] = ptrId;
536     return ptrId;
537   }
538 
539   // Ambiguous type. We must perform a linear search to try and find the right
540   // type.
541   for (auto global : context()->types_values()) {
542     if (global.opcode() == SpvOpTypePointer &&
543         global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
544         global.GetSingleWordInOperand(1u) == id) {
545       if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
546         // Only reuse a decoration-less pointer of the correct type.
547         ptrId = global.result_id();
548         break;
549       }
550     }
551   }
552 
553   if (ptrId != 0) {
554     pointee_to_pointer_[id] = ptrId;
555     return ptrId;
556   }
557 
558   ptrId = TakeNextId();
559   context()->AddType(MakeUnique<Instruction>(
560       context(), SpvOpTypePointer, 0, ptrId,
561       std::initializer_list<Operand>{
562           {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
563           {SPV_OPERAND_TYPE_ID, {id}}}));
564   Instruction* ptr = &*--context()->types_values_end();
565   get_def_use_mgr()->AnalyzeInstDefUse(ptr);
566   pointee_to_pointer_[id] = ptrId;
567   // Register with the type manager if necessary.
568   context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
569 
570   return ptrId;
571 }
572 
GetOrCreateInitialValue(Instruction * source,uint32_t index,Instruction * newVar)573 void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
574                                                     uint32_t index,
575                                                     Instruction* newVar) {
576   assert(source->opcode() == SpvOpVariable);
577   if (source->NumInOperands() < 2) return;
578 
579   uint32_t initId = source->GetSingleWordInOperand(1u);
580   uint32_t storageId = GetStorageType(newVar)->result_id();
581   Instruction* init = get_def_use_mgr()->GetDef(initId);
582   uint32_t newInitId = 0;
583   // TODO(dnovillo): Refactor this with constant propagation.
584   if (init->opcode() == SpvOpConstantNull) {
585     // Initialize to appropriate NULL.
586     auto iter = type_to_null_.find(storageId);
587     if (iter == type_to_null_.end()) {
588       newInitId = TakeNextId();
589       type_to_null_[storageId] = newInitId;
590       context()->AddGlobalValue(
591           MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
592                                   newInitId, std::initializer_list<Operand>{}));
593       Instruction* newNull = &*--context()->types_values_end();
594       get_def_use_mgr()->AnalyzeInstDefUse(newNull);
595     } else {
596       newInitId = iter->second;
597     }
598   } else if (IsSpecConstantInst(init->opcode())) {
599     // Create a new constant extract.
600     newInitId = TakeNextId();
601     context()->AddGlobalValue(MakeUnique<Instruction>(
602         context(), SpvOpSpecConstantOp, storageId, newInitId,
603         std::initializer_list<Operand>{
604             {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
605             {SPV_OPERAND_TYPE_ID, {init->result_id()}},
606             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
607     Instruction* newSpecConst = &*--context()->types_values_end();
608     get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
609   } else if (init->opcode() == SpvOpConstantComposite) {
610     // Get the appropriate index constant.
611     newInitId = init->GetSingleWordInOperand(index);
612     Instruction* element = get_def_use_mgr()->GetDef(newInitId);
613     if (element->opcode() == SpvOpUndef) {
614       // Undef is not a valid initializer for a variable.
615       newInitId = 0;
616     }
617   } else {
618     assert(false);
619   }
620 
621   if (newInitId != 0) {
622     newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
623   }
624 }
625 
GetArrayLength(const Instruction * arrayType) const626 uint64_t ScalarReplacementPass::GetArrayLength(
627     const Instruction* arrayType) const {
628   assert(arrayType->opcode() == SpvOpTypeArray);
629   const Instruction* length =
630       get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
631   return context()
632       ->get_constant_mgr()
633       ->GetConstantFromInst(length)
634       ->GetZeroExtendedValue();
635 }
636 
GetNumElements(const Instruction * type) const637 uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
638   assert(type->opcode() == SpvOpTypeVector ||
639          type->opcode() == SpvOpTypeMatrix);
640   const Operand& op = type->GetInOperand(1u);
641   assert(op.words.size() <= 2);
642   uint64_t len = 0;
643   for (size_t i = 0; i != op.words.size(); ++i) {
644     len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
645   }
646   return len;
647 }
648 
IsSpecConstant(uint32_t id) const649 bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
650   const Instruction* inst = get_def_use_mgr()->GetDef(id);
651   assert(inst);
652   return spvOpcodeIsSpecConstant(inst->opcode());
653 }
654 
GetStorageType(const Instruction * inst) const655 Instruction* ScalarReplacementPass::GetStorageType(
656     const Instruction* inst) const {
657   assert(inst->opcode() == SpvOpVariable);
658 
659   uint32_t ptrTypeId = inst->type_id();
660   uint32_t typeId =
661       get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
662   return get_def_use_mgr()->GetDef(typeId);
663 }
664 
CanReplaceVariable(const Instruction * varInst) const665 bool ScalarReplacementPass::CanReplaceVariable(
666     const Instruction* varInst) const {
667   assert(varInst->opcode() == SpvOpVariable);
668 
669   // Can only replace function scope variables.
670   if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
671     return false;
672   }
673 
674   if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
675     return false;
676   }
677 
678   const Instruction* typeInst = GetStorageType(varInst);
679   if (!CheckType(typeInst)) {
680     return false;
681   }
682 
683   if (!CheckAnnotations(varInst)) {
684     return false;
685   }
686 
687   if (!CheckUses(varInst)) {
688     return false;
689   }
690 
691   return true;
692 }
693 
CheckType(const Instruction * typeInst) const694 bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
695   if (!CheckTypeAnnotations(typeInst)) {
696     return false;
697   }
698 
699   switch (typeInst->opcode()) {
700     case SpvOpTypeStruct:
701       // Don't bother with empty structs or very large structs.
702       if (typeInst->NumInOperands() == 0 ||
703           IsLargerThanSizeLimit(typeInst->NumInOperands())) {
704         return false;
705       }
706       return true;
707     case SpvOpTypeArray:
708       if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
709         return false;
710       }
711       if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
712         return false;
713       }
714       return true;
715       // TODO(alanbaker): Develop some heuristics for when this should be
716       // re-enabled.
717       //// Specifically including matrix and vector in an attempt to reduce the
718       //// number of vector registers required.
719       // case SpvOpTypeMatrix:
720       // case SpvOpTypeVector:
721       //  if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
722       //  return true;
723 
724     case SpvOpTypeRuntimeArray:
725     default:
726       return false;
727   }
728 }
729 
CheckTypeAnnotations(const Instruction * typeInst) const730 bool ScalarReplacementPass::CheckTypeAnnotations(
731     const Instruction* typeInst) const {
732   for (auto inst :
733        get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
734     uint32_t decoration;
735     if (inst->opcode() == SpvOpDecorate) {
736       decoration = inst->GetSingleWordInOperand(1u);
737     } else {
738       assert(inst->opcode() == SpvOpMemberDecorate);
739       decoration = inst->GetSingleWordInOperand(2u);
740     }
741 
742     switch (decoration) {
743       case SpvDecorationRowMajor:
744       case SpvDecorationColMajor:
745       case SpvDecorationArrayStride:
746       case SpvDecorationMatrixStride:
747       case SpvDecorationCPacked:
748       case SpvDecorationInvariant:
749       case SpvDecorationRestrict:
750       case SpvDecorationOffset:
751       case SpvDecorationAlignment:
752       case SpvDecorationAlignmentId:
753       case SpvDecorationMaxByteOffset:
754       case SpvDecorationRelaxedPrecision:
755         break;
756       default:
757         return false;
758     }
759   }
760 
761   return true;
762 }
763 
CheckAnnotations(const Instruction * varInst) const764 bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
765   for (auto inst :
766        get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
767     assert(inst->opcode() == SpvOpDecorate);
768     uint32_t decoration = inst->GetSingleWordInOperand(1u);
769     switch (decoration) {
770       case SpvDecorationInvariant:
771       case SpvDecorationRestrict:
772       case SpvDecorationAlignment:
773       case SpvDecorationAlignmentId:
774       case SpvDecorationMaxByteOffset:
775         break;
776       default:
777         return false;
778     }
779   }
780 
781   return true;
782 }
783 
CheckUses(const Instruction * inst) const784 bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
785   VariableStats stats = {0, 0};
786   bool ok = CheckUses(inst, &stats);
787 
788   // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
789   // SRoA is costly, such as when the structure has many (unaccessed?)
790   // members.
791 
792   return ok;
793 }
794 
CheckUses(const Instruction * inst,VariableStats * stats) const795 bool ScalarReplacementPass::CheckUses(const Instruction* inst,
796                                       VariableStats* stats) const {
797   uint64_t max_legal_index = GetMaxLegalIndex(inst);
798 
799   bool ok = true;
800   get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
801                                           const Instruction* user,
802                                           uint32_t index) {
803     if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare ||
804         user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
805       // TODO: include num_partial_accesses if it uses Fragment operation or
806       // DebugValue has Indexes operand.
807       stats->num_full_accesses++;
808       return;
809     }
810 
811     // Annotations are check as a group separately.
812     if (!IsAnnotationInst(user->opcode())) {
813       switch (user->opcode()) {
814         case SpvOpAccessChain:
815         case SpvOpInBoundsAccessChain:
816           if (index == 2u && user->NumInOperands() > 1) {
817             uint32_t id = user->GetSingleWordInOperand(1u);
818             const Instruction* opInst = get_def_use_mgr()->GetDef(id);
819             const auto* constant =
820                 context()->get_constant_mgr()->GetConstantFromInst(opInst);
821             if (!constant) {
822               ok = false;
823             } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
824               ok = false;
825             } else {
826               if (!CheckUsesRelaxed(user)) ok = false;
827             }
828             stats->num_partial_accesses++;
829           } else {
830             ok = false;
831           }
832           break;
833         case SpvOpLoad:
834           if (!CheckLoad(user, index)) ok = false;
835           stats->num_full_accesses++;
836           break;
837         case SpvOpStore:
838           if (!CheckStore(user, index)) ok = false;
839           stats->num_full_accesses++;
840           break;
841         case SpvOpName:
842         case SpvOpMemberName:
843           break;
844         default:
845           ok = false;
846           break;
847       }
848     }
849   });
850 
851   return ok;
852 }
853 
CheckUsesRelaxed(const Instruction * inst) const854 bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
855   bool ok = true;
856   get_def_use_mgr()->ForEachUse(
857       inst, [this, &ok](const Instruction* user, uint32_t index) {
858         switch (user->opcode()) {
859           case SpvOpAccessChain:
860           case SpvOpInBoundsAccessChain:
861             if (index != 2u) {
862               ok = false;
863             } else {
864               if (!CheckUsesRelaxed(user)) ok = false;
865             }
866             break;
867           case SpvOpLoad:
868             if (!CheckLoad(user, index)) ok = false;
869             break;
870           case SpvOpStore:
871             if (!CheckStore(user, index)) ok = false;
872             break;
873           case SpvOpImageTexelPointer:
874             if (!CheckImageTexelPointer(index)) ok = false;
875             break;
876           case SpvOpExtInst:
877             if (user->GetCommonDebugOpcode() != CommonDebugInfoDebugDeclare ||
878                 !CheckDebugDeclare(index))
879               ok = false;
880             break;
881           default:
882             ok = false;
883             break;
884         }
885       });
886 
887   return ok;
888 }
889 
CheckImageTexelPointer(uint32_t index) const890 bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
891   return index == 2u;
892 }
893 
CheckLoad(const Instruction * inst,uint32_t index) const894 bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
895                                       uint32_t index) const {
896   if (index != 2u) return false;
897   if (inst->NumInOperands() >= 2 &&
898       inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
899     return false;
900   return true;
901 }
902 
CheckStore(const Instruction * inst,uint32_t index) const903 bool ScalarReplacementPass::CheckStore(const Instruction* inst,
904                                        uint32_t index) const {
905   if (index != 0u) return false;
906   if (inst->NumInOperands() >= 3 &&
907       inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
908     return false;
909   return true;
910 }
911 
CheckDebugDeclare(uint32_t index) const912 bool ScalarReplacementPass::CheckDebugDeclare(uint32_t index) const {
913   if (index != kDebugDeclareOperandVariableIndex) return false;
914   return true;
915 }
916 
IsLargerThanSizeLimit(uint64_t length) const917 bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
918   if (max_num_elements_ == 0) {
919     return false;
920   }
921   return length > max_num_elements_;
922 }
923 
924 std::unique_ptr<std::unordered_set<int64_t>>
GetUsedComponents(Instruction * inst)925 ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
926   std::unique_ptr<std::unordered_set<int64_t>> result(
927       new std::unordered_set<int64_t>());
928 
929   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
930 
931   def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
932                                     this](Instruction* use) {
933     switch (use->opcode()) {
934       case SpvOpLoad: {
935         // Look for extract from the load.
936         std::vector<uint32_t> t;
937         if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
938               if (use2->opcode() != SpvOpCompositeExtract ||
939                   use2->NumInOperands() <= 1) {
940                 return false;
941               }
942               t.push_back(use2->GetSingleWordInOperand(1));
943               return true;
944             })) {
945           result->insert(t.begin(), t.end());
946           return true;
947         } else {
948           result.reset(nullptr);
949           return false;
950         }
951       }
952       case SpvOpName:
953       case SpvOpMemberName:
954       case SpvOpStore:
955         // No components are used.
956         return true;
957       case SpvOpAccessChain:
958       case SpvOpInBoundsAccessChain: {
959         // Add the first index it if is a constant.
960         // TODO: Could be improved by checking if the address is used in a load.
961         analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
962         uint32_t index_id = use->GetSingleWordInOperand(1);
963         const analysis::Constant* index_const =
964             const_mgr->FindDeclaredConstant(index_id);
965         if (index_const) {
966           result->insert(index_const->GetSignExtendedValue());
967           return true;
968         } else {
969           // Could be any element.  Assuming all are used.
970           result.reset(nullptr);
971           return false;
972         }
973       }
974       default:
975         // We do not know what is happening.  Have to assume the worst.
976         result.reset(nullptr);
977         return false;
978     }
979   });
980 
981   return result;
982 }
983 
CreateNullConstant(uint32_t type_id)984 Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
985   analysis::TypeManager* type_mgr = context()->get_type_mgr();
986   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
987 
988   const analysis::Type* type = type_mgr->GetType(type_id);
989   const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
990   Instruction* null_inst =
991       const_mgr->GetDefiningInstruction(null_const, type_id);
992   if (null_inst != nullptr) {
993     context()->UpdateDefUse(null_inst);
994   }
995   return null_inst;
996 }
997 
GetMaxLegalIndex(const Instruction * var_inst) const998 uint64_t ScalarReplacementPass::GetMaxLegalIndex(
999     const Instruction* var_inst) const {
1000   assert(var_inst->opcode() == SpvOpVariable &&
1001          "|var_inst| must be a variable instruction.");
1002   Instruction* type = GetStorageType(var_inst);
1003   switch (type->opcode()) {
1004     case SpvOpTypeStruct:
1005       return type->NumInOperands();
1006     case SpvOpTypeArray:
1007       return GetArrayLength(type);
1008     case SpvOpTypeMatrix:
1009     case SpvOpTypeVector:
1010       return GetNumElements(type);
1011     default:
1012       return 0;
1013   }
1014   return 0;
1015 }
1016 
1017 }  // namespace opt
1018 }  // namespace spvtools
1019