1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/copy_prop_arrays.h"
16
17 #include <utility>
18
19 #include "source/opt/ir_builder.h"
20
21 namespace spvtools {
22 namespace opt {
23 namespace {
24
25 constexpr uint32_t kLoadPointerInOperand = 0;
26 constexpr uint32_t kStorePointerInOperand = 0;
27 constexpr uint32_t kStoreObjectInOperand = 1;
28 constexpr uint32_t kCompositeExtractObjectInOperand = 0;
29 constexpr uint32_t kTypePointerStorageClassInIdx = 0;
30 constexpr uint32_t kTypePointerPointeeInIdx = 1;
31
IsDebugDeclareOrValue(Instruction * di)32 bool IsDebugDeclareOrValue(Instruction* di) {
33 auto dbg_opcode = di->GetCommonDebugOpcode();
34 return dbg_opcode == CommonDebugInfoDebugDeclare ||
35 dbg_opcode == CommonDebugInfoDebugValue;
36 }
37
38 } // namespace
39
Process()40 Pass::Status CopyPropagateArrays::Process() {
41 bool modified = false;
42 for (Function& function : *get_module()) {
43 if (function.IsDeclaration()) {
44 continue;
45 }
46
47 BasicBlock* entry_bb = &*function.begin();
48
49 for (auto var_inst = entry_bb->begin();
50 var_inst->opcode() == spv::Op::OpVariable; ++var_inst) {
51 if (!IsPointerToArrayType(var_inst->type_id())) {
52 continue;
53 }
54
55 // Find the only store to the entire memory location, if it exists.
56 Instruction* store_inst = FindStoreInstruction(&*var_inst);
57
58 if (!store_inst) {
59 continue;
60 }
61
62 std::unique_ptr<MemoryObject> source_object =
63 FindSourceObjectIfPossible(&*var_inst, store_inst);
64
65 if (source_object != nullptr) {
66 if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
67 modified = true;
68 PropagateObject(&*var_inst, source_object.get(), store_inst);
69 }
70 }
71 }
72 }
73 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
74 }
75
76 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)77 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
78 Instruction* store_inst) {
79 assert(var_inst->opcode() == spv::Op::OpVariable && "Expecting a variable.");
80
81 // Check that the variable is a composite object where |store_inst|
82 // dominates all of its loads.
83 if (!store_inst) {
84 return nullptr;
85 }
86
87 // Look at the loads to ensure they are dominated by the store.
88 if (!HasValidReferencesOnly(var_inst, store_inst)) {
89 return nullptr;
90 }
91
92 // If so, look at the store to see if it is the copy of an object.
93 std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
94 store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
95
96 if (!source) {
97 return nullptr;
98 }
99
100 // Ensure that |source| does not change between the point at which it is
101 // loaded, and the position in which |var_inst| is loaded.
102 //
103 // For now we will go with the easy to implement approach, and check that the
104 // entire variable (not just the specific component) is never written to.
105
106 if (!HasNoStores(source->GetVariable())) {
107 return nullptr;
108 }
109 return source;
110 }
111
FindStoreInstruction(const Instruction * var_inst) const112 Instruction* CopyPropagateArrays::FindStoreInstruction(
113 const Instruction* var_inst) const {
114 Instruction* store_inst = nullptr;
115 get_def_use_mgr()->WhileEachUser(
116 var_inst, [&store_inst, var_inst](Instruction* use) {
117 if (use->opcode() == spv::Op::OpStore &&
118 use->GetSingleWordInOperand(kStorePointerInOperand) ==
119 var_inst->result_id()) {
120 if (store_inst == nullptr) {
121 store_inst = use;
122 } else {
123 store_inst = nullptr;
124 return false;
125 }
126 }
127 return true;
128 });
129 return store_inst;
130 }
131
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)132 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
133 MemoryObject* source,
134 Instruction* insertion_point) {
135 assert(var_inst->opcode() == spv::Op::OpVariable &&
136 "This function propagates variables.");
137
138 Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
139 context()->KillNamesAndDecorates(var_inst);
140 UpdateUses(var_inst, new_access_chain);
141 }
142
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const143 Instruction* CopyPropagateArrays::BuildNewAccessChain(
144 Instruction* insertion_point,
145 CopyPropagateArrays::MemoryObject* source) const {
146 InstructionBuilder builder(
147 context(), insertion_point,
148 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
149
150 if (source->AccessChain().size() == 0) {
151 return source->GetVariable();
152 }
153
154 source->BuildConstants();
155 std::vector<uint32_t> access_ids(source->AccessChain().size());
156 std::transform(
157 source->AccessChain().cbegin(), source->AccessChain().cend(),
158 access_ids.begin(), [](const AccessChainEntry& entry) {
159 assert(entry.is_result_id && "Constants needs to be built first.");
160 return entry.result_id;
161 });
162
163 return builder.AddAccessChain(source->GetPointerTypeId(this),
164 source->GetVariable()->result_id(), access_ids);
165 }
166
HasNoStores(Instruction * ptr_inst)167 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
168 return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
169 if (use->opcode() == spv::Op::OpLoad) {
170 return true;
171 } else if (use->opcode() == spv::Op::OpAccessChain) {
172 return HasNoStores(use);
173 } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
174 return true;
175 } else if (use->opcode() == spv::Op::OpStore) {
176 return false;
177 } else if (use->opcode() == spv::Op::OpImageTexelPointer) {
178 return true;
179 } else if (use->opcode() == spv::Op::OpEntryPoint) {
180 return true;
181 }
182 // Some other instruction. Be conservative.
183 return false;
184 });
185 }
186
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)187 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
188 Instruction* store_inst) {
189 BasicBlock* store_block = context()->get_instr_block(store_inst);
190 DominatorAnalysis* dominator_analysis =
191 context()->GetDominatorAnalysis(store_block->GetParent());
192
193 return get_def_use_mgr()->WhileEachUser(
194 ptr_inst,
195 [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
196 if (use->opcode() == spv::Op::OpLoad ||
197 use->opcode() == spv::Op::OpImageTexelPointer) {
198 // TODO: If there are many load in the same BB as |store_inst| the
199 // time to do the multiple traverses can add up. Consider collecting
200 // those loads and doing a single traversal.
201 return dominator_analysis->Dominates(store_inst, use);
202 } else if (use->opcode() == spv::Op::OpAccessChain) {
203 return HasValidReferencesOnly(use, store_inst);
204 } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
205 return true;
206 } else if (use->opcode() == spv::Op::OpStore) {
207 // If we are storing to part of the object it is not an candidate.
208 return ptr_inst->opcode() == spv::Op::OpVariable &&
209 store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
210 ptr_inst->result_id();
211 } else if (IsDebugDeclareOrValue(use)) {
212 return true;
213 }
214 // Some other instruction. Be conservative.
215 return false;
216 });
217 }
218
219 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)220 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
221 Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
222
223 switch (result_inst->opcode()) {
224 case spv::Op::OpLoad:
225 return BuildMemoryObjectFromLoad(result_inst);
226 case spv::Op::OpCompositeExtract:
227 return BuildMemoryObjectFromExtract(result_inst);
228 case spv::Op::OpCompositeConstruct:
229 return BuildMemoryObjectFromCompositeConstruct(result_inst);
230 case spv::Op::OpCopyObject:
231 return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
232 case spv::Op::OpCompositeInsert:
233 return BuildMemoryObjectFromInsert(result_inst);
234 default:
235 return nullptr;
236 }
237 }
238
239 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)240 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
241 std::vector<uint32_t> components_in_reverse;
242 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
243
244 Instruction* current_inst = def_use_mgr->GetDef(
245 load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
246
247 // Build the access chain for the memory object by collecting the indices used
248 // in the OpAccessChain instructions. If we find a variable index, then
249 // return |nullptr| because we cannot know for sure which memory location is
250 // used.
251 //
252 // It is built in reverse order because the different |OpAccessChain|
253 // instructions are visited in reverse order from which they are applied.
254 while (current_inst->opcode() == spv::Op::OpAccessChain) {
255 for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
256 uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
257 components_in_reverse.push_back(element_index_id);
258 }
259 current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
260 }
261
262 // If the address in the load is not constructed from an |OpVariable|
263 // instruction followed by a series of |OpAccessChain| instructions, then
264 // return |nullptr| because we cannot identify the owner or access chain
265 // exactly.
266 if (current_inst->opcode() != spv::Op::OpVariable) {
267 return nullptr;
268 }
269
270 // Build the memory object. Use |rbegin| and |rend| to put the access chain
271 // back in the correct order.
272 return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
273 new MemoryObject(current_inst, components_in_reverse.rbegin(),
274 components_in_reverse.rend()));
275 }
276
277 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)278 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
279 assert(extract_inst->opcode() == spv::Op::OpCompositeExtract &&
280 "Expecting an OpCompositeExtract instruction.");
281 std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
282 extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
283
284 if (!result) {
285 return nullptr;
286 }
287
288 // Copy the indices of the extract instruction to |OpAccessChain| indices.
289 std::vector<AccessChainEntry> components;
290 for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
291 components.push_back({false, {extract_inst->GetSingleWordInOperand(i)}});
292 }
293 result->PushIndirection(components);
294 return result;
295 }
296
297 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)298 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
299 Instruction* conststruct_inst) {
300 assert(conststruct_inst->opcode() == spv::Op::OpCompositeConstruct &&
301 "Expecting an OpCompositeConstruct instruction.");
302
303 // If every operand in the instruction are part of the same memory object, and
304 // are being combined in the same order, then the result is the same as the
305 // parent.
306
307 std::unique_ptr<MemoryObject> memory_object =
308 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
309
310 if (!memory_object) {
311 return nullptr;
312 }
313
314 if (!memory_object->IsMember()) {
315 return nullptr;
316 }
317
318 AccessChainEntry last_access = memory_object->AccessChain().back();
319 if (!IsAccessChainIndexValidAndEqualTo(last_access, 0)) {
320 return nullptr;
321 }
322
323 memory_object->PopIndirection();
324 if (memory_object->GetNumberOfMembers() !=
325 conststruct_inst->NumInOperands()) {
326 return nullptr;
327 }
328
329 for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
330 std::unique_ptr<MemoryObject> member_object =
331 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
332
333 if (!member_object) {
334 return nullptr;
335 }
336
337 if (!member_object->IsMember()) {
338 return nullptr;
339 }
340
341 if (!memory_object->Contains(member_object.get())) {
342 return nullptr;
343 }
344
345 last_access = member_object->AccessChain().back();
346 if (!IsAccessChainIndexValidAndEqualTo(last_access, i)) {
347 return nullptr;
348 }
349 }
350 return memory_object;
351 }
352
353 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)354 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
355 assert(insert_inst->opcode() == spv::Op::OpCompositeInsert &&
356 "Expecting an OpCompositeInsert instruction.");
357
358 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
359 analysis::TypeManager* type_mgr = context()->get_type_mgr();
360 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
361 const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
362
363 uint32_t number_of_elements = 0;
364 if (const analysis::Struct* struct_type = result_type->AsStruct()) {
365 number_of_elements =
366 static_cast<uint32_t>(struct_type->element_types().size());
367 } else if (const analysis::Array* array_type = result_type->AsArray()) {
368 const analysis::Constant* length_const =
369 const_mgr->FindDeclaredConstant(array_type->LengthId());
370 number_of_elements = length_const->GetU32();
371 } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
372 number_of_elements = vector_type->element_count();
373 } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
374 number_of_elements = matrix_type->element_count();
375 }
376
377 if (number_of_elements == 0) {
378 return nullptr;
379 }
380
381 if (insert_inst->NumInOperands() != 3) {
382 return nullptr;
383 }
384
385 if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
386 return nullptr;
387 }
388
389 std::unique_ptr<MemoryObject> memory_object =
390 GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
391
392 if (!memory_object) {
393 return nullptr;
394 }
395
396 if (!memory_object->IsMember()) {
397 return nullptr;
398 }
399
400 AccessChainEntry last_access = memory_object->AccessChain().back();
401 if (!IsAccessChainIndexValidAndEqualTo(last_access, number_of_elements - 1)) {
402 return nullptr;
403 }
404
405 memory_object->PopIndirection();
406
407 Instruction* current_insert =
408 def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
409 for (uint32_t i = number_of_elements - 1; i > 0; --i) {
410 if (current_insert->opcode() != spv::Op::OpCompositeInsert) {
411 return nullptr;
412 }
413
414 if (current_insert->NumInOperands() != 3) {
415 return nullptr;
416 }
417
418 if (current_insert->GetSingleWordInOperand(2) != i - 1) {
419 return nullptr;
420 }
421
422 std::unique_ptr<MemoryObject> current_memory_object =
423 GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
424
425 if (!current_memory_object) {
426 return nullptr;
427 }
428
429 if (!current_memory_object->IsMember()) {
430 return nullptr;
431 }
432
433 if (memory_object->AccessChain().size() + 1 !=
434 current_memory_object->AccessChain().size()) {
435 return nullptr;
436 }
437
438 if (!memory_object->Contains(current_memory_object.get())) {
439 return nullptr;
440 }
441
442 AccessChainEntry current_last_access =
443 current_memory_object->AccessChain().back();
444 if (!IsAccessChainIndexValidAndEqualTo(current_last_access, i - 1)) {
445 return nullptr;
446 }
447 current_insert =
448 def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
449 }
450
451 return memory_object;
452 }
453
IsAccessChainIndexValidAndEqualTo(const AccessChainEntry & entry,uint32_t value) const454 bool CopyPropagateArrays::IsAccessChainIndexValidAndEqualTo(
455 const AccessChainEntry& entry, uint32_t value) const {
456 if (!entry.is_result_id) {
457 return entry.immediate == value;
458 }
459
460 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
461 const analysis::Constant* constant =
462 const_mgr->FindDeclaredConstant(entry.result_id);
463 if (!constant || !constant->type()->AsInteger()) {
464 return false;
465 }
466 return constant->GetU32() == value;
467 }
468
IsPointerToArrayType(uint32_t type_id)469 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
470 analysis::TypeManager* type_mgr = context()->get_type_mgr();
471 analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
472 if (pointer_type) {
473 return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
474 pointer_type->pointee_type()->kind() == analysis::Type::kImage;
475 }
476 return false;
477 }
478
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)479 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
480 uint32_t type_id) {
481 analysis::TypeManager* type_mgr = context()->get_type_mgr();
482 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
483 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
484
485 analysis::Type* type = type_mgr->GetType(type_id);
486 if (type->AsRuntimeArray()) {
487 return false;
488 }
489
490 if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
491 // If the type is not an aggregate, then the desired type must be the
492 // same as the current type. No work to do, and we can do that.
493 return true;
494 }
495
496 return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
497 const_mgr,
498 type](Instruction* use,
499 uint32_t) {
500 if (IsDebugDeclareOrValue(use)) return true;
501
502 switch (use->opcode()) {
503 case spv::Op::OpLoad: {
504 analysis::Pointer* pointer_type = type->AsPointer();
505 uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
506
507 if (new_type_id != use->type_id()) {
508 return CanUpdateUses(use, new_type_id);
509 }
510 return true;
511 }
512 case spv::Op::OpAccessChain: {
513 analysis::Pointer* pointer_type = type->AsPointer();
514 const analysis::Type* pointee_type = pointer_type->pointee_type();
515
516 std::vector<uint32_t> access_chain;
517 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
518 const analysis::Constant* index_const =
519 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
520 if (index_const) {
521 access_chain.push_back(index_const->GetU32());
522 } else {
523 // Variable index means the type is a type where every element
524 // is the same type. Use element 0 to get the type.
525 access_chain.push_back(0);
526
527 // We are trying to access a struct with variable indices.
528 // This cannot happen.
529 if (pointee_type->kind() == analysis::Type::kStruct) {
530 return false;
531 }
532 }
533 }
534
535 const analysis::Type* new_pointee_type =
536 type_mgr->GetMemberType(pointee_type, access_chain);
537 analysis::Pointer pointerTy(new_pointee_type,
538 pointer_type->storage_class());
539 uint32_t new_pointer_type_id =
540 context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
541 if (new_pointer_type_id == 0) {
542 return false;
543 }
544
545 if (new_pointer_type_id != use->type_id()) {
546 return CanUpdateUses(use, new_pointer_type_id);
547 }
548 return true;
549 }
550 case spv::Op::OpCompositeExtract: {
551 std::vector<uint32_t> access_chain;
552 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
553 access_chain.push_back(use->GetSingleWordInOperand(i));
554 }
555
556 const analysis::Type* new_type =
557 type_mgr->GetMemberType(type, access_chain);
558 uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
559 if (new_type_id == 0) {
560 return false;
561 }
562
563 if (new_type_id != use->type_id()) {
564 return CanUpdateUses(use, new_type_id);
565 }
566 return true;
567 }
568 case spv::Op::OpStore:
569 // If needed, we can create an element-by-element copy to change the
570 // type of the value being stored. This way we can always handled
571 // stores.
572 return true;
573 case spv::Op::OpImageTexelPointer:
574 case spv::Op::OpName:
575 return true;
576 default:
577 return use->IsDecoration();
578 }
579 });
580 }
581
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)582 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
583 Instruction* new_ptr_inst) {
584 analysis::TypeManager* type_mgr = context()->get_type_mgr();
585 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
586 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
587
588 std::vector<std::pair<Instruction*, uint32_t> > uses;
589 def_use_mgr->ForEachUse(original_ptr_inst,
590 [&uses](Instruction* use, uint32_t index) {
591 uses.push_back({use, index});
592 });
593
594 for (auto pair : uses) {
595 Instruction* use = pair.first;
596 uint32_t index = pair.second;
597
598 if (use->IsCommonDebugInstr()) {
599 switch (use->GetCommonDebugOpcode()) {
600 case CommonDebugInfoDebugDeclare: {
601 if (new_ptr_inst->opcode() == spv::Op::OpVariable ||
602 new_ptr_inst->opcode() == spv::Op::OpFunctionParameter) {
603 context()->ForgetUses(use);
604 use->SetOperand(index, {new_ptr_inst->result_id()});
605 context()->AnalyzeUses(use);
606 } else {
607 // Based on the spec, we cannot use a pointer other than OpVariable
608 // or OpFunctionParameter for DebugDeclare. We have to use
609 // DebugValue with Deref.
610
611 context()->ForgetUses(use);
612
613 // Change DebugDeclare to DebugValue.
614 use->SetOperand(index - 2,
615 {static_cast<uint32_t>(CommonDebugInfoDebugValue)});
616 use->SetOperand(index, {new_ptr_inst->result_id()});
617
618 // Add Deref operation.
619 Instruction* dbg_expr =
620 def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1));
621 auto* deref_expr_instr =
622 context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
623 use->SetOperand(index + 1, {deref_expr_instr->result_id()});
624
625 context()->AnalyzeUses(deref_expr_instr);
626 context()->AnalyzeUses(use);
627 }
628 break;
629 }
630 case CommonDebugInfoDebugValue:
631 context()->ForgetUses(use);
632 use->SetOperand(index, {new_ptr_inst->result_id()});
633 context()->AnalyzeUses(use);
634 break;
635 default:
636 assert(false && "Don't know how to rewrite instruction");
637 break;
638 }
639 continue;
640 }
641
642 switch (use->opcode()) {
643 case spv::Op::OpLoad: {
644 // Replace the actual use.
645 context()->ForgetUses(use);
646 use->SetOperand(index, {new_ptr_inst->result_id()});
647
648 // Update the type.
649 Instruction* pointer_type_inst =
650 def_use_mgr->GetDef(new_ptr_inst->type_id());
651 uint32_t new_type_id =
652 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
653 if (new_type_id != use->type_id()) {
654 use->SetResultType(new_type_id);
655 context()->AnalyzeUses(use);
656 UpdateUses(use, use);
657 } else {
658 context()->AnalyzeUses(use);
659 }
660 } break;
661 case spv::Op::OpAccessChain: {
662 // Update the actual use.
663 context()->ForgetUses(use);
664 use->SetOperand(index, {new_ptr_inst->result_id()});
665
666 // Convert the ids on the OpAccessChain to indices that can be used to
667 // get the specific member.
668 std::vector<uint32_t> access_chain;
669 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
670 const analysis::Constant* index_const =
671 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
672 if (index_const) {
673 access_chain.push_back(index_const->GetU32());
674 } else {
675 // Variable index means the type is an type where every element
676 // is the same type. Use element 0 to get the type.
677 access_chain.push_back(0);
678 }
679 }
680
681 Instruction* pointer_type_inst =
682 get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
683
684 uint32_t new_pointee_type_id = GetMemberTypeId(
685 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
686 access_chain);
687
688 spv::StorageClass storage_class = static_cast<spv::StorageClass>(
689 pointer_type_inst->GetSingleWordInOperand(
690 kTypePointerStorageClassInIdx));
691
692 uint32_t new_pointer_type_id =
693 type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
694
695 if (new_pointer_type_id != use->type_id()) {
696 use->SetResultType(new_pointer_type_id);
697 context()->AnalyzeUses(use);
698 UpdateUses(use, use);
699 } else {
700 context()->AnalyzeUses(use);
701 }
702 } break;
703 case spv::Op::OpCompositeExtract: {
704 // Update the actual use.
705 context()->ForgetUses(use);
706 use->SetOperand(index, {new_ptr_inst->result_id()});
707
708 uint32_t new_type_id = new_ptr_inst->type_id();
709 std::vector<uint32_t> access_chain;
710 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
711 access_chain.push_back(use->GetSingleWordInOperand(i));
712 }
713
714 new_type_id = GetMemberTypeId(new_type_id, access_chain);
715
716 if (new_type_id != use->type_id()) {
717 use->SetResultType(new_type_id);
718 context()->AnalyzeUses(use);
719 UpdateUses(use, use);
720 } else {
721 context()->AnalyzeUses(use);
722 }
723 } break;
724 case spv::Op::OpStore:
725 // If the use is the pointer, then it is the single store to that
726 // variable. We do not want to replace it. Instead, it will become
727 // dead after all of the loads are removed, and ADCE will get rid of it.
728 //
729 // If the use is the object being stored, we will create a copy of the
730 // object turning it into the correct type. The copy is done by
731 // decomposing the object into the base type, which must be the same,
732 // and then rebuilding them.
733 if (index == 1) {
734 Instruction* target_pointer = def_use_mgr->GetDef(
735 use->GetSingleWordInOperand(kStorePointerInOperand));
736 Instruction* pointer_type =
737 def_use_mgr->GetDef(target_pointer->type_id());
738 uint32_t pointee_type_id =
739 pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
740 uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
741
742 context()->ForgetUses(use);
743 use->SetInOperand(index, {copy});
744 context()->AnalyzeUses(use);
745 }
746 break;
747 case spv::Op::OpDecorate:
748 // We treat an OpImageTexelPointer as a load. The result type should
749 // always have the Image storage class, and should not need to be
750 // updated.
751 case spv::Op::OpImageTexelPointer:
752 // Replace the actual use.
753 context()->ForgetUses(use);
754 use->SetOperand(index, {new_ptr_inst->result_id()});
755 context()->AnalyzeUses(use);
756 break;
757 default:
758 assert(false && "Don't know how to rewrite instruction");
759 break;
760 }
761 }
762 }
763
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const764 uint32_t CopyPropagateArrays::GetMemberTypeId(
765 uint32_t id, const std::vector<uint32_t>& access_chain) const {
766 for (uint32_t element_index : access_chain) {
767 Instruction* type_inst = get_def_use_mgr()->GetDef(id);
768 switch (type_inst->opcode()) {
769 case spv::Op::OpTypeArray:
770 case spv::Op::OpTypeRuntimeArray:
771 case spv::Op::OpTypeMatrix:
772 case spv::Op::OpTypeVector:
773 id = type_inst->GetSingleWordInOperand(0);
774 break;
775 case spv::Op::OpTypeStruct:
776 id = type_inst->GetSingleWordInOperand(element_index);
777 break;
778 default:
779 break;
780 }
781 assert(id != 0 &&
782 "Tried to extract from an object where it cannot be done.");
783 }
784 return id;
785 }
786
PushIndirection(const std::vector<AccessChainEntry> & access_chain)787 void CopyPropagateArrays::MemoryObject::PushIndirection(
788 const std::vector<AccessChainEntry>& access_chain) {
789 access_chain_.insert(access_chain_.end(), access_chain.begin(),
790 access_chain.end());
791 }
792
GetNumberOfMembers()793 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
794 IRContext* context = variable_inst_->context();
795 analysis::TypeManager* type_mgr = context->get_type_mgr();
796
797 const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
798 type = type->AsPointer()->pointee_type();
799
800 std::vector<uint32_t> access_indices = GetAccessIds();
801 type = type_mgr->GetMemberType(type, access_indices);
802
803 if (const analysis::Struct* struct_type = type->AsStruct()) {
804 return static_cast<uint32_t>(struct_type->element_types().size());
805 } else if (const analysis::Array* array_type = type->AsArray()) {
806 const analysis::Constant* length_const =
807 context->get_constant_mgr()->FindDeclaredConstant(
808 array_type->LengthId());
809 assert(length_const->type()->AsInteger());
810 return length_const->GetU32();
811 } else if (const analysis::Vector* vector_type = type->AsVector()) {
812 return vector_type->element_count();
813 } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
814 return matrix_type->element_count();
815 } else {
816 return 0;
817 }
818 }
819
820 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)821 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
822 iterator begin, iterator end)
823 : variable_inst_(var_inst) {
824 std::transform(begin, end, std::back_inserter(access_chain_),
825 [](uint32_t id) {
826 return AccessChainEntry{true, {id}};
827 });
828 }
829
GetAccessIds() const830 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
831 analysis::ConstantManager* const_mgr =
832 variable_inst_->context()->get_constant_mgr();
833
834 std::vector<uint32_t> indices(AccessChain().size());
835 std::transform(AccessChain().cbegin(), AccessChain().cend(), indices.begin(),
836 [&const_mgr](const AccessChainEntry& entry) {
837 if (entry.is_result_id) {
838 const analysis::Constant* constant =
839 const_mgr->FindDeclaredConstant(entry.result_id);
840 return constant == nullptr ? 0 : constant->GetU32();
841 }
842
843 return entry.immediate;
844 });
845 return indices;
846 }
847
Contains(CopyPropagateArrays::MemoryObject * other)848 bool CopyPropagateArrays::MemoryObject::Contains(
849 CopyPropagateArrays::MemoryObject* other) {
850 if (this->GetVariable() != other->GetVariable()) {
851 return false;
852 }
853
854 if (AccessChain().size() > other->AccessChain().size()) {
855 return false;
856 }
857
858 for (uint32_t i = 0; i < AccessChain().size(); i++) {
859 if (AccessChain()[i] != other->AccessChain()[i]) {
860 return false;
861 }
862 }
863 return true;
864 }
865
BuildConstants()866 void CopyPropagateArrays::MemoryObject::BuildConstants() {
867 for (auto& entry : access_chain_) {
868 if (entry.is_result_id) {
869 continue;
870 }
871
872 auto context = variable_inst_->context();
873 analysis::Integer int_type(32, false);
874 const analysis::Type* uint32_type =
875 context->get_type_mgr()->GetRegisteredType(&int_type);
876 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
877 const analysis::Constant* index_const =
878 const_mgr->GetConstant(uint32_type, {entry.immediate});
879 entry.result_id =
880 const_mgr->GetDefiningInstruction(index_const)->result_id();
881 entry.is_result_id = true;
882 }
883 }
884
885 } // namespace opt
886 } // namespace spvtools
887