• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/copy_prop_arrays.h"
16 
17 #include <utility>
18 
19 #include "source/opt/ir_builder.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 
25 const uint32_t kLoadPointerInOperand = 0;
26 const uint32_t kStorePointerInOperand = 0;
27 const uint32_t kStoreObjectInOperand = 1;
28 const uint32_t kCompositeExtractObjectInOperand = 0;
29 const uint32_t kTypePointerStorageClassInIdx = 0;
30 const uint32_t kTypePointerPointeeInIdx = 1;
31 
IsOpenCL100DebugDeclareOrValue(Instruction * di)32 bool IsOpenCL100DebugDeclareOrValue(Instruction* di) {
33   auto dbg_opcode = di->GetOpenCL100DebugOpcode();
34   return dbg_opcode == OpenCLDebugInfo100DebugDeclare ||
35          dbg_opcode == OpenCLDebugInfo100DebugValue;
36 }
37 
38 }  // namespace
39 
Process()40 Pass::Status CopyPropagateArrays::Process() {
41   bool modified = false;
42   for (Function& function : *get_module()) {
43     BasicBlock* entry_bb = &*function.begin();
44 
45     for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
46          ++var_inst) {
47       if (!IsPointerToArrayType(var_inst->type_id())) {
48         continue;
49       }
50 
51       // Find the only store to the entire memory location, if it exists.
52       Instruction* store_inst = FindStoreInstruction(&*var_inst);
53 
54       if (!store_inst) {
55         continue;
56       }
57 
58       std::unique_ptr<MemoryObject> source_object =
59           FindSourceObjectIfPossible(&*var_inst, store_inst);
60 
61       if (source_object != nullptr) {
62         if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
63           modified = true;
64           PropagateObject(&*var_inst, source_object.get(), store_inst);
65         }
66       }
67     }
68   }
69   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
70 }
71 
72 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)73 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
74                                                 Instruction* store_inst) {
75   assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
76 
77   // Check that the variable is a composite object where |store_inst|
78   // dominates all of its loads.
79   if (!store_inst) {
80     return nullptr;
81   }
82 
83   // Look at the loads to ensure they are dominated by the store.
84   if (!HasValidReferencesOnly(var_inst, store_inst)) {
85     return nullptr;
86   }
87 
88   // If so, look at the store to see if it is the copy of an object.
89   std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
90       store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
91 
92   if (!source) {
93     return nullptr;
94   }
95 
96   // Ensure that |source| does not change between the point at which it is
97   // loaded, and the position in which |var_inst| is loaded.
98   //
99   // For now we will go with the easy to implement approach, and check that the
100   // entire variable (not just the specific component) is never written to.
101 
102   if (!HasNoStores(source->GetVariable())) {
103     return nullptr;
104   }
105   return source;
106 }
107 
FindStoreInstruction(const Instruction * var_inst) const108 Instruction* CopyPropagateArrays::FindStoreInstruction(
109     const Instruction* var_inst) const {
110   Instruction* store_inst = nullptr;
111   get_def_use_mgr()->WhileEachUser(
112       var_inst, [&store_inst, var_inst](Instruction* use) {
113         if (use->opcode() == SpvOpStore &&
114             use->GetSingleWordInOperand(kStorePointerInOperand) ==
115                 var_inst->result_id()) {
116           if (store_inst == nullptr) {
117             store_inst = use;
118           } else {
119             store_inst = nullptr;
120             return false;
121           }
122         }
123         return true;
124       });
125   return store_inst;
126 }
127 
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)128 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
129                                           MemoryObject* source,
130                                           Instruction* insertion_point) {
131   assert(var_inst->opcode() == SpvOpVariable &&
132          "This function propagates variables.");
133 
134   Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
135   context()->KillNamesAndDecorates(var_inst);
136   UpdateUses(var_inst, new_access_chain);
137 }
138 
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const139 Instruction* CopyPropagateArrays::BuildNewAccessChain(
140     Instruction* insertion_point,
141     CopyPropagateArrays::MemoryObject* source) const {
142   InstructionBuilder builder(
143       context(), insertion_point,
144       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
145 
146   if (source->AccessChain().size() == 0) {
147     return source->GetVariable();
148   }
149 
150   return builder.AddAccessChain(source->GetPointerTypeId(this),
151                                 source->GetVariable()->result_id(),
152                                 source->AccessChain());
153 }
154 
HasNoStores(Instruction * ptr_inst)155 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
156   return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
157     if (use->opcode() == SpvOpLoad) {
158       return true;
159     } else if (use->opcode() == SpvOpAccessChain) {
160       return HasNoStores(use);
161     } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
162       return true;
163     } else if (use->opcode() == SpvOpStore) {
164       return false;
165     } else if (use->opcode() == SpvOpImageTexelPointer) {
166       return true;
167     }
168     // Some other instruction.  Be conservative.
169     return false;
170   });
171 }
172 
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)173 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
174                                                  Instruction* store_inst) {
175   BasicBlock* store_block = context()->get_instr_block(store_inst);
176   DominatorAnalysis* dominator_analysis =
177       context()->GetDominatorAnalysis(store_block->GetParent());
178 
179   return get_def_use_mgr()->WhileEachUser(
180       ptr_inst,
181       [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
182         if (use->opcode() == SpvOpLoad ||
183             use->opcode() == SpvOpImageTexelPointer) {
184           // TODO: If there are many load in the same BB as |store_inst| the
185           // time to do the multiple traverses can add up.  Consider collecting
186           // those loads and doing a single traversal.
187           return dominator_analysis->Dominates(store_inst, use);
188         } else if (use->opcode() == SpvOpAccessChain) {
189           return HasValidReferencesOnly(use, store_inst);
190         } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
191           return true;
192         } else if (use->opcode() == SpvOpStore) {
193           // If we are storing to part of the object it is not an candidate.
194           return ptr_inst->opcode() == SpvOpVariable &&
195                  store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
196                      ptr_inst->result_id();
197         } else if (IsOpenCL100DebugDeclareOrValue(use)) {
198           return true;
199         }
200         // Some other instruction.  Be conservative.
201         return false;
202       });
203 }
204 
205 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)206 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
207   Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
208 
209   switch (result_inst->opcode()) {
210     case SpvOpLoad:
211       return BuildMemoryObjectFromLoad(result_inst);
212     case SpvOpCompositeExtract:
213       return BuildMemoryObjectFromExtract(result_inst);
214     case SpvOpCompositeConstruct:
215       return BuildMemoryObjectFromCompositeConstruct(result_inst);
216     case SpvOpCopyObject:
217       return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
218     case SpvOpCompositeInsert:
219       return BuildMemoryObjectFromInsert(result_inst);
220     default:
221       return nullptr;
222   }
223 }
224 
225 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)226 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
227   std::vector<uint32_t> components_in_reverse;
228   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
229 
230   Instruction* current_inst = def_use_mgr->GetDef(
231       load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
232 
233   // Build the access chain for the memory object by collecting the indices used
234   // in the OpAccessChain instructions.  If we find a variable index, then
235   // return |nullptr| because we cannot know for sure which memory location is
236   // used.
237   //
238   // It is built in reverse order because the different |OpAccessChain|
239   // instructions are visited in reverse order from which they are applied.
240   while (current_inst->opcode() == SpvOpAccessChain) {
241     for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
242       uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
243       components_in_reverse.push_back(element_index_id);
244     }
245     current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
246   }
247 
248   // If the address in the load is not constructed from an |OpVariable|
249   // instruction followed by a series of |OpAccessChain| instructions, then
250   // return |nullptr| because we cannot identify the owner or access chain
251   // exactly.
252   if (current_inst->opcode() != SpvOpVariable) {
253     return nullptr;
254   }
255 
256   // Build the memory object.  Use |rbegin| and |rend| to put the access chain
257   // back in the correct order.
258   return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
259       new MemoryObject(current_inst, components_in_reverse.rbegin(),
260                        components_in_reverse.rend()));
261 }
262 
263 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)264 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
265   assert(extract_inst->opcode() == SpvOpCompositeExtract &&
266          "Expecting an OpCompositeExtract instruction.");
267   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
268 
269   std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
270       extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
271 
272   if (result) {
273     analysis::Integer int_type(32, false);
274     const analysis::Type* uint32_type =
275         context()->get_type_mgr()->GetRegisteredType(&int_type);
276 
277     std::vector<uint32_t> components;
278     // Convert the indices in the extract instruction to a series of ids that
279     // can be used by the |OpAccessChain| instruction.
280     for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
281       uint32_t index = extract_inst->GetSingleWordInOperand(i);
282       const analysis::Constant* index_const =
283           const_mgr->GetConstant(uint32_type, {index});
284       components.push_back(
285           const_mgr->GetDefiningInstruction(index_const)->result_id());
286     }
287     result->GetMember(components);
288     return result;
289   }
290   return nullptr;
291 }
292 
293 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)294 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
295     Instruction* conststruct_inst) {
296   assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
297          "Expecting an OpCompositeConstruct instruction.");
298 
299   // If every operand in the instruction are part of the same memory object, and
300   // are being combined in the same order, then the result is the same as the
301   // parent.
302 
303   std::unique_ptr<MemoryObject> memory_object =
304       GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
305 
306   if (!memory_object) {
307     return nullptr;
308   }
309 
310   if (!memory_object->IsMember()) {
311     return nullptr;
312   }
313 
314   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
315   const analysis::Constant* last_access =
316       const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
317   if (!last_access || !last_access->type()->AsInteger()) {
318     return nullptr;
319   }
320 
321   if (last_access->GetU32() != 0) {
322     return nullptr;
323   }
324 
325   memory_object->GetParent();
326 
327   if (memory_object->GetNumberOfMembers() !=
328       conststruct_inst->NumInOperands()) {
329     return nullptr;
330   }
331 
332   for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
333     std::unique_ptr<MemoryObject> member_object =
334         GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
335 
336     if (!member_object) {
337       return nullptr;
338     }
339 
340     if (!member_object->IsMember()) {
341       return nullptr;
342     }
343 
344     if (!memory_object->Contains(member_object.get())) {
345       return nullptr;
346     }
347 
348     last_access =
349         const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
350     if (!last_access || !last_access->type()->AsInteger()) {
351       return nullptr;
352     }
353 
354     if (last_access->GetU32() != i) {
355       return nullptr;
356     }
357   }
358   return memory_object;
359 }
360 
361 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)362 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
363   assert(insert_inst->opcode() == SpvOpCompositeInsert &&
364          "Expecting an OpCompositeInsert instruction.");
365 
366   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
367   analysis::TypeManager* type_mgr = context()->get_type_mgr();
368   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
369   const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
370 
371   uint32_t number_of_elements = 0;
372   if (const analysis::Struct* struct_type = result_type->AsStruct()) {
373     number_of_elements =
374         static_cast<uint32_t>(struct_type->element_types().size());
375   } else if (const analysis::Array* array_type = result_type->AsArray()) {
376     const analysis::Constant* length_const =
377         const_mgr->FindDeclaredConstant(array_type->LengthId());
378     number_of_elements = length_const->GetU32();
379   } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
380     number_of_elements = vector_type->element_count();
381   } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
382     number_of_elements = matrix_type->element_count();
383   }
384 
385   if (number_of_elements == 0) {
386     return nullptr;
387   }
388 
389   if (insert_inst->NumInOperands() != 3) {
390     return nullptr;
391   }
392 
393   if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
394     return nullptr;
395   }
396 
397   std::unique_ptr<MemoryObject> memory_object =
398       GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
399 
400   if (!memory_object) {
401     return nullptr;
402   }
403 
404   if (!memory_object->IsMember()) {
405     return nullptr;
406   }
407 
408   const analysis::Constant* last_access =
409       const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
410   if (!last_access || !last_access->type()->AsInteger()) {
411     return nullptr;
412   }
413 
414   if (last_access->GetU32() != number_of_elements - 1) {
415     return nullptr;
416   }
417 
418   memory_object->GetParent();
419 
420   Instruction* current_insert =
421       def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
422   for (uint32_t i = number_of_elements - 1; i > 0; --i) {
423     if (current_insert->opcode() != SpvOpCompositeInsert) {
424       return nullptr;
425     }
426 
427     if (current_insert->NumInOperands() != 3) {
428       return nullptr;
429     }
430 
431     if (current_insert->GetSingleWordInOperand(2) != i - 1) {
432       return nullptr;
433     }
434 
435     std::unique_ptr<MemoryObject> current_memory_object =
436         GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
437 
438     if (!current_memory_object) {
439       return nullptr;
440     }
441 
442     if (!current_memory_object->IsMember()) {
443       return nullptr;
444     }
445 
446     if (memory_object->AccessChain().size() + 1 !=
447         current_memory_object->AccessChain().size()) {
448       return nullptr;
449     }
450 
451     if (!memory_object->Contains(current_memory_object.get())) {
452       return nullptr;
453     }
454 
455     const analysis::Constant* current_last_access =
456         const_mgr->FindDeclaredConstant(
457             current_memory_object->AccessChain().back());
458     if (!current_last_access || !current_last_access->type()->AsInteger()) {
459       return nullptr;
460     }
461 
462     if (current_last_access->GetU32() != i - 1) {
463       return nullptr;
464     }
465     current_insert =
466         def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
467   }
468 
469   return memory_object;
470 }
471 
IsPointerToArrayType(uint32_t type_id)472 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
473   analysis::TypeManager* type_mgr = context()->get_type_mgr();
474   analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
475   if (pointer_type) {
476     return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
477            pointer_type->pointee_type()->kind() == analysis::Type::kImage;
478   }
479   return false;
480 }
481 
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)482 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
483                                         uint32_t type_id) {
484   analysis::TypeManager* type_mgr = context()->get_type_mgr();
485   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
486   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
487 
488   analysis::Type* type = type_mgr->GetType(type_id);
489   if (type->AsRuntimeArray()) {
490     return false;
491   }
492 
493   if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
494     // If the type is not an aggregate, then the desired type must be the
495     // same as the current type.  No work to do, and we can do that.
496     return true;
497   }
498 
499   return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
500                                                        const_mgr,
501                                                        type](Instruction* use,
502                                                              uint32_t) {
503     if (IsOpenCL100DebugDeclareOrValue(use)) return true;
504 
505     switch (use->opcode()) {
506       case SpvOpLoad: {
507         analysis::Pointer* pointer_type = type->AsPointer();
508         uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
509 
510         if (new_type_id != use->type_id()) {
511           return CanUpdateUses(use, new_type_id);
512         }
513         return true;
514       }
515       case SpvOpAccessChain: {
516         analysis::Pointer* pointer_type = type->AsPointer();
517         const analysis::Type* pointee_type = pointer_type->pointee_type();
518 
519         std::vector<uint32_t> access_chain;
520         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
521           const analysis::Constant* index_const =
522               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
523           if (index_const) {
524             access_chain.push_back(index_const->GetU32());
525           } else {
526             // Variable index means the type is a type where every element
527             // is the same type.  Use element 0 to get the type.
528             access_chain.push_back(0);
529           }
530         }
531 
532         const analysis::Type* new_pointee_type =
533             type_mgr->GetMemberType(pointee_type, access_chain);
534         analysis::Pointer pointerTy(new_pointee_type,
535                                     pointer_type->storage_class());
536         uint32_t new_pointer_type_id =
537             context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
538         if (new_pointer_type_id == 0) {
539           return false;
540         }
541 
542         if (new_pointer_type_id != use->type_id()) {
543           return CanUpdateUses(use, new_pointer_type_id);
544         }
545         return true;
546       }
547       case SpvOpCompositeExtract: {
548         std::vector<uint32_t> access_chain;
549         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
550           access_chain.push_back(use->GetSingleWordInOperand(i));
551         }
552 
553         const analysis::Type* new_type =
554             type_mgr->GetMemberType(type, access_chain);
555         uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
556         if (new_type_id == 0) {
557           return false;
558         }
559 
560         if (new_type_id != use->type_id()) {
561           return CanUpdateUses(use, new_type_id);
562         }
563         return true;
564       }
565       case SpvOpStore:
566         // If needed, we can create an element-by-element copy to change the
567         // type of the value being stored.  This way we can always handled
568         // stores.
569         return true;
570       case SpvOpImageTexelPointer:
571       case SpvOpName:
572         return true;
573       default:
574         return use->IsDecoration();
575     }
576   });
577 }
578 
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)579 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
580                                      Instruction* new_ptr_inst) {
581   analysis::TypeManager* type_mgr = context()->get_type_mgr();
582   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
583   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
584 
585   std::vector<std::pair<Instruction*, uint32_t> > uses;
586   def_use_mgr->ForEachUse(original_ptr_inst,
587                           [&uses](Instruction* use, uint32_t index) {
588                             uses.push_back({use, index});
589                           });
590 
591   for (auto pair : uses) {
592     Instruction* use = pair.first;
593     uint32_t index = pair.second;
594 
595     if (use->IsOpenCL100DebugInstr()) {
596       switch (use->GetOpenCL100DebugOpcode()) {
597         case OpenCLDebugInfo100DebugDeclare: {
598           if (new_ptr_inst->opcode() == SpvOpVariable ||
599               new_ptr_inst->opcode() == SpvOpFunctionParameter) {
600             context()->ForgetUses(use);
601             use->SetOperand(index, {new_ptr_inst->result_id()});
602             context()->AnalyzeUses(use);
603           } else {
604             // Based on the spec, we cannot use a pointer other than OpVariable
605             // or OpFunctionParameter for DebugDeclare. We have to use
606             // DebugValue with Deref.
607 
608             context()->ForgetUses(use);
609 
610             // Change DebugDeclare to DebugValue.
611             use->SetOperand(
612                 index - 2,
613                 {static_cast<uint32_t>(OpenCLDebugInfo100DebugValue)});
614             use->SetOperand(index, {new_ptr_inst->result_id()});
615 
616             // Add Deref operation.
617             Instruction* dbg_expr =
618                 def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1));
619             auto* deref_expr_instr =
620                 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
621             use->SetOperand(index + 1, {deref_expr_instr->result_id()});
622 
623             context()->AnalyzeUses(deref_expr_instr);
624             context()->AnalyzeUses(use);
625           }
626           break;
627         }
628         case OpenCLDebugInfo100DebugValue:
629           context()->ForgetUses(use);
630           use->SetOperand(index, {new_ptr_inst->result_id()});
631           context()->AnalyzeUses(use);
632           break;
633         default:
634           assert(false && "Don't know how to rewrite instruction");
635           break;
636       }
637       continue;
638     }
639 
640     switch (use->opcode()) {
641       case SpvOpLoad: {
642         // Replace the actual use.
643         context()->ForgetUses(use);
644         use->SetOperand(index, {new_ptr_inst->result_id()});
645 
646         // Update the type.
647         Instruction* pointer_type_inst =
648             def_use_mgr->GetDef(new_ptr_inst->type_id());
649         uint32_t new_type_id =
650             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
651         if (new_type_id != use->type_id()) {
652           use->SetResultType(new_type_id);
653           context()->AnalyzeUses(use);
654           UpdateUses(use, use);
655         } else {
656           context()->AnalyzeUses(use);
657         }
658       } break;
659       case SpvOpAccessChain: {
660         // Update the actual use.
661         context()->ForgetUses(use);
662         use->SetOperand(index, {new_ptr_inst->result_id()});
663 
664         // Convert the ids on the OpAccessChain to indices that can be used to
665         // get the specific member.
666         std::vector<uint32_t> access_chain;
667         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
668           const analysis::Constant* index_const =
669               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
670           if (index_const) {
671             access_chain.push_back(index_const->GetU32());
672           } else {
673             // Variable index means the type is an type where every element
674             // is the same type.  Use element 0 to get the type.
675             access_chain.push_back(0);
676           }
677         }
678 
679         Instruction* pointer_type_inst =
680             get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
681 
682         uint32_t new_pointee_type_id = GetMemberTypeId(
683             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
684             access_chain);
685 
686         SpvStorageClass storage_class = static_cast<SpvStorageClass>(
687             pointer_type_inst->GetSingleWordInOperand(
688                 kTypePointerStorageClassInIdx));
689 
690         uint32_t new_pointer_type_id =
691             type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
692 
693         if (new_pointer_type_id != use->type_id()) {
694           use->SetResultType(new_pointer_type_id);
695           context()->AnalyzeUses(use);
696           UpdateUses(use, use);
697         } else {
698           context()->AnalyzeUses(use);
699         }
700       } break;
701       case SpvOpCompositeExtract: {
702         // Update the actual use.
703         context()->ForgetUses(use);
704         use->SetOperand(index, {new_ptr_inst->result_id()});
705 
706         uint32_t new_type_id = new_ptr_inst->type_id();
707         std::vector<uint32_t> access_chain;
708         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
709           access_chain.push_back(use->GetSingleWordInOperand(i));
710         }
711 
712         new_type_id = GetMemberTypeId(new_type_id, access_chain);
713 
714         if (new_type_id != use->type_id()) {
715           use->SetResultType(new_type_id);
716           context()->AnalyzeUses(use);
717           UpdateUses(use, use);
718         } else {
719           context()->AnalyzeUses(use);
720         }
721       } break;
722       case SpvOpStore:
723         // If the use is the pointer, then it is the single store to that
724         // variable.  We do not want to replace it.  Instead, it will become
725         // dead after all of the loads are removed, and ADCE will get rid of it.
726         //
727         // If the use is the object being stored, we will create a copy of the
728         // object turning it into the correct type. The copy is done by
729         // decomposing the object into the base type, which must be the same,
730         // and then rebuilding them.
731         if (index == 1) {
732           Instruction* target_pointer = def_use_mgr->GetDef(
733               use->GetSingleWordInOperand(kStorePointerInOperand));
734           Instruction* pointer_type =
735               def_use_mgr->GetDef(target_pointer->type_id());
736           uint32_t pointee_type_id =
737               pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
738           uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
739 
740           context()->ForgetUses(use);
741           use->SetInOperand(index, {copy});
742           context()->AnalyzeUses(use);
743         }
744         break;
745       case SpvOpImageTexelPointer:
746         // We treat an OpImageTexelPointer as a load.  The result type should
747         // always have the Image storage class, and should not need to be
748         // updated.
749 
750         // Replace the actual use.
751         context()->ForgetUses(use);
752         use->SetOperand(index, {new_ptr_inst->result_id()});
753         context()->AnalyzeUses(use);
754         break;
755       default:
756         assert(false && "Don't know how to rewrite instruction");
757         break;
758     }
759   }
760 }
761 
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const762 uint32_t CopyPropagateArrays::GetMemberTypeId(
763     uint32_t id, const std::vector<uint32_t>& access_chain) const {
764   for (uint32_t element_index : access_chain) {
765     Instruction* type_inst = get_def_use_mgr()->GetDef(id);
766     switch (type_inst->opcode()) {
767       case SpvOpTypeArray:
768       case SpvOpTypeRuntimeArray:
769       case SpvOpTypeMatrix:
770       case SpvOpTypeVector:
771         id = type_inst->GetSingleWordInOperand(0);
772         break;
773       case SpvOpTypeStruct:
774         id = type_inst->GetSingleWordInOperand(element_index);
775         break;
776       default:
777         break;
778     }
779     assert(id != 0 &&
780            "Tried to extract from an object where it cannot be done.");
781   }
782   return id;
783 }
784 
GetMember(const std::vector<uint32_t> & access_chain)785 void CopyPropagateArrays::MemoryObject::GetMember(
786     const std::vector<uint32_t>& access_chain) {
787   access_chain_.insert(access_chain_.end(), access_chain.begin(),
788                        access_chain.end());
789 }
790 
GetNumberOfMembers()791 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
792   IRContext* context = variable_inst_->context();
793   analysis::TypeManager* type_mgr = context->get_type_mgr();
794 
795   const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
796   type = type->AsPointer()->pointee_type();
797 
798   std::vector<uint32_t> access_indices = GetAccessIds();
799   type = type_mgr->GetMemberType(type, access_indices);
800 
801   if (const analysis::Struct* struct_type = type->AsStruct()) {
802     return static_cast<uint32_t>(struct_type->element_types().size());
803   } else if (const analysis::Array* array_type = type->AsArray()) {
804     const analysis::Constant* length_const =
805         context->get_constant_mgr()->FindDeclaredConstant(
806             array_type->LengthId());
807     assert(length_const->type()->AsInteger());
808     return length_const->GetU32();
809   } else if (const analysis::Vector* vector_type = type->AsVector()) {
810     return vector_type->element_count();
811   } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
812     return matrix_type->element_count();
813   } else {
814     return 0;
815   }
816 }
817 
818 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)819 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
820                                                 iterator begin, iterator end)
821     : variable_inst_(var_inst), access_chain_(begin, end) {}
822 
GetAccessIds() const823 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
824   analysis::ConstantManager* const_mgr =
825       variable_inst_->context()->get_constant_mgr();
826 
827   std::vector<uint32_t> access_indices;
828   for (uint32_t id : AccessChain()) {
829     const analysis::Constant* element_index_const =
830         const_mgr->FindDeclaredConstant(id);
831     if (!element_index_const) {
832       access_indices.push_back(0);
833     } else {
834       access_indices.push_back(element_index_const->GetU32());
835     }
836   }
837   return access_indices;
838 }
839 
Contains(CopyPropagateArrays::MemoryObject * other)840 bool CopyPropagateArrays::MemoryObject::Contains(
841     CopyPropagateArrays::MemoryObject* other) {
842   if (this->GetVariable() != other->GetVariable()) {
843     return false;
844   }
845 
846   if (AccessChain().size() > other->AccessChain().size()) {
847     return false;
848   }
849 
850   for (uint32_t i = 0; i < AccessChain().size(); i++) {
851     if (AccessChain()[i] != other->AccessChain()[i]) {
852       return false;
853     }
854   }
855   return true;
856 }
857 
858 }  // namespace opt
859 }  // namespace spvtools
860