1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/copy_prop_arrays.h"
16
17 #include <utility>
18
19 #include "source/opt/ir_builder.h"
20
21 namespace spvtools {
22 namespace opt {
23 namespace {
24
25 const uint32_t kLoadPointerInOperand = 0;
26 const uint32_t kStorePointerInOperand = 0;
27 const uint32_t kStoreObjectInOperand = 1;
28 const uint32_t kCompositeExtractObjectInOperand = 0;
29 const uint32_t kTypePointerStorageClassInIdx = 0;
30 const uint32_t kTypePointerPointeeInIdx = 1;
31
IsDebugDeclareOrValue(Instruction * di)32 bool IsDebugDeclareOrValue(Instruction* di) {
33 auto dbg_opcode = di->GetCommonDebugOpcode();
34 return dbg_opcode == CommonDebugInfoDebugDeclare ||
35 dbg_opcode == CommonDebugInfoDebugValue;
36 }
37
38 } // namespace
39
Process()40 Pass::Status CopyPropagateArrays::Process() {
41 bool modified = false;
42 for (Function& function : *get_module()) {
43 if (function.IsDeclaration()) {
44 continue;
45 }
46
47 BasicBlock* entry_bb = &*function.begin();
48
49 for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
50 ++var_inst) {
51 if (!IsPointerToArrayType(var_inst->type_id())) {
52 continue;
53 }
54
55 // Find the only store to the entire memory location, if it exists.
56 Instruction* store_inst = FindStoreInstruction(&*var_inst);
57
58 if (!store_inst) {
59 continue;
60 }
61
62 std::unique_ptr<MemoryObject> source_object =
63 FindSourceObjectIfPossible(&*var_inst, store_inst);
64
65 if (source_object != nullptr) {
66 if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
67 modified = true;
68 PropagateObject(&*var_inst, source_object.get(), store_inst);
69 }
70 }
71 }
72 }
73 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
74 }
75
76 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)77 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
78 Instruction* store_inst) {
79 assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
80
81 // Check that the variable is a composite object where |store_inst|
82 // dominates all of its loads.
83 if (!store_inst) {
84 return nullptr;
85 }
86
87 // Look at the loads to ensure they are dominated by the store.
88 if (!HasValidReferencesOnly(var_inst, store_inst)) {
89 return nullptr;
90 }
91
92 // If so, look at the store to see if it is the copy of an object.
93 std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
94 store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
95
96 if (!source) {
97 return nullptr;
98 }
99
100 // Ensure that |source| does not change between the point at which it is
101 // loaded, and the position in which |var_inst| is loaded.
102 //
103 // For now we will go with the easy to implement approach, and check that the
104 // entire variable (not just the specific component) is never written to.
105
106 if (!HasNoStores(source->GetVariable())) {
107 return nullptr;
108 }
109 return source;
110 }
111
FindStoreInstruction(const Instruction * var_inst) const112 Instruction* CopyPropagateArrays::FindStoreInstruction(
113 const Instruction* var_inst) const {
114 Instruction* store_inst = nullptr;
115 get_def_use_mgr()->WhileEachUser(
116 var_inst, [&store_inst, var_inst](Instruction* use) {
117 if (use->opcode() == SpvOpStore &&
118 use->GetSingleWordInOperand(kStorePointerInOperand) ==
119 var_inst->result_id()) {
120 if (store_inst == nullptr) {
121 store_inst = use;
122 } else {
123 store_inst = nullptr;
124 return false;
125 }
126 }
127 return true;
128 });
129 return store_inst;
130 }
131
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)132 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
133 MemoryObject* source,
134 Instruction* insertion_point) {
135 assert(var_inst->opcode() == SpvOpVariable &&
136 "This function propagates variables.");
137
138 Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
139 context()->KillNamesAndDecorates(var_inst);
140 UpdateUses(var_inst, new_access_chain);
141 }
142
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const143 Instruction* CopyPropagateArrays::BuildNewAccessChain(
144 Instruction* insertion_point,
145 CopyPropagateArrays::MemoryObject* source) const {
146 InstructionBuilder builder(
147 context(), insertion_point,
148 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
149
150 if (source->AccessChain().size() == 0) {
151 return source->GetVariable();
152 }
153
154 return builder.AddAccessChain(source->GetPointerTypeId(this),
155 source->GetVariable()->result_id(),
156 source->AccessChain());
157 }
158
HasNoStores(Instruction * ptr_inst)159 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
160 return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
161 if (use->opcode() == SpvOpLoad) {
162 return true;
163 } else if (use->opcode() == SpvOpAccessChain) {
164 return HasNoStores(use);
165 } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
166 return true;
167 } else if (use->opcode() == SpvOpStore) {
168 return false;
169 } else if (use->opcode() == SpvOpImageTexelPointer) {
170 return true;
171 }
172 // Some other instruction. Be conservative.
173 return false;
174 });
175 }
176
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)177 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
178 Instruction* store_inst) {
179 BasicBlock* store_block = context()->get_instr_block(store_inst);
180 DominatorAnalysis* dominator_analysis =
181 context()->GetDominatorAnalysis(store_block->GetParent());
182
183 return get_def_use_mgr()->WhileEachUser(
184 ptr_inst,
185 [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
186 if (use->opcode() == SpvOpLoad ||
187 use->opcode() == SpvOpImageTexelPointer) {
188 // TODO: If there are many load in the same BB as |store_inst| the
189 // time to do the multiple traverses can add up. Consider collecting
190 // those loads and doing a single traversal.
191 return dominator_analysis->Dominates(store_inst, use);
192 } else if (use->opcode() == SpvOpAccessChain) {
193 return HasValidReferencesOnly(use, store_inst);
194 } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
195 return true;
196 } else if (use->opcode() == SpvOpStore) {
197 // If we are storing to part of the object it is not an candidate.
198 return ptr_inst->opcode() == SpvOpVariable &&
199 store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
200 ptr_inst->result_id();
201 } else if (IsDebugDeclareOrValue(use)) {
202 return true;
203 }
204 // Some other instruction. Be conservative.
205 return false;
206 });
207 }
208
209 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)210 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
211 Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
212
213 switch (result_inst->opcode()) {
214 case SpvOpLoad:
215 return BuildMemoryObjectFromLoad(result_inst);
216 case SpvOpCompositeExtract:
217 return BuildMemoryObjectFromExtract(result_inst);
218 case SpvOpCompositeConstruct:
219 return BuildMemoryObjectFromCompositeConstruct(result_inst);
220 case SpvOpCopyObject:
221 return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
222 case SpvOpCompositeInsert:
223 return BuildMemoryObjectFromInsert(result_inst);
224 default:
225 return nullptr;
226 }
227 }
228
229 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)230 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
231 std::vector<uint32_t> components_in_reverse;
232 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
233
234 Instruction* current_inst = def_use_mgr->GetDef(
235 load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
236
237 // Build the access chain for the memory object by collecting the indices used
238 // in the OpAccessChain instructions. If we find a variable index, then
239 // return |nullptr| because we cannot know for sure which memory location is
240 // used.
241 //
242 // It is built in reverse order because the different |OpAccessChain|
243 // instructions are visited in reverse order from which they are applied.
244 while (current_inst->opcode() == SpvOpAccessChain) {
245 for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
246 uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
247 components_in_reverse.push_back(element_index_id);
248 }
249 current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
250 }
251
252 // If the address in the load is not constructed from an |OpVariable|
253 // instruction followed by a series of |OpAccessChain| instructions, then
254 // return |nullptr| because we cannot identify the owner or access chain
255 // exactly.
256 if (current_inst->opcode() != SpvOpVariable) {
257 return nullptr;
258 }
259
260 // Build the memory object. Use |rbegin| and |rend| to put the access chain
261 // back in the correct order.
262 return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
263 new MemoryObject(current_inst, components_in_reverse.rbegin(),
264 components_in_reverse.rend()));
265 }
266
267 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)268 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
269 assert(extract_inst->opcode() == SpvOpCompositeExtract &&
270 "Expecting an OpCompositeExtract instruction.");
271 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
272
273 std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
274 extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
275
276 if (result) {
277 analysis::Integer int_type(32, false);
278 const analysis::Type* uint32_type =
279 context()->get_type_mgr()->GetRegisteredType(&int_type);
280
281 std::vector<uint32_t> components;
282 // Convert the indices in the extract instruction to a series of ids that
283 // can be used by the |OpAccessChain| instruction.
284 for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
285 uint32_t index = extract_inst->GetSingleWordInOperand(i);
286 const analysis::Constant* index_const =
287 const_mgr->GetConstant(uint32_type, {index});
288 components.push_back(
289 const_mgr->GetDefiningInstruction(index_const)->result_id());
290 }
291 result->GetMember(components);
292 return result;
293 }
294 return nullptr;
295 }
296
297 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)298 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
299 Instruction* conststruct_inst) {
300 assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
301 "Expecting an OpCompositeConstruct instruction.");
302
303 // If every operand in the instruction are part of the same memory object, and
304 // are being combined in the same order, then the result is the same as the
305 // parent.
306
307 std::unique_ptr<MemoryObject> memory_object =
308 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
309
310 if (!memory_object) {
311 return nullptr;
312 }
313
314 if (!memory_object->IsMember()) {
315 return nullptr;
316 }
317
318 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
319 const analysis::Constant* last_access =
320 const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
321 if (!last_access || !last_access->type()->AsInteger()) {
322 return nullptr;
323 }
324
325 if (last_access->GetU32() != 0) {
326 return nullptr;
327 }
328
329 memory_object->GetParent();
330
331 if (memory_object->GetNumberOfMembers() !=
332 conststruct_inst->NumInOperands()) {
333 return nullptr;
334 }
335
336 for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
337 std::unique_ptr<MemoryObject> member_object =
338 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
339
340 if (!member_object) {
341 return nullptr;
342 }
343
344 if (!member_object->IsMember()) {
345 return nullptr;
346 }
347
348 if (!memory_object->Contains(member_object.get())) {
349 return nullptr;
350 }
351
352 last_access =
353 const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
354 if (!last_access || !last_access->type()->AsInteger()) {
355 return nullptr;
356 }
357
358 if (last_access->GetU32() != i) {
359 return nullptr;
360 }
361 }
362 return memory_object;
363 }
364
365 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)366 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
367 assert(insert_inst->opcode() == SpvOpCompositeInsert &&
368 "Expecting an OpCompositeInsert instruction.");
369
370 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
371 analysis::TypeManager* type_mgr = context()->get_type_mgr();
372 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
373 const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
374
375 uint32_t number_of_elements = 0;
376 if (const analysis::Struct* struct_type = result_type->AsStruct()) {
377 number_of_elements =
378 static_cast<uint32_t>(struct_type->element_types().size());
379 } else if (const analysis::Array* array_type = result_type->AsArray()) {
380 const analysis::Constant* length_const =
381 const_mgr->FindDeclaredConstant(array_type->LengthId());
382 number_of_elements = length_const->GetU32();
383 } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
384 number_of_elements = vector_type->element_count();
385 } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
386 number_of_elements = matrix_type->element_count();
387 }
388
389 if (number_of_elements == 0) {
390 return nullptr;
391 }
392
393 if (insert_inst->NumInOperands() != 3) {
394 return nullptr;
395 }
396
397 if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
398 return nullptr;
399 }
400
401 std::unique_ptr<MemoryObject> memory_object =
402 GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
403
404 if (!memory_object) {
405 return nullptr;
406 }
407
408 if (!memory_object->IsMember()) {
409 return nullptr;
410 }
411
412 const analysis::Constant* last_access =
413 const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
414 if (!last_access || !last_access->type()->AsInteger()) {
415 return nullptr;
416 }
417
418 if (last_access->GetU32() != number_of_elements - 1) {
419 return nullptr;
420 }
421
422 memory_object->GetParent();
423
424 Instruction* current_insert =
425 def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
426 for (uint32_t i = number_of_elements - 1; i > 0; --i) {
427 if (current_insert->opcode() != SpvOpCompositeInsert) {
428 return nullptr;
429 }
430
431 if (current_insert->NumInOperands() != 3) {
432 return nullptr;
433 }
434
435 if (current_insert->GetSingleWordInOperand(2) != i - 1) {
436 return nullptr;
437 }
438
439 std::unique_ptr<MemoryObject> current_memory_object =
440 GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
441
442 if (!current_memory_object) {
443 return nullptr;
444 }
445
446 if (!current_memory_object->IsMember()) {
447 return nullptr;
448 }
449
450 if (memory_object->AccessChain().size() + 1 !=
451 current_memory_object->AccessChain().size()) {
452 return nullptr;
453 }
454
455 if (!memory_object->Contains(current_memory_object.get())) {
456 return nullptr;
457 }
458
459 const analysis::Constant* current_last_access =
460 const_mgr->FindDeclaredConstant(
461 current_memory_object->AccessChain().back());
462 if (!current_last_access || !current_last_access->type()->AsInteger()) {
463 return nullptr;
464 }
465
466 if (current_last_access->GetU32() != i - 1) {
467 return nullptr;
468 }
469 current_insert =
470 def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
471 }
472
473 return memory_object;
474 }
475
IsPointerToArrayType(uint32_t type_id)476 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
477 analysis::TypeManager* type_mgr = context()->get_type_mgr();
478 analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
479 if (pointer_type) {
480 return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
481 pointer_type->pointee_type()->kind() == analysis::Type::kImage;
482 }
483 return false;
484 }
485
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)486 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
487 uint32_t type_id) {
488 analysis::TypeManager* type_mgr = context()->get_type_mgr();
489 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
490 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
491
492 analysis::Type* type = type_mgr->GetType(type_id);
493 if (type->AsRuntimeArray()) {
494 return false;
495 }
496
497 if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
498 // If the type is not an aggregate, then the desired type must be the
499 // same as the current type. No work to do, and we can do that.
500 return true;
501 }
502
503 return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
504 const_mgr,
505 type](Instruction* use,
506 uint32_t) {
507 if (IsDebugDeclareOrValue(use)) return true;
508
509 switch (use->opcode()) {
510 case SpvOpLoad: {
511 analysis::Pointer* pointer_type = type->AsPointer();
512 uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
513
514 if (new_type_id != use->type_id()) {
515 return CanUpdateUses(use, new_type_id);
516 }
517 return true;
518 }
519 case SpvOpAccessChain: {
520 analysis::Pointer* pointer_type = type->AsPointer();
521 const analysis::Type* pointee_type = pointer_type->pointee_type();
522
523 std::vector<uint32_t> access_chain;
524 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
525 const analysis::Constant* index_const =
526 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
527 if (index_const) {
528 access_chain.push_back(index_const->GetU32());
529 } else {
530 // Variable index means the type is a type where every element
531 // is the same type. Use element 0 to get the type.
532 access_chain.push_back(0);
533 }
534 }
535
536 const analysis::Type* new_pointee_type =
537 type_mgr->GetMemberType(pointee_type, access_chain);
538 analysis::Pointer pointerTy(new_pointee_type,
539 pointer_type->storage_class());
540 uint32_t new_pointer_type_id =
541 context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
542 if (new_pointer_type_id == 0) {
543 return false;
544 }
545
546 if (new_pointer_type_id != use->type_id()) {
547 return CanUpdateUses(use, new_pointer_type_id);
548 }
549 return true;
550 }
551 case SpvOpCompositeExtract: {
552 std::vector<uint32_t> access_chain;
553 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
554 access_chain.push_back(use->GetSingleWordInOperand(i));
555 }
556
557 const analysis::Type* new_type =
558 type_mgr->GetMemberType(type, access_chain);
559 uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
560 if (new_type_id == 0) {
561 return false;
562 }
563
564 if (new_type_id != use->type_id()) {
565 return CanUpdateUses(use, new_type_id);
566 }
567 return true;
568 }
569 case SpvOpStore:
570 // If needed, we can create an element-by-element copy to change the
571 // type of the value being stored. This way we can always handled
572 // stores.
573 return true;
574 case SpvOpImageTexelPointer:
575 case SpvOpName:
576 return true;
577 default:
578 return use->IsDecoration();
579 }
580 });
581 }
582
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)583 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
584 Instruction* new_ptr_inst) {
585 analysis::TypeManager* type_mgr = context()->get_type_mgr();
586 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
587 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
588
589 std::vector<std::pair<Instruction*, uint32_t> > uses;
590 def_use_mgr->ForEachUse(original_ptr_inst,
591 [&uses](Instruction* use, uint32_t index) {
592 uses.push_back({use, index});
593 });
594
595 for (auto pair : uses) {
596 Instruction* use = pair.first;
597 uint32_t index = pair.second;
598
599 if (use->IsCommonDebugInstr()) {
600 switch (use->GetCommonDebugOpcode()) {
601 case CommonDebugInfoDebugDeclare: {
602 if (new_ptr_inst->opcode() == SpvOpVariable ||
603 new_ptr_inst->opcode() == SpvOpFunctionParameter) {
604 context()->ForgetUses(use);
605 use->SetOperand(index, {new_ptr_inst->result_id()});
606 context()->AnalyzeUses(use);
607 } else {
608 // Based on the spec, we cannot use a pointer other than OpVariable
609 // or OpFunctionParameter for DebugDeclare. We have to use
610 // DebugValue with Deref.
611
612 context()->ForgetUses(use);
613
614 // Change DebugDeclare to DebugValue.
615 use->SetOperand(index - 2,
616 {static_cast<uint32_t>(CommonDebugInfoDebugValue)});
617 use->SetOperand(index, {new_ptr_inst->result_id()});
618
619 // Add Deref operation.
620 Instruction* dbg_expr =
621 def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1));
622 auto* deref_expr_instr =
623 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
624 use->SetOperand(index + 1, {deref_expr_instr->result_id()});
625
626 context()->AnalyzeUses(deref_expr_instr);
627 context()->AnalyzeUses(use);
628 }
629 break;
630 }
631 case CommonDebugInfoDebugValue:
632 context()->ForgetUses(use);
633 use->SetOperand(index, {new_ptr_inst->result_id()});
634 context()->AnalyzeUses(use);
635 break;
636 default:
637 assert(false && "Don't know how to rewrite instruction");
638 break;
639 }
640 continue;
641 }
642
643 switch (use->opcode()) {
644 case SpvOpLoad: {
645 // Replace the actual use.
646 context()->ForgetUses(use);
647 use->SetOperand(index, {new_ptr_inst->result_id()});
648
649 // Update the type.
650 Instruction* pointer_type_inst =
651 def_use_mgr->GetDef(new_ptr_inst->type_id());
652 uint32_t new_type_id =
653 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
654 if (new_type_id != use->type_id()) {
655 use->SetResultType(new_type_id);
656 context()->AnalyzeUses(use);
657 UpdateUses(use, use);
658 } else {
659 context()->AnalyzeUses(use);
660 }
661 } break;
662 case SpvOpAccessChain: {
663 // Update the actual use.
664 context()->ForgetUses(use);
665 use->SetOperand(index, {new_ptr_inst->result_id()});
666
667 // Convert the ids on the OpAccessChain to indices that can be used to
668 // get the specific member.
669 std::vector<uint32_t> access_chain;
670 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
671 const analysis::Constant* index_const =
672 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
673 if (index_const) {
674 access_chain.push_back(index_const->GetU32());
675 } else {
676 // Variable index means the type is an type where every element
677 // is the same type. Use element 0 to get the type.
678 access_chain.push_back(0);
679 }
680 }
681
682 Instruction* pointer_type_inst =
683 get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
684
685 uint32_t new_pointee_type_id = GetMemberTypeId(
686 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
687 access_chain);
688
689 SpvStorageClass storage_class = static_cast<SpvStorageClass>(
690 pointer_type_inst->GetSingleWordInOperand(
691 kTypePointerStorageClassInIdx));
692
693 uint32_t new_pointer_type_id =
694 type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
695
696 if (new_pointer_type_id != use->type_id()) {
697 use->SetResultType(new_pointer_type_id);
698 context()->AnalyzeUses(use);
699 UpdateUses(use, use);
700 } else {
701 context()->AnalyzeUses(use);
702 }
703 } break;
704 case SpvOpCompositeExtract: {
705 // Update the actual use.
706 context()->ForgetUses(use);
707 use->SetOperand(index, {new_ptr_inst->result_id()});
708
709 uint32_t new_type_id = new_ptr_inst->type_id();
710 std::vector<uint32_t> access_chain;
711 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
712 access_chain.push_back(use->GetSingleWordInOperand(i));
713 }
714
715 new_type_id = GetMemberTypeId(new_type_id, access_chain);
716
717 if (new_type_id != use->type_id()) {
718 use->SetResultType(new_type_id);
719 context()->AnalyzeUses(use);
720 UpdateUses(use, use);
721 } else {
722 context()->AnalyzeUses(use);
723 }
724 } break;
725 case SpvOpStore:
726 // If the use is the pointer, then it is the single store to that
727 // variable. We do not want to replace it. Instead, it will become
728 // dead after all of the loads are removed, and ADCE will get rid of it.
729 //
730 // If the use is the object being stored, we will create a copy of the
731 // object turning it into the correct type. The copy is done by
732 // decomposing the object into the base type, which must be the same,
733 // and then rebuilding them.
734 if (index == 1) {
735 Instruction* target_pointer = def_use_mgr->GetDef(
736 use->GetSingleWordInOperand(kStorePointerInOperand));
737 Instruction* pointer_type =
738 def_use_mgr->GetDef(target_pointer->type_id());
739 uint32_t pointee_type_id =
740 pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
741 uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
742
743 context()->ForgetUses(use);
744 use->SetInOperand(index, {copy});
745 context()->AnalyzeUses(use);
746 }
747 break;
748 case SpvOpDecorate:
749 // We treat an OpImageTexelPointer as a load. The result type should
750 // always have the Image storage class, and should not need to be
751 // updated.
752 case SpvOpImageTexelPointer:
753 // Replace the actual use.
754 context()->ForgetUses(use);
755 use->SetOperand(index, {new_ptr_inst->result_id()});
756 context()->AnalyzeUses(use);
757 break;
758 default:
759 assert(false && "Don't know how to rewrite instruction");
760 break;
761 }
762 }
763 }
764
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const765 uint32_t CopyPropagateArrays::GetMemberTypeId(
766 uint32_t id, const std::vector<uint32_t>& access_chain) const {
767 for (uint32_t element_index : access_chain) {
768 Instruction* type_inst = get_def_use_mgr()->GetDef(id);
769 switch (type_inst->opcode()) {
770 case SpvOpTypeArray:
771 case SpvOpTypeRuntimeArray:
772 case SpvOpTypeMatrix:
773 case SpvOpTypeVector:
774 id = type_inst->GetSingleWordInOperand(0);
775 break;
776 case SpvOpTypeStruct:
777 id = type_inst->GetSingleWordInOperand(element_index);
778 break;
779 default:
780 break;
781 }
782 assert(id != 0 &&
783 "Tried to extract from an object where it cannot be done.");
784 }
785 return id;
786 }
787
GetMember(const std::vector<uint32_t> & access_chain)788 void CopyPropagateArrays::MemoryObject::GetMember(
789 const std::vector<uint32_t>& access_chain) {
790 access_chain_.insert(access_chain_.end(), access_chain.begin(),
791 access_chain.end());
792 }
793
GetNumberOfMembers()794 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
795 IRContext* context = variable_inst_->context();
796 analysis::TypeManager* type_mgr = context->get_type_mgr();
797
798 const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
799 type = type->AsPointer()->pointee_type();
800
801 std::vector<uint32_t> access_indices = GetAccessIds();
802 type = type_mgr->GetMemberType(type, access_indices);
803
804 if (const analysis::Struct* struct_type = type->AsStruct()) {
805 return static_cast<uint32_t>(struct_type->element_types().size());
806 } else if (const analysis::Array* array_type = type->AsArray()) {
807 const analysis::Constant* length_const =
808 context->get_constant_mgr()->FindDeclaredConstant(
809 array_type->LengthId());
810 assert(length_const->type()->AsInteger());
811 return length_const->GetU32();
812 } else if (const analysis::Vector* vector_type = type->AsVector()) {
813 return vector_type->element_count();
814 } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
815 return matrix_type->element_count();
816 } else {
817 return 0;
818 }
819 }
820
821 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)822 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
823 iterator begin, iterator end)
824 : variable_inst_(var_inst), access_chain_(begin, end) {}
825
GetAccessIds() const826 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
827 analysis::ConstantManager* const_mgr =
828 variable_inst_->context()->get_constant_mgr();
829
830 std::vector<uint32_t> access_indices;
831 for (uint32_t id : AccessChain()) {
832 const analysis::Constant* element_index_const =
833 const_mgr->FindDeclaredConstant(id);
834 if (!element_index_const) {
835 access_indices.push_back(0);
836 } else {
837 access_indices.push_back(element_index_const->GetU32());
838 }
839 }
840 return access_indices;
841 }
842
Contains(CopyPropagateArrays::MemoryObject * other)843 bool CopyPropagateArrays::MemoryObject::Contains(
844 CopyPropagateArrays::MemoryObject* other) {
845 if (this->GetVariable() != other->GetVariable()) {
846 return false;
847 }
848
849 if (AccessChain().size() > other->AccessChain().size()) {
850 return false;
851 }
852
853 for (uint32_t i = 0; i < AccessChain().size(); i++) {
854 if (AccessChain()[i] != other->AccessChain()[i]) {
855 return false;
856 }
857 }
858 return true;
859 }
860
861 } // namespace opt
862 } // namespace spvtools
863