• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/copy_prop_arrays.h"
16 
17 #include <utility>
18 
19 #include "source/opt/ir_builder.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 
25 constexpr uint32_t kLoadPointerInOperand = 0;
26 constexpr uint32_t kStorePointerInOperand = 0;
27 constexpr uint32_t kStoreObjectInOperand = 1;
28 constexpr uint32_t kCompositeExtractObjectInOperand = 0;
29 constexpr uint32_t kTypePointerStorageClassInIdx = 0;
30 constexpr uint32_t kTypePointerPointeeInIdx = 1;
31 
IsDebugDeclareOrValue(Instruction * di)32 bool IsDebugDeclareOrValue(Instruction* di) {
33   auto dbg_opcode = di->GetCommonDebugOpcode();
34   return dbg_opcode == CommonDebugInfoDebugDeclare ||
35          dbg_opcode == CommonDebugInfoDebugValue;
36 }
37 
38 }  // namespace
39 
Process()40 Pass::Status CopyPropagateArrays::Process() {
41   bool modified = false;
42   for (Function& function : *get_module()) {
43     if (function.IsDeclaration()) {
44       continue;
45     }
46 
47     BasicBlock* entry_bb = &*function.begin();
48 
49     for (auto var_inst = entry_bb->begin();
50          var_inst->opcode() == spv::Op::OpVariable; ++var_inst) {
51       if (!IsPointerToArrayType(var_inst->type_id())) {
52         continue;
53       }
54 
55       // Find the only store to the entire memory location, if it exists.
56       Instruction* store_inst = FindStoreInstruction(&*var_inst);
57 
58       if (!store_inst) {
59         continue;
60       }
61 
62       std::unique_ptr<MemoryObject> source_object =
63           FindSourceObjectIfPossible(&*var_inst, store_inst);
64 
65       if (source_object != nullptr) {
66         if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
67           modified = true;
68           PropagateObject(&*var_inst, source_object.get(), store_inst);
69         }
70       }
71     }
72   }
73   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
74 }
75 
76 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)77 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
78                                                 Instruction* store_inst) {
79   assert(var_inst->opcode() == spv::Op::OpVariable && "Expecting a variable.");
80 
81   // Check that the variable is a composite object where |store_inst|
82   // dominates all of its loads.
83   if (!store_inst) {
84     return nullptr;
85   }
86 
87   // Look at the loads to ensure they are dominated by the store.
88   if (!HasValidReferencesOnly(var_inst, store_inst)) {
89     return nullptr;
90   }
91 
92   // If so, look at the store to see if it is the copy of an object.
93   std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
94       store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
95 
96   if (!source) {
97     return nullptr;
98   }
99 
100   // Ensure that |source| does not change between the point at which it is
101   // loaded, and the position in which |var_inst| is loaded.
102   //
103   // For now we will go with the easy to implement approach, and check that the
104   // entire variable (not just the specific component) is never written to.
105 
106   if (!HasNoStores(source->GetVariable())) {
107     return nullptr;
108   }
109   return source;
110 }
111 
FindStoreInstruction(const Instruction * var_inst) const112 Instruction* CopyPropagateArrays::FindStoreInstruction(
113     const Instruction* var_inst) const {
114   Instruction* store_inst = nullptr;
115   get_def_use_mgr()->WhileEachUser(
116       var_inst, [&store_inst, var_inst](Instruction* use) {
117         if (use->opcode() == spv::Op::OpStore &&
118             use->GetSingleWordInOperand(kStorePointerInOperand) ==
119                 var_inst->result_id()) {
120           if (store_inst == nullptr) {
121             store_inst = use;
122           } else {
123             store_inst = nullptr;
124             return false;
125           }
126         }
127         return true;
128       });
129   return store_inst;
130 }
131 
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)132 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
133                                           MemoryObject* source,
134                                           Instruction* insertion_point) {
135   assert(var_inst->opcode() == spv::Op::OpVariable &&
136          "This function propagates variables.");
137 
138   Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
139   context()->KillNamesAndDecorates(var_inst);
140   UpdateUses(var_inst, new_access_chain);
141 }
142 
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const143 Instruction* CopyPropagateArrays::BuildNewAccessChain(
144     Instruction* insertion_point,
145     CopyPropagateArrays::MemoryObject* source) const {
146   InstructionBuilder builder(
147       context(), insertion_point,
148       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
149 
150   if (source->AccessChain().size() == 0) {
151     return source->GetVariable();
152   }
153 
154   source->BuildConstants();
155   std::vector<uint32_t> access_ids(source->AccessChain().size());
156   std::transform(
157       source->AccessChain().cbegin(), source->AccessChain().cend(),
158       access_ids.begin(), [](const AccessChainEntry& entry) {
159         assert(entry.is_result_id && "Constants needs to be built first.");
160         return entry.result_id;
161       });
162 
163   return builder.AddAccessChain(source->GetPointerTypeId(this),
164                                 source->GetVariable()->result_id(), access_ids);
165 }
166 
HasNoStores(Instruction * ptr_inst)167 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
168   return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
169     if (use->opcode() == spv::Op::OpLoad) {
170       return true;
171     } else if (use->opcode() == spv::Op::OpAccessChain) {
172       return HasNoStores(use);
173     } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
174       return true;
175     } else if (use->opcode() == spv::Op::OpStore) {
176       return false;
177     } else if (use->opcode() == spv::Op::OpImageTexelPointer) {
178       return true;
179     } else if (use->opcode() == spv::Op::OpEntryPoint) {
180       return true;
181     }
182     // Some other instruction.  Be conservative.
183     return false;
184   });
185 }
186 
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)187 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
188                                                  Instruction* store_inst) {
189   BasicBlock* store_block = context()->get_instr_block(store_inst);
190   DominatorAnalysis* dominator_analysis =
191       context()->GetDominatorAnalysis(store_block->GetParent());
192 
193   return get_def_use_mgr()->WhileEachUser(
194       ptr_inst,
195       [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
196         if (use->opcode() == spv::Op::OpLoad ||
197             use->opcode() == spv::Op::OpImageTexelPointer) {
198           // TODO: If there are many load in the same BB as |store_inst| the
199           // time to do the multiple traverses can add up.  Consider collecting
200           // those loads and doing a single traversal.
201           return dominator_analysis->Dominates(store_inst, use);
202         } else if (use->opcode() == spv::Op::OpAccessChain) {
203           return HasValidReferencesOnly(use, store_inst);
204         } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
205           return true;
206         } else if (use->opcode() == spv::Op::OpStore) {
207           // If we are storing to part of the object it is not an candidate.
208           return ptr_inst->opcode() == spv::Op::OpVariable &&
209                  store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
210                      ptr_inst->result_id();
211         } else if (IsDebugDeclareOrValue(use)) {
212           return true;
213         }
214         // Some other instruction.  Be conservative.
215         return false;
216       });
217 }
218 
219 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)220 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
221   Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
222 
223   switch (result_inst->opcode()) {
224     case spv::Op::OpLoad:
225       return BuildMemoryObjectFromLoad(result_inst);
226     case spv::Op::OpCompositeExtract:
227       return BuildMemoryObjectFromExtract(result_inst);
228     case spv::Op::OpCompositeConstruct:
229       return BuildMemoryObjectFromCompositeConstruct(result_inst);
230     case spv::Op::OpCopyObject:
231       return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
232     case spv::Op::OpCompositeInsert:
233       return BuildMemoryObjectFromInsert(result_inst);
234     default:
235       return nullptr;
236   }
237 }
238 
239 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)240 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
241   std::vector<uint32_t> components_in_reverse;
242   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
243 
244   Instruction* current_inst = def_use_mgr->GetDef(
245       load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
246 
247   // Build the access chain for the memory object by collecting the indices used
248   // in the OpAccessChain instructions.  If we find a variable index, then
249   // return |nullptr| because we cannot know for sure which memory location is
250   // used.
251   //
252   // It is built in reverse order because the different |OpAccessChain|
253   // instructions are visited in reverse order from which they are applied.
254   while (current_inst->opcode() == spv::Op::OpAccessChain) {
255     for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
256       uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
257       components_in_reverse.push_back(element_index_id);
258     }
259     current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
260   }
261 
262   // If the address in the load is not constructed from an |OpVariable|
263   // instruction followed by a series of |OpAccessChain| instructions, then
264   // return |nullptr| because we cannot identify the owner or access chain
265   // exactly.
266   if (current_inst->opcode() != spv::Op::OpVariable) {
267     return nullptr;
268   }
269 
270   // Build the memory object.  Use |rbegin| and |rend| to put the access chain
271   // back in the correct order.
272   return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
273       new MemoryObject(current_inst, components_in_reverse.rbegin(),
274                        components_in_reverse.rend()));
275 }
276 
277 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)278 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
279   assert(extract_inst->opcode() == spv::Op::OpCompositeExtract &&
280          "Expecting an OpCompositeExtract instruction.");
281   std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
282       extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
283 
284   if (!result) {
285     return nullptr;
286   }
287 
288   // Copy the indices of the extract instruction to |OpAccessChain| indices.
289   std::vector<AccessChainEntry> components;
290   for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
291     components.push_back({false, {extract_inst->GetSingleWordInOperand(i)}});
292   }
293   result->PushIndirection(components);
294   return result;
295 }
296 
297 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)298 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
299     Instruction* conststruct_inst) {
300   assert(conststruct_inst->opcode() == spv::Op::OpCompositeConstruct &&
301          "Expecting an OpCompositeConstruct instruction.");
302 
303   // If every operand in the instruction are part of the same memory object, and
304   // are being combined in the same order, then the result is the same as the
305   // parent.
306 
307   std::unique_ptr<MemoryObject> memory_object =
308       GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
309 
310   if (!memory_object) {
311     return nullptr;
312   }
313 
314   if (!memory_object->IsMember()) {
315     return nullptr;
316   }
317 
318   AccessChainEntry last_access = memory_object->AccessChain().back();
319   if (!IsAccessChainIndexValidAndEqualTo(last_access, 0)) {
320     return nullptr;
321   }
322 
323   memory_object->PopIndirection();
324   if (memory_object->GetNumberOfMembers() !=
325       conststruct_inst->NumInOperands()) {
326     return nullptr;
327   }
328 
329   for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
330     std::unique_ptr<MemoryObject> member_object =
331         GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
332 
333     if (!member_object) {
334       return nullptr;
335     }
336 
337     if (!member_object->IsMember()) {
338       return nullptr;
339     }
340 
341     if (!memory_object->Contains(member_object.get())) {
342       return nullptr;
343     }
344 
345     last_access = member_object->AccessChain().back();
346     if (!IsAccessChainIndexValidAndEqualTo(last_access, i)) {
347       return nullptr;
348     }
349   }
350   return memory_object;
351 }
352 
353 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)354 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
355   assert(insert_inst->opcode() == spv::Op::OpCompositeInsert &&
356          "Expecting an OpCompositeInsert instruction.");
357 
358   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
359   analysis::TypeManager* type_mgr = context()->get_type_mgr();
360   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
361   const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
362 
363   uint32_t number_of_elements = 0;
364   if (const analysis::Struct* struct_type = result_type->AsStruct()) {
365     number_of_elements =
366         static_cast<uint32_t>(struct_type->element_types().size());
367   } else if (const analysis::Array* array_type = result_type->AsArray()) {
368     const analysis::Constant* length_const =
369         const_mgr->FindDeclaredConstant(array_type->LengthId());
370     number_of_elements = length_const->GetU32();
371   } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
372     number_of_elements = vector_type->element_count();
373   } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
374     number_of_elements = matrix_type->element_count();
375   }
376 
377   if (number_of_elements == 0) {
378     return nullptr;
379   }
380 
381   if (insert_inst->NumInOperands() != 3) {
382     return nullptr;
383   }
384 
385   if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
386     return nullptr;
387   }
388 
389   std::unique_ptr<MemoryObject> memory_object =
390       GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
391 
392   if (!memory_object) {
393     return nullptr;
394   }
395 
396   if (!memory_object->IsMember()) {
397     return nullptr;
398   }
399 
400   AccessChainEntry last_access = memory_object->AccessChain().back();
401   if (!IsAccessChainIndexValidAndEqualTo(last_access, number_of_elements - 1)) {
402     return nullptr;
403   }
404 
405   memory_object->PopIndirection();
406 
407   Instruction* current_insert =
408       def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
409   for (uint32_t i = number_of_elements - 1; i > 0; --i) {
410     if (current_insert->opcode() != spv::Op::OpCompositeInsert) {
411       return nullptr;
412     }
413 
414     if (current_insert->NumInOperands() != 3) {
415       return nullptr;
416     }
417 
418     if (current_insert->GetSingleWordInOperand(2) != i - 1) {
419       return nullptr;
420     }
421 
422     std::unique_ptr<MemoryObject> current_memory_object =
423         GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
424 
425     if (!current_memory_object) {
426       return nullptr;
427     }
428 
429     if (!current_memory_object->IsMember()) {
430       return nullptr;
431     }
432 
433     if (memory_object->AccessChain().size() + 1 !=
434         current_memory_object->AccessChain().size()) {
435       return nullptr;
436     }
437 
438     if (!memory_object->Contains(current_memory_object.get())) {
439       return nullptr;
440     }
441 
442     AccessChainEntry current_last_access =
443         current_memory_object->AccessChain().back();
444     if (!IsAccessChainIndexValidAndEqualTo(current_last_access, i - 1)) {
445       return nullptr;
446     }
447     current_insert =
448         def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
449   }
450 
451   return memory_object;
452 }
453 
IsAccessChainIndexValidAndEqualTo(const AccessChainEntry & entry,uint32_t value) const454 bool CopyPropagateArrays::IsAccessChainIndexValidAndEqualTo(
455     const AccessChainEntry& entry, uint32_t value) const {
456   if (!entry.is_result_id) {
457     return entry.immediate == value;
458   }
459 
460   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
461   const analysis::Constant* constant =
462       const_mgr->FindDeclaredConstant(entry.result_id);
463   if (!constant || !constant->type()->AsInteger()) {
464     return false;
465   }
466   return constant->GetU32() == value;
467 }
468 
IsPointerToArrayType(uint32_t type_id)469 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
470   analysis::TypeManager* type_mgr = context()->get_type_mgr();
471   analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
472   if (pointer_type) {
473     return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
474            pointer_type->pointee_type()->kind() == analysis::Type::kImage;
475   }
476   return false;
477 }
478 
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)479 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
480                                         uint32_t type_id) {
481   analysis::TypeManager* type_mgr = context()->get_type_mgr();
482   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
483   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
484 
485   analysis::Type* type = type_mgr->GetType(type_id);
486   if (type->AsRuntimeArray()) {
487     return false;
488   }
489 
490   if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
491     // If the type is not an aggregate, then the desired type must be the
492     // same as the current type.  No work to do, and we can do that.
493     return true;
494   }
495 
496   return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
497                                                        const_mgr,
498                                                        type](Instruction* use,
499                                                              uint32_t) {
500     if (IsDebugDeclareOrValue(use)) return true;
501 
502     switch (use->opcode()) {
503       case spv::Op::OpLoad: {
504         analysis::Pointer* pointer_type = type->AsPointer();
505         uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
506 
507         if (new_type_id != use->type_id()) {
508           return CanUpdateUses(use, new_type_id);
509         }
510         return true;
511       }
512       case spv::Op::OpAccessChain: {
513         analysis::Pointer* pointer_type = type->AsPointer();
514         const analysis::Type* pointee_type = pointer_type->pointee_type();
515 
516         std::vector<uint32_t> access_chain;
517         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
518           const analysis::Constant* index_const =
519               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
520           if (index_const) {
521             access_chain.push_back(index_const->GetU32());
522           } else {
523             // Variable index means the type is a type where every element
524             // is the same type.  Use element 0 to get the type.
525             access_chain.push_back(0);
526 
527             // We are trying to access a struct with variable indices.
528             // This cannot happen.
529             if (pointee_type->kind() == analysis::Type::kStruct) {
530               return false;
531             }
532           }
533         }
534 
535         const analysis::Type* new_pointee_type =
536             type_mgr->GetMemberType(pointee_type, access_chain);
537         analysis::Pointer pointerTy(new_pointee_type,
538                                     pointer_type->storage_class());
539         uint32_t new_pointer_type_id =
540             context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
541         if (new_pointer_type_id == 0) {
542           return false;
543         }
544 
545         if (new_pointer_type_id != use->type_id()) {
546           return CanUpdateUses(use, new_pointer_type_id);
547         }
548         return true;
549       }
550       case spv::Op::OpCompositeExtract: {
551         std::vector<uint32_t> access_chain;
552         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
553           access_chain.push_back(use->GetSingleWordInOperand(i));
554         }
555 
556         const analysis::Type* new_type =
557             type_mgr->GetMemberType(type, access_chain);
558         uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
559         if (new_type_id == 0) {
560           return false;
561         }
562 
563         if (new_type_id != use->type_id()) {
564           return CanUpdateUses(use, new_type_id);
565         }
566         return true;
567       }
568       case spv::Op::OpStore:
569         // If needed, we can create an element-by-element copy to change the
570         // type of the value being stored.  This way we can always handled
571         // stores.
572         return true;
573       case spv::Op::OpImageTexelPointer:
574       case spv::Op::OpName:
575         return true;
576       default:
577         return use->IsDecoration();
578     }
579   });
580 }
581 
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)582 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
583                                      Instruction* new_ptr_inst) {
584   analysis::TypeManager* type_mgr = context()->get_type_mgr();
585   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
586   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
587 
588   std::vector<std::pair<Instruction*, uint32_t> > uses;
589   def_use_mgr->ForEachUse(original_ptr_inst,
590                           [&uses](Instruction* use, uint32_t index) {
591                             uses.push_back({use, index});
592                           });
593 
594   for (auto pair : uses) {
595     Instruction* use = pair.first;
596     uint32_t index = pair.second;
597 
598     if (use->IsCommonDebugInstr()) {
599       switch (use->GetCommonDebugOpcode()) {
600         case CommonDebugInfoDebugDeclare: {
601           if (new_ptr_inst->opcode() == spv::Op::OpVariable ||
602               new_ptr_inst->opcode() == spv::Op::OpFunctionParameter) {
603             context()->ForgetUses(use);
604             use->SetOperand(index, {new_ptr_inst->result_id()});
605             context()->AnalyzeUses(use);
606           } else {
607             // Based on the spec, we cannot use a pointer other than OpVariable
608             // or OpFunctionParameter for DebugDeclare. We have to use
609             // DebugValue with Deref.
610 
611             context()->ForgetUses(use);
612 
613             // Change DebugDeclare to DebugValue.
614             use->SetOperand(index - 2,
615                             {static_cast<uint32_t>(CommonDebugInfoDebugValue)});
616             use->SetOperand(index, {new_ptr_inst->result_id()});
617 
618             // Add Deref operation.
619             Instruction* dbg_expr =
620                 def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1));
621             auto* deref_expr_instr =
622                 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
623             use->SetOperand(index + 1, {deref_expr_instr->result_id()});
624 
625             context()->AnalyzeUses(deref_expr_instr);
626             context()->AnalyzeUses(use);
627           }
628           break;
629         }
630         case CommonDebugInfoDebugValue:
631           context()->ForgetUses(use);
632           use->SetOperand(index, {new_ptr_inst->result_id()});
633           context()->AnalyzeUses(use);
634           break;
635         default:
636           assert(false && "Don't know how to rewrite instruction");
637           break;
638       }
639       continue;
640     }
641 
642     switch (use->opcode()) {
643       case spv::Op::OpLoad: {
644         // Replace the actual use.
645         context()->ForgetUses(use);
646         use->SetOperand(index, {new_ptr_inst->result_id()});
647 
648         // Update the type.
649         Instruction* pointer_type_inst =
650             def_use_mgr->GetDef(new_ptr_inst->type_id());
651         uint32_t new_type_id =
652             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
653         if (new_type_id != use->type_id()) {
654           use->SetResultType(new_type_id);
655           context()->AnalyzeUses(use);
656           UpdateUses(use, use);
657         } else {
658           context()->AnalyzeUses(use);
659         }
660       } break;
661       case spv::Op::OpAccessChain: {
662         // Update the actual use.
663         context()->ForgetUses(use);
664         use->SetOperand(index, {new_ptr_inst->result_id()});
665 
666         // Convert the ids on the OpAccessChain to indices that can be used to
667         // get the specific member.
668         std::vector<uint32_t> access_chain;
669         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
670           const analysis::Constant* index_const =
671               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
672           if (index_const) {
673             access_chain.push_back(index_const->GetU32());
674           } else {
675             // Variable index means the type is an type where every element
676             // is the same type.  Use element 0 to get the type.
677             access_chain.push_back(0);
678           }
679         }
680 
681         Instruction* pointer_type_inst =
682             get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
683 
684         uint32_t new_pointee_type_id = GetMemberTypeId(
685             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
686             access_chain);
687 
688         spv::StorageClass storage_class = static_cast<spv::StorageClass>(
689             pointer_type_inst->GetSingleWordInOperand(
690                 kTypePointerStorageClassInIdx));
691 
692         uint32_t new_pointer_type_id =
693             type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
694 
695         if (new_pointer_type_id != use->type_id()) {
696           use->SetResultType(new_pointer_type_id);
697           context()->AnalyzeUses(use);
698           UpdateUses(use, use);
699         } else {
700           context()->AnalyzeUses(use);
701         }
702       } break;
703       case spv::Op::OpCompositeExtract: {
704         // Update the actual use.
705         context()->ForgetUses(use);
706         use->SetOperand(index, {new_ptr_inst->result_id()});
707 
708         uint32_t new_type_id = new_ptr_inst->type_id();
709         std::vector<uint32_t> access_chain;
710         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
711           access_chain.push_back(use->GetSingleWordInOperand(i));
712         }
713 
714         new_type_id = GetMemberTypeId(new_type_id, access_chain);
715 
716         if (new_type_id != use->type_id()) {
717           use->SetResultType(new_type_id);
718           context()->AnalyzeUses(use);
719           UpdateUses(use, use);
720         } else {
721           context()->AnalyzeUses(use);
722         }
723       } break;
724       case spv::Op::OpStore:
725         // If the use is the pointer, then it is the single store to that
726         // variable.  We do not want to replace it.  Instead, it will become
727         // dead after all of the loads are removed, and ADCE will get rid of it.
728         //
729         // If the use is the object being stored, we will create a copy of the
730         // object turning it into the correct type. The copy is done by
731         // decomposing the object into the base type, which must be the same,
732         // and then rebuilding them.
733         if (index == 1) {
734           Instruction* target_pointer = def_use_mgr->GetDef(
735               use->GetSingleWordInOperand(kStorePointerInOperand));
736           Instruction* pointer_type =
737               def_use_mgr->GetDef(target_pointer->type_id());
738           uint32_t pointee_type_id =
739               pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
740           uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
741 
742           context()->ForgetUses(use);
743           use->SetInOperand(index, {copy});
744           context()->AnalyzeUses(use);
745         }
746         break;
747       case spv::Op::OpDecorate:
748       // We treat an OpImageTexelPointer as a load.  The result type should
749       // always have the Image storage class, and should not need to be
750       // updated.
751       case spv::Op::OpImageTexelPointer:
752         // Replace the actual use.
753         context()->ForgetUses(use);
754         use->SetOperand(index, {new_ptr_inst->result_id()});
755         context()->AnalyzeUses(use);
756         break;
757       default:
758         assert(false && "Don't know how to rewrite instruction");
759         break;
760     }
761   }
762 }
763 
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const764 uint32_t CopyPropagateArrays::GetMemberTypeId(
765     uint32_t id, const std::vector<uint32_t>& access_chain) const {
766   for (uint32_t element_index : access_chain) {
767     Instruction* type_inst = get_def_use_mgr()->GetDef(id);
768     switch (type_inst->opcode()) {
769       case spv::Op::OpTypeArray:
770       case spv::Op::OpTypeRuntimeArray:
771       case spv::Op::OpTypeMatrix:
772       case spv::Op::OpTypeVector:
773         id = type_inst->GetSingleWordInOperand(0);
774         break;
775       case spv::Op::OpTypeStruct:
776         id = type_inst->GetSingleWordInOperand(element_index);
777         break;
778       default:
779         break;
780     }
781     assert(id != 0 &&
782            "Tried to extract from an object where it cannot be done.");
783   }
784   return id;
785 }
786 
PushIndirection(const std::vector<AccessChainEntry> & access_chain)787 void CopyPropagateArrays::MemoryObject::PushIndirection(
788     const std::vector<AccessChainEntry>& access_chain) {
789   access_chain_.insert(access_chain_.end(), access_chain.begin(),
790                        access_chain.end());
791 }
792 
GetNumberOfMembers()793 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
794   IRContext* context = variable_inst_->context();
795   analysis::TypeManager* type_mgr = context->get_type_mgr();
796 
797   const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
798   type = type->AsPointer()->pointee_type();
799 
800   std::vector<uint32_t> access_indices = GetAccessIds();
801   type = type_mgr->GetMemberType(type, access_indices);
802 
803   if (const analysis::Struct* struct_type = type->AsStruct()) {
804     return static_cast<uint32_t>(struct_type->element_types().size());
805   } else if (const analysis::Array* array_type = type->AsArray()) {
806     const analysis::Constant* length_const =
807         context->get_constant_mgr()->FindDeclaredConstant(
808             array_type->LengthId());
809     assert(length_const->type()->AsInteger());
810     return length_const->GetU32();
811   } else if (const analysis::Vector* vector_type = type->AsVector()) {
812     return vector_type->element_count();
813   } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
814     return matrix_type->element_count();
815   } else {
816     return 0;
817   }
818 }
819 
820 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)821 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
822                                                 iterator begin, iterator end)
823     : variable_inst_(var_inst) {
824   std::transform(begin, end, std::back_inserter(access_chain_),
825                  [](uint32_t id) {
826                    return AccessChainEntry{true, {id}};
827                  });
828 }
829 
GetAccessIds() const830 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
831   analysis::ConstantManager* const_mgr =
832       variable_inst_->context()->get_constant_mgr();
833 
834   std::vector<uint32_t> indices(AccessChain().size());
835   std::transform(AccessChain().cbegin(), AccessChain().cend(), indices.begin(),
836                  [&const_mgr](const AccessChainEntry& entry) {
837                    if (entry.is_result_id) {
838                      const analysis::Constant* constant =
839                          const_mgr->FindDeclaredConstant(entry.result_id);
840                      return constant == nullptr ? 0 : constant->GetU32();
841                    }
842 
843                    return entry.immediate;
844                  });
845   return indices;
846 }
847 
Contains(CopyPropagateArrays::MemoryObject * other)848 bool CopyPropagateArrays::MemoryObject::Contains(
849     CopyPropagateArrays::MemoryObject* other) {
850   if (this->GetVariable() != other->GetVariable()) {
851     return false;
852   }
853 
854   if (AccessChain().size() > other->AccessChain().size()) {
855     return false;
856   }
857 
858   for (uint32_t i = 0; i < AccessChain().size(); i++) {
859     if (AccessChain()[i] != other->AccessChain()[i]) {
860       return false;
861     }
862   }
863   return true;
864 }
865 
BuildConstants()866 void CopyPropagateArrays::MemoryObject::BuildConstants() {
867   for (auto& entry : access_chain_) {
868     if (entry.is_result_id) {
869       continue;
870     }
871 
872     auto context = variable_inst_->context();
873     analysis::Integer int_type(32, false);
874     const analysis::Type* uint32_type =
875         context->get_type_mgr()->GetRegisteredType(&int_type);
876     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
877     const analysis::Constant* index_const =
878         const_mgr->GetConstant(uint32_type, {entry.immediate});
879     entry.result_id =
880         const_mgr->GetDefiningInstruction(index_const)->result_id();
881     entry.is_result_id = true;
882   }
883 }
884 
885 }  // namespace opt
886 }  // namespace spvtools
887