• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/copy_prop_arrays.h"
16 
17 #include <utility>
18 
19 #include "source/opt/ir_builder.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 
25 const uint32_t kLoadPointerInOperand = 0;
26 const uint32_t kStorePointerInOperand = 0;
27 const uint32_t kStoreObjectInOperand = 1;
28 const uint32_t kCompositeExtractObjectInOperand = 0;
29 const uint32_t kTypePointerStorageClassInIdx = 0;
30 const uint32_t kTypePointerPointeeInIdx = 1;
31 
IsDebugDeclareOrValue(Instruction * di)32 bool IsDebugDeclareOrValue(Instruction* di) {
33   auto dbg_opcode = di->GetCommonDebugOpcode();
34   return dbg_opcode == CommonDebugInfoDebugDeclare ||
35          dbg_opcode == CommonDebugInfoDebugValue;
36 }
37 
38 }  // namespace
39 
Process()40 Pass::Status CopyPropagateArrays::Process() {
41   bool modified = false;
42   for (Function& function : *get_module()) {
43     if (function.IsDeclaration()) {
44       continue;
45     }
46 
47     BasicBlock* entry_bb = &*function.begin();
48 
49     for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
50          ++var_inst) {
51       if (!IsPointerToArrayType(var_inst->type_id())) {
52         continue;
53       }
54 
55       // Find the only store to the entire memory location, if it exists.
56       Instruction* store_inst = FindStoreInstruction(&*var_inst);
57 
58       if (!store_inst) {
59         continue;
60       }
61 
62       std::unique_ptr<MemoryObject> source_object =
63           FindSourceObjectIfPossible(&*var_inst, store_inst);
64 
65       if (source_object != nullptr) {
66         if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
67           modified = true;
68           PropagateObject(&*var_inst, source_object.get(), store_inst);
69         }
70       }
71     }
72   }
73   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
74 }
75 
76 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)77 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
78                                                 Instruction* store_inst) {
79   assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
80 
81   // Check that the variable is a composite object where |store_inst|
82   // dominates all of its loads.
83   if (!store_inst) {
84     return nullptr;
85   }
86 
87   // Look at the loads to ensure they are dominated by the store.
88   if (!HasValidReferencesOnly(var_inst, store_inst)) {
89     return nullptr;
90   }
91 
92   // If so, look at the store to see if it is the copy of an object.
93   std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
94       store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
95 
96   if (!source) {
97     return nullptr;
98   }
99 
100   // Ensure that |source| does not change between the point at which it is
101   // loaded, and the position in which |var_inst| is loaded.
102   //
103   // For now we will go with the easy to implement approach, and check that the
104   // entire variable (not just the specific component) is never written to.
105 
106   if (!HasNoStores(source->GetVariable())) {
107     return nullptr;
108   }
109   return source;
110 }
111 
FindStoreInstruction(const Instruction * var_inst) const112 Instruction* CopyPropagateArrays::FindStoreInstruction(
113     const Instruction* var_inst) const {
114   Instruction* store_inst = nullptr;
115   get_def_use_mgr()->WhileEachUser(
116       var_inst, [&store_inst, var_inst](Instruction* use) {
117         if (use->opcode() == SpvOpStore &&
118             use->GetSingleWordInOperand(kStorePointerInOperand) ==
119                 var_inst->result_id()) {
120           if (store_inst == nullptr) {
121             store_inst = use;
122           } else {
123             store_inst = nullptr;
124             return false;
125           }
126         }
127         return true;
128       });
129   return store_inst;
130 }
131 
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)132 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
133                                           MemoryObject* source,
134                                           Instruction* insertion_point) {
135   assert(var_inst->opcode() == SpvOpVariable &&
136          "This function propagates variables.");
137 
138   Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
139   context()->KillNamesAndDecorates(var_inst);
140   UpdateUses(var_inst, new_access_chain);
141 }
142 
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const143 Instruction* CopyPropagateArrays::BuildNewAccessChain(
144     Instruction* insertion_point,
145     CopyPropagateArrays::MemoryObject* source) const {
146   InstructionBuilder builder(
147       context(), insertion_point,
148       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
149 
150   if (source->AccessChain().size() == 0) {
151     return source->GetVariable();
152   }
153 
154   return builder.AddAccessChain(source->GetPointerTypeId(this),
155                                 source->GetVariable()->result_id(),
156                                 source->AccessChain());
157 }
158 
HasNoStores(Instruction * ptr_inst)159 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
160   return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
161     if (use->opcode() == SpvOpLoad) {
162       return true;
163     } else if (use->opcode() == SpvOpAccessChain) {
164       return HasNoStores(use);
165     } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
166       return true;
167     } else if (use->opcode() == SpvOpStore) {
168       return false;
169     } else if (use->opcode() == SpvOpImageTexelPointer) {
170       return true;
171     }
172     // Some other instruction.  Be conservative.
173     return false;
174   });
175 }
176 
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)177 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
178                                                  Instruction* store_inst) {
179   BasicBlock* store_block = context()->get_instr_block(store_inst);
180   DominatorAnalysis* dominator_analysis =
181       context()->GetDominatorAnalysis(store_block->GetParent());
182 
183   return get_def_use_mgr()->WhileEachUser(
184       ptr_inst,
185       [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
186         if (use->opcode() == SpvOpLoad ||
187             use->opcode() == SpvOpImageTexelPointer) {
188           // TODO: If there are many load in the same BB as |store_inst| the
189           // time to do the multiple traverses can add up.  Consider collecting
190           // those loads and doing a single traversal.
191           return dominator_analysis->Dominates(store_inst, use);
192         } else if (use->opcode() == SpvOpAccessChain) {
193           return HasValidReferencesOnly(use, store_inst);
194         } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
195           return true;
196         } else if (use->opcode() == SpvOpStore) {
197           // If we are storing to part of the object it is not an candidate.
198           return ptr_inst->opcode() == SpvOpVariable &&
199                  store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
200                      ptr_inst->result_id();
201         } else if (IsDebugDeclareOrValue(use)) {
202           return true;
203         }
204         // Some other instruction.  Be conservative.
205         return false;
206       });
207 }
208 
209 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)210 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
211   Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
212 
213   switch (result_inst->opcode()) {
214     case SpvOpLoad:
215       return BuildMemoryObjectFromLoad(result_inst);
216     case SpvOpCompositeExtract:
217       return BuildMemoryObjectFromExtract(result_inst);
218     case SpvOpCompositeConstruct:
219       return BuildMemoryObjectFromCompositeConstruct(result_inst);
220     case SpvOpCopyObject:
221       return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
222     case SpvOpCompositeInsert:
223       return BuildMemoryObjectFromInsert(result_inst);
224     default:
225       return nullptr;
226   }
227 }
228 
229 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)230 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
231   std::vector<uint32_t> components_in_reverse;
232   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
233 
234   Instruction* current_inst = def_use_mgr->GetDef(
235       load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
236 
237   // Build the access chain for the memory object by collecting the indices used
238   // in the OpAccessChain instructions.  If we find a variable index, then
239   // return |nullptr| because we cannot know for sure which memory location is
240   // used.
241   //
242   // It is built in reverse order because the different |OpAccessChain|
243   // instructions are visited in reverse order from which they are applied.
244   while (current_inst->opcode() == SpvOpAccessChain) {
245     for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
246       uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
247       components_in_reverse.push_back(element_index_id);
248     }
249     current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
250   }
251 
252   // If the address in the load is not constructed from an |OpVariable|
253   // instruction followed by a series of |OpAccessChain| instructions, then
254   // return |nullptr| because we cannot identify the owner or access chain
255   // exactly.
256   if (current_inst->opcode() != SpvOpVariable) {
257     return nullptr;
258   }
259 
260   // Build the memory object.  Use |rbegin| and |rend| to put the access chain
261   // back in the correct order.
262   return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
263       new MemoryObject(current_inst, components_in_reverse.rbegin(),
264                        components_in_reverse.rend()));
265 }
266 
267 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)268 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
269   assert(extract_inst->opcode() == SpvOpCompositeExtract &&
270          "Expecting an OpCompositeExtract instruction.");
271   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
272 
273   std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
274       extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
275 
276   if (result) {
277     analysis::Integer int_type(32, false);
278     const analysis::Type* uint32_type =
279         context()->get_type_mgr()->GetRegisteredType(&int_type);
280 
281     std::vector<uint32_t> components;
282     // Convert the indices in the extract instruction to a series of ids that
283     // can be used by the |OpAccessChain| instruction.
284     for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
285       uint32_t index = extract_inst->GetSingleWordInOperand(i);
286       const analysis::Constant* index_const =
287           const_mgr->GetConstant(uint32_type, {index});
288       components.push_back(
289           const_mgr->GetDefiningInstruction(index_const)->result_id());
290     }
291     result->GetMember(components);
292     return result;
293   }
294   return nullptr;
295 }
296 
297 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)298 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
299     Instruction* conststruct_inst) {
300   assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
301          "Expecting an OpCompositeConstruct instruction.");
302 
303   // If every operand in the instruction are part of the same memory object, and
304   // are being combined in the same order, then the result is the same as the
305   // parent.
306 
307   std::unique_ptr<MemoryObject> memory_object =
308       GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
309 
310   if (!memory_object) {
311     return nullptr;
312   }
313 
314   if (!memory_object->IsMember()) {
315     return nullptr;
316   }
317 
318   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
319   const analysis::Constant* last_access =
320       const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
321   if (!last_access || !last_access->type()->AsInteger()) {
322     return nullptr;
323   }
324 
325   if (last_access->GetU32() != 0) {
326     return nullptr;
327   }
328 
329   memory_object->GetParent();
330 
331   if (memory_object->GetNumberOfMembers() !=
332       conststruct_inst->NumInOperands()) {
333     return nullptr;
334   }
335 
336   for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
337     std::unique_ptr<MemoryObject> member_object =
338         GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
339 
340     if (!member_object) {
341       return nullptr;
342     }
343 
344     if (!member_object->IsMember()) {
345       return nullptr;
346     }
347 
348     if (!memory_object->Contains(member_object.get())) {
349       return nullptr;
350     }
351 
352     last_access =
353         const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
354     if (!last_access || !last_access->type()->AsInteger()) {
355       return nullptr;
356     }
357 
358     if (last_access->GetU32() != i) {
359       return nullptr;
360     }
361   }
362   return memory_object;
363 }
364 
365 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)366 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
367   assert(insert_inst->opcode() == SpvOpCompositeInsert &&
368          "Expecting an OpCompositeInsert instruction.");
369 
370   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
371   analysis::TypeManager* type_mgr = context()->get_type_mgr();
372   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
373   const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
374 
375   uint32_t number_of_elements = 0;
376   if (const analysis::Struct* struct_type = result_type->AsStruct()) {
377     number_of_elements =
378         static_cast<uint32_t>(struct_type->element_types().size());
379   } else if (const analysis::Array* array_type = result_type->AsArray()) {
380     const analysis::Constant* length_const =
381         const_mgr->FindDeclaredConstant(array_type->LengthId());
382     number_of_elements = length_const->GetU32();
383   } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
384     number_of_elements = vector_type->element_count();
385   } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
386     number_of_elements = matrix_type->element_count();
387   }
388 
389   if (number_of_elements == 0) {
390     return nullptr;
391   }
392 
393   if (insert_inst->NumInOperands() != 3) {
394     return nullptr;
395   }
396 
397   if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
398     return nullptr;
399   }
400 
401   std::unique_ptr<MemoryObject> memory_object =
402       GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
403 
404   if (!memory_object) {
405     return nullptr;
406   }
407 
408   if (!memory_object->IsMember()) {
409     return nullptr;
410   }
411 
412   const analysis::Constant* last_access =
413       const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
414   if (!last_access || !last_access->type()->AsInteger()) {
415     return nullptr;
416   }
417 
418   if (last_access->GetU32() != number_of_elements - 1) {
419     return nullptr;
420   }
421 
422   memory_object->GetParent();
423 
424   Instruction* current_insert =
425       def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
426   for (uint32_t i = number_of_elements - 1; i > 0; --i) {
427     if (current_insert->opcode() != SpvOpCompositeInsert) {
428       return nullptr;
429     }
430 
431     if (current_insert->NumInOperands() != 3) {
432       return nullptr;
433     }
434 
435     if (current_insert->GetSingleWordInOperand(2) != i - 1) {
436       return nullptr;
437     }
438 
439     std::unique_ptr<MemoryObject> current_memory_object =
440         GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
441 
442     if (!current_memory_object) {
443       return nullptr;
444     }
445 
446     if (!current_memory_object->IsMember()) {
447       return nullptr;
448     }
449 
450     if (memory_object->AccessChain().size() + 1 !=
451         current_memory_object->AccessChain().size()) {
452       return nullptr;
453     }
454 
455     if (!memory_object->Contains(current_memory_object.get())) {
456       return nullptr;
457     }
458 
459     const analysis::Constant* current_last_access =
460         const_mgr->FindDeclaredConstant(
461             current_memory_object->AccessChain().back());
462     if (!current_last_access || !current_last_access->type()->AsInteger()) {
463       return nullptr;
464     }
465 
466     if (current_last_access->GetU32() != i - 1) {
467       return nullptr;
468     }
469     current_insert =
470         def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
471   }
472 
473   return memory_object;
474 }
475 
IsPointerToArrayType(uint32_t type_id)476 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
477   analysis::TypeManager* type_mgr = context()->get_type_mgr();
478   analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
479   if (pointer_type) {
480     return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
481            pointer_type->pointee_type()->kind() == analysis::Type::kImage;
482   }
483   return false;
484 }
485 
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)486 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
487                                         uint32_t type_id) {
488   analysis::TypeManager* type_mgr = context()->get_type_mgr();
489   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
490   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
491 
492   analysis::Type* type = type_mgr->GetType(type_id);
493   if (type->AsRuntimeArray()) {
494     return false;
495   }
496 
497   if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
498     // If the type is not an aggregate, then the desired type must be the
499     // same as the current type.  No work to do, and we can do that.
500     return true;
501   }
502 
503   return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
504                                                        const_mgr,
505                                                        type](Instruction* use,
506                                                              uint32_t) {
507     if (IsDebugDeclareOrValue(use)) return true;
508 
509     switch (use->opcode()) {
510       case SpvOpLoad: {
511         analysis::Pointer* pointer_type = type->AsPointer();
512         uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
513 
514         if (new_type_id != use->type_id()) {
515           return CanUpdateUses(use, new_type_id);
516         }
517         return true;
518       }
519       case SpvOpAccessChain: {
520         analysis::Pointer* pointer_type = type->AsPointer();
521         const analysis::Type* pointee_type = pointer_type->pointee_type();
522 
523         std::vector<uint32_t> access_chain;
524         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
525           const analysis::Constant* index_const =
526               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
527           if (index_const) {
528             access_chain.push_back(index_const->GetU32());
529           } else {
530             // Variable index means the type is a type where every element
531             // is the same type.  Use element 0 to get the type.
532             access_chain.push_back(0);
533           }
534         }
535 
536         const analysis::Type* new_pointee_type =
537             type_mgr->GetMemberType(pointee_type, access_chain);
538         analysis::Pointer pointerTy(new_pointee_type,
539                                     pointer_type->storage_class());
540         uint32_t new_pointer_type_id =
541             context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
542         if (new_pointer_type_id == 0) {
543           return false;
544         }
545 
546         if (new_pointer_type_id != use->type_id()) {
547           return CanUpdateUses(use, new_pointer_type_id);
548         }
549         return true;
550       }
551       case SpvOpCompositeExtract: {
552         std::vector<uint32_t> access_chain;
553         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
554           access_chain.push_back(use->GetSingleWordInOperand(i));
555         }
556 
557         const analysis::Type* new_type =
558             type_mgr->GetMemberType(type, access_chain);
559         uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
560         if (new_type_id == 0) {
561           return false;
562         }
563 
564         if (new_type_id != use->type_id()) {
565           return CanUpdateUses(use, new_type_id);
566         }
567         return true;
568       }
569       case SpvOpStore:
570         // If needed, we can create an element-by-element copy to change the
571         // type of the value being stored.  This way we can always handled
572         // stores.
573         return true;
574       case SpvOpImageTexelPointer:
575       case SpvOpName:
576         return true;
577       default:
578         return use->IsDecoration();
579     }
580   });
581 }
582 
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)583 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
584                                      Instruction* new_ptr_inst) {
585   analysis::TypeManager* type_mgr = context()->get_type_mgr();
586   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
587   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
588 
589   std::vector<std::pair<Instruction*, uint32_t> > uses;
590   def_use_mgr->ForEachUse(original_ptr_inst,
591                           [&uses](Instruction* use, uint32_t index) {
592                             uses.push_back({use, index});
593                           });
594 
595   for (auto pair : uses) {
596     Instruction* use = pair.first;
597     uint32_t index = pair.second;
598 
599     if (use->IsCommonDebugInstr()) {
600       switch (use->GetCommonDebugOpcode()) {
601         case CommonDebugInfoDebugDeclare: {
602           if (new_ptr_inst->opcode() == SpvOpVariable ||
603               new_ptr_inst->opcode() == SpvOpFunctionParameter) {
604             context()->ForgetUses(use);
605             use->SetOperand(index, {new_ptr_inst->result_id()});
606             context()->AnalyzeUses(use);
607           } else {
608             // Based on the spec, we cannot use a pointer other than OpVariable
609             // or OpFunctionParameter for DebugDeclare. We have to use
610             // DebugValue with Deref.
611 
612             context()->ForgetUses(use);
613 
614             // Change DebugDeclare to DebugValue.
615             use->SetOperand(index - 2,
616                             {static_cast<uint32_t>(CommonDebugInfoDebugValue)});
617             use->SetOperand(index, {new_ptr_inst->result_id()});
618 
619             // Add Deref operation.
620             Instruction* dbg_expr =
621                 def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1));
622             auto* deref_expr_instr =
623                 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
624             use->SetOperand(index + 1, {deref_expr_instr->result_id()});
625 
626             context()->AnalyzeUses(deref_expr_instr);
627             context()->AnalyzeUses(use);
628           }
629           break;
630         }
631         case CommonDebugInfoDebugValue:
632           context()->ForgetUses(use);
633           use->SetOperand(index, {new_ptr_inst->result_id()});
634           context()->AnalyzeUses(use);
635           break;
636         default:
637           assert(false && "Don't know how to rewrite instruction");
638           break;
639       }
640       continue;
641     }
642 
643     switch (use->opcode()) {
644       case SpvOpLoad: {
645         // Replace the actual use.
646         context()->ForgetUses(use);
647         use->SetOperand(index, {new_ptr_inst->result_id()});
648 
649         // Update the type.
650         Instruction* pointer_type_inst =
651             def_use_mgr->GetDef(new_ptr_inst->type_id());
652         uint32_t new_type_id =
653             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
654         if (new_type_id != use->type_id()) {
655           use->SetResultType(new_type_id);
656           context()->AnalyzeUses(use);
657           UpdateUses(use, use);
658         } else {
659           context()->AnalyzeUses(use);
660         }
661       } break;
662       case SpvOpAccessChain: {
663         // Update the actual use.
664         context()->ForgetUses(use);
665         use->SetOperand(index, {new_ptr_inst->result_id()});
666 
667         // Convert the ids on the OpAccessChain to indices that can be used to
668         // get the specific member.
669         std::vector<uint32_t> access_chain;
670         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
671           const analysis::Constant* index_const =
672               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
673           if (index_const) {
674             access_chain.push_back(index_const->GetU32());
675           } else {
676             // Variable index means the type is an type where every element
677             // is the same type.  Use element 0 to get the type.
678             access_chain.push_back(0);
679           }
680         }
681 
682         Instruction* pointer_type_inst =
683             get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
684 
685         uint32_t new_pointee_type_id = GetMemberTypeId(
686             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
687             access_chain);
688 
689         SpvStorageClass storage_class = static_cast<SpvStorageClass>(
690             pointer_type_inst->GetSingleWordInOperand(
691                 kTypePointerStorageClassInIdx));
692 
693         uint32_t new_pointer_type_id =
694             type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
695 
696         if (new_pointer_type_id != use->type_id()) {
697           use->SetResultType(new_pointer_type_id);
698           context()->AnalyzeUses(use);
699           UpdateUses(use, use);
700         } else {
701           context()->AnalyzeUses(use);
702         }
703       } break;
704       case SpvOpCompositeExtract: {
705         // Update the actual use.
706         context()->ForgetUses(use);
707         use->SetOperand(index, {new_ptr_inst->result_id()});
708 
709         uint32_t new_type_id = new_ptr_inst->type_id();
710         std::vector<uint32_t> access_chain;
711         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
712           access_chain.push_back(use->GetSingleWordInOperand(i));
713         }
714 
715         new_type_id = GetMemberTypeId(new_type_id, access_chain);
716 
717         if (new_type_id != use->type_id()) {
718           use->SetResultType(new_type_id);
719           context()->AnalyzeUses(use);
720           UpdateUses(use, use);
721         } else {
722           context()->AnalyzeUses(use);
723         }
724       } break;
725       case SpvOpStore:
726         // If the use is the pointer, then it is the single store to that
727         // variable.  We do not want to replace it.  Instead, it will become
728         // dead after all of the loads are removed, and ADCE will get rid of it.
729         //
730         // If the use is the object being stored, we will create a copy of the
731         // object turning it into the correct type. The copy is done by
732         // decomposing the object into the base type, which must be the same,
733         // and then rebuilding them.
734         if (index == 1) {
735           Instruction* target_pointer = def_use_mgr->GetDef(
736               use->GetSingleWordInOperand(kStorePointerInOperand));
737           Instruction* pointer_type =
738               def_use_mgr->GetDef(target_pointer->type_id());
739           uint32_t pointee_type_id =
740               pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
741           uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
742 
743           context()->ForgetUses(use);
744           use->SetInOperand(index, {copy});
745           context()->AnalyzeUses(use);
746         }
747         break;
748       case SpvOpDecorate:
749       // We treat an OpImageTexelPointer as a load.  The result type should
750       // always have the Image storage class, and should not need to be
751       // updated.
752       case SpvOpImageTexelPointer:
753         // Replace the actual use.
754         context()->ForgetUses(use);
755         use->SetOperand(index, {new_ptr_inst->result_id()});
756         context()->AnalyzeUses(use);
757         break;
758       default:
759         assert(false && "Don't know how to rewrite instruction");
760         break;
761     }
762   }
763 }
764 
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const765 uint32_t CopyPropagateArrays::GetMemberTypeId(
766     uint32_t id, const std::vector<uint32_t>& access_chain) const {
767   for (uint32_t element_index : access_chain) {
768     Instruction* type_inst = get_def_use_mgr()->GetDef(id);
769     switch (type_inst->opcode()) {
770       case SpvOpTypeArray:
771       case SpvOpTypeRuntimeArray:
772       case SpvOpTypeMatrix:
773       case SpvOpTypeVector:
774         id = type_inst->GetSingleWordInOperand(0);
775         break;
776       case SpvOpTypeStruct:
777         id = type_inst->GetSingleWordInOperand(element_index);
778         break;
779       default:
780         break;
781     }
782     assert(id != 0 &&
783            "Tried to extract from an object where it cannot be done.");
784   }
785   return id;
786 }
787 
GetMember(const std::vector<uint32_t> & access_chain)788 void CopyPropagateArrays::MemoryObject::GetMember(
789     const std::vector<uint32_t>& access_chain) {
790   access_chain_.insert(access_chain_.end(), access_chain.begin(),
791                        access_chain.end());
792 }
793 
GetNumberOfMembers()794 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
795   IRContext* context = variable_inst_->context();
796   analysis::TypeManager* type_mgr = context->get_type_mgr();
797 
798   const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
799   type = type->AsPointer()->pointee_type();
800 
801   std::vector<uint32_t> access_indices = GetAccessIds();
802   type = type_mgr->GetMemberType(type, access_indices);
803 
804   if (const analysis::Struct* struct_type = type->AsStruct()) {
805     return static_cast<uint32_t>(struct_type->element_types().size());
806   } else if (const analysis::Array* array_type = type->AsArray()) {
807     const analysis::Constant* length_const =
808         context->get_constant_mgr()->FindDeclaredConstant(
809             array_type->LengthId());
810     assert(length_const->type()->AsInteger());
811     return length_const->GetU32();
812   } else if (const analysis::Vector* vector_type = type->AsVector()) {
813     return vector_type->element_count();
814   } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
815     return matrix_type->element_count();
816   } else {
817     return 0;
818   }
819 }
820 
821 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)822 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
823                                                 iterator begin, iterator end)
824     : variable_inst_(var_inst), access_chain_(begin, end) {}
825 
GetAccessIds() const826 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
827   analysis::ConstantManager* const_mgr =
828       variable_inst_->context()->get_constant_mgr();
829 
830   std::vector<uint32_t> access_indices;
831   for (uint32_t id : AccessChain()) {
832     const analysis::Constant* element_index_const =
833         const_mgr->FindDeclaredConstant(id);
834     if (!element_index_const) {
835       access_indices.push_back(0);
836     } else {
837       access_indices.push_back(element_index_const->GetU32());
838     }
839   }
840   return access_indices;
841 }
842 
Contains(CopyPropagateArrays::MemoryObject * other)843 bool CopyPropagateArrays::MemoryObject::Contains(
844     CopyPropagateArrays::MemoryObject* other) {
845   if (this->GetVariable() != other->GetVariable()) {
846     return false;
847   }
848 
849   if (AccessChain().size() > other->AccessChain().size()) {
850     return false;
851   }
852 
853   for (uint32_t i = 0; i < AccessChain().size(); i++) {
854     if (AccessChain()[i] != other->AccessChain()[i]) {
855       return false;
856     }
857   }
858   return true;
859 }
860 
861 }  // namespace opt
862 }  // namespace spvtools
863