• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/scalar_replacement_pass.h"
16 
17 #include <algorithm>
18 #include <queue>
19 #include <tuple>
20 #include <utility>
21 
22 #include "source/extensions.h"
23 #include "source/opt/reflect.h"
24 #include "source/opt/types.h"
25 #include "source/util/make_unique.h"
26 
27 namespace spvtools {
28 namespace opt {
29 namespace {
30 constexpr uint32_t kDebugValueOperandValueIndex = 5;
31 constexpr uint32_t kDebugValueOperandExpressionIndex = 6;
32 constexpr uint32_t kDebugDeclareOperandVariableIndex = 5;
33 }  // namespace
34 
Process()35 Pass::Status ScalarReplacementPass::Process() {
36   Status status = Status::SuccessWithoutChange;
37   for (auto& f : *get_module()) {
38     if (f.IsDeclaration()) {
39       continue;
40     }
41 
42     Status functionStatus = ProcessFunction(&f);
43     if (functionStatus == Status::Failure)
44       return functionStatus;
45     else if (functionStatus == Status::SuccessWithChange)
46       status = functionStatus;
47   }
48 
49   return status;
50 }
51 
ProcessFunction(Function * function)52 Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
53   std::queue<Instruction*> worklist;
54   BasicBlock& entry = *function->begin();
55   for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
56     // Function storage class OpVariables must appear as the first instructions
57     // of the entry block.
58     if (iter->opcode() != spv::Op::OpVariable) break;
59 
60     Instruction* varInst = &*iter;
61     if (CanReplaceVariable(varInst)) {
62       worklist.push(varInst);
63     }
64   }
65 
66   Status status = Status::SuccessWithoutChange;
67   while (!worklist.empty()) {
68     Instruction* varInst = worklist.front();
69     worklist.pop();
70 
71     Status var_status = ReplaceVariable(varInst, &worklist);
72     if (var_status == Status::Failure)
73       return var_status;
74     else if (var_status == Status::SuccessWithChange)
75       status = var_status;
76   }
77 
78   return status;
79 }
80 
ReplaceVariable(Instruction * inst,std::queue<Instruction * > * worklist)81 Pass::Status ScalarReplacementPass::ReplaceVariable(
82     Instruction* inst, std::queue<Instruction*>* worklist) {
83   std::vector<Instruction*> replacements;
84   if (!CreateReplacementVariables(inst, &replacements)) {
85     return Status::Failure;
86   }
87 
88   std::vector<Instruction*> dead;
89   bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
90       inst, [this, &replacements, &dead](Instruction* user) {
91         if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
92           if (ReplaceWholeDebugDeclare(user, replacements)) {
93             dead.push_back(user);
94             return true;
95           }
96           return false;
97         }
98         if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
99           if (ReplaceWholeDebugValue(user, replacements)) {
100             dead.push_back(user);
101             return true;
102           }
103           return false;
104         }
105         if (!IsAnnotationInst(user->opcode())) {
106           switch (user->opcode()) {
107             case spv::Op::OpLoad:
108               if (ReplaceWholeLoad(user, replacements)) {
109                 dead.push_back(user);
110               } else {
111                 return false;
112               }
113               break;
114             case spv::Op::OpStore:
115               if (ReplaceWholeStore(user, replacements)) {
116                 dead.push_back(user);
117               } else {
118                 return false;
119               }
120               break;
121             case spv::Op::OpAccessChain:
122             case spv::Op::OpInBoundsAccessChain:
123               if (ReplaceAccessChain(user, replacements))
124                 dead.push_back(user);
125               else
126                 return false;
127               break;
128             case spv::Op::OpName:
129             case spv::Op::OpMemberName:
130               break;
131             default:
132               assert(false && "Unexpected opcode");
133               break;
134           }
135         }
136         return true;
137       });
138 
139   if (replaced_all_uses) {
140     dead.push_back(inst);
141   } else {
142     return Status::Failure;
143   }
144 
145   // If there are no dead instructions to clean up, return with no changes.
146   if (dead.empty()) return Status::SuccessWithoutChange;
147 
148   // Clean up some dead code.
149   while (!dead.empty()) {
150     Instruction* toKill = dead.back();
151     dead.pop_back();
152     context()->KillInst(toKill);
153   }
154 
155   // Attempt to further scalarize.
156   for (auto var : replacements) {
157     if (var->opcode() == spv::Op::OpVariable) {
158       if (get_def_use_mgr()->NumUsers(var) == 0) {
159         context()->KillInst(var);
160       } else if (CanReplaceVariable(var)) {
161         worklist->push(var);
162       }
163     }
164   }
165 
166   return Status::SuccessWithChange;
167 }
168 
ReplaceWholeDebugDeclare(Instruction * dbg_decl,const std::vector<Instruction * > & replacements)169 bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
170     Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
171   // Insert Deref operation to the front of the operation list of |dbg_decl|.
172   Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
173       dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
174   auto* deref_expr =
175       context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
176 
177   // Add DebugValue instruction with Indexes operand and Deref operation.
178   int32_t idx = 0;
179   for (const auto* var : replacements) {
180     Instruction* insert_before = var->NextNode();
181     while (insert_before->opcode() == spv::Op::OpVariable)
182       insert_before = insert_before->NextNode();
183     assert(insert_before != nullptr && "unexpected end of list");
184     Instruction* added_dbg_value =
185         context()->get_debug_info_mgr()->AddDebugValueForDecl(
186             dbg_decl, /*value_id=*/var->result_id(),
187             /*insert_before=*/insert_before, /*scope_and_line=*/dbg_decl);
188 
189     if (added_dbg_value == nullptr) return false;
190     added_dbg_value->AddOperand(
191         {SPV_OPERAND_TYPE_ID,
192          {context()->get_constant_mgr()->GetSIntConstId(idx)}});
193     added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
194                                 {deref_expr->result_id()});
195     if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
196       context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
197     }
198     ++idx;
199   }
200   return true;
201 }
202 
ReplaceWholeDebugValue(Instruction * dbg_value,const std::vector<Instruction * > & replacements)203 bool ScalarReplacementPass::ReplaceWholeDebugValue(
204     Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
205   int32_t idx = 0;
206   BasicBlock* block = context()->get_instr_block(dbg_value);
207   for (auto var : replacements) {
208     // Clone the DebugValue.
209     std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
210     uint32_t new_id = TakeNextId();
211     if (new_id == 0) return false;
212     new_dbg_value->SetResultId(new_id);
213     // Update 'Value' operand to the |replacements|.
214     new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
215     // Append 'Indexes' operand.
216     new_dbg_value->AddOperand(
217         {SPV_OPERAND_TYPE_ID,
218          {context()->get_constant_mgr()->GetSIntConstId(idx)}});
219     // Insert the new DebugValue to the basic block.
220     auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
221     get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
222     context()->set_instr_block(added_instr, block);
223     ++idx;
224   }
225   return true;
226 }
227 
ReplaceWholeLoad(Instruction * load,const std::vector<Instruction * > & replacements)228 bool ScalarReplacementPass::ReplaceWholeLoad(
229     Instruction* load, const std::vector<Instruction*>& replacements) {
230   // Replaces the load of the entire composite with a load from each replacement
231   // variable followed by a composite construction.
232   BasicBlock* block = context()->get_instr_block(load);
233   std::vector<Instruction*> loads;
234   loads.reserve(replacements.size());
235   BasicBlock::iterator where(load);
236   for (auto var : replacements) {
237     // Create a load of each replacement variable.
238     if (var->opcode() != spv::Op::OpVariable) {
239       loads.push_back(var);
240       continue;
241     }
242 
243     Instruction* type = GetStorageType(var);
244     uint32_t loadId = TakeNextId();
245     if (loadId == 0) {
246       return false;
247     }
248     std::unique_ptr<Instruction> newLoad(
249         new Instruction(context(), spv::Op::OpLoad, type->result_id(), loadId,
250                         std::initializer_list<Operand>{
251                             {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
252     // Copy memory access attributes which start at index 1. Index 0 is the
253     // pointer to load.
254     for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
255       Operand copy(load->GetInOperand(i));
256       newLoad->AddOperand(std::move(copy));
257     }
258     where = where.InsertBefore(std::move(newLoad));
259     get_def_use_mgr()->AnalyzeInstDefUse(&*where);
260     context()->set_instr_block(&*where, block);
261     where->UpdateDebugInfoFrom(load);
262     loads.push_back(&*where);
263   }
264 
265   // Construct a new composite.
266   uint32_t compositeId = TakeNextId();
267   if (compositeId == 0) {
268     return false;
269   }
270   where = load;
271   std::unique_ptr<Instruction> compositeConstruct(
272       new Instruction(context(), spv::Op::OpCompositeConstruct, load->type_id(),
273                       compositeId, {}));
274   for (auto l : loads) {
275     Operand op(SPV_OPERAND_TYPE_ID,
276                std::initializer_list<uint32_t>{l->result_id()});
277     compositeConstruct->AddOperand(std::move(op));
278   }
279   where = where.InsertBefore(std::move(compositeConstruct));
280   get_def_use_mgr()->AnalyzeInstDefUse(&*where);
281   where->UpdateDebugInfoFrom(load);
282   context()->set_instr_block(&*where, block);
283   context()->ReplaceAllUsesWith(load->result_id(), compositeId);
284   return true;
285 }
286 
ReplaceWholeStore(Instruction * store,const std::vector<Instruction * > & replacements)287 bool ScalarReplacementPass::ReplaceWholeStore(
288     Instruction* store, const std::vector<Instruction*>& replacements) {
289   // Replaces a store to the whole composite with a series of extract and stores
290   // to each element.
291   uint32_t storeInput = store->GetSingleWordInOperand(1u);
292   BasicBlock* block = context()->get_instr_block(store);
293   BasicBlock::iterator where(store);
294   uint32_t elementIndex = 0;
295   for (auto var : replacements) {
296     // Create the extract.
297     if (var->opcode() != spv::Op::OpVariable) {
298       elementIndex++;
299       continue;
300     }
301 
302     Instruction* type = GetStorageType(var);
303     uint32_t extractId = TakeNextId();
304     if (extractId == 0) {
305       return false;
306     }
307     std::unique_ptr<Instruction> extract(new Instruction(
308         context(), spv::Op::OpCompositeExtract, type->result_id(), extractId,
309         std::initializer_list<Operand>{
310             {SPV_OPERAND_TYPE_ID, {storeInput}},
311             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
312     auto iter = where.InsertBefore(std::move(extract));
313     iter->UpdateDebugInfoFrom(store);
314     get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
315     context()->set_instr_block(&*iter, block);
316 
317     // Create the store.
318     std::unique_ptr<Instruction> newStore(
319         new Instruction(context(), spv::Op::OpStore, 0, 0,
320                         std::initializer_list<Operand>{
321                             {SPV_OPERAND_TYPE_ID, {var->result_id()}},
322                             {SPV_OPERAND_TYPE_ID, {extractId}}}));
323     // Copy memory access attributes which start at index 2. Index 0 is the
324     // pointer and index 1 is the data.
325     for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
326       Operand copy(store->GetInOperand(i));
327       newStore->AddOperand(std::move(copy));
328     }
329     iter = where.InsertBefore(std::move(newStore));
330     iter->UpdateDebugInfoFrom(store);
331     get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
332     context()->set_instr_block(&*iter, block);
333   }
334   return true;
335 }
336 
ReplaceAccessChain(Instruction * chain,const std::vector<Instruction * > & replacements)337 bool ScalarReplacementPass::ReplaceAccessChain(
338     Instruction* chain, const std::vector<Instruction*>& replacements) {
339   // Replaces the access chain with either another access chain (with one fewer
340   // indexes) or a direct use of the replacement variable.
341   uint32_t indexId = chain->GetSingleWordInOperand(1u);
342   const Instruction* index = get_def_use_mgr()->GetDef(indexId);
343   int64_t indexValue = context()
344                            ->get_constant_mgr()
345                            ->GetConstantFromInst(index)
346                            ->GetSignExtendedValue();
347   if (indexValue < 0 ||
348       indexValue >= static_cast<int64_t>(replacements.size())) {
349     // Out of bounds access, this is illegal IR.  Notice that OpAccessChain
350     // indexing is 0-based, so we should also reject index == size-of-array.
351     return false;
352   } else {
353     const Instruction* var = replacements[static_cast<size_t>(indexValue)];
354     if (chain->NumInOperands() > 2) {
355       // Replace input access chain with another access chain.
356       BasicBlock::iterator chainIter(chain);
357       uint32_t replacementId = TakeNextId();
358       if (replacementId == 0) {
359         return false;
360       }
361       std::unique_ptr<Instruction> replacementChain(new Instruction(
362           context(), chain->opcode(), chain->type_id(), replacementId,
363           std::initializer_list<Operand>{
364               {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
365       // Add the remaining indexes.
366       for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
367         Operand copy(chain->GetInOperand(i));
368         replacementChain->AddOperand(std::move(copy));
369       }
370       replacementChain->UpdateDebugInfoFrom(chain);
371       auto iter = chainIter.InsertBefore(std::move(replacementChain));
372       get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
373       context()->set_instr_block(&*iter, context()->get_instr_block(chain));
374       context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
375     } else {
376       // Replace with a use of the variable.
377       context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
378     }
379   }
380 
381   return true;
382 }
383 
CreateReplacementVariables(Instruction * inst,std::vector<Instruction * > * replacements)384 bool ScalarReplacementPass::CreateReplacementVariables(
385     Instruction* inst, std::vector<Instruction*>* replacements) {
386   Instruction* type = GetStorageType(inst);
387 
388   std::unique_ptr<std::unordered_set<int64_t>> components_used =
389       GetUsedComponents(inst);
390 
391   uint32_t elem = 0;
392   switch (type->opcode()) {
393     case spv::Op::OpTypeStruct:
394       type->ForEachInOperand(
395           [this, inst, &elem, replacements, &components_used](uint32_t* id) {
396             if (!components_used || components_used->count(elem)) {
397               CreateVariable(*id, inst, elem, replacements);
398             } else {
399               replacements->push_back(GetUndef(*id));
400             }
401             elem++;
402           });
403       break;
404     case spv::Op::OpTypeArray:
405       for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
406         if (!components_used || components_used->count(i)) {
407           CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
408                          replacements);
409         } else {
410           uint32_t element_type_id = type->GetSingleWordInOperand(0);
411           replacements->push_back(GetUndef(element_type_id));
412         }
413       }
414       break;
415 
416     case spv::Op::OpTypeMatrix:
417     case spv::Op::OpTypeVector:
418       for (uint32_t i = 0; i != GetNumElements(type); ++i) {
419         CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
420       }
421       break;
422 
423     default:
424       assert(false && "Unexpected type.");
425       break;
426   }
427 
428   TransferAnnotations(inst, replacements);
429   return std::find(replacements->begin(), replacements->end(), nullptr) ==
430          replacements->end();
431 }
432 
GetUndef(uint32_t type_id)433 Instruction* ScalarReplacementPass::GetUndef(uint32_t type_id) {
434   return get_def_use_mgr()->GetDef(Type2Undef(type_id));
435 }
436 
TransferAnnotations(const Instruction * source,std::vector<Instruction * > * replacements)437 void ScalarReplacementPass::TransferAnnotations(
438     const Instruction* source, std::vector<Instruction*>* replacements) {
439   // Only transfer invariant and restrict decorations on the variable. There are
440   // no type or member decorations that are necessary to transfer.
441   for (auto inst :
442        get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
443     assert(inst->opcode() == spv::Op::OpDecorate);
444     auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u));
445     if (decoration == spv::Decoration::Invariant ||
446         decoration == spv::Decoration::Restrict) {
447       for (auto var : *replacements) {
448         if (var == nullptr) {
449           continue;
450         }
451 
452         std::unique_ptr<Instruction> annotation(new Instruction(
453             context(), spv::Op::OpDecorate, 0, 0,
454             std::initializer_list<Operand>{
455                 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
456                 {SPV_OPERAND_TYPE_DECORATION, {uint32_t(decoration)}}}));
457         for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
458           Operand copy(inst->GetInOperand(i));
459           annotation->AddOperand(std::move(copy));
460         }
461         context()->AddAnnotationInst(std::move(annotation));
462         get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
463       }
464     }
465   }
466 }
467 
CreateVariable(uint32_t type_id,Instruction * var_inst,uint32_t index,std::vector<Instruction * > * replacements)468 void ScalarReplacementPass::CreateVariable(
469     uint32_t type_id, Instruction* var_inst, uint32_t index,
470     std::vector<Instruction*>* replacements) {
471   uint32_t ptr_id = GetOrCreatePointerType(type_id);
472   uint32_t id = TakeNextId();
473 
474   if (id == 0) {
475     replacements->push_back(nullptr);
476   }
477 
478   std::unique_ptr<Instruction> variable(
479       new Instruction(context(), spv::Op::OpVariable, ptr_id, id,
480                       std::initializer_list<Operand>{
481                           {SPV_OPERAND_TYPE_STORAGE_CLASS,
482                            {uint32_t(spv::StorageClass::Function)}}}));
483 
484   BasicBlock* block = context()->get_instr_block(var_inst);
485   block->begin().InsertBefore(std::move(variable));
486   Instruction* inst = &*block->begin();
487 
488   // If varInst was initialized, make sure to initialize its replacement.
489   GetOrCreateInitialValue(var_inst, index, inst);
490   get_def_use_mgr()->AnalyzeInstDefUse(inst);
491   context()->set_instr_block(inst, block);
492 
493   CopyDecorationsToVariable(var_inst, inst, index);
494   inst->UpdateDebugInfoFrom(var_inst);
495 
496   replacements->push_back(inst);
497 }
498 
GetOrCreatePointerType(uint32_t id)499 uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
500   auto iter = pointee_to_pointer_.find(id);
501   if (iter != pointee_to_pointer_.end()) return iter->second;
502 
503   analysis::TypeManager* type_mgr = context()->get_type_mgr();
504   uint32_t ptr_type_id =
505       type_mgr->FindPointerToType(id, spv::StorageClass::Function);
506   pointee_to_pointer_[id] = ptr_type_id;
507   return ptr_type_id;
508 }
509 
GetOrCreateInitialValue(Instruction * source,uint32_t index,Instruction * newVar)510 void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
511                                                     uint32_t index,
512                                                     Instruction* newVar) {
513   assert(source->opcode() == spv::Op::OpVariable);
514   if (source->NumInOperands() < 2) return;
515 
516   uint32_t initId = source->GetSingleWordInOperand(1u);
517   uint32_t storageId = GetStorageType(newVar)->result_id();
518   Instruction* init = get_def_use_mgr()->GetDef(initId);
519   uint32_t newInitId = 0;
520   // TODO(dnovillo): Refactor this with constant propagation.
521   if (init->opcode() == spv::Op::OpConstantNull) {
522     // Initialize to appropriate NULL.
523     auto iter = type_to_null_.find(storageId);
524     if (iter == type_to_null_.end()) {
525       newInitId = TakeNextId();
526       type_to_null_[storageId] = newInitId;
527       context()->AddGlobalValue(
528           MakeUnique<Instruction>(context(), spv::Op::OpConstantNull, storageId,
529                                   newInitId, std::initializer_list<Operand>{}));
530       Instruction* newNull = &*--context()->types_values_end();
531       get_def_use_mgr()->AnalyzeInstDefUse(newNull);
532     } else {
533       newInitId = iter->second;
534     }
535   } else if (IsSpecConstantInst(init->opcode())) {
536     // Create a new constant extract.
537     newInitId = TakeNextId();
538     context()->AddGlobalValue(MakeUnique<Instruction>(
539         context(), spv::Op::OpSpecConstantOp, storageId, newInitId,
540         std::initializer_list<Operand>{
541             {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER,
542              {uint32_t(spv::Op::OpCompositeExtract)}},
543             {SPV_OPERAND_TYPE_ID, {init->result_id()}},
544             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
545     Instruction* newSpecConst = &*--context()->types_values_end();
546     get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
547   } else if (init->opcode() == spv::Op::OpConstantComposite) {
548     // Get the appropriate index constant.
549     newInitId = init->GetSingleWordInOperand(index);
550     Instruction* element = get_def_use_mgr()->GetDef(newInitId);
551     if (element->opcode() == spv::Op::OpUndef) {
552       // Undef is not a valid initializer for a variable.
553       newInitId = 0;
554     }
555   } else {
556     assert(false);
557   }
558 
559   if (newInitId != 0) {
560     newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
561   }
562 }
563 
GetArrayLength(const Instruction * arrayType) const564 uint64_t ScalarReplacementPass::GetArrayLength(
565     const Instruction* arrayType) const {
566   assert(arrayType->opcode() == spv::Op::OpTypeArray);
567   const Instruction* length =
568       get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
569   return context()
570       ->get_constant_mgr()
571       ->GetConstantFromInst(length)
572       ->GetZeroExtendedValue();
573 }
574 
GetNumElements(const Instruction * type) const575 uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
576   assert(type->opcode() == spv::Op::OpTypeVector ||
577          type->opcode() == spv::Op::OpTypeMatrix);
578   const Operand& op = type->GetInOperand(1u);
579   assert(op.words.size() <= 2);
580   uint64_t len = 0;
581   for (size_t i = 0; i != op.words.size(); ++i) {
582     len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
583   }
584   return len;
585 }
586 
IsSpecConstant(uint32_t id) const587 bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
588   const Instruction* inst = get_def_use_mgr()->GetDef(id);
589   assert(inst);
590   return spvOpcodeIsSpecConstant(inst->opcode());
591 }
592 
GetStorageType(const Instruction * inst) const593 Instruction* ScalarReplacementPass::GetStorageType(
594     const Instruction* inst) const {
595   assert(inst->opcode() == spv::Op::OpVariable);
596 
597   uint32_t ptrTypeId = inst->type_id();
598   uint32_t typeId =
599       get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
600   return get_def_use_mgr()->GetDef(typeId);
601 }
602 
CanReplaceVariable(const Instruction * varInst) const603 bool ScalarReplacementPass::CanReplaceVariable(
604     const Instruction* varInst) const {
605   assert(varInst->opcode() == spv::Op::OpVariable);
606 
607   // Can only replace function scope variables.
608   if (spv::StorageClass(varInst->GetSingleWordInOperand(0u)) !=
609       spv::StorageClass::Function) {
610     return false;
611   }
612 
613   if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
614     return false;
615   }
616 
617   const Instruction* typeInst = GetStorageType(varInst);
618   if (!CheckType(typeInst)) {
619     return false;
620   }
621 
622   if (!CheckAnnotations(varInst)) {
623     return false;
624   }
625 
626   if (!CheckUses(varInst)) {
627     return false;
628   }
629 
630   return true;
631 }
632 
CheckType(const Instruction * typeInst) const633 bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
634   if (!CheckTypeAnnotations(typeInst)) {
635     return false;
636   }
637 
638   switch (typeInst->opcode()) {
639     case spv::Op::OpTypeStruct:
640       // Don't bother with empty structs or very large structs.
641       if (typeInst->NumInOperands() == 0 ||
642           IsLargerThanSizeLimit(typeInst->NumInOperands())) {
643         return false;
644       }
645       return true;
646     case spv::Op::OpTypeArray:
647       if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
648         return false;
649       }
650       if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
651         return false;
652       }
653       return true;
654       // TODO(alanbaker): Develop some heuristics for when this should be
655       // re-enabled.
656       //// Specifically including matrix and vector in an attempt to reduce the
657       //// number of vector registers required.
658       // case spv::Op::OpTypeMatrix:
659       // case spv::Op::OpTypeVector:
660       //  if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
661       //  return true;
662 
663     case spv::Op::OpTypeRuntimeArray:
664     default:
665       return false;
666   }
667 }
668 
CheckTypeAnnotations(const Instruction * typeInst) const669 bool ScalarReplacementPass::CheckTypeAnnotations(
670     const Instruction* typeInst) const {
671   for (auto inst :
672        get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
673     uint32_t decoration;
674     if (inst->opcode() == spv::Op::OpDecorate) {
675       decoration = inst->GetSingleWordInOperand(1u);
676     } else {
677       assert(inst->opcode() == spv::Op::OpMemberDecorate);
678       decoration = inst->GetSingleWordInOperand(2u);
679     }
680 
681     switch (spv::Decoration(decoration)) {
682       case spv::Decoration::RowMajor:
683       case spv::Decoration::ColMajor:
684       case spv::Decoration::ArrayStride:
685       case spv::Decoration::MatrixStride:
686       case spv::Decoration::CPacked:
687       case spv::Decoration::Invariant:
688       case spv::Decoration::Restrict:
689       case spv::Decoration::Offset:
690       case spv::Decoration::Alignment:
691       case spv::Decoration::AlignmentId:
692       case spv::Decoration::MaxByteOffset:
693       case spv::Decoration::RelaxedPrecision:
694       case spv::Decoration::AliasedPointer:
695       case spv::Decoration::RestrictPointer:
696         break;
697       default:
698         return false;
699     }
700   }
701 
702   return true;
703 }
704 
CheckAnnotations(const Instruction * varInst) const705 bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
706   for (auto inst :
707        get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
708     assert(inst->opcode() == spv::Op::OpDecorate);
709     auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u));
710     switch (decoration) {
711       case spv::Decoration::Invariant:
712       case spv::Decoration::Restrict:
713       case spv::Decoration::Alignment:
714       case spv::Decoration::AlignmentId:
715       case spv::Decoration::MaxByteOffset:
716       case spv::Decoration::AliasedPointer:
717       case spv::Decoration::RestrictPointer:
718         break;
719       default:
720         return false;
721     }
722   }
723 
724   return true;
725 }
726 
CheckUses(const Instruction * inst) const727 bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
728   VariableStats stats = {0, 0};
729   bool ok = CheckUses(inst, &stats);
730 
731   // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
732   // SRoA is costly, such as when the structure has many (unaccessed?)
733   // members.
734 
735   return ok;
736 }
737 
CheckUses(const Instruction * inst,VariableStats * stats) const738 bool ScalarReplacementPass::CheckUses(const Instruction* inst,
739                                       VariableStats* stats) const {
740   uint64_t max_legal_index = GetMaxLegalIndex(inst);
741 
742   bool ok = true;
743   get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
744                                           const Instruction* user,
745                                           uint32_t index) {
746     if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare ||
747         user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
748       // TODO: include num_partial_accesses if it uses Fragment operation or
749       // DebugValue has Indexes operand.
750       stats->num_full_accesses++;
751       return;
752     }
753 
754     // Annotations are check as a group separately.
755     if (!IsAnnotationInst(user->opcode())) {
756       switch (user->opcode()) {
757         case spv::Op::OpAccessChain:
758         case spv::Op::OpInBoundsAccessChain:
759           if (index == 2u && user->NumInOperands() > 1) {
760             uint32_t id = user->GetSingleWordInOperand(1u);
761             const Instruction* opInst = get_def_use_mgr()->GetDef(id);
762             const auto* constant =
763                 context()->get_constant_mgr()->GetConstantFromInst(opInst);
764             if (!constant) {
765               ok = false;
766             } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
767               ok = false;
768             } else {
769               if (!CheckUsesRelaxed(user)) ok = false;
770             }
771             stats->num_partial_accesses++;
772           } else {
773             ok = false;
774           }
775           break;
776         case spv::Op::OpLoad:
777           if (!CheckLoad(user, index)) ok = false;
778           stats->num_full_accesses++;
779           break;
780         case spv::Op::OpStore:
781           if (!CheckStore(user, index)) ok = false;
782           stats->num_full_accesses++;
783           break;
784         case spv::Op::OpName:
785         case spv::Op::OpMemberName:
786           break;
787         default:
788           ok = false;
789           break;
790       }
791     }
792   });
793 
794   return ok;
795 }
796 
CheckUsesRelaxed(const Instruction * inst) const797 bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
798   bool ok = true;
799   get_def_use_mgr()->ForEachUse(
800       inst, [this, &ok](const Instruction* user, uint32_t index) {
801         switch (user->opcode()) {
802           case spv::Op::OpAccessChain:
803           case spv::Op::OpInBoundsAccessChain:
804             if (index != 2u) {
805               ok = false;
806             } else {
807               if (!CheckUsesRelaxed(user)) ok = false;
808             }
809             break;
810           case spv::Op::OpLoad:
811             if (!CheckLoad(user, index)) ok = false;
812             break;
813           case spv::Op::OpStore:
814             if (!CheckStore(user, index)) ok = false;
815             break;
816           case spv::Op::OpImageTexelPointer:
817             if (!CheckImageTexelPointer(index)) ok = false;
818             break;
819           case spv::Op::OpExtInst:
820             if (user->GetCommonDebugOpcode() != CommonDebugInfoDebugDeclare ||
821                 !CheckDebugDeclare(index))
822               ok = false;
823             break;
824           default:
825             ok = false;
826             break;
827         }
828       });
829 
830   return ok;
831 }
832 
CheckImageTexelPointer(uint32_t index) const833 bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
834   return index == 2u;
835 }
836 
CheckLoad(const Instruction * inst,uint32_t index) const837 bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
838                                       uint32_t index) const {
839   if (index != 2u) return false;
840   if (inst->NumInOperands() >= 2 &&
841       inst->GetSingleWordInOperand(1u) &
842           uint32_t(spv::MemoryAccessMask::Volatile))
843     return false;
844   return true;
845 }
846 
CheckStore(const Instruction * inst,uint32_t index) const847 bool ScalarReplacementPass::CheckStore(const Instruction* inst,
848                                        uint32_t index) const {
849   if (index != 0u) return false;
850   if (inst->NumInOperands() >= 3 &&
851       inst->GetSingleWordInOperand(2u) &
852           uint32_t(spv::MemoryAccessMask::Volatile))
853     return false;
854   return true;
855 }
856 
CheckDebugDeclare(uint32_t index) const857 bool ScalarReplacementPass::CheckDebugDeclare(uint32_t index) const {
858   if (index != kDebugDeclareOperandVariableIndex) return false;
859   return true;
860 }
861 
IsLargerThanSizeLimit(uint64_t length) const862 bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
863   if (max_num_elements_ == 0) {
864     return false;
865   }
866   return length > max_num_elements_;
867 }
868 
869 std::unique_ptr<std::unordered_set<int64_t>>
GetUsedComponents(Instruction * inst)870 ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
871   std::unique_ptr<std::unordered_set<int64_t>> result(
872       new std::unordered_set<int64_t>());
873 
874   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
875 
876   def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
877                                     this](Instruction* use) {
878     switch (use->opcode()) {
879       case spv::Op::OpLoad: {
880         // Look for extract from the load.
881         std::vector<uint32_t> t;
882         if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
883               if (use2->opcode() != spv::Op::OpCompositeExtract ||
884                   use2->NumInOperands() <= 1) {
885                 return false;
886               }
887               t.push_back(use2->GetSingleWordInOperand(1));
888               return true;
889             })) {
890           result->insert(t.begin(), t.end());
891           return true;
892         } else {
893           result.reset(nullptr);
894           return false;
895         }
896       }
897       case spv::Op::OpName:
898       case spv::Op::OpMemberName:
899       case spv::Op::OpStore:
900         // No components are used.
901         return true;
902       case spv::Op::OpAccessChain:
903       case spv::Op::OpInBoundsAccessChain: {
904         // Add the first index it if is a constant.
905         // TODO: Could be improved by checking if the address is used in a load.
906         analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
907         uint32_t index_id = use->GetSingleWordInOperand(1);
908         const analysis::Constant* index_const =
909             const_mgr->FindDeclaredConstant(index_id);
910         if (index_const) {
911           result->insert(index_const->GetSignExtendedValue());
912           return true;
913         } else {
914           // Could be any element.  Assuming all are used.
915           result.reset(nullptr);
916           return false;
917         }
918       }
919       default:
920         // We do not know what is happening.  Have to assume the worst.
921         result.reset(nullptr);
922         return false;
923     }
924   });
925 
926   return result;
927 }
928 
GetMaxLegalIndex(const Instruction * var_inst) const929 uint64_t ScalarReplacementPass::GetMaxLegalIndex(
930     const Instruction* var_inst) const {
931   assert(var_inst->opcode() == spv::Op::OpVariable &&
932          "|var_inst| must be a variable instruction.");
933   Instruction* type = GetStorageType(var_inst);
934   switch (type->opcode()) {
935     case spv::Op::OpTypeStruct:
936       return type->NumInOperands();
937     case spv::Op::OpTypeArray:
938       return GetArrayLength(type);
939     case spv::Op::OpTypeMatrix:
940     case spv::Op::OpTypeVector:
941       return GetNumElements(type);
942     default:
943       return 0;
944   }
945   return 0;
946 }
947 
CopyDecorationsToVariable(Instruction * from,Instruction * to,uint32_t member_index)948 void ScalarReplacementPass::CopyDecorationsToVariable(Instruction* from,
949                                                       Instruction* to,
950                                                       uint32_t member_index) {
951   CopyPointerDecorationsToVariable(from, to);
952   CopyNecessaryMemberDecorationsToVariable(from, to, member_index);
953 }
954 
CopyPointerDecorationsToVariable(Instruction * from,Instruction * to)955 void ScalarReplacementPass::CopyPointerDecorationsToVariable(Instruction* from,
956                                                              Instruction* to) {
957   // The RestrictPointer and AliasedPointer decorations are copied to all
958   // members even if the new variable does not contain a pointer. It does
959   // not hurt to do so.
960   for (auto dec_inst :
961        get_decoration_mgr()->GetDecorationsFor(from->result_id(), false)) {
962     uint32_t decoration;
963     decoration = dec_inst->GetSingleWordInOperand(1u);
964     switch (spv::Decoration(decoration)) {
965       case spv::Decoration::AliasedPointer:
966       case spv::Decoration::RestrictPointer: {
967         std::unique_ptr<Instruction> new_dec_inst(dec_inst->Clone(context()));
968         new_dec_inst->SetInOperand(0, {to->result_id()});
969         context()->AddAnnotationInst(std::move(new_dec_inst));
970       } break;
971       default:
972         break;
973     }
974   }
975 }
976 
CopyNecessaryMemberDecorationsToVariable(Instruction * from,Instruction * to,uint32_t member_index)977 void ScalarReplacementPass::CopyNecessaryMemberDecorationsToVariable(
978     Instruction* from, Instruction* to, uint32_t member_index) {
979   Instruction* type_inst = GetStorageType(from);
980   for (auto dec_inst :
981        get_decoration_mgr()->GetDecorationsFor(type_inst->result_id(), false)) {
982     uint32_t decoration;
983     if (dec_inst->opcode() == spv::Op::OpMemberDecorate) {
984       if (dec_inst->GetSingleWordInOperand(1) != member_index) {
985         continue;
986       }
987 
988       decoration = dec_inst->GetSingleWordInOperand(2u);
989       switch (spv::Decoration(decoration)) {
990         case spv::Decoration::ArrayStride:
991         case spv::Decoration::Alignment:
992         case spv::Decoration::AlignmentId:
993         case spv::Decoration::MaxByteOffset:
994         case spv::Decoration::MaxByteOffsetId:
995         case spv::Decoration::RelaxedPrecision: {
996           std::unique_ptr<Instruction> new_dec_inst(
997               new Instruction(context(), spv::Op::OpDecorate, 0, 0, {}));
998           new_dec_inst->AddOperand(
999               Operand(SPV_OPERAND_TYPE_ID, {to->result_id()}));
1000           for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
1001             new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
1002           }
1003           context()->AddAnnotationInst(std::move(new_dec_inst));
1004         } break;
1005         default:
1006           break;
1007       }
1008     }
1009   }
1010 }
1011 
1012 }  // namespace opt
1013 }  // namespace spvtools
1014