1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/copy_prop_arrays.h"
16
17 #include <utility>
18
19 #include "source/opt/ir_builder.h"
20
21 namespace spvtools {
22 namespace opt {
23 namespace {
24
25 const uint32_t kLoadPointerInOperand = 0;
26 const uint32_t kStorePointerInOperand = 0;
27 const uint32_t kStoreObjectInOperand = 1;
28 const uint32_t kCompositeExtractObjectInOperand = 0;
29 const uint32_t kTypePointerStorageClassInIdx = 0;
30 const uint32_t kTypePointerPointeeInIdx = 1;
31
IsOpenCL100DebugDeclareOrValue(Instruction * di)32 bool IsOpenCL100DebugDeclareOrValue(Instruction* di) {
33 auto dbg_opcode = di->GetOpenCL100DebugOpcode();
34 return dbg_opcode == OpenCLDebugInfo100DebugDeclare ||
35 dbg_opcode == OpenCLDebugInfo100DebugValue;
36 }
37
38 } // namespace
39
Process()40 Pass::Status CopyPropagateArrays::Process() {
41 bool modified = false;
42 for (Function& function : *get_module()) {
43 BasicBlock* entry_bb = &*function.begin();
44
45 for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
46 ++var_inst) {
47 if (!IsPointerToArrayType(var_inst->type_id())) {
48 continue;
49 }
50
51 // Find the only store to the entire memory location, if it exists.
52 Instruction* store_inst = FindStoreInstruction(&*var_inst);
53
54 if (!store_inst) {
55 continue;
56 }
57
58 std::unique_ptr<MemoryObject> source_object =
59 FindSourceObjectIfPossible(&*var_inst, store_inst);
60
61 if (source_object != nullptr) {
62 if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
63 modified = true;
64 PropagateObject(&*var_inst, source_object.get(), store_inst);
65 }
66 }
67 }
68 }
69 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
70 }
71
72 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)73 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
74 Instruction* store_inst) {
75 assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
76
77 // Check that the variable is a composite object where |store_inst|
78 // dominates all of its loads.
79 if (!store_inst) {
80 return nullptr;
81 }
82
83 // Look at the loads to ensure they are dominated by the store.
84 if (!HasValidReferencesOnly(var_inst, store_inst)) {
85 return nullptr;
86 }
87
88 // If so, look at the store to see if it is the copy of an object.
89 std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
90 store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
91
92 if (!source) {
93 return nullptr;
94 }
95
96 // Ensure that |source| does not change between the point at which it is
97 // loaded, and the position in which |var_inst| is loaded.
98 //
99 // For now we will go with the easy to implement approach, and check that the
100 // entire variable (not just the specific component) is never written to.
101
102 if (!HasNoStores(source->GetVariable())) {
103 return nullptr;
104 }
105 return source;
106 }
107
FindStoreInstruction(const Instruction * var_inst) const108 Instruction* CopyPropagateArrays::FindStoreInstruction(
109 const Instruction* var_inst) const {
110 Instruction* store_inst = nullptr;
111 get_def_use_mgr()->WhileEachUser(
112 var_inst, [&store_inst, var_inst](Instruction* use) {
113 if (use->opcode() == SpvOpStore &&
114 use->GetSingleWordInOperand(kStorePointerInOperand) ==
115 var_inst->result_id()) {
116 if (store_inst == nullptr) {
117 store_inst = use;
118 } else {
119 store_inst = nullptr;
120 return false;
121 }
122 }
123 return true;
124 });
125 return store_inst;
126 }
127
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)128 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
129 MemoryObject* source,
130 Instruction* insertion_point) {
131 assert(var_inst->opcode() == SpvOpVariable &&
132 "This function propagates variables.");
133
134 Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
135 context()->KillNamesAndDecorates(var_inst);
136 UpdateUses(var_inst, new_access_chain);
137 }
138
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const139 Instruction* CopyPropagateArrays::BuildNewAccessChain(
140 Instruction* insertion_point,
141 CopyPropagateArrays::MemoryObject* source) const {
142 InstructionBuilder builder(
143 context(), insertion_point,
144 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
145
146 if (source->AccessChain().size() == 0) {
147 return source->GetVariable();
148 }
149
150 return builder.AddAccessChain(source->GetPointerTypeId(this),
151 source->GetVariable()->result_id(),
152 source->AccessChain());
153 }
154
HasNoStores(Instruction * ptr_inst)155 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
156 return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
157 if (use->opcode() == SpvOpLoad) {
158 return true;
159 } else if (use->opcode() == SpvOpAccessChain) {
160 return HasNoStores(use);
161 } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
162 return true;
163 } else if (use->opcode() == SpvOpStore) {
164 return false;
165 } else if (use->opcode() == SpvOpImageTexelPointer) {
166 return true;
167 }
168 // Some other instruction. Be conservative.
169 return false;
170 });
171 }
172
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)173 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
174 Instruction* store_inst) {
175 BasicBlock* store_block = context()->get_instr_block(store_inst);
176 DominatorAnalysis* dominator_analysis =
177 context()->GetDominatorAnalysis(store_block->GetParent());
178
179 return get_def_use_mgr()->WhileEachUser(
180 ptr_inst,
181 [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
182 if (use->opcode() == SpvOpLoad ||
183 use->opcode() == SpvOpImageTexelPointer) {
184 // TODO: If there are many load in the same BB as |store_inst| the
185 // time to do the multiple traverses can add up. Consider collecting
186 // those loads and doing a single traversal.
187 return dominator_analysis->Dominates(store_inst, use);
188 } else if (use->opcode() == SpvOpAccessChain) {
189 return HasValidReferencesOnly(use, store_inst);
190 } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
191 return true;
192 } else if (use->opcode() == SpvOpStore) {
193 // If we are storing to part of the object it is not an candidate.
194 return ptr_inst->opcode() == SpvOpVariable &&
195 store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
196 ptr_inst->result_id();
197 } else if (IsOpenCL100DebugDeclareOrValue(use)) {
198 return true;
199 }
200 // Some other instruction. Be conservative.
201 return false;
202 });
203 }
204
205 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)206 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
207 Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
208
209 switch (result_inst->opcode()) {
210 case SpvOpLoad:
211 return BuildMemoryObjectFromLoad(result_inst);
212 case SpvOpCompositeExtract:
213 return BuildMemoryObjectFromExtract(result_inst);
214 case SpvOpCompositeConstruct:
215 return BuildMemoryObjectFromCompositeConstruct(result_inst);
216 case SpvOpCopyObject:
217 return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
218 case SpvOpCompositeInsert:
219 return BuildMemoryObjectFromInsert(result_inst);
220 default:
221 return nullptr;
222 }
223 }
224
225 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)226 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
227 std::vector<uint32_t> components_in_reverse;
228 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
229
230 Instruction* current_inst = def_use_mgr->GetDef(
231 load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
232
233 // Build the access chain for the memory object by collecting the indices used
234 // in the OpAccessChain instructions. If we find a variable index, then
235 // return |nullptr| because we cannot know for sure which memory location is
236 // used.
237 //
238 // It is built in reverse order because the different |OpAccessChain|
239 // instructions are visited in reverse order from which they are applied.
240 while (current_inst->opcode() == SpvOpAccessChain) {
241 for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
242 uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
243 components_in_reverse.push_back(element_index_id);
244 }
245 current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
246 }
247
248 // If the address in the load is not constructed from an |OpVariable|
249 // instruction followed by a series of |OpAccessChain| instructions, then
250 // return |nullptr| because we cannot identify the owner or access chain
251 // exactly.
252 if (current_inst->opcode() != SpvOpVariable) {
253 return nullptr;
254 }
255
256 // Build the memory object. Use |rbegin| and |rend| to put the access chain
257 // back in the correct order.
258 return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
259 new MemoryObject(current_inst, components_in_reverse.rbegin(),
260 components_in_reverse.rend()));
261 }
262
263 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)264 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
265 assert(extract_inst->opcode() == SpvOpCompositeExtract &&
266 "Expecting an OpCompositeExtract instruction.");
267 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
268
269 std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
270 extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
271
272 if (result) {
273 analysis::Integer int_type(32, false);
274 const analysis::Type* uint32_type =
275 context()->get_type_mgr()->GetRegisteredType(&int_type);
276
277 std::vector<uint32_t> components;
278 // Convert the indices in the extract instruction to a series of ids that
279 // can be used by the |OpAccessChain| instruction.
280 for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
281 uint32_t index = extract_inst->GetSingleWordInOperand(i);
282 const analysis::Constant* index_const =
283 const_mgr->GetConstant(uint32_type, {index});
284 components.push_back(
285 const_mgr->GetDefiningInstruction(index_const)->result_id());
286 }
287 result->GetMember(components);
288 return result;
289 }
290 return nullptr;
291 }
292
293 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)294 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
295 Instruction* conststruct_inst) {
296 assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
297 "Expecting an OpCompositeConstruct instruction.");
298
299 // If every operand in the instruction are part of the same memory object, and
300 // are being combined in the same order, then the result is the same as the
301 // parent.
302
303 std::unique_ptr<MemoryObject> memory_object =
304 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
305
306 if (!memory_object) {
307 return nullptr;
308 }
309
310 if (!memory_object->IsMember()) {
311 return nullptr;
312 }
313
314 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
315 const analysis::Constant* last_access =
316 const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
317 if (!last_access || !last_access->type()->AsInteger()) {
318 return nullptr;
319 }
320
321 if (last_access->GetU32() != 0) {
322 return nullptr;
323 }
324
325 memory_object->GetParent();
326
327 if (memory_object->GetNumberOfMembers() !=
328 conststruct_inst->NumInOperands()) {
329 return nullptr;
330 }
331
332 for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
333 std::unique_ptr<MemoryObject> member_object =
334 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
335
336 if (!member_object) {
337 return nullptr;
338 }
339
340 if (!member_object->IsMember()) {
341 return nullptr;
342 }
343
344 if (!memory_object->Contains(member_object.get())) {
345 return nullptr;
346 }
347
348 last_access =
349 const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
350 if (!last_access || !last_access->type()->AsInteger()) {
351 return nullptr;
352 }
353
354 if (last_access->GetU32() != i) {
355 return nullptr;
356 }
357 }
358 return memory_object;
359 }
360
361 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)362 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
363 assert(insert_inst->opcode() == SpvOpCompositeInsert &&
364 "Expecting an OpCompositeInsert instruction.");
365
366 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
367 analysis::TypeManager* type_mgr = context()->get_type_mgr();
368 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
369 const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
370
371 uint32_t number_of_elements = 0;
372 if (const analysis::Struct* struct_type = result_type->AsStruct()) {
373 number_of_elements =
374 static_cast<uint32_t>(struct_type->element_types().size());
375 } else if (const analysis::Array* array_type = result_type->AsArray()) {
376 const analysis::Constant* length_const =
377 const_mgr->FindDeclaredConstant(array_type->LengthId());
378 number_of_elements = length_const->GetU32();
379 } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
380 number_of_elements = vector_type->element_count();
381 } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
382 number_of_elements = matrix_type->element_count();
383 }
384
385 if (number_of_elements == 0) {
386 return nullptr;
387 }
388
389 if (insert_inst->NumInOperands() != 3) {
390 return nullptr;
391 }
392
393 if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
394 return nullptr;
395 }
396
397 std::unique_ptr<MemoryObject> memory_object =
398 GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
399
400 if (!memory_object) {
401 return nullptr;
402 }
403
404 if (!memory_object->IsMember()) {
405 return nullptr;
406 }
407
408 const analysis::Constant* last_access =
409 const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
410 if (!last_access || !last_access->type()->AsInteger()) {
411 return nullptr;
412 }
413
414 if (last_access->GetU32() != number_of_elements - 1) {
415 return nullptr;
416 }
417
418 memory_object->GetParent();
419
420 Instruction* current_insert =
421 def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
422 for (uint32_t i = number_of_elements - 1; i > 0; --i) {
423 if (current_insert->opcode() != SpvOpCompositeInsert) {
424 return nullptr;
425 }
426
427 if (current_insert->NumInOperands() != 3) {
428 return nullptr;
429 }
430
431 if (current_insert->GetSingleWordInOperand(2) != i - 1) {
432 return nullptr;
433 }
434
435 std::unique_ptr<MemoryObject> current_memory_object =
436 GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
437
438 if (!current_memory_object) {
439 return nullptr;
440 }
441
442 if (!current_memory_object->IsMember()) {
443 return nullptr;
444 }
445
446 if (memory_object->AccessChain().size() + 1 !=
447 current_memory_object->AccessChain().size()) {
448 return nullptr;
449 }
450
451 if (!memory_object->Contains(current_memory_object.get())) {
452 return nullptr;
453 }
454
455 const analysis::Constant* current_last_access =
456 const_mgr->FindDeclaredConstant(
457 current_memory_object->AccessChain().back());
458 if (!current_last_access || !current_last_access->type()->AsInteger()) {
459 return nullptr;
460 }
461
462 if (current_last_access->GetU32() != i - 1) {
463 return nullptr;
464 }
465 current_insert =
466 def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
467 }
468
469 return memory_object;
470 }
471
IsPointerToArrayType(uint32_t type_id)472 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
473 analysis::TypeManager* type_mgr = context()->get_type_mgr();
474 analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
475 if (pointer_type) {
476 return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
477 pointer_type->pointee_type()->kind() == analysis::Type::kImage;
478 }
479 return false;
480 }
481
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)482 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
483 uint32_t type_id) {
484 analysis::TypeManager* type_mgr = context()->get_type_mgr();
485 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
486 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
487
488 analysis::Type* type = type_mgr->GetType(type_id);
489 if (type->AsRuntimeArray()) {
490 return false;
491 }
492
493 if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
494 // If the type is not an aggregate, then the desired type must be the
495 // same as the current type. No work to do, and we can do that.
496 return true;
497 }
498
499 return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
500 const_mgr,
501 type](Instruction* use,
502 uint32_t) {
503 if (IsOpenCL100DebugDeclareOrValue(use)) return true;
504
505 switch (use->opcode()) {
506 case SpvOpLoad: {
507 analysis::Pointer* pointer_type = type->AsPointer();
508 uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
509
510 if (new_type_id != use->type_id()) {
511 return CanUpdateUses(use, new_type_id);
512 }
513 return true;
514 }
515 case SpvOpAccessChain: {
516 analysis::Pointer* pointer_type = type->AsPointer();
517 const analysis::Type* pointee_type = pointer_type->pointee_type();
518
519 std::vector<uint32_t> access_chain;
520 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
521 const analysis::Constant* index_const =
522 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
523 if (index_const) {
524 access_chain.push_back(index_const->GetU32());
525 } else {
526 // Variable index means the type is a type where every element
527 // is the same type. Use element 0 to get the type.
528 access_chain.push_back(0);
529 }
530 }
531
532 const analysis::Type* new_pointee_type =
533 type_mgr->GetMemberType(pointee_type, access_chain);
534 analysis::Pointer pointerTy(new_pointee_type,
535 pointer_type->storage_class());
536 uint32_t new_pointer_type_id =
537 context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
538 if (new_pointer_type_id == 0) {
539 return false;
540 }
541
542 if (new_pointer_type_id != use->type_id()) {
543 return CanUpdateUses(use, new_pointer_type_id);
544 }
545 return true;
546 }
547 case SpvOpCompositeExtract: {
548 std::vector<uint32_t> access_chain;
549 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
550 access_chain.push_back(use->GetSingleWordInOperand(i));
551 }
552
553 const analysis::Type* new_type =
554 type_mgr->GetMemberType(type, access_chain);
555 uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
556 if (new_type_id == 0) {
557 return false;
558 }
559
560 if (new_type_id != use->type_id()) {
561 return CanUpdateUses(use, new_type_id);
562 }
563 return true;
564 }
565 case SpvOpStore:
566 // If needed, we can create an element-by-element copy to change the
567 // type of the value being stored. This way we can always handled
568 // stores.
569 return true;
570 case SpvOpImageTexelPointer:
571 case SpvOpName:
572 return true;
573 default:
574 return use->IsDecoration();
575 }
576 });
577 }
578
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)579 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
580 Instruction* new_ptr_inst) {
581 analysis::TypeManager* type_mgr = context()->get_type_mgr();
582 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
583 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
584
585 std::vector<std::pair<Instruction*, uint32_t> > uses;
586 def_use_mgr->ForEachUse(original_ptr_inst,
587 [&uses](Instruction* use, uint32_t index) {
588 uses.push_back({use, index});
589 });
590
591 for (auto pair : uses) {
592 Instruction* use = pair.first;
593 uint32_t index = pair.second;
594
595 if (use->IsOpenCL100DebugInstr()) {
596 switch (use->GetOpenCL100DebugOpcode()) {
597 case OpenCLDebugInfo100DebugDeclare: {
598 if (new_ptr_inst->opcode() == SpvOpVariable ||
599 new_ptr_inst->opcode() == SpvOpFunctionParameter) {
600 context()->ForgetUses(use);
601 use->SetOperand(index, {new_ptr_inst->result_id()});
602 context()->AnalyzeUses(use);
603 } else {
604 // Based on the spec, we cannot use a pointer other than OpVariable
605 // or OpFunctionParameter for DebugDeclare. We have to use
606 // DebugValue with Deref.
607
608 context()->ForgetUses(use);
609
610 // Change DebugDeclare to DebugValue.
611 use->SetOperand(
612 index - 2,
613 {static_cast<uint32_t>(OpenCLDebugInfo100DebugValue)});
614 use->SetOperand(index, {new_ptr_inst->result_id()});
615
616 // Add Deref operation.
617 Instruction* dbg_expr =
618 def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1));
619 auto* deref_expr_instr =
620 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
621 use->SetOperand(index + 1, {deref_expr_instr->result_id()});
622
623 context()->AnalyzeUses(deref_expr_instr);
624 context()->AnalyzeUses(use);
625 }
626 break;
627 }
628 case OpenCLDebugInfo100DebugValue:
629 context()->ForgetUses(use);
630 use->SetOperand(index, {new_ptr_inst->result_id()});
631 context()->AnalyzeUses(use);
632 break;
633 default:
634 assert(false && "Don't know how to rewrite instruction");
635 break;
636 }
637 continue;
638 }
639
640 switch (use->opcode()) {
641 case SpvOpLoad: {
642 // Replace the actual use.
643 context()->ForgetUses(use);
644 use->SetOperand(index, {new_ptr_inst->result_id()});
645
646 // Update the type.
647 Instruction* pointer_type_inst =
648 def_use_mgr->GetDef(new_ptr_inst->type_id());
649 uint32_t new_type_id =
650 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
651 if (new_type_id != use->type_id()) {
652 use->SetResultType(new_type_id);
653 context()->AnalyzeUses(use);
654 UpdateUses(use, use);
655 } else {
656 context()->AnalyzeUses(use);
657 }
658 } break;
659 case SpvOpAccessChain: {
660 // Update the actual use.
661 context()->ForgetUses(use);
662 use->SetOperand(index, {new_ptr_inst->result_id()});
663
664 // Convert the ids on the OpAccessChain to indices that can be used to
665 // get the specific member.
666 std::vector<uint32_t> access_chain;
667 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
668 const analysis::Constant* index_const =
669 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
670 if (index_const) {
671 access_chain.push_back(index_const->GetU32());
672 } else {
673 // Variable index means the type is an type where every element
674 // is the same type. Use element 0 to get the type.
675 access_chain.push_back(0);
676 }
677 }
678
679 Instruction* pointer_type_inst =
680 get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
681
682 uint32_t new_pointee_type_id = GetMemberTypeId(
683 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
684 access_chain);
685
686 SpvStorageClass storage_class = static_cast<SpvStorageClass>(
687 pointer_type_inst->GetSingleWordInOperand(
688 kTypePointerStorageClassInIdx));
689
690 uint32_t new_pointer_type_id =
691 type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
692
693 if (new_pointer_type_id != use->type_id()) {
694 use->SetResultType(new_pointer_type_id);
695 context()->AnalyzeUses(use);
696 UpdateUses(use, use);
697 } else {
698 context()->AnalyzeUses(use);
699 }
700 } break;
701 case SpvOpCompositeExtract: {
702 // Update the actual use.
703 context()->ForgetUses(use);
704 use->SetOperand(index, {new_ptr_inst->result_id()});
705
706 uint32_t new_type_id = new_ptr_inst->type_id();
707 std::vector<uint32_t> access_chain;
708 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
709 access_chain.push_back(use->GetSingleWordInOperand(i));
710 }
711
712 new_type_id = GetMemberTypeId(new_type_id, access_chain);
713
714 if (new_type_id != use->type_id()) {
715 use->SetResultType(new_type_id);
716 context()->AnalyzeUses(use);
717 UpdateUses(use, use);
718 } else {
719 context()->AnalyzeUses(use);
720 }
721 } break;
722 case SpvOpStore:
723 // If the use is the pointer, then it is the single store to that
724 // variable. We do not want to replace it. Instead, it will become
725 // dead after all of the loads are removed, and ADCE will get rid of it.
726 //
727 // If the use is the object being stored, we will create a copy of the
728 // object turning it into the correct type. The copy is done by
729 // decomposing the object into the base type, which must be the same,
730 // and then rebuilding them.
731 if (index == 1) {
732 Instruction* target_pointer = def_use_mgr->GetDef(
733 use->GetSingleWordInOperand(kStorePointerInOperand));
734 Instruction* pointer_type =
735 def_use_mgr->GetDef(target_pointer->type_id());
736 uint32_t pointee_type_id =
737 pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
738 uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
739
740 context()->ForgetUses(use);
741 use->SetInOperand(index, {copy});
742 context()->AnalyzeUses(use);
743 }
744 break;
745 case SpvOpImageTexelPointer:
746 // We treat an OpImageTexelPointer as a load. The result type should
747 // always have the Image storage class, and should not need to be
748 // updated.
749
750 // Replace the actual use.
751 context()->ForgetUses(use);
752 use->SetOperand(index, {new_ptr_inst->result_id()});
753 context()->AnalyzeUses(use);
754 break;
755 default:
756 assert(false && "Don't know how to rewrite instruction");
757 break;
758 }
759 }
760 }
761
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const762 uint32_t CopyPropagateArrays::GetMemberTypeId(
763 uint32_t id, const std::vector<uint32_t>& access_chain) const {
764 for (uint32_t element_index : access_chain) {
765 Instruction* type_inst = get_def_use_mgr()->GetDef(id);
766 switch (type_inst->opcode()) {
767 case SpvOpTypeArray:
768 case SpvOpTypeRuntimeArray:
769 case SpvOpTypeMatrix:
770 case SpvOpTypeVector:
771 id = type_inst->GetSingleWordInOperand(0);
772 break;
773 case SpvOpTypeStruct:
774 id = type_inst->GetSingleWordInOperand(element_index);
775 break;
776 default:
777 break;
778 }
779 assert(id != 0 &&
780 "Tried to extract from an object where it cannot be done.");
781 }
782 return id;
783 }
784
GetMember(const std::vector<uint32_t> & access_chain)785 void CopyPropagateArrays::MemoryObject::GetMember(
786 const std::vector<uint32_t>& access_chain) {
787 access_chain_.insert(access_chain_.end(), access_chain.begin(),
788 access_chain.end());
789 }
790
GetNumberOfMembers()791 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
792 IRContext* context = variable_inst_->context();
793 analysis::TypeManager* type_mgr = context->get_type_mgr();
794
795 const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
796 type = type->AsPointer()->pointee_type();
797
798 std::vector<uint32_t> access_indices = GetAccessIds();
799 type = type_mgr->GetMemberType(type, access_indices);
800
801 if (const analysis::Struct* struct_type = type->AsStruct()) {
802 return static_cast<uint32_t>(struct_type->element_types().size());
803 } else if (const analysis::Array* array_type = type->AsArray()) {
804 const analysis::Constant* length_const =
805 context->get_constant_mgr()->FindDeclaredConstant(
806 array_type->LengthId());
807 assert(length_const->type()->AsInteger());
808 return length_const->GetU32();
809 } else if (const analysis::Vector* vector_type = type->AsVector()) {
810 return vector_type->element_count();
811 } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
812 return matrix_type->element_count();
813 } else {
814 return 0;
815 }
816 }
817
818 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)819 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
820 iterator begin, iterator end)
821 : variable_inst_(var_inst), access_chain_(begin, end) {}
822
GetAccessIds() const823 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
824 analysis::ConstantManager* const_mgr =
825 variable_inst_->context()->get_constant_mgr();
826
827 std::vector<uint32_t> access_indices;
828 for (uint32_t id : AccessChain()) {
829 const analysis::Constant* element_index_const =
830 const_mgr->FindDeclaredConstant(id);
831 if (!element_index_const) {
832 access_indices.push_back(0);
833 } else {
834 access_indices.push_back(element_index_const->GetU32());
835 }
836 }
837 return access_indices;
838 }
839
Contains(CopyPropagateArrays::MemoryObject * other)840 bool CopyPropagateArrays::MemoryObject::Contains(
841 CopyPropagateArrays::MemoryObject* other) {
842 if (this->GetVariable() != other->GetVariable()) {
843 return false;
844 }
845
846 if (AccessChain().size() > other->AccessChain().size()) {
847 return false;
848 }
849
850 for (uint32_t i = 0; i < AccessChain().size(); i++) {
851 if (AccessChain()[i] != other->AccessChain()[i]) {
852 return false;
853 }
854 }
855 return true;
856 }
857
858 } // namespace opt
859 } // namespace spvtools
860