• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/copy_prop_arrays.h"
16 
17 #include <utility>
18 
19 #include "source/opt/ir_builder.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 
25 const uint32_t kLoadPointerInOperand = 0;
26 const uint32_t kStorePointerInOperand = 0;
27 const uint32_t kStoreObjectInOperand = 1;
28 const uint32_t kCompositeExtractObjectInOperand = 0;
29 const uint32_t kTypePointerStorageClassInIdx = 0;
30 const uint32_t kTypePointerPointeeInIdx = 1;
31 
32 }  // namespace
33 
Process()34 Pass::Status CopyPropagateArrays::Process() {
35   bool modified = false;
36   for (Function& function : *get_module()) {
37     BasicBlock* entry_bb = &*function.begin();
38 
39     for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
40          ++var_inst) {
41       if (!IsPointerToArrayType(var_inst->type_id())) {
42         continue;
43       }
44 
45       // Find the only store to the entire memory location, if it exists.
46       Instruction* store_inst = FindStoreInstruction(&*var_inst);
47 
48       if (!store_inst) {
49         continue;
50       }
51 
52       std::unique_ptr<MemoryObject> source_object =
53           FindSourceObjectIfPossible(&*var_inst, store_inst);
54 
55       if (source_object != nullptr) {
56         if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
57           modified = true;
58           PropagateObject(&*var_inst, source_object.get(), store_inst);
59         }
60       }
61     }
62   }
63   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
64 }
65 
66 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)67 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
68                                                 Instruction* store_inst) {
69   assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
70 
71   // Check that the variable is a composite object where |store_inst|
72   // dominates all of its loads.
73   if (!store_inst) {
74     return nullptr;
75   }
76 
77   // Look at the loads to ensure they are dominated by the store.
78   if (!HasValidReferencesOnly(var_inst, store_inst)) {
79     return nullptr;
80   }
81 
82   // If so, look at the store to see if it is the copy of an object.
83   std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
84       store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
85 
86   if (!source) {
87     return nullptr;
88   }
89 
90   // Ensure that |source| does not change between the point at which it is
91   // loaded, and the position in which |var_inst| is loaded.
92   //
93   // For now we will go with the easy to implement approach, and check that the
94   // entire variable (not just the specific component) is never written to.
95 
96   if (!HasNoStores(source->GetVariable())) {
97     return nullptr;
98   }
99   return source;
100 }
101 
FindStoreInstruction(const Instruction * var_inst) const102 Instruction* CopyPropagateArrays::FindStoreInstruction(
103     const Instruction* var_inst) const {
104   Instruction* store_inst = nullptr;
105   get_def_use_mgr()->WhileEachUser(
106       var_inst, [&store_inst, var_inst](Instruction* use) {
107         if (use->opcode() == SpvOpStore &&
108             use->GetSingleWordInOperand(kStorePointerInOperand) ==
109                 var_inst->result_id()) {
110           if (store_inst == nullptr) {
111             store_inst = use;
112           } else {
113             store_inst = nullptr;
114             return false;
115           }
116         }
117         return true;
118       });
119   return store_inst;
120 }
121 
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)122 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
123                                           MemoryObject* source,
124                                           Instruction* insertion_point) {
125   assert(var_inst->opcode() == SpvOpVariable &&
126          "This function propagates variables.");
127 
128   Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
129   context()->KillNamesAndDecorates(var_inst);
130   UpdateUses(var_inst, new_access_chain);
131 }
132 
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const133 Instruction* CopyPropagateArrays::BuildNewAccessChain(
134     Instruction* insertion_point,
135     CopyPropagateArrays::MemoryObject* source) const {
136   InstructionBuilder builder(
137       context(), insertion_point,
138       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
139 
140   if (source->AccessChain().size() == 0) {
141     return source->GetVariable();
142   }
143 
144   return builder.AddAccessChain(source->GetPointerTypeId(this),
145                                 source->GetVariable()->result_id(),
146                                 source->AccessChain());
147 }
148 
HasNoStores(Instruction * ptr_inst)149 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
150   return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
151     if (use->opcode() == SpvOpLoad) {
152       return true;
153     } else if (use->opcode() == SpvOpAccessChain) {
154       return HasNoStores(use);
155     } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
156       return true;
157     } else if (use->opcode() == SpvOpStore) {
158       return false;
159     } else if (use->opcode() == SpvOpImageTexelPointer) {
160       return true;
161     }
162     // Some other instruction.  Be conservative.
163     return false;
164   });
165 }
166 
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)167 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
168                                                  Instruction* store_inst) {
169   BasicBlock* store_block = context()->get_instr_block(store_inst);
170   DominatorAnalysis* dominator_analysis =
171       context()->GetDominatorAnalysis(store_block->GetParent());
172 
173   return get_def_use_mgr()->WhileEachUser(
174       ptr_inst,
175       [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
176         if (use->opcode() == SpvOpLoad ||
177             use->opcode() == SpvOpImageTexelPointer) {
178           // TODO: If there are many load in the same BB as |store_inst| the
179           // time to do the multiple traverses can add up.  Consider collecting
180           // those loads and doing a single traversal.
181           return dominator_analysis->Dominates(store_inst, use);
182         } else if (use->opcode() == SpvOpAccessChain) {
183           return HasValidReferencesOnly(use, store_inst);
184         } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
185           return true;
186         } else if (use->opcode() == SpvOpStore) {
187           // If we are storing to part of the object it is not an candidate.
188           return ptr_inst->opcode() == SpvOpVariable &&
189                  store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
190                      ptr_inst->result_id();
191         }
192         // Some other instruction.  Be conservative.
193         return false;
194       });
195 }
196 
197 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)198 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
199   Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
200 
201   switch (result_inst->opcode()) {
202     case SpvOpLoad:
203       return BuildMemoryObjectFromLoad(result_inst);
204     case SpvOpCompositeExtract:
205       return BuildMemoryObjectFromExtract(result_inst);
206     case SpvOpCompositeConstruct:
207       return BuildMemoryObjectFromCompositeConstruct(result_inst);
208     case SpvOpCopyObject:
209       return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
210     case SpvOpCompositeInsert:
211       return BuildMemoryObjectFromInsert(result_inst);
212     default:
213       return nullptr;
214   }
215 }
216 
217 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)218 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
219   std::vector<uint32_t> components_in_reverse;
220   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
221 
222   Instruction* current_inst = def_use_mgr->GetDef(
223       load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
224 
225   // Build the access chain for the memory object by collecting the indices used
226   // in the OpAccessChain instructions.  If we find a variable index, then
227   // return |nullptr| because we cannot know for sure which memory location is
228   // used.
229   //
230   // It is built in reverse order because the different |OpAccessChain|
231   // instructions are visited in reverse order from which they are applied.
232   while (current_inst->opcode() == SpvOpAccessChain) {
233     for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
234       uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
235       components_in_reverse.push_back(element_index_id);
236     }
237     current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
238   }
239 
240   // If the address in the load is not constructed from an |OpVariable|
241   // instruction followed by a series of |OpAccessChain| instructions, then
242   // return |nullptr| because we cannot identify the owner or access chain
243   // exactly.
244   if (current_inst->opcode() != SpvOpVariable) {
245     return nullptr;
246   }
247 
248   // Build the memory object.  Use |rbegin| and |rend| to put the access chain
249   // back in the correct order.
250   return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
251       new MemoryObject(current_inst, components_in_reverse.rbegin(),
252                        components_in_reverse.rend()));
253 }
254 
255 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)256 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
257   assert(extract_inst->opcode() == SpvOpCompositeExtract &&
258          "Expecting an OpCompositeExtract instruction.");
259   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
260 
261   std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
262       extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
263 
264   if (result) {
265     analysis::Integer int_type(32, false);
266     const analysis::Type* uint32_type =
267         context()->get_type_mgr()->GetRegisteredType(&int_type);
268 
269     std::vector<uint32_t> components;
270     // Convert the indices in the extract instruction to a series of ids that
271     // can be used by the |OpAccessChain| instruction.
272     for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
273       uint32_t index = extract_inst->GetSingleWordInOperand(i);
274       const analysis::Constant* index_const =
275           const_mgr->GetConstant(uint32_type, {index});
276       components.push_back(
277           const_mgr->GetDefiningInstruction(index_const)->result_id());
278     }
279     result->GetMember(components);
280     return result;
281   }
282   return nullptr;
283 }
284 
285 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)286 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
287     Instruction* conststruct_inst) {
288   assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
289          "Expecting an OpCompositeConstruct instruction.");
290 
291   // If every operand in the instruction are part of the same memory object, and
292   // are being combined in the same order, then the result is the same as the
293   // parent.
294 
295   std::unique_ptr<MemoryObject> memory_object =
296       GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
297 
298   if (!memory_object) {
299     return nullptr;
300   }
301 
302   if (!memory_object->IsMember()) {
303     return nullptr;
304   }
305 
306   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
307   const analysis::Constant* last_access =
308       const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
309   if (!last_access || !last_access->type()->AsInteger()) {
310     return nullptr;
311   }
312 
313   if (last_access->GetU32() != 0) {
314     return nullptr;
315   }
316 
317   memory_object->GetParent();
318 
319   if (memory_object->GetNumberOfMembers() !=
320       conststruct_inst->NumInOperands()) {
321     return nullptr;
322   }
323 
324   for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
325     std::unique_ptr<MemoryObject> member_object =
326         GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
327 
328     if (!member_object) {
329       return nullptr;
330     }
331 
332     if (!member_object->IsMember()) {
333       return nullptr;
334     }
335 
336     if (!memory_object->Contains(member_object.get())) {
337       return nullptr;
338     }
339 
340     last_access =
341         const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
342     if (!last_access || !last_access->type()->AsInteger()) {
343       return nullptr;
344     }
345 
346     if (last_access->GetU32() != i) {
347       return nullptr;
348     }
349   }
350   return memory_object;
351 }
352 
353 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)354 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
355   assert(insert_inst->opcode() == SpvOpCompositeInsert &&
356          "Expecting an OpCompositeInsert instruction.");
357 
358   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
359   analysis::TypeManager* type_mgr = context()->get_type_mgr();
360   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
361   const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
362 
363   uint32_t number_of_elements = 0;
364   if (const analysis::Struct* struct_type = result_type->AsStruct()) {
365     number_of_elements =
366         static_cast<uint32_t>(struct_type->element_types().size());
367   } else if (const analysis::Array* array_type = result_type->AsArray()) {
368     const analysis::Constant* length_const =
369         const_mgr->FindDeclaredConstant(array_type->LengthId());
370     number_of_elements = length_const->GetU32();
371   } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
372     number_of_elements = vector_type->element_count();
373   } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
374     number_of_elements = matrix_type->element_count();
375   }
376 
377   if (number_of_elements == 0) {
378     return nullptr;
379   }
380 
381   if (insert_inst->NumInOperands() != 3) {
382     return nullptr;
383   }
384 
385   if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
386     return nullptr;
387   }
388 
389   std::unique_ptr<MemoryObject> memory_object =
390       GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
391 
392   if (!memory_object) {
393     return nullptr;
394   }
395 
396   if (!memory_object->IsMember()) {
397     return nullptr;
398   }
399 
400   const analysis::Constant* last_access =
401       const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
402   if (!last_access || !last_access->type()->AsInteger()) {
403     return nullptr;
404   }
405 
406   if (last_access->GetU32() != number_of_elements - 1) {
407     return nullptr;
408   }
409 
410   memory_object->GetParent();
411 
412   Instruction* current_insert =
413       def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
414   for (uint32_t i = number_of_elements - 1; i > 0; --i) {
415     if (current_insert->opcode() != SpvOpCompositeInsert) {
416       return nullptr;
417     }
418 
419     if (current_insert->NumInOperands() != 3) {
420       return nullptr;
421     }
422 
423     if (current_insert->GetSingleWordInOperand(2) != i - 1) {
424       return nullptr;
425     }
426 
427     std::unique_ptr<MemoryObject> current_memory_object =
428         GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
429 
430     if (!current_memory_object) {
431       return nullptr;
432     }
433 
434     if (!current_memory_object->IsMember()) {
435       return nullptr;
436     }
437 
438     if (memory_object->AccessChain().size() + 1 !=
439         current_memory_object->AccessChain().size()) {
440       return nullptr;
441     }
442 
443     if (!memory_object->Contains(current_memory_object.get())) {
444       return nullptr;
445     }
446 
447     const analysis::Constant* current_last_access =
448         const_mgr->FindDeclaredConstant(
449             current_memory_object->AccessChain().back());
450     if (!current_last_access || !current_last_access->type()->AsInteger()) {
451       return nullptr;
452     }
453 
454     if (current_last_access->GetU32() != i - 1) {
455       return nullptr;
456     }
457     current_insert =
458         def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
459   }
460 
461   return memory_object;
462 }
463 
IsPointerToArrayType(uint32_t type_id)464 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
465   analysis::TypeManager* type_mgr = context()->get_type_mgr();
466   analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
467   if (pointer_type) {
468     return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
469            pointer_type->pointee_type()->kind() == analysis::Type::kImage;
470   }
471   return false;
472 }
473 
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)474 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
475                                         uint32_t type_id) {
476   analysis::TypeManager* type_mgr = context()->get_type_mgr();
477   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
478   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
479 
480   analysis::Type* type = type_mgr->GetType(type_id);
481   if (type->AsRuntimeArray()) {
482     return false;
483   }
484 
485   if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
486     // If the type is not an aggregate, then the desired type must be the
487     // same as the current type.  No work to do, and we can do that.
488     return true;
489   }
490 
491   return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
492                                                        const_mgr,
493                                                        type](Instruction* use,
494                                                              uint32_t) {
495     switch (use->opcode()) {
496       case SpvOpLoad: {
497         analysis::Pointer* pointer_type = type->AsPointer();
498         uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
499 
500         if (new_type_id != use->type_id()) {
501           return CanUpdateUses(use, new_type_id);
502         }
503         return true;
504       }
505       case SpvOpAccessChain: {
506         analysis::Pointer* pointer_type = type->AsPointer();
507         const analysis::Type* pointee_type = pointer_type->pointee_type();
508 
509         std::vector<uint32_t> access_chain;
510         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
511           const analysis::Constant* index_const =
512               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
513           if (index_const) {
514             access_chain.push_back(index_const->GetU32());
515           } else {
516             // Variable index means the type is a type where every element
517             // is the same type.  Use element 0 to get the type.
518             access_chain.push_back(0);
519           }
520         }
521 
522         const analysis::Type* new_pointee_type =
523             type_mgr->GetMemberType(pointee_type, access_chain);
524         analysis::Pointer pointerTy(new_pointee_type,
525                                     pointer_type->storage_class());
526         uint32_t new_pointer_type_id =
527             context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
528         if (new_pointer_type_id == 0) {
529           return false;
530         }
531 
532         if (new_pointer_type_id != use->type_id()) {
533           return CanUpdateUses(use, new_pointer_type_id);
534         }
535         return true;
536       }
537       case SpvOpCompositeExtract: {
538         std::vector<uint32_t> access_chain;
539         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
540           access_chain.push_back(use->GetSingleWordInOperand(i));
541         }
542 
543         const analysis::Type* new_type =
544             type_mgr->GetMemberType(type, access_chain);
545         uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
546         if (new_type_id == 0) {
547           return false;
548         }
549 
550         if (new_type_id != use->type_id()) {
551           return CanUpdateUses(use, new_type_id);
552         }
553         return true;
554       }
555       case SpvOpStore:
556         // If needed, we can create an element-by-element copy to change the
557         // type of the value being stored.  This way we can always handled
558         // stores.
559         return true;
560       case SpvOpImageTexelPointer:
561       case SpvOpName:
562         return true;
563       default:
564         return use->IsDecoration();
565     }
566   });
567 }
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)568 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
569                                      Instruction* new_ptr_inst) {
570   analysis::TypeManager* type_mgr = context()->get_type_mgr();
571   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
572   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
573 
574   std::vector<std::pair<Instruction*, uint32_t> > uses;
575   def_use_mgr->ForEachUse(original_ptr_inst,
576                           [&uses](Instruction* use, uint32_t index) {
577                             uses.push_back({use, index});
578                           });
579 
580   for (auto pair : uses) {
581     Instruction* use = pair.first;
582     uint32_t index = pair.second;
583     switch (use->opcode()) {
584       case SpvOpLoad: {
585         // Replace the actual use.
586         context()->ForgetUses(use);
587         use->SetOperand(index, {new_ptr_inst->result_id()});
588 
589         // Update the type.
590         Instruction* pointer_type_inst =
591             def_use_mgr->GetDef(new_ptr_inst->type_id());
592         uint32_t new_type_id =
593             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
594         if (new_type_id != use->type_id()) {
595           use->SetResultType(new_type_id);
596           context()->AnalyzeUses(use);
597           UpdateUses(use, use);
598         } else {
599           context()->AnalyzeUses(use);
600         }
601       } break;
602       case SpvOpAccessChain: {
603         // Update the actual use.
604         context()->ForgetUses(use);
605         use->SetOperand(index, {new_ptr_inst->result_id()});
606 
607         // Convert the ids on the OpAccessChain to indices that can be used to
608         // get the specific member.
609         std::vector<uint32_t> access_chain;
610         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
611           const analysis::Constant* index_const =
612               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
613           if (index_const) {
614             access_chain.push_back(index_const->GetU32());
615           } else {
616             // Variable index means the type is an type where every element
617             // is the same type.  Use element 0 to get the type.
618             access_chain.push_back(0);
619           }
620         }
621 
622         Instruction* pointer_type_inst =
623             get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
624 
625         uint32_t new_pointee_type_id = GetMemberTypeId(
626             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
627             access_chain);
628 
629         SpvStorageClass storage_class = static_cast<SpvStorageClass>(
630             pointer_type_inst->GetSingleWordInOperand(
631                 kTypePointerStorageClassInIdx));
632 
633         uint32_t new_pointer_type_id =
634             type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
635 
636         if (new_pointer_type_id != use->type_id()) {
637           use->SetResultType(new_pointer_type_id);
638           context()->AnalyzeUses(use);
639           UpdateUses(use, use);
640         } else {
641           context()->AnalyzeUses(use);
642         }
643       } break;
644       case SpvOpCompositeExtract: {
645         // Update the actual use.
646         context()->ForgetUses(use);
647         use->SetOperand(index, {new_ptr_inst->result_id()});
648 
649         uint32_t new_type_id = new_ptr_inst->type_id();
650         std::vector<uint32_t> access_chain;
651         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
652           access_chain.push_back(use->GetSingleWordInOperand(i));
653         }
654 
655         new_type_id = GetMemberTypeId(new_type_id, access_chain);
656 
657         if (new_type_id != use->type_id()) {
658           use->SetResultType(new_type_id);
659           context()->AnalyzeUses(use);
660           UpdateUses(use, use);
661         } else {
662           context()->AnalyzeUses(use);
663         }
664       } break;
665       case SpvOpStore:
666         // If the use is the pointer, then it is the single store to that
667         // variable.  We do not want to replace it.  Instead, it will become
668         // dead after all of the loads are removed, and ADCE will get rid of it.
669         //
670         // If the use is the object being stored, we will create a copy of the
671         // object turning it into the correct type. The copy is done by
672         // decomposing the object into the base type, which must be the same,
673         // and then rebuilding them.
674         if (index == 1) {
675           Instruction* target_pointer = def_use_mgr->GetDef(
676               use->GetSingleWordInOperand(kStorePointerInOperand));
677           Instruction* pointer_type =
678               def_use_mgr->GetDef(target_pointer->type_id());
679           uint32_t pointee_type_id =
680               pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
681           uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
682 
683           context()->ForgetUses(use);
684           use->SetInOperand(index, {copy});
685           context()->AnalyzeUses(use);
686         }
687         break;
688       case SpvOpImageTexelPointer:
689         // We treat an OpImageTexelPointer as a load.  The result type should
690         // always have the Image storage class, and should not need to be
691         // updated.
692 
693         // Replace the actual use.
694         context()->ForgetUses(use);
695         use->SetOperand(index, {new_ptr_inst->result_id()});
696         context()->AnalyzeUses(use);
697         break;
698       default:
699         assert(false && "Don't know how to rewrite instruction");
700         break;
701     }
702   }
703 }
704 
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const705 uint32_t CopyPropagateArrays::GetMemberTypeId(
706     uint32_t id, const std::vector<uint32_t>& access_chain) const {
707   for (uint32_t element_index : access_chain) {
708     Instruction* type_inst = get_def_use_mgr()->GetDef(id);
709     switch (type_inst->opcode()) {
710       case SpvOpTypeArray:
711       case SpvOpTypeRuntimeArray:
712       case SpvOpTypeMatrix:
713       case SpvOpTypeVector:
714         id = type_inst->GetSingleWordInOperand(0);
715         break;
716       case SpvOpTypeStruct:
717         id = type_inst->GetSingleWordInOperand(element_index);
718         break;
719       default:
720         break;
721     }
722     assert(id != 0 &&
723            "Tried to extract from an object where it cannot be done.");
724   }
725   return id;
726 }
727 
GetMember(const std::vector<uint32_t> & access_chain)728 void CopyPropagateArrays::MemoryObject::GetMember(
729     const std::vector<uint32_t>& access_chain) {
730   access_chain_.insert(access_chain_.end(), access_chain.begin(),
731                        access_chain.end());
732 }
733 
GetNumberOfMembers()734 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
735   IRContext* context = variable_inst_->context();
736   analysis::TypeManager* type_mgr = context->get_type_mgr();
737 
738   const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
739   type = type->AsPointer()->pointee_type();
740 
741   std::vector<uint32_t> access_indices = GetAccessIds();
742   type = type_mgr->GetMemberType(type, access_indices);
743 
744   if (const analysis::Struct* struct_type = type->AsStruct()) {
745     return static_cast<uint32_t>(struct_type->element_types().size());
746   } else if (const analysis::Array* array_type = type->AsArray()) {
747     const analysis::Constant* length_const =
748         context->get_constant_mgr()->FindDeclaredConstant(
749             array_type->LengthId());
750     assert(length_const->type()->AsInteger());
751     return length_const->GetU32();
752   } else if (const analysis::Vector* vector_type = type->AsVector()) {
753     return vector_type->element_count();
754   } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
755     return matrix_type->element_count();
756   } else {
757     return 0;
758   }
759 }
760 
761 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)762 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
763                                                 iterator begin, iterator end)
764     : variable_inst_(var_inst), access_chain_(begin, end) {}
765 
GetAccessIds() const766 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
767   analysis::ConstantManager* const_mgr =
768       variable_inst_->context()->get_constant_mgr();
769 
770   std::vector<uint32_t> access_indices;
771   for (uint32_t id : AccessChain()) {
772     const analysis::Constant* element_index_const =
773         const_mgr->FindDeclaredConstant(id);
774     if (!element_index_const) {
775       access_indices.push_back(0);
776     } else {
777       access_indices.push_back(element_index_const->GetU32());
778     }
779   }
780   return access_indices;
781 }
782 
Contains(CopyPropagateArrays::MemoryObject * other)783 bool CopyPropagateArrays::MemoryObject::Contains(
784     CopyPropagateArrays::MemoryObject* other) {
785   if (this->GetVariable() != other->GetVariable()) {
786     return false;
787   }
788 
789   if (AccessChain().size() > other->AccessChain().size()) {
790     return false;
791   }
792 
793   for (uint32_t i = 0; i < AccessChain().size(); i++) {
794     if (AccessChain()[i] != other->AccessChain()[i]) {
795       return false;
796     }
797   }
798   return true;
799 }
800 
801 }  // namespace opt
802 }  // namespace spvtools
803