• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/copy_prop_arrays.h"
16 
17 #include <utility>
18 
19 #include "source/opt/ir_builder.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 
25 constexpr uint32_t kLoadPointerInOperand = 0;
26 constexpr uint32_t kStorePointerInOperand = 0;
27 constexpr uint32_t kStoreObjectInOperand = 1;
28 constexpr uint32_t kCompositeExtractObjectInOperand = 0;
29 constexpr uint32_t kTypePointerStorageClassInIdx = 0;
30 constexpr uint32_t kTypePointerPointeeInIdx = 1;
31 
IsDebugDeclareOrValue(Instruction * di)32 bool IsDebugDeclareOrValue(Instruction* di) {
33   auto dbg_opcode = di->GetCommonDebugOpcode();
34   return dbg_opcode == CommonDebugInfoDebugDeclare ||
35          dbg_opcode == CommonDebugInfoDebugValue;
36 }
37 
38 // Returns the number of members in |type|.  If |type| is not a composite type
39 // or the number of components is not known at compile time, the return value
40 // will be 0.
GetNumberOfMembers(const analysis::Type * type,IRContext * context)41 uint32_t GetNumberOfMembers(const analysis::Type* type, IRContext* context) {
42   if (const analysis::Struct* struct_type = type->AsStruct()) {
43     return static_cast<uint32_t>(struct_type->element_types().size());
44   } else if (const analysis::Array* array_type = type->AsArray()) {
45     const analysis::Constant* length_const =
46         context->get_constant_mgr()->FindDeclaredConstant(
47             array_type->LengthId());
48 
49     if (length_const == nullptr) {
50       // This can happen if the length is an OpSpecConstant.
51       return 0;
52     }
53     assert(length_const->type()->AsInteger());
54     return length_const->GetU32();
55   } else if (const analysis::Vector* vector_type = type->AsVector()) {
56     return vector_type->element_count();
57   } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
58     return matrix_type->element_count();
59   } else {
60     return 0;
61   }
62 }
63 
64 }  // namespace
65 
Process()66 Pass::Status CopyPropagateArrays::Process() {
67   bool modified = false;
68   for (Function& function : *get_module()) {
69     if (function.IsDeclaration()) {
70       continue;
71     }
72 
73     BasicBlock* entry_bb = &*function.begin();
74 
75     for (auto var_inst = entry_bb->begin();
76          var_inst->opcode() == spv::Op::OpVariable; ++var_inst) {
77       if (!IsPointerToArrayType(var_inst->type_id())) {
78         continue;
79       }
80 
81       // Find the only store to the entire memory location, if it exists.
82       Instruction* store_inst = FindStoreInstruction(&*var_inst);
83 
84       if (!store_inst) {
85         continue;
86       }
87 
88       std::unique_ptr<MemoryObject> source_object =
89           FindSourceObjectIfPossible(&*var_inst, store_inst);
90 
91       if (source_object != nullptr) {
92         if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
93           modified = true;
94           PropagateObject(&*var_inst, source_object.get(), store_inst);
95         }
96       }
97     }
98   }
99   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
100 }
101 
102 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)103 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
104                                                 Instruction* store_inst) {
105   assert(var_inst->opcode() == spv::Op::OpVariable && "Expecting a variable.");
106 
107   // Check that the variable is a composite object where |store_inst|
108   // dominates all of its loads.
109   if (!store_inst) {
110     return nullptr;
111   }
112 
113   // Look at the loads to ensure they are dominated by the store.
114   if (!HasValidReferencesOnly(var_inst, store_inst)) {
115     return nullptr;
116   }
117 
118   // If so, look at the store to see if it is the copy of an object.
119   std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
120       store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
121 
122   if (!source) {
123     return nullptr;
124   }
125 
126   // Ensure that |source| does not change between the point at which it is
127   // loaded, and the position in which |var_inst| is loaded.
128   //
129   // For now we will go with the easy to implement approach, and check that the
130   // entire variable (not just the specific component) is never written to.
131 
132   if (!HasNoStores(source->GetVariable())) {
133     return nullptr;
134   }
135   return source;
136 }
137 
FindStoreInstruction(const Instruction * var_inst) const138 Instruction* CopyPropagateArrays::FindStoreInstruction(
139     const Instruction* var_inst) const {
140   Instruction* store_inst = nullptr;
141   get_def_use_mgr()->WhileEachUser(
142       var_inst, [&store_inst, var_inst](Instruction* use) {
143         if (use->opcode() == spv::Op::OpStore &&
144             use->GetSingleWordInOperand(kStorePointerInOperand) ==
145                 var_inst->result_id()) {
146           if (store_inst == nullptr) {
147             store_inst = use;
148           } else {
149             store_inst = nullptr;
150             return false;
151           }
152         }
153         return true;
154       });
155   return store_inst;
156 }
157 
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)158 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
159                                           MemoryObject* source,
160                                           Instruction* insertion_point) {
161   assert(var_inst->opcode() == spv::Op::OpVariable &&
162          "This function propagates variables.");
163 
164   Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
165   context()->KillNamesAndDecorates(var_inst);
166   UpdateUses(var_inst, new_access_chain);
167 }
168 
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const169 Instruction* CopyPropagateArrays::BuildNewAccessChain(
170     Instruction* insertion_point,
171     CopyPropagateArrays::MemoryObject* source) const {
172   InstructionBuilder builder(
173       context(), insertion_point,
174       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
175 
176   if (source->AccessChain().size() == 0) {
177     return source->GetVariable();
178   }
179 
180   source->BuildConstants();
181   std::vector<uint32_t> access_ids(source->AccessChain().size());
182   std::transform(
183       source->AccessChain().cbegin(), source->AccessChain().cend(),
184       access_ids.begin(), [](const AccessChainEntry& entry) {
185         assert(entry.is_result_id && "Constants needs to be built first.");
186         return entry.result_id;
187       });
188 
189   return builder.AddAccessChain(source->GetPointerTypeId(this),
190                                 source->GetVariable()->result_id(), access_ids);
191 }
192 
HasNoStores(Instruction * ptr_inst)193 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
194   return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
195     if (use->opcode() == spv::Op::OpLoad) {
196       return true;
197     } else if (use->opcode() == spv::Op::OpAccessChain) {
198       return HasNoStores(use);
199     } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
200       return true;
201     } else if (use->opcode() == spv::Op::OpStore) {
202       return false;
203     } else if (use->opcode() == spv::Op::OpImageTexelPointer) {
204       return true;
205     } else if (use->opcode() == spv::Op::OpEntryPoint) {
206       return true;
207     }
208     // Some other instruction.  Be conservative.
209     return false;
210   });
211 }
212 
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)213 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
214                                                  Instruction* store_inst) {
215   BasicBlock* store_block = context()->get_instr_block(store_inst);
216   DominatorAnalysis* dominator_analysis =
217       context()->GetDominatorAnalysis(store_block->GetParent());
218 
219   return get_def_use_mgr()->WhileEachUser(
220       ptr_inst,
221       [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
222         if (use->opcode() == spv::Op::OpLoad ||
223             use->opcode() == spv::Op::OpImageTexelPointer) {
224           // TODO: If there are many load in the same BB as |store_inst| the
225           // time to do the multiple traverses can add up.  Consider collecting
226           // those loads and doing a single traversal.
227           return dominator_analysis->Dominates(store_inst, use);
228         } else if (use->opcode() == spv::Op::OpAccessChain) {
229           return HasValidReferencesOnly(use, store_inst);
230         } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
231           return true;
232         } else if (use->opcode() == spv::Op::OpStore) {
233           // If we are storing to part of the object it is not an candidate.
234           return ptr_inst->opcode() == spv::Op::OpVariable &&
235                  store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
236                      ptr_inst->result_id();
237         } else if (IsDebugDeclareOrValue(use)) {
238           return true;
239         }
240         // Some other instruction.  Be conservative.
241         return false;
242       });
243 }
244 
245 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)246 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
247   Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
248 
249   switch (result_inst->opcode()) {
250     case spv::Op::OpLoad:
251       return BuildMemoryObjectFromLoad(result_inst);
252     case spv::Op::OpCompositeExtract:
253       return BuildMemoryObjectFromExtract(result_inst);
254     case spv::Op::OpCompositeConstruct:
255       return BuildMemoryObjectFromCompositeConstruct(result_inst);
256     case spv::Op::OpCopyObject:
257       return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
258     case spv::Op::OpCompositeInsert:
259       return BuildMemoryObjectFromInsert(result_inst);
260     default:
261       return nullptr;
262   }
263 }
264 
265 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)266 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
267   std::vector<uint32_t> components_in_reverse;
268   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
269 
270   Instruction* current_inst = def_use_mgr->GetDef(
271       load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
272 
273   // Build the access chain for the memory object by collecting the indices used
274   // in the OpAccessChain instructions.  If we find a variable index, then
275   // return |nullptr| because we cannot know for sure which memory location is
276   // used.
277   //
278   // It is built in reverse order because the different |OpAccessChain|
279   // instructions are visited in reverse order from which they are applied.
280   while (current_inst->opcode() == spv::Op::OpAccessChain) {
281     for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
282       uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
283       components_in_reverse.push_back(element_index_id);
284     }
285     current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
286   }
287 
288   // If the address in the load is not constructed from an |OpVariable|
289   // instruction followed by a series of |OpAccessChain| instructions, then
290   // return |nullptr| because we cannot identify the owner or access chain
291   // exactly.
292   if (current_inst->opcode() != spv::Op::OpVariable) {
293     return nullptr;
294   }
295 
296   // Build the memory object.  Use |rbegin| and |rend| to put the access chain
297   // back in the correct order.
298   return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
299       new MemoryObject(current_inst, components_in_reverse.rbegin(),
300                        components_in_reverse.rend()));
301 }
302 
303 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)304 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
305   assert(extract_inst->opcode() == spv::Op::OpCompositeExtract &&
306          "Expecting an OpCompositeExtract instruction.");
307   std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
308       extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
309 
310   if (!result) {
311     return nullptr;
312   }
313 
314   // Copy the indices of the extract instruction to |OpAccessChain| indices.
315   std::vector<AccessChainEntry> components;
316   for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
317     components.push_back({false, {extract_inst->GetSingleWordInOperand(i)}});
318   }
319   result->PushIndirection(components);
320   return result;
321 }
322 
323 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)324 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
325     Instruction* conststruct_inst) {
326   assert(conststruct_inst->opcode() == spv::Op::OpCompositeConstruct &&
327          "Expecting an OpCompositeConstruct instruction.");
328 
329   // If every operand in the instruction are part of the same memory object, and
330   // are being combined in the same order, then the result is the same as the
331   // parent.
332 
333   std::unique_ptr<MemoryObject> memory_object =
334       GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
335 
336   if (!memory_object) {
337     return nullptr;
338   }
339 
340   if (!memory_object->IsMember()) {
341     return nullptr;
342   }
343 
344   AccessChainEntry last_access = memory_object->AccessChain().back();
345   if (!IsAccessChainIndexValidAndEqualTo(last_access, 0)) {
346     return nullptr;
347   }
348 
349   memory_object->PopIndirection();
350   if (memory_object->GetNumberOfMembers() !=
351       conststruct_inst->NumInOperands()) {
352     return nullptr;
353   }
354 
355   for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
356     std::unique_ptr<MemoryObject> member_object =
357         GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
358 
359     if (!member_object) {
360       return nullptr;
361     }
362 
363     if (!member_object->IsMember()) {
364       return nullptr;
365     }
366 
367     if (!memory_object->Contains(member_object.get())) {
368       return nullptr;
369     }
370 
371     last_access = member_object->AccessChain().back();
372     if (!IsAccessChainIndexValidAndEqualTo(last_access, i)) {
373       return nullptr;
374     }
375   }
376   return memory_object;
377 }
378 
379 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)380 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
381   assert(insert_inst->opcode() == spv::Op::OpCompositeInsert &&
382          "Expecting an OpCompositeInsert instruction.");
383 
384   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
385   analysis::TypeManager* type_mgr = context()->get_type_mgr();
386   const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
387 
388   uint32_t number_of_elements = GetNumberOfMembers(result_type, context());
389 
390   if (number_of_elements == 0) {
391     return nullptr;
392   }
393 
394   if (insert_inst->NumInOperands() != 3) {
395     return nullptr;
396   }
397 
398   if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
399     return nullptr;
400   }
401 
402   std::unique_ptr<MemoryObject> memory_object =
403       GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
404 
405   if (!memory_object) {
406     return nullptr;
407   }
408 
409   if (!memory_object->IsMember()) {
410     return nullptr;
411   }
412 
413   AccessChainEntry last_access = memory_object->AccessChain().back();
414   if (!IsAccessChainIndexValidAndEqualTo(last_access, number_of_elements - 1)) {
415     return nullptr;
416   }
417 
418   memory_object->PopIndirection();
419 
420   Instruction* current_insert =
421       def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
422   for (uint32_t i = number_of_elements - 1; i > 0; --i) {
423     if (current_insert->opcode() != spv::Op::OpCompositeInsert) {
424       return nullptr;
425     }
426 
427     if (current_insert->NumInOperands() != 3) {
428       return nullptr;
429     }
430 
431     if (current_insert->GetSingleWordInOperand(2) != i - 1) {
432       return nullptr;
433     }
434 
435     std::unique_ptr<MemoryObject> current_memory_object =
436         GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
437 
438     if (!current_memory_object) {
439       return nullptr;
440     }
441 
442     if (!current_memory_object->IsMember()) {
443       return nullptr;
444     }
445 
446     if (memory_object->AccessChain().size() + 1 !=
447         current_memory_object->AccessChain().size()) {
448       return nullptr;
449     }
450 
451     if (!memory_object->Contains(current_memory_object.get())) {
452       return nullptr;
453     }
454 
455     AccessChainEntry current_last_access =
456         current_memory_object->AccessChain().back();
457     if (!IsAccessChainIndexValidAndEqualTo(current_last_access, i - 1)) {
458       return nullptr;
459     }
460     current_insert =
461         def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
462   }
463 
464   return memory_object;
465 }
466 
IsAccessChainIndexValidAndEqualTo(const AccessChainEntry & entry,uint32_t value) const467 bool CopyPropagateArrays::IsAccessChainIndexValidAndEqualTo(
468     const AccessChainEntry& entry, uint32_t value) const {
469   if (!entry.is_result_id) {
470     return entry.immediate == value;
471   }
472 
473   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
474   const analysis::Constant* constant =
475       const_mgr->FindDeclaredConstant(entry.result_id);
476   if (!constant || !constant->type()->AsInteger()) {
477     return false;
478   }
479   return constant->GetU32() == value;
480 }
481 
IsPointerToArrayType(uint32_t type_id)482 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
483   analysis::TypeManager* type_mgr = context()->get_type_mgr();
484   analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
485   if (pointer_type) {
486     return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
487            pointer_type->pointee_type()->kind() == analysis::Type::kImage;
488   }
489   return false;
490 }
491 
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)492 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
493                                         uint32_t type_id) {
494   analysis::TypeManager* type_mgr = context()->get_type_mgr();
495   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
496   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
497 
498   analysis::Type* type = type_mgr->GetType(type_id);
499   if (type->AsRuntimeArray()) {
500     return false;
501   }
502 
503   if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
504     // If the type is not an aggregate, then the desired type must be the
505     // same as the current type.  No work to do, and we can do that.
506     return true;
507   }
508 
509   return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
510                                                        const_mgr,
511                                                        type](Instruction* use,
512                                                              uint32_t) {
513     if (IsDebugDeclareOrValue(use)) return true;
514 
515     switch (use->opcode()) {
516       case spv::Op::OpLoad: {
517         analysis::Pointer* pointer_type = type->AsPointer();
518         uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
519 
520         if (new_type_id != use->type_id()) {
521           return CanUpdateUses(use, new_type_id);
522         }
523         return true;
524       }
525       case spv::Op::OpAccessChain: {
526         analysis::Pointer* pointer_type = type->AsPointer();
527         const analysis::Type* pointee_type = pointer_type->pointee_type();
528 
529         std::vector<uint32_t> access_chain;
530         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
531           const analysis::Constant* index_const =
532               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
533           if (index_const) {
534             access_chain.push_back(index_const->GetU32());
535           } else {
536             // Variable index means the type is a type where every element
537             // is the same type.  Use element 0 to get the type.
538             access_chain.push_back(0);
539 
540             // We are trying to access a struct with variable indices.
541             // This cannot happen.
542             if (pointee_type->kind() == analysis::Type::kStruct) {
543               return false;
544             }
545           }
546         }
547 
548         const analysis::Type* new_pointee_type =
549             type_mgr->GetMemberType(pointee_type, access_chain);
550         analysis::Pointer pointerTy(new_pointee_type,
551                                     pointer_type->storage_class());
552         uint32_t new_pointer_type_id =
553             context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
554         if (new_pointer_type_id == 0) {
555           return false;
556         }
557 
558         if (new_pointer_type_id != use->type_id()) {
559           return CanUpdateUses(use, new_pointer_type_id);
560         }
561         return true;
562       }
563       case spv::Op::OpCompositeExtract: {
564         std::vector<uint32_t> access_chain;
565         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
566           access_chain.push_back(use->GetSingleWordInOperand(i));
567         }
568 
569         const analysis::Type* new_type =
570             type_mgr->GetMemberType(type, access_chain);
571         uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
572         if (new_type_id == 0) {
573           return false;
574         }
575 
576         if (new_type_id != use->type_id()) {
577           return CanUpdateUses(use, new_type_id);
578         }
579         return true;
580       }
581       case spv::Op::OpStore:
582         // If needed, we can create an element-by-element copy to change the
583         // type of the value being stored.  This way we can always handled
584         // stores.
585         return true;
586       case spv::Op::OpImageTexelPointer:
587       case spv::Op::OpName:
588         return true;
589       default:
590         return use->IsDecoration();
591     }
592   });
593 }
594 
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)595 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
596                                      Instruction* new_ptr_inst) {
597   analysis::TypeManager* type_mgr = context()->get_type_mgr();
598   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
599   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
600 
601   std::vector<std::pair<Instruction*, uint32_t> > uses;
602   def_use_mgr->ForEachUse(original_ptr_inst,
603                           [&uses](Instruction* use, uint32_t index) {
604                             uses.push_back({use, index});
605                           });
606 
607   for (auto pair : uses) {
608     Instruction* use = pair.first;
609     uint32_t index = pair.second;
610 
611     if (use->IsCommonDebugInstr()) {
612       switch (use->GetCommonDebugOpcode()) {
613         case CommonDebugInfoDebugDeclare: {
614           if (new_ptr_inst->opcode() == spv::Op::OpVariable ||
615               new_ptr_inst->opcode() == spv::Op::OpFunctionParameter) {
616             context()->ForgetUses(use);
617             use->SetOperand(index, {new_ptr_inst->result_id()});
618             context()->AnalyzeUses(use);
619           } else {
620             // Based on the spec, we cannot use a pointer other than OpVariable
621             // or OpFunctionParameter for DebugDeclare. We have to use
622             // DebugValue with Deref.
623 
624             context()->ForgetUses(use);
625 
626             // Change DebugDeclare to DebugValue.
627             use->SetOperand(index - 2,
628                             {static_cast<uint32_t>(CommonDebugInfoDebugValue)});
629             use->SetOperand(index, {new_ptr_inst->result_id()});
630 
631             // Add Deref operation.
632             Instruction* dbg_expr =
633                 def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1));
634             auto* deref_expr_instr =
635                 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
636             use->SetOperand(index + 1, {deref_expr_instr->result_id()});
637 
638             context()->AnalyzeUses(deref_expr_instr);
639             context()->AnalyzeUses(use);
640           }
641           break;
642         }
643         case CommonDebugInfoDebugValue:
644           context()->ForgetUses(use);
645           use->SetOperand(index, {new_ptr_inst->result_id()});
646           context()->AnalyzeUses(use);
647           break;
648         default:
649           assert(false && "Don't know how to rewrite instruction");
650           break;
651       }
652       continue;
653     }
654 
655     switch (use->opcode()) {
656       case spv::Op::OpLoad: {
657         // Replace the actual use.
658         context()->ForgetUses(use);
659         use->SetOperand(index, {new_ptr_inst->result_id()});
660 
661         // Update the type.
662         Instruction* pointer_type_inst =
663             def_use_mgr->GetDef(new_ptr_inst->type_id());
664         uint32_t new_type_id =
665             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
666         if (new_type_id != use->type_id()) {
667           use->SetResultType(new_type_id);
668           context()->AnalyzeUses(use);
669           UpdateUses(use, use);
670         } else {
671           context()->AnalyzeUses(use);
672         }
673       } break;
674       case spv::Op::OpAccessChain: {
675         // Update the actual use.
676         context()->ForgetUses(use);
677         use->SetOperand(index, {new_ptr_inst->result_id()});
678 
679         // Convert the ids on the OpAccessChain to indices that can be used to
680         // get the specific member.
681         std::vector<uint32_t> access_chain;
682         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
683           const analysis::Constant* index_const =
684               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
685           if (index_const) {
686             access_chain.push_back(index_const->GetU32());
687           } else {
688             // Variable index means the type is an type where every element
689             // is the same type.  Use element 0 to get the type.
690             access_chain.push_back(0);
691           }
692         }
693 
694         Instruction* pointer_type_inst =
695             get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
696 
697         uint32_t new_pointee_type_id = GetMemberTypeId(
698             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
699             access_chain);
700 
701         spv::StorageClass storage_class = static_cast<spv::StorageClass>(
702             pointer_type_inst->GetSingleWordInOperand(
703                 kTypePointerStorageClassInIdx));
704 
705         uint32_t new_pointer_type_id =
706             type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
707 
708         if (new_pointer_type_id != use->type_id()) {
709           use->SetResultType(new_pointer_type_id);
710           context()->AnalyzeUses(use);
711           UpdateUses(use, use);
712         } else {
713           context()->AnalyzeUses(use);
714         }
715       } break;
716       case spv::Op::OpCompositeExtract: {
717         // Update the actual use.
718         context()->ForgetUses(use);
719         use->SetOperand(index, {new_ptr_inst->result_id()});
720 
721         uint32_t new_type_id = new_ptr_inst->type_id();
722         std::vector<uint32_t> access_chain;
723         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
724           access_chain.push_back(use->GetSingleWordInOperand(i));
725         }
726 
727         new_type_id = GetMemberTypeId(new_type_id, access_chain);
728 
729         if (new_type_id != use->type_id()) {
730           use->SetResultType(new_type_id);
731           context()->AnalyzeUses(use);
732           UpdateUses(use, use);
733         } else {
734           context()->AnalyzeUses(use);
735         }
736       } break;
737       case spv::Op::OpStore:
738         // If the use is the pointer, then it is the single store to that
739         // variable.  We do not want to replace it.  Instead, it will become
740         // dead after all of the loads are removed, and ADCE will get rid of it.
741         //
742         // If the use is the object being stored, we will create a copy of the
743         // object turning it into the correct type. The copy is done by
744         // decomposing the object into the base type, which must be the same,
745         // and then rebuilding them.
746         if (index == 1) {
747           Instruction* target_pointer = def_use_mgr->GetDef(
748               use->GetSingleWordInOperand(kStorePointerInOperand));
749           Instruction* pointer_type =
750               def_use_mgr->GetDef(target_pointer->type_id());
751           uint32_t pointee_type_id =
752               pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
753           uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
754 
755           context()->ForgetUses(use);
756           use->SetInOperand(index, {copy});
757           context()->AnalyzeUses(use);
758         }
759         break;
760       case spv::Op::OpDecorate:
761       // We treat an OpImageTexelPointer as a load.  The result type should
762       // always have the Image storage class, and should not need to be
763       // updated.
764       case spv::Op::OpImageTexelPointer:
765         // Replace the actual use.
766         context()->ForgetUses(use);
767         use->SetOperand(index, {new_ptr_inst->result_id()});
768         context()->AnalyzeUses(use);
769         break;
770       default:
771         assert(false && "Don't know how to rewrite instruction");
772         break;
773     }
774   }
775 }
776 
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const777 uint32_t CopyPropagateArrays::GetMemberTypeId(
778     uint32_t id, const std::vector<uint32_t>& access_chain) const {
779   for (uint32_t element_index : access_chain) {
780     Instruction* type_inst = get_def_use_mgr()->GetDef(id);
781     switch (type_inst->opcode()) {
782       case spv::Op::OpTypeArray:
783       case spv::Op::OpTypeRuntimeArray:
784       case spv::Op::OpTypeMatrix:
785       case spv::Op::OpTypeVector:
786         id = type_inst->GetSingleWordInOperand(0);
787         break;
788       case spv::Op::OpTypeStruct:
789         id = type_inst->GetSingleWordInOperand(element_index);
790         break;
791       default:
792         break;
793     }
794     assert(id != 0 &&
795            "Tried to extract from an object where it cannot be done.");
796   }
797   return id;
798 }
799 
PushIndirection(const std::vector<AccessChainEntry> & access_chain)800 void CopyPropagateArrays::MemoryObject::PushIndirection(
801     const std::vector<AccessChainEntry>& access_chain) {
802   access_chain_.insert(access_chain_.end(), access_chain.begin(),
803                        access_chain.end());
804 }
805 
GetNumberOfMembers()806 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
807   IRContext* context = variable_inst_->context();
808   analysis::TypeManager* type_mgr = context->get_type_mgr();
809 
810   const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
811   type = type->AsPointer()->pointee_type();
812 
813   std::vector<uint32_t> access_indices = GetAccessIds();
814   type = type_mgr->GetMemberType(type, access_indices);
815 
816   return opt::GetNumberOfMembers(type, context);
817 }
818 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)819 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
820                                                 iterator begin, iterator end)
821     : variable_inst_(var_inst) {
822   std::transform(begin, end, std::back_inserter(access_chain_),
823                  [](uint32_t id) {
824                    return AccessChainEntry{true, {id}};
825                  });
826 }
827 
GetAccessIds() const828 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
829   analysis::ConstantManager* const_mgr =
830       variable_inst_->context()->get_constant_mgr();
831 
832   std::vector<uint32_t> indices(AccessChain().size());
833   std::transform(AccessChain().cbegin(), AccessChain().cend(), indices.begin(),
834                  [&const_mgr](const AccessChainEntry& entry) {
835                    if (entry.is_result_id) {
836                      const analysis::Constant* constant =
837                          const_mgr->FindDeclaredConstant(entry.result_id);
838                      return constant == nullptr ? 0 : constant->GetU32();
839                    }
840 
841                    return entry.immediate;
842                  });
843   return indices;
844 }
845 
Contains(CopyPropagateArrays::MemoryObject * other)846 bool CopyPropagateArrays::MemoryObject::Contains(
847     CopyPropagateArrays::MemoryObject* other) {
848   if (this->GetVariable() != other->GetVariable()) {
849     return false;
850   }
851 
852   if (AccessChain().size() > other->AccessChain().size()) {
853     return false;
854   }
855 
856   for (uint32_t i = 0; i < AccessChain().size(); i++) {
857     if (AccessChain()[i] != other->AccessChain()[i]) {
858       return false;
859     }
860   }
861   return true;
862 }
863 
BuildConstants()864 void CopyPropagateArrays::MemoryObject::BuildConstants() {
865   for (auto& entry : access_chain_) {
866     if (entry.is_result_id) {
867       continue;
868     }
869 
870     auto context = variable_inst_->context();
871     analysis::Integer int_type(32, false);
872     const analysis::Type* uint32_type =
873         context->get_type_mgr()->GetRegisteredType(&int_type);
874     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
875     const analysis::Constant* index_const =
876         const_mgr->GetConstant(uint32_type, {entry.immediate});
877     entry.result_id =
878         const_mgr->GetDefiningInstruction(index_const)->result_id();
879     entry.is_result_id = true;
880   }
881 }
882 
883 }  // namespace opt
884 }  // namespace spvtools
885