• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/scalar_replacement_pass.h"
16 
17 #include <algorithm>
18 #include <queue>
19 #include <tuple>
20 #include <utility>
21 
22 #include "source/enum_string_mapping.h"
23 #include "source/extensions.h"
24 #include "source/opt/reflect.h"
25 #include "source/opt/types.h"
26 #include "source/util/make_unique.h"
27 
28 static const uint32_t kDebugValueOperandValueIndex = 5;
29 static const uint32_t kDebugValueOperandExpressionIndex = 6;
30 
31 namespace spvtools {
32 namespace opt {
33 
Process()34 Pass::Status ScalarReplacementPass::Process() {
35   Status status = Status::SuccessWithoutChange;
36   for (auto& f : *get_module()) {
37     Status functionStatus = ProcessFunction(&f);
38     if (functionStatus == Status::Failure)
39       return functionStatus;
40     else if (functionStatus == Status::SuccessWithChange)
41       status = functionStatus;
42   }
43 
44   return status;
45 }
46 
ProcessFunction(Function * function)47 Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
48   std::queue<Instruction*> worklist;
49   BasicBlock& entry = *function->begin();
50   for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
51     // Function storage class OpVariables must appear as the first instructions
52     // of the entry block.
53     if (iter->opcode() != SpvOpVariable) break;
54 
55     Instruction* varInst = &*iter;
56     if (CanReplaceVariable(varInst)) {
57       worklist.push(varInst);
58     }
59   }
60 
61   Status status = Status::SuccessWithoutChange;
62   while (!worklist.empty()) {
63     Instruction* varInst = worklist.front();
64     worklist.pop();
65 
66     Status var_status = ReplaceVariable(varInst, &worklist);
67     if (var_status == Status::Failure)
68       return var_status;
69     else if (var_status == Status::SuccessWithChange)
70       status = var_status;
71   }
72 
73   return status;
74 }
75 
ReplaceVariable(Instruction * inst,std::queue<Instruction * > * worklist)76 Pass::Status ScalarReplacementPass::ReplaceVariable(
77     Instruction* inst, std::queue<Instruction*>* worklist) {
78   std::vector<Instruction*> replacements;
79   if (!CreateReplacementVariables(inst, &replacements)) {
80     return Status::Failure;
81   }
82 
83   std::vector<Instruction*> dead;
84   bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
85       inst, [this, &replacements, &dead](Instruction* user) {
86         if (user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugDeclare) {
87           if (ReplaceWholeDebugDeclare(user, replacements)) {
88             dead.push_back(user);
89             return true;
90           }
91           return false;
92         }
93         if (user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugValue) {
94           if (ReplaceWholeDebugValue(user, replacements)) {
95             dead.push_back(user);
96             return true;
97           }
98           return false;
99         }
100         if (!IsAnnotationInst(user->opcode())) {
101           switch (user->opcode()) {
102             case SpvOpLoad:
103               if (ReplaceWholeLoad(user, replacements)) {
104                 dead.push_back(user);
105               } else {
106                 return false;
107               }
108               break;
109             case SpvOpStore:
110               if (ReplaceWholeStore(user, replacements)) {
111                 dead.push_back(user);
112               } else {
113                 return false;
114               }
115               break;
116             case SpvOpAccessChain:
117             case SpvOpInBoundsAccessChain:
118               if (ReplaceAccessChain(user, replacements))
119                 dead.push_back(user);
120               else
121                 return false;
122               break;
123             case SpvOpName:
124             case SpvOpMemberName:
125               break;
126             default:
127               assert(false && "Unexpected opcode");
128               break;
129           }
130         }
131         return true;
132       });
133 
134   if (replaced_all_uses) {
135     dead.push_back(inst);
136   } else {
137     return Status::Failure;
138   }
139 
140   // If there are no dead instructions to clean up, return with no changes.
141   if (dead.empty()) return Status::SuccessWithoutChange;
142 
143   // Clean up some dead code.
144   while (!dead.empty()) {
145     Instruction* toKill = dead.back();
146     dead.pop_back();
147     context()->KillInst(toKill);
148   }
149 
150   // Attempt to further scalarize.
151   for (auto var : replacements) {
152     if (var->opcode() == SpvOpVariable) {
153       if (get_def_use_mgr()->NumUsers(var) == 0) {
154         context()->KillInst(var);
155       } else if (CanReplaceVariable(var)) {
156         worklist->push(var);
157       }
158     }
159   }
160 
161   return Status::SuccessWithChange;
162 }
163 
ReplaceWholeDebugDeclare(Instruction * dbg_decl,const std::vector<Instruction * > & replacements)164 bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
165     Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
166   // Insert Deref operation to the front of the operation list of |dbg_decl|.
167   Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
168       dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
169   auto* deref_expr =
170       context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
171 
172   // Add DebugValue instruction with Indexes operand and Deref operation.
173   int32_t idx = 0;
174   for (const auto* var : replacements) {
175     Instruction* added_dbg_value =
176         context()->get_debug_info_mgr()->AddDebugValueForDecl(
177             dbg_decl, /*value_id=*/var->result_id(),
178             /*insert_before=*/var->NextNode(), /*scope_and_line=*/dbg_decl);
179 
180     if (added_dbg_value == nullptr) return false;
181     added_dbg_value->AddOperand(
182         {SPV_OPERAND_TYPE_ID,
183          {context()->get_constant_mgr()->GetSIntConst(idx)}});
184     added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
185                                 {deref_expr->result_id()});
186     if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
187       context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
188     }
189     ++idx;
190   }
191   return true;
192 }
193 
ReplaceWholeDebugValue(Instruction * dbg_value,const std::vector<Instruction * > & replacements)194 bool ScalarReplacementPass::ReplaceWholeDebugValue(
195     Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
196   int32_t idx = 0;
197   BasicBlock* block = context()->get_instr_block(dbg_value);
198   for (auto var : replacements) {
199     // Clone the DebugValue.
200     std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
201     uint32_t new_id = TakeNextId();
202     if (new_id == 0) return false;
203     new_dbg_value->SetResultId(new_id);
204     // Update 'Value' operand to the |replacements|.
205     new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
206     // Append 'Indexes' operand.
207     new_dbg_value->AddOperand(
208         {SPV_OPERAND_TYPE_ID,
209          {context()->get_constant_mgr()->GetSIntConst(idx)}});
210     // Insert the new DebugValue to the basic block.
211     auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
212     get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
213     context()->set_instr_block(added_instr, block);
214     ++idx;
215   }
216   return true;
217 }
218 
ReplaceWholeLoad(Instruction * load,const std::vector<Instruction * > & replacements)219 bool ScalarReplacementPass::ReplaceWholeLoad(
220     Instruction* load, const std::vector<Instruction*>& replacements) {
221   // Replaces the load of the entire composite with a load from each replacement
222   // variable followed by a composite construction.
223   BasicBlock* block = context()->get_instr_block(load);
224   std::vector<Instruction*> loads;
225   loads.reserve(replacements.size());
226   BasicBlock::iterator where(load);
227   for (auto var : replacements) {
228     // Create a load of each replacement variable.
229     if (var->opcode() != SpvOpVariable) {
230       loads.push_back(var);
231       continue;
232     }
233 
234     Instruction* type = GetStorageType(var);
235     uint32_t loadId = TakeNextId();
236     if (loadId == 0) {
237       return false;
238     }
239     std::unique_ptr<Instruction> newLoad(
240         new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
241                         std::initializer_list<Operand>{
242                             {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
243     // Copy memory access attributes which start at index 1. Index 0 is the
244     // pointer to load.
245     for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
246       Operand copy(load->GetInOperand(i));
247       newLoad->AddOperand(std::move(copy));
248     }
249     where = where.InsertBefore(std::move(newLoad));
250     get_def_use_mgr()->AnalyzeInstDefUse(&*where);
251     context()->set_instr_block(&*where, block);
252     where->UpdateDebugInfoFrom(load);
253     loads.push_back(&*where);
254   }
255 
256   // Construct a new composite.
257   uint32_t compositeId = TakeNextId();
258   if (compositeId == 0) {
259     return false;
260   }
261   where = load;
262   std::unique_ptr<Instruction> compositeConstruct(new Instruction(
263       context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
264   for (auto l : loads) {
265     Operand op(SPV_OPERAND_TYPE_ID,
266                std::initializer_list<uint32_t>{l->result_id()});
267     compositeConstruct->AddOperand(std::move(op));
268   }
269   where = where.InsertBefore(std::move(compositeConstruct));
270   get_def_use_mgr()->AnalyzeInstDefUse(&*where);
271   where->UpdateDebugInfoFrom(load);
272   context()->set_instr_block(&*where, block);
273   context()->ReplaceAllUsesWith(load->result_id(), compositeId);
274   return true;
275 }
276 
ReplaceWholeStore(Instruction * store,const std::vector<Instruction * > & replacements)277 bool ScalarReplacementPass::ReplaceWholeStore(
278     Instruction* store, const std::vector<Instruction*>& replacements) {
279   // Replaces a store to the whole composite with a series of extract and stores
280   // to each element.
281   uint32_t storeInput = store->GetSingleWordInOperand(1u);
282   BasicBlock* block = context()->get_instr_block(store);
283   BasicBlock::iterator where(store);
284   uint32_t elementIndex = 0;
285   for (auto var : replacements) {
286     // Create the extract.
287     if (var->opcode() != SpvOpVariable) {
288       elementIndex++;
289       continue;
290     }
291 
292     Instruction* type = GetStorageType(var);
293     uint32_t extractId = TakeNextId();
294     if (extractId == 0) {
295       return false;
296     }
297     std::unique_ptr<Instruction> extract(new Instruction(
298         context(), SpvOpCompositeExtract, type->result_id(), extractId,
299         std::initializer_list<Operand>{
300             {SPV_OPERAND_TYPE_ID, {storeInput}},
301             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
302     auto iter = where.InsertBefore(std::move(extract));
303     iter->UpdateDebugInfoFrom(store);
304     get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
305     context()->set_instr_block(&*iter, block);
306 
307     // Create the store.
308     std::unique_ptr<Instruction> newStore(
309         new Instruction(context(), SpvOpStore, 0, 0,
310                         std::initializer_list<Operand>{
311                             {SPV_OPERAND_TYPE_ID, {var->result_id()}},
312                             {SPV_OPERAND_TYPE_ID, {extractId}}}));
313     // Copy memory access attributes which start at index 2. Index 0 is the
314     // pointer and index 1 is the data.
315     for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
316       Operand copy(store->GetInOperand(i));
317       newStore->AddOperand(std::move(copy));
318     }
319     iter = where.InsertBefore(std::move(newStore));
320     iter->UpdateDebugInfoFrom(store);
321     get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
322     context()->set_instr_block(&*iter, block);
323   }
324   return true;
325 }
326 
ReplaceAccessChain(Instruction * chain,const std::vector<Instruction * > & replacements)327 bool ScalarReplacementPass::ReplaceAccessChain(
328     Instruction* chain, const std::vector<Instruction*>& replacements) {
329   // Replaces the access chain with either another access chain (with one fewer
330   // indexes) or a direct use of the replacement variable.
331   uint32_t indexId = chain->GetSingleWordInOperand(1u);
332   const Instruction* index = get_def_use_mgr()->GetDef(indexId);
333   int64_t indexValue = context()
334                            ->get_constant_mgr()
335                            ->GetConstantFromInst(index)
336                            ->GetSignExtendedValue();
337   if (indexValue < 0 ||
338       indexValue >= static_cast<int64_t>(replacements.size())) {
339     // Out of bounds access, this is illegal IR.  Notice that OpAccessChain
340     // indexing is 0-based, so we should also reject index == size-of-array.
341     return false;
342   } else {
343     const Instruction* var = replacements[static_cast<size_t>(indexValue)];
344     if (chain->NumInOperands() > 2) {
345       // Replace input access chain with another access chain.
346       BasicBlock::iterator chainIter(chain);
347       uint32_t replacementId = TakeNextId();
348       if (replacementId == 0) {
349         return false;
350       }
351       std::unique_ptr<Instruction> replacementChain(new Instruction(
352           context(), chain->opcode(), chain->type_id(), replacementId,
353           std::initializer_list<Operand>{
354               {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
355       // Add the remaining indexes.
356       for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
357         Operand copy(chain->GetInOperand(i));
358         replacementChain->AddOperand(std::move(copy));
359       }
360       replacementChain->UpdateDebugInfoFrom(chain);
361       auto iter = chainIter.InsertBefore(std::move(replacementChain));
362       get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
363       context()->set_instr_block(&*iter, context()->get_instr_block(chain));
364       context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
365     } else {
366       // Replace with a use of the variable.
367       context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
368     }
369   }
370 
371   return true;
372 }
373 
CreateReplacementVariables(Instruction * inst,std::vector<Instruction * > * replacements)374 bool ScalarReplacementPass::CreateReplacementVariables(
375     Instruction* inst, std::vector<Instruction*>* replacements) {
376   Instruction* type = GetStorageType(inst);
377 
378   std::unique_ptr<std::unordered_set<int64_t>> components_used =
379       GetUsedComponents(inst);
380 
381   uint32_t elem = 0;
382   switch (type->opcode()) {
383     case SpvOpTypeStruct:
384       type->ForEachInOperand(
385           [this, inst, &elem, replacements, &components_used](uint32_t* id) {
386             if (!components_used || components_used->count(elem)) {
387               CreateVariable(*id, inst, elem, replacements);
388             } else {
389               replacements->push_back(CreateNullConstant(*id));
390             }
391             elem++;
392           });
393       break;
394     case SpvOpTypeArray:
395       for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
396         if (!components_used || components_used->count(i)) {
397           CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
398                          replacements);
399         } else {
400           replacements->push_back(
401               CreateNullConstant(type->GetSingleWordInOperand(0u)));
402         }
403       }
404       break;
405 
406     case SpvOpTypeMatrix:
407     case SpvOpTypeVector:
408       for (uint32_t i = 0; i != GetNumElements(type); ++i) {
409         CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
410       }
411       break;
412 
413     default:
414       assert(false && "Unexpected type.");
415       break;
416   }
417 
418   TransferAnnotations(inst, replacements);
419   return std::find(replacements->begin(), replacements->end(), nullptr) ==
420          replacements->end();
421 }
422 
TransferAnnotations(const Instruction * source,std::vector<Instruction * > * replacements)423 void ScalarReplacementPass::TransferAnnotations(
424     const Instruction* source, std::vector<Instruction*>* replacements) {
425   // Only transfer invariant and restrict decorations on the variable. There are
426   // no type or member decorations that are necessary to transfer.
427   for (auto inst :
428        get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
429     assert(inst->opcode() == SpvOpDecorate);
430     uint32_t decoration = inst->GetSingleWordInOperand(1u);
431     if (decoration == SpvDecorationInvariant ||
432         decoration == SpvDecorationRestrict) {
433       for (auto var : *replacements) {
434         if (var == nullptr) {
435           continue;
436         }
437 
438         std::unique_ptr<Instruction> annotation(
439             new Instruction(context(), SpvOpDecorate, 0, 0,
440                             std::initializer_list<Operand>{
441                                 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
442                                 {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
443         for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
444           Operand copy(inst->GetInOperand(i));
445           annotation->AddOperand(std::move(copy));
446         }
447         context()->AddAnnotationInst(std::move(annotation));
448         get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
449       }
450     }
451   }
452 }
453 
CreateVariable(uint32_t typeId,Instruction * varInst,uint32_t index,std::vector<Instruction * > * replacements)454 void ScalarReplacementPass::CreateVariable(
455     uint32_t typeId, Instruction* varInst, uint32_t index,
456     std::vector<Instruction*>* replacements) {
457   uint32_t ptrId = GetOrCreatePointerType(typeId);
458   uint32_t id = TakeNextId();
459 
460   if (id == 0) {
461     replacements->push_back(nullptr);
462   }
463 
464   std::unique_ptr<Instruction> variable(new Instruction(
465       context(), SpvOpVariable, ptrId, id,
466       std::initializer_list<Operand>{
467           {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
468 
469   BasicBlock* block = context()->get_instr_block(varInst);
470   block->begin().InsertBefore(std::move(variable));
471   Instruction* inst = &*block->begin();
472 
473   // If varInst was initialized, make sure to initialize its replacement.
474   GetOrCreateInitialValue(varInst, index, inst);
475   get_def_use_mgr()->AnalyzeInstDefUse(inst);
476   context()->set_instr_block(inst, block);
477 
478   // Copy decorations from the member to the new variable.
479   Instruction* typeInst = GetStorageType(varInst);
480   for (auto dec_inst :
481        get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
482     uint32_t decoration;
483     if (dec_inst->opcode() != SpvOpMemberDecorate) {
484       continue;
485     }
486 
487     if (dec_inst->GetSingleWordInOperand(1) != index) {
488       continue;
489     }
490 
491     decoration = dec_inst->GetSingleWordInOperand(2u);
492     switch (decoration) {
493       case SpvDecorationRelaxedPrecision: {
494         std::unique_ptr<Instruction> new_dec_inst(
495             new Instruction(context(), SpvOpDecorate, 0, 0, {}));
496         new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
497         for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
498           new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
499         }
500         context()->AddAnnotationInst(std::move(new_dec_inst));
501       } break;
502       default:
503         break;
504     }
505   }
506 
507   // Update the OpenCL.DebugInfo.100 debug information.
508   inst->UpdateDebugInfoFrom(varInst);
509 
510   replacements->push_back(inst);
511 }
512 
GetOrCreatePointerType(uint32_t id)513 uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
514   auto iter = pointee_to_pointer_.find(id);
515   if (iter != pointee_to_pointer_.end()) return iter->second;
516 
517   analysis::Type* pointeeTy;
518   std::unique_ptr<analysis::Pointer> pointerTy;
519   std::tie(pointeeTy, pointerTy) =
520       context()->get_type_mgr()->GetTypeAndPointerType(id,
521                                                        SpvStorageClassFunction);
522   uint32_t ptrId = 0;
523   if (pointeeTy->IsUniqueType()) {
524     // Non-ambiguous type, just ask the type manager for an id.
525     ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
526     pointee_to_pointer_[id] = ptrId;
527     return ptrId;
528   }
529 
530   // Ambiguous type. We must perform a linear search to try and find the right
531   // type.
532   for (auto global : context()->types_values()) {
533     if (global.opcode() == SpvOpTypePointer &&
534         global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
535         global.GetSingleWordInOperand(1u) == id) {
536       if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
537         // Only reuse a decoration-less pointer of the correct type.
538         ptrId = global.result_id();
539         break;
540       }
541     }
542   }
543 
544   if (ptrId != 0) {
545     pointee_to_pointer_[id] = ptrId;
546     return ptrId;
547   }
548 
549   ptrId = TakeNextId();
550   context()->AddType(MakeUnique<Instruction>(
551       context(), SpvOpTypePointer, 0, ptrId,
552       std::initializer_list<Operand>{
553           {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
554           {SPV_OPERAND_TYPE_ID, {id}}}));
555   Instruction* ptr = &*--context()->types_values_end();
556   get_def_use_mgr()->AnalyzeInstDefUse(ptr);
557   pointee_to_pointer_[id] = ptrId;
558   // Register with the type manager if necessary.
559   context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
560 
561   return ptrId;
562 }
563 
GetOrCreateInitialValue(Instruction * source,uint32_t index,Instruction * newVar)564 void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
565                                                     uint32_t index,
566                                                     Instruction* newVar) {
567   assert(source->opcode() == SpvOpVariable);
568   if (source->NumInOperands() < 2) return;
569 
570   uint32_t initId = source->GetSingleWordInOperand(1u);
571   uint32_t storageId = GetStorageType(newVar)->result_id();
572   Instruction* init = get_def_use_mgr()->GetDef(initId);
573   uint32_t newInitId = 0;
574   // TODO(dnovillo): Refactor this with constant propagation.
575   if (init->opcode() == SpvOpConstantNull) {
576     // Initialize to appropriate NULL.
577     auto iter = type_to_null_.find(storageId);
578     if (iter == type_to_null_.end()) {
579       newInitId = TakeNextId();
580       type_to_null_[storageId] = newInitId;
581       context()->AddGlobalValue(
582           MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
583                                   newInitId, std::initializer_list<Operand>{}));
584       Instruction* newNull = &*--context()->types_values_end();
585       get_def_use_mgr()->AnalyzeInstDefUse(newNull);
586     } else {
587       newInitId = iter->second;
588     }
589   } else if (IsSpecConstantInst(init->opcode())) {
590     // Create a new constant extract.
591     newInitId = TakeNextId();
592     context()->AddGlobalValue(MakeUnique<Instruction>(
593         context(), SpvOpSpecConstantOp, storageId, newInitId,
594         std::initializer_list<Operand>{
595             {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
596             {SPV_OPERAND_TYPE_ID, {init->result_id()}},
597             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
598     Instruction* newSpecConst = &*--context()->types_values_end();
599     get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
600   } else if (init->opcode() == SpvOpConstantComposite) {
601     // Get the appropriate index constant.
602     newInitId = init->GetSingleWordInOperand(index);
603     Instruction* element = get_def_use_mgr()->GetDef(newInitId);
604     if (element->opcode() == SpvOpUndef) {
605       // Undef is not a valid initializer for a variable.
606       newInitId = 0;
607     }
608   } else {
609     assert(false);
610   }
611 
612   if (newInitId != 0) {
613     newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
614   }
615 }
616 
GetArrayLength(const Instruction * arrayType) const617 uint64_t ScalarReplacementPass::GetArrayLength(
618     const Instruction* arrayType) const {
619   assert(arrayType->opcode() == SpvOpTypeArray);
620   const Instruction* length =
621       get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
622   return context()
623       ->get_constant_mgr()
624       ->GetConstantFromInst(length)
625       ->GetZeroExtendedValue();
626 }
627 
GetNumElements(const Instruction * type) const628 uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
629   assert(type->opcode() == SpvOpTypeVector ||
630          type->opcode() == SpvOpTypeMatrix);
631   const Operand& op = type->GetInOperand(1u);
632   assert(op.words.size() <= 2);
633   uint64_t len = 0;
634   for (size_t i = 0; i != op.words.size(); ++i) {
635     len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
636   }
637   return len;
638 }
639 
IsSpecConstant(uint32_t id) const640 bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
641   const Instruction* inst = get_def_use_mgr()->GetDef(id);
642   assert(inst);
643   return spvOpcodeIsSpecConstant(inst->opcode());
644 }
645 
GetStorageType(const Instruction * inst) const646 Instruction* ScalarReplacementPass::GetStorageType(
647     const Instruction* inst) const {
648   assert(inst->opcode() == SpvOpVariable);
649 
650   uint32_t ptrTypeId = inst->type_id();
651   uint32_t typeId =
652       get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
653   return get_def_use_mgr()->GetDef(typeId);
654 }
655 
CanReplaceVariable(const Instruction * varInst) const656 bool ScalarReplacementPass::CanReplaceVariable(
657     const Instruction* varInst) const {
658   assert(varInst->opcode() == SpvOpVariable);
659 
660   // Can only replace function scope variables.
661   if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
662     return false;
663   }
664 
665   if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
666     return false;
667   }
668 
669   const Instruction* typeInst = GetStorageType(varInst);
670   if (!CheckType(typeInst)) {
671     return false;
672   }
673 
674   if (!CheckAnnotations(varInst)) {
675     return false;
676   }
677 
678   if (!CheckUses(varInst)) {
679     return false;
680   }
681 
682   return true;
683 }
684 
CheckType(const Instruction * typeInst) const685 bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
686   if (!CheckTypeAnnotations(typeInst)) {
687     return false;
688   }
689 
690   switch (typeInst->opcode()) {
691     case SpvOpTypeStruct:
692       // Don't bother with empty structs or very large structs.
693       if (typeInst->NumInOperands() == 0 ||
694           IsLargerThanSizeLimit(typeInst->NumInOperands())) {
695         return false;
696       }
697       return true;
698     case SpvOpTypeArray:
699       if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
700         return false;
701       }
702       if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
703         return false;
704       }
705       return true;
706       // TODO(alanbaker): Develop some heuristics for when this should be
707       // re-enabled.
708       //// Specifically including matrix and vector in an attempt to reduce the
709       //// number of vector registers required.
710       // case SpvOpTypeMatrix:
711       // case SpvOpTypeVector:
712       //  if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
713       //  return true;
714 
715     case SpvOpTypeRuntimeArray:
716     default:
717       return false;
718   }
719 }
720 
CheckTypeAnnotations(const Instruction * typeInst) const721 bool ScalarReplacementPass::CheckTypeAnnotations(
722     const Instruction* typeInst) const {
723   for (auto inst :
724        get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
725     uint32_t decoration;
726     if (inst->opcode() == SpvOpDecorate) {
727       decoration = inst->GetSingleWordInOperand(1u);
728     } else {
729       assert(inst->opcode() == SpvOpMemberDecorate);
730       decoration = inst->GetSingleWordInOperand(2u);
731     }
732 
733     switch (decoration) {
734       case SpvDecorationRowMajor:
735       case SpvDecorationColMajor:
736       case SpvDecorationArrayStride:
737       case SpvDecorationMatrixStride:
738       case SpvDecorationCPacked:
739       case SpvDecorationInvariant:
740       case SpvDecorationRestrict:
741       case SpvDecorationOffset:
742       case SpvDecorationAlignment:
743       case SpvDecorationAlignmentId:
744       case SpvDecorationMaxByteOffset:
745       case SpvDecorationRelaxedPrecision:
746         break;
747       default:
748         return false;
749     }
750   }
751 
752   return true;
753 }
754 
CheckAnnotations(const Instruction * varInst) const755 bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
756   for (auto inst :
757        get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
758     assert(inst->opcode() == SpvOpDecorate);
759     uint32_t decoration = inst->GetSingleWordInOperand(1u);
760     switch (decoration) {
761       case SpvDecorationInvariant:
762       case SpvDecorationRestrict:
763       case SpvDecorationAlignment:
764       case SpvDecorationAlignmentId:
765       case SpvDecorationMaxByteOffset:
766         break;
767       default:
768         return false;
769     }
770   }
771 
772   return true;
773 }
774 
CheckUses(const Instruction * inst) const775 bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
776   VariableStats stats = {0, 0};
777   bool ok = CheckUses(inst, &stats);
778 
779   // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
780   // SRoA is costly, such as when the structure has many (unaccessed?)
781   // members.
782 
783   return ok;
784 }
785 
CheckUses(const Instruction * inst,VariableStats * stats) const786 bool ScalarReplacementPass::CheckUses(const Instruction* inst,
787                                       VariableStats* stats) const {
788   uint64_t max_legal_index = GetMaxLegalIndex(inst);
789 
790   bool ok = true;
791   get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
792                                           const Instruction* user,
793                                           uint32_t index) {
794     if (user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugDeclare ||
795         user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugValue) {
796       // TODO: include num_partial_accesses if it uses Fragment operation or
797       // DebugValue has Indexes operand.
798       stats->num_full_accesses++;
799       return;
800     }
801 
802     // Annotations are check as a group separately.
803     if (!IsAnnotationInst(user->opcode())) {
804       switch (user->opcode()) {
805         case SpvOpAccessChain:
806         case SpvOpInBoundsAccessChain:
807           if (index == 2u && user->NumInOperands() > 1) {
808             uint32_t id = user->GetSingleWordInOperand(1u);
809             const Instruction* opInst = get_def_use_mgr()->GetDef(id);
810             const auto* constant =
811                 context()->get_constant_mgr()->GetConstantFromInst(opInst);
812             if (!constant) {
813               ok = false;
814             } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
815               ok = false;
816             } else {
817               if (!CheckUsesRelaxed(user)) ok = false;
818             }
819             stats->num_partial_accesses++;
820           } else {
821             ok = false;
822           }
823           break;
824         case SpvOpLoad:
825           if (!CheckLoad(user, index)) ok = false;
826           stats->num_full_accesses++;
827           break;
828         case SpvOpStore:
829           if (!CheckStore(user, index)) ok = false;
830           stats->num_full_accesses++;
831           break;
832         case SpvOpName:
833         case SpvOpMemberName:
834           break;
835         default:
836           ok = false;
837           break;
838       }
839     }
840   });
841 
842   return ok;
843 }
844 
CheckUsesRelaxed(const Instruction * inst) const845 bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
846   bool ok = true;
847   get_def_use_mgr()->ForEachUse(
848       inst, [this, &ok](const Instruction* user, uint32_t index) {
849         switch (user->opcode()) {
850           case SpvOpAccessChain:
851           case SpvOpInBoundsAccessChain:
852             if (index != 2u) {
853               ok = false;
854             } else {
855               if (!CheckUsesRelaxed(user)) ok = false;
856             }
857             break;
858           case SpvOpLoad:
859             if (!CheckLoad(user, index)) ok = false;
860             break;
861           case SpvOpStore:
862             if (!CheckStore(user, index)) ok = false;
863             break;
864           case SpvOpImageTexelPointer:
865             if (!CheckImageTexelPointer(index)) ok = false;
866             break;
867           default:
868             ok = false;
869             break;
870         }
871       });
872 
873   return ok;
874 }
875 
CheckImageTexelPointer(uint32_t index) const876 bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
877   return index == 2u;
878 }
879 
CheckLoad(const Instruction * inst,uint32_t index) const880 bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
881                                       uint32_t index) const {
882   if (index != 2u) return false;
883   if (inst->NumInOperands() >= 2 &&
884       inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
885     return false;
886   return true;
887 }
888 
CheckStore(const Instruction * inst,uint32_t index) const889 bool ScalarReplacementPass::CheckStore(const Instruction* inst,
890                                        uint32_t index) const {
891   if (index != 0u) return false;
892   if (inst->NumInOperands() >= 3 &&
893       inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
894     return false;
895   return true;
896 }
IsLargerThanSizeLimit(uint64_t length) const897 bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
898   if (max_num_elements_ == 0) {
899     return false;
900   }
901   return length > max_num_elements_;
902 }
903 
904 std::unique_ptr<std::unordered_set<int64_t>>
GetUsedComponents(Instruction * inst)905 ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
906   std::unique_ptr<std::unordered_set<int64_t>> result(
907       new std::unordered_set<int64_t>());
908 
909   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
910 
911   def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
912                                     this](Instruction* use) {
913     switch (use->opcode()) {
914       case SpvOpLoad: {
915         // Look for extract from the load.
916         std::vector<uint32_t> t;
917         if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
918               if (use2->opcode() != SpvOpCompositeExtract ||
919                   use2->NumInOperands() <= 1) {
920                 return false;
921               }
922               t.push_back(use2->GetSingleWordInOperand(1));
923               return true;
924             })) {
925           result->insert(t.begin(), t.end());
926           return true;
927         } else {
928           result.reset(nullptr);
929           return false;
930         }
931       }
932       case SpvOpName:
933       case SpvOpMemberName:
934       case SpvOpStore:
935         // No components are used.
936         return true;
937       case SpvOpAccessChain:
938       case SpvOpInBoundsAccessChain: {
939         // Add the first index it if is a constant.
940         // TODO: Could be improved by checking if the address is used in a load.
941         analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
942         uint32_t index_id = use->GetSingleWordInOperand(1);
943         const analysis::Constant* index_const =
944             const_mgr->FindDeclaredConstant(index_id);
945         if (index_const) {
946           result->insert(index_const->GetSignExtendedValue());
947           return true;
948         } else {
949           // Could be any element.  Assuming all are used.
950           result.reset(nullptr);
951           return false;
952         }
953       }
954       default:
955         // We do not know what is happening.  Have to assume the worst.
956         result.reset(nullptr);
957         return false;
958     }
959   });
960 
961   return result;
962 }
963 
CreateNullConstant(uint32_t type_id)964 Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
965   analysis::TypeManager* type_mgr = context()->get_type_mgr();
966   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
967 
968   const analysis::Type* type = type_mgr->GetType(type_id);
969   const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
970   Instruction* null_inst =
971       const_mgr->GetDefiningInstruction(null_const, type_id);
972   if (null_inst != nullptr) {
973     context()->UpdateDefUse(null_inst);
974   }
975   return null_inst;
976 }
977 
GetMaxLegalIndex(const Instruction * var_inst) const978 uint64_t ScalarReplacementPass::GetMaxLegalIndex(
979     const Instruction* var_inst) const {
980   assert(var_inst->opcode() == SpvOpVariable &&
981          "|var_inst| must be a variable instruction.");
982   Instruction* type = GetStorageType(var_inst);
983   switch (type->opcode()) {
984     case SpvOpTypeStruct:
985       return type->NumInOperands();
986     case SpvOpTypeArray:
987       return GetArrayLength(type);
988     case SpvOpTypeMatrix:
989     case SpvOpTypeVector:
990       return GetNumElements(type);
991     default:
992       return 0;
993   }
994   return 0;
995 }
996 
997 }  // namespace opt
998 }  // namespace spvtools
999