• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/eliminate_dead_members_pass.h"
16 
17 #include "ir_builder.h"
18 #include "source/opt/ir_context.h"
19 
20 namespace spvtools {
21 namespace opt {
22 namespace {
23 constexpr uint32_t kRemovedMember = 0xFFFFFFFF;
24 constexpr uint32_t kSpecConstOpOpcodeIdx = 0;
25 constexpr uint32_t kArrayElementTypeIdx = 0;
26 }  // namespace
27 
Process()28 Pass::Status EliminateDeadMembersPass::Process() {
29   if (!context()->get_feature_mgr()->HasCapability(spv::Capability::Shader))
30     return Status::SuccessWithoutChange;
31 
32   FindLiveMembers();
33   if (RemoveDeadMembers()) {
34     return Status::SuccessWithChange;
35   }
36   return Status::SuccessWithoutChange;
37 }
38 
FindLiveMembers()39 void EliminateDeadMembersPass::FindLiveMembers() {
40   // Until we have implemented the rewriting of OpSpecConsantOp instructions,
41   // we have to mark them as fully used just to be safe.
42   for (auto& inst : get_module()->types_values()) {
43     if (inst.opcode() == spv::Op::OpSpecConstantOp) {
44       switch (spv::Op(inst.GetSingleWordInOperand(kSpecConstOpOpcodeIdx))) {
45         case spv::Op::OpCompositeExtract:
46           MarkMembersAsLiveForExtract(&inst);
47           break;
48         case spv::Op::OpCompositeInsert:
49           // Nothing specific to do.
50           break;
51         case spv::Op::OpAccessChain:
52         case spv::Op::OpInBoundsAccessChain:
53         case spv::Op::OpPtrAccessChain:
54         case spv::Op::OpInBoundsPtrAccessChain:
55           assert(false && "Not implemented yet.");
56           break;
57         default:
58           break;
59       }
60     } else if (inst.opcode() == spv::Op::OpVariable) {
61       switch (spv::StorageClass(inst.GetSingleWordInOperand(0))) {
62         case spv::StorageClass::Input:
63         case spv::StorageClass::Output:
64           MarkPointeeTypeAsFullUsed(inst.type_id());
65           break;
66         default:
67           // Ignore structured buffers as layout(offset) qualifiers cannot be
68           // applied to structure fields
69           if (inst.IsVulkanStorageBufferVariable())
70             MarkPointeeTypeAsFullUsed(inst.type_id());
71           break;
72       }
73     } else if (inst.opcode() == spv::Op::OpTypePointer) {
74       uint32_t storage_class = inst.GetSingleWordInOperand(0);
75       if (storage_class == uint32_t(spv::StorageClass::PhysicalStorageBuffer)) {
76         MarkTypeAsFullyUsed(inst.GetSingleWordInOperand(1));
77       }
78     }
79   }
80 
81   for (const Function& func : *get_module()) {
82     FindLiveMembers(func);
83   }
84 }
85 
FindLiveMembers(const Function & function)86 void EliminateDeadMembersPass::FindLiveMembers(const Function& function) {
87   function.ForEachInst(
88       [this](const Instruction* inst) { FindLiveMembers(inst); });
89 }
90 
FindLiveMembers(const Instruction * inst)91 void EliminateDeadMembersPass::FindLiveMembers(const Instruction* inst) {
92   switch (inst->opcode()) {
93     case spv::Op::OpStore:
94       MarkMembersAsLiveForStore(inst);
95       break;
96     case spv::Op::OpCopyMemory:
97     case spv::Op::OpCopyMemorySized:
98       MarkMembersAsLiveForCopyMemory(inst);
99       break;
100     case spv::Op::OpCompositeExtract:
101       MarkMembersAsLiveForExtract(inst);
102       break;
103     case spv::Op::OpAccessChain:
104     case spv::Op::OpInBoundsAccessChain:
105     case spv::Op::OpPtrAccessChain:
106     case spv::Op::OpInBoundsPtrAccessChain:
107       MarkMembersAsLiveForAccessChain(inst);
108       break;
109     case spv::Op::OpReturnValue:
110       // This should be an issue only if we are returning from the entry point.
111       // However, for now I will keep it more conservative because functions are
112       // often inlined leaving only the entry points.
113       MarkOperandTypeAsFullyUsed(inst, 0);
114       break;
115     case spv::Op::OpArrayLength:
116       MarkMembersAsLiveForArrayLength(inst);
117       break;
118     case spv::Op::OpLoad:
119     case spv::Op::OpCompositeInsert:
120     case spv::Op::OpCompositeConstruct:
121       break;
122     default:
123       // This path is here for safety.  All instructions that can reference
124       // structs in a function body should be handled above.  However, this will
125       // keep the pass valid, but not optimal, as new instructions get added
126       // or if something was missed.
127       MarkStructOperandsAsFullyUsed(inst);
128       break;
129   }
130 }
131 
MarkMembersAsLiveForStore(const Instruction * inst)132 void EliminateDeadMembersPass::MarkMembersAsLiveForStore(
133     const Instruction* inst) {
134   // We should only have to mark the members as live if the store is to
135   // memory that is read outside of the shader.  Other passes can remove all
136   // store to memory that is not visible outside of the shader, so we do not
137   // complicate the code for now.
138   assert(inst->opcode() == spv::Op::OpStore);
139   uint32_t object_id = inst->GetSingleWordInOperand(1);
140   Instruction* object_inst = context()->get_def_use_mgr()->GetDef(object_id);
141   uint32_t object_type_id = object_inst->type_id();
142   MarkTypeAsFullyUsed(object_type_id);
143 }
144 
MarkTypeAsFullyUsed(uint32_t type_id)145 void EliminateDeadMembersPass::MarkTypeAsFullyUsed(uint32_t type_id) {
146   Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
147   assert(type_inst != nullptr);
148 
149   switch (type_inst->opcode()) {
150     case spv::Op::OpTypeStruct:
151       // Mark every member and its type as fully used.
152       for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
153         used_members_[type_id].insert(i);
154         MarkTypeAsFullyUsed(type_inst->GetSingleWordInOperand(i));
155       }
156       break;
157     case spv::Op::OpTypeArray:
158     case spv::Op::OpTypeRuntimeArray:
159       MarkTypeAsFullyUsed(
160           type_inst->GetSingleWordInOperand(kArrayElementTypeIdx));
161       break;
162     default:
163       break;
164   }
165 }
166 
MarkPointeeTypeAsFullUsed(uint32_t ptr_type_id)167 void EliminateDeadMembersPass::MarkPointeeTypeAsFullUsed(uint32_t ptr_type_id) {
168   Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
169   assert(ptr_type_inst->opcode() == spv::Op::OpTypePointer);
170   MarkTypeAsFullyUsed(ptr_type_inst->GetSingleWordInOperand(1));
171 }
172 
MarkMembersAsLiveForCopyMemory(const Instruction * inst)173 void EliminateDeadMembersPass::MarkMembersAsLiveForCopyMemory(
174     const Instruction* inst) {
175   uint32_t target_id = inst->GetSingleWordInOperand(0);
176   Instruction* target_inst = get_def_use_mgr()->GetDef(target_id);
177   uint32_t pointer_type_id = target_inst->type_id();
178   Instruction* pointer_type_inst = get_def_use_mgr()->GetDef(pointer_type_id);
179   uint32_t type_id = pointer_type_inst->GetSingleWordInOperand(1);
180   MarkTypeAsFullyUsed(type_id);
181 }
182 
MarkMembersAsLiveForExtract(const Instruction * inst)183 void EliminateDeadMembersPass::MarkMembersAsLiveForExtract(
184     const Instruction* inst) {
185   assert(inst->opcode() == spv::Op::OpCompositeExtract ||
186          (inst->opcode() == spv::Op::OpSpecConstantOp &&
187           spv::Op(inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx)) ==
188               spv::Op::OpCompositeExtract));
189 
190   uint32_t first_operand =
191       (inst->opcode() == spv::Op::OpSpecConstantOp ? 1 : 0);
192   uint32_t composite_id = inst->GetSingleWordInOperand(first_operand);
193   Instruction* composite_inst = get_def_use_mgr()->GetDef(composite_id);
194   uint32_t type_id = composite_inst->type_id();
195 
196   for (uint32_t i = first_operand + 1; i < inst->NumInOperands(); ++i) {
197     Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
198     uint32_t member_idx = inst->GetSingleWordInOperand(i);
199     switch (type_inst->opcode()) {
200       case spv::Op::OpTypeStruct:
201         used_members_[type_id].insert(member_idx);
202         type_id = type_inst->GetSingleWordInOperand(member_idx);
203         break;
204       case spv::Op::OpTypeArray:
205       case spv::Op::OpTypeRuntimeArray:
206       case spv::Op::OpTypeVector:
207       case spv::Op::OpTypeMatrix:
208       case spv::Op::OpTypeCooperativeMatrixNV:
209       case spv::Op::OpTypeCooperativeMatrixKHR:
210         type_id = type_inst->GetSingleWordInOperand(0);
211         break;
212       default:
213         assert(false);
214     }
215   }
216 }
217 
MarkMembersAsLiveForAccessChain(const Instruction * inst)218 void EliminateDeadMembersPass::MarkMembersAsLiveForAccessChain(
219     const Instruction* inst) {
220   assert(inst->opcode() == spv::Op::OpAccessChain ||
221          inst->opcode() == spv::Op::OpInBoundsAccessChain ||
222          inst->opcode() == spv::Op::OpPtrAccessChain ||
223          inst->opcode() == spv::Op::OpInBoundsPtrAccessChain);
224 
225   uint32_t pointer_id = inst->GetSingleWordInOperand(0);
226   Instruction* pointer_inst = get_def_use_mgr()->GetDef(pointer_id);
227   uint32_t pointer_type_id = pointer_inst->type_id();
228   Instruction* pointer_type_inst = get_def_use_mgr()->GetDef(pointer_type_id);
229   uint32_t type_id = pointer_type_inst->GetSingleWordInOperand(1);
230 
231   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
232 
233   // For a pointer access chain, we need to skip the |element| index.  It is not
234   // a reference to the member of a struct, and it does not change the type.
235   uint32_t i = (inst->opcode() == spv::Op::OpAccessChain ||
236                         inst->opcode() == spv::Op::OpInBoundsAccessChain
237                     ? 1
238                     : 2);
239   for (; i < inst->NumInOperands(); ++i) {
240     Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
241     switch (type_inst->opcode()) {
242       case spv::Op::OpTypeStruct: {
243         const analysis::IntConstant* member_idx =
244             const_mgr->FindDeclaredConstant(inst->GetSingleWordInOperand(i))
245                 ->AsIntConstant();
246         assert(member_idx);
247         uint32_t index =
248             static_cast<uint32_t>(member_idx->GetZeroExtendedValue());
249         used_members_[type_id].insert(index);
250         type_id = type_inst->GetSingleWordInOperand(index);
251       } break;
252       case spv::Op::OpTypeArray:
253       case spv::Op::OpTypeRuntimeArray:
254       case spv::Op::OpTypeVector:
255       case spv::Op::OpTypeMatrix:
256       case spv::Op::OpTypeCooperativeMatrixNV:
257       case spv::Op::OpTypeCooperativeMatrixKHR:
258         type_id = type_inst->GetSingleWordInOperand(0);
259         break;
260       default:
261         assert(false);
262     }
263   }
264 }
265 
MarkOperandTypeAsFullyUsed(const Instruction * inst,uint32_t in_idx)266 void EliminateDeadMembersPass::MarkOperandTypeAsFullyUsed(
267     const Instruction* inst, uint32_t in_idx) {
268   uint32_t op_id = inst->GetSingleWordInOperand(in_idx);
269   Instruction* op_inst = get_def_use_mgr()->GetDef(op_id);
270   MarkTypeAsFullyUsed(op_inst->type_id());
271 }
272 
MarkMembersAsLiveForArrayLength(const Instruction * inst)273 void EliminateDeadMembersPass::MarkMembersAsLiveForArrayLength(
274     const Instruction* inst) {
275   assert(inst->opcode() == spv::Op::OpArrayLength);
276   uint32_t object_id = inst->GetSingleWordInOperand(0);
277   Instruction* object_inst = get_def_use_mgr()->GetDef(object_id);
278   uint32_t pointer_type_id = object_inst->type_id();
279   Instruction* pointer_type_inst = get_def_use_mgr()->GetDef(pointer_type_id);
280   uint32_t type_id = pointer_type_inst->GetSingleWordInOperand(1);
281   used_members_[type_id].insert(inst->GetSingleWordInOperand(1));
282 }
283 
RemoveDeadMembers()284 bool EliminateDeadMembersPass::RemoveDeadMembers() {
285   bool modified = false;
286 
287   // First update all of the OpTypeStruct instructions.
288   get_module()->ForEachInst([&modified, this](Instruction* inst) {
289     switch (inst->opcode()) {
290       case spv::Op::OpTypeStruct:
291         modified |= UpdateOpTypeStruct(inst);
292         break;
293       default:
294         break;
295     }
296   });
297 
298   // Now update all of the instructions that reference the OpTypeStructs.
299   get_module()->ForEachInst([&modified, this](Instruction* inst) {
300     switch (inst->opcode()) {
301       case spv::Op::OpMemberName:
302         modified |= UpdateOpMemberNameOrDecorate(inst);
303         break;
304       case spv::Op::OpMemberDecorate:
305         modified |= UpdateOpMemberNameOrDecorate(inst);
306         break;
307       case spv::Op::OpGroupMemberDecorate:
308         modified |= UpdateOpGroupMemberDecorate(inst);
309         break;
310       case spv::Op::OpSpecConstantComposite:
311       case spv::Op::OpConstantComposite:
312       case spv::Op::OpCompositeConstruct:
313         modified |= UpdateConstantComposite(inst);
314         break;
315       case spv::Op::OpAccessChain:
316       case spv::Op::OpInBoundsAccessChain:
317       case spv::Op::OpPtrAccessChain:
318       case spv::Op::OpInBoundsPtrAccessChain:
319         modified |= UpdateAccessChain(inst);
320         break;
321       case spv::Op::OpCompositeExtract:
322         modified |= UpdateCompsiteExtract(inst);
323         break;
324       case spv::Op::OpCompositeInsert:
325         modified |= UpdateCompositeInsert(inst);
326         break;
327       case spv::Op::OpArrayLength:
328         modified |= UpdateOpArrayLength(inst);
329         break;
330       case spv::Op::OpSpecConstantOp:
331         switch (spv::Op(inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx))) {
332           case spv::Op::OpCompositeExtract:
333             modified |= UpdateCompsiteExtract(inst);
334             break;
335           case spv::Op::OpCompositeInsert:
336             modified |= UpdateCompositeInsert(inst);
337             break;
338           case spv::Op::OpAccessChain:
339           case spv::Op::OpInBoundsAccessChain:
340           case spv::Op::OpPtrAccessChain:
341           case spv::Op::OpInBoundsPtrAccessChain:
342             assert(false && "Not implemented yet.");
343             break;
344           default:
345             break;
346         }
347         break;
348       default:
349         break;
350     }
351   });
352   return modified;
353 }
354 
UpdateOpTypeStruct(Instruction * inst)355 bool EliminateDeadMembersPass::UpdateOpTypeStruct(Instruction* inst) {
356   assert(inst->opcode() == spv::Op::OpTypeStruct);
357 
358   const auto& live_members = used_members_[inst->result_id()];
359   if (live_members.size() == inst->NumInOperands()) {
360     return false;
361   }
362 
363   Instruction::OperandList new_operands;
364   for (uint32_t idx : live_members) {
365     new_operands.emplace_back(inst->GetInOperand(idx));
366   }
367 
368   inst->SetInOperands(std::move(new_operands));
369   context()->UpdateDefUse(inst);
370   return true;
371 }
372 
UpdateOpMemberNameOrDecorate(Instruction * inst)373 bool EliminateDeadMembersPass::UpdateOpMemberNameOrDecorate(Instruction* inst) {
374   assert(inst->opcode() == spv::Op::OpMemberName ||
375          inst->opcode() == spv::Op::OpMemberDecorate);
376 
377   uint32_t type_id = inst->GetSingleWordInOperand(0);
378   auto live_members = used_members_.find(type_id);
379   if (live_members == used_members_.end()) {
380     return false;
381   }
382 
383   uint32_t orig_member_idx = inst->GetSingleWordInOperand(1);
384   uint32_t new_member_idx = GetNewMemberIndex(type_id, orig_member_idx);
385 
386   if (new_member_idx == kRemovedMember) {
387     context()->KillInst(inst);
388     return true;
389   }
390 
391   if (new_member_idx == orig_member_idx) {
392     return false;
393   }
394 
395   inst->SetInOperand(1, {new_member_idx});
396   return true;
397 }
398 
UpdateOpGroupMemberDecorate(Instruction * inst)399 bool EliminateDeadMembersPass::UpdateOpGroupMemberDecorate(Instruction* inst) {
400   assert(inst->opcode() == spv::Op::OpGroupMemberDecorate);
401 
402   bool modified = false;
403 
404   Instruction::OperandList new_operands;
405   new_operands.emplace_back(inst->GetInOperand(0));
406   for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) {
407     uint32_t type_id = inst->GetSingleWordInOperand(i);
408     uint32_t member_idx = inst->GetSingleWordInOperand(i + 1);
409     uint32_t new_member_idx = GetNewMemberIndex(type_id, member_idx);
410 
411     if (new_member_idx == kRemovedMember) {
412       modified = true;
413       continue;
414     }
415 
416     new_operands.emplace_back(inst->GetOperand(i));
417     if (new_member_idx != member_idx) {
418       new_operands.emplace_back(
419           Operand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {new_member_idx}}));
420       modified = true;
421     } else {
422       new_operands.emplace_back(inst->GetOperand(i + 1));
423     }
424   }
425 
426   if (!modified) {
427     return false;
428   }
429 
430   if (new_operands.size() == 1) {
431     context()->KillInst(inst);
432     return true;
433   }
434 
435   inst->SetInOperands(std::move(new_operands));
436   context()->UpdateDefUse(inst);
437   return true;
438 }
439 
UpdateConstantComposite(Instruction * inst)440 bool EliminateDeadMembersPass::UpdateConstantComposite(Instruction* inst) {
441   assert(inst->opcode() == spv::Op::OpSpecConstantComposite ||
442          inst->opcode() == spv::Op::OpConstantComposite ||
443          inst->opcode() == spv::Op::OpCompositeConstruct);
444   uint32_t type_id = inst->type_id();
445 
446   bool modified = false;
447   Instruction::OperandList new_operands;
448   for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
449     uint32_t new_idx = GetNewMemberIndex(type_id, i);
450     if (new_idx == kRemovedMember) {
451       modified = true;
452     } else {
453       new_operands.emplace_back(inst->GetInOperand(i));
454     }
455   }
456   inst->SetInOperands(std::move(new_operands));
457   context()->UpdateDefUse(inst);
458   return modified;
459 }
460 
UpdateAccessChain(Instruction * inst)461 bool EliminateDeadMembersPass::UpdateAccessChain(Instruction* inst) {
462   assert(inst->opcode() == spv::Op::OpAccessChain ||
463          inst->opcode() == spv::Op::OpInBoundsAccessChain ||
464          inst->opcode() == spv::Op::OpPtrAccessChain ||
465          inst->opcode() == spv::Op::OpInBoundsPtrAccessChain);
466 
467   uint32_t pointer_id = inst->GetSingleWordInOperand(0);
468   Instruction* pointer_inst = get_def_use_mgr()->GetDef(pointer_id);
469   uint32_t pointer_type_id = pointer_inst->type_id();
470   Instruction* pointer_type_inst = get_def_use_mgr()->GetDef(pointer_type_id);
471   uint32_t type_id = pointer_type_inst->GetSingleWordInOperand(1);
472 
473   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
474   Instruction::OperandList new_operands;
475   bool modified = false;
476   new_operands.emplace_back(inst->GetInOperand(0));
477 
478   // For pointer access chains we want to copy the element operand.
479   if (inst->opcode() == spv::Op::OpPtrAccessChain ||
480       inst->opcode() == spv::Op::OpInBoundsPtrAccessChain) {
481     new_operands.emplace_back(inst->GetInOperand(1));
482   }
483 
484   for (uint32_t i = static_cast<uint32_t>(new_operands.size());
485        i < inst->NumInOperands(); ++i) {
486     Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
487     switch (type_inst->opcode()) {
488       case spv::Op::OpTypeStruct: {
489         const analysis::IntConstant* member_idx =
490             const_mgr->FindDeclaredConstant(inst->GetSingleWordInOperand(i))
491                 ->AsIntConstant();
492         assert(member_idx);
493         uint32_t orig_member_idx =
494             static_cast<uint32_t>(member_idx->GetZeroExtendedValue());
495         uint32_t new_member_idx = GetNewMemberIndex(type_id, orig_member_idx);
496         assert(new_member_idx != kRemovedMember);
497         if (orig_member_idx != new_member_idx) {
498           InstructionBuilder ir_builder(
499               context(), inst,
500               IRContext::kAnalysisDefUse |
501                   IRContext::kAnalysisInstrToBlockMapping);
502           uint32_t const_id =
503               ir_builder.GetUintConstant(new_member_idx)->result_id();
504           new_operands.emplace_back(Operand({SPV_OPERAND_TYPE_ID, {const_id}}));
505           modified = true;
506         } else {
507           new_operands.emplace_back(inst->GetInOperand(i));
508         }
509         // The type will have already been rewritten, so use the new member
510         // index.
511         type_id = type_inst->GetSingleWordInOperand(new_member_idx);
512       } break;
513       case spv::Op::OpTypeArray:
514       case spv::Op::OpTypeRuntimeArray:
515       case spv::Op::OpTypeVector:
516       case spv::Op::OpTypeMatrix:
517       case spv::Op::OpTypeCooperativeMatrixNV:
518       case spv::Op::OpTypeCooperativeMatrixKHR:
519         new_operands.emplace_back(inst->GetInOperand(i));
520         type_id = type_inst->GetSingleWordInOperand(0);
521         break;
522       default:
523         assert(false);
524         break;
525     }
526   }
527 
528   if (!modified) {
529     return false;
530   }
531   inst->SetInOperands(std::move(new_operands));
532   context()->UpdateDefUse(inst);
533   return true;
534 }
535 
GetNewMemberIndex(uint32_t type_id,uint32_t member_idx)536 uint32_t EliminateDeadMembersPass::GetNewMemberIndex(uint32_t type_id,
537                                                      uint32_t member_idx) {
538   auto live_members = used_members_.find(type_id);
539   if (live_members == used_members_.end()) {
540     return member_idx;
541   }
542 
543   auto current_member = live_members->second.find(member_idx);
544   if (current_member == live_members->second.end()) {
545     return kRemovedMember;
546   }
547 
548   return static_cast<uint32_t>(
549       std::distance(live_members->second.begin(), current_member));
550 }
551 
UpdateCompsiteExtract(Instruction * inst)552 bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) {
553   assert(inst->opcode() == spv::Op::OpCompositeExtract ||
554          (inst->opcode() == spv::Op::OpSpecConstantOp &&
555           spv::Op(inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx)) ==
556               spv::Op::OpCompositeExtract));
557 
558   uint32_t first_operand = 0;
559   if (inst->opcode() == spv::Op::OpSpecConstantOp) {
560     first_operand = 1;
561   }
562   uint32_t object_id = inst->GetSingleWordInOperand(first_operand);
563   Instruction* object_inst = get_def_use_mgr()->GetDef(object_id);
564   uint32_t type_id = object_inst->type_id();
565 
566   Instruction::OperandList new_operands;
567   bool modified = false;
568   for (uint32_t i = 0; i < first_operand + 1; i++) {
569     new_operands.emplace_back(inst->GetInOperand(i));
570   }
571   for (uint32_t i = first_operand + 1; i < inst->NumInOperands(); ++i) {
572     uint32_t member_idx = inst->GetSingleWordInOperand(i);
573     uint32_t new_member_idx = GetNewMemberIndex(type_id, member_idx);
574     assert(new_member_idx != kRemovedMember);
575     if (member_idx != new_member_idx) {
576       modified = true;
577     }
578     new_operands.emplace_back(
579         Operand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {new_member_idx}}));
580 
581     Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
582     switch (type_inst->opcode()) {
583       case spv::Op::OpTypeStruct:
584         // The type will have already been rewritten, so use the new member
585         // index.
586         type_id = type_inst->GetSingleWordInOperand(new_member_idx);
587         break;
588       case spv::Op::OpTypeArray:
589       case spv::Op::OpTypeRuntimeArray:
590       case spv::Op::OpTypeVector:
591       case spv::Op::OpTypeMatrix:
592       case spv::Op::OpTypeCooperativeMatrixNV:
593       case spv::Op::OpTypeCooperativeMatrixKHR:
594         type_id = type_inst->GetSingleWordInOperand(0);
595         break;
596       default:
597         assert(false);
598     }
599   }
600 
601   if (!modified) {
602     return false;
603   }
604   inst->SetInOperands(std::move(new_operands));
605   context()->UpdateDefUse(inst);
606   return true;
607 }
608 
UpdateCompositeInsert(Instruction * inst)609 bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) {
610   assert(inst->opcode() == spv::Op::OpCompositeInsert ||
611          (inst->opcode() == spv::Op::OpSpecConstantOp &&
612           spv::Op(inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx)) ==
613               spv::Op::OpCompositeInsert));
614 
615   uint32_t first_operand = 0;
616   if (inst->opcode() == spv::Op::OpSpecConstantOp) {
617     first_operand = 1;
618   }
619 
620   uint32_t composite_id = inst->GetSingleWordInOperand(first_operand + 1);
621   Instruction* composite_inst = get_def_use_mgr()->GetDef(composite_id);
622   uint32_t type_id = composite_inst->type_id();
623 
624   Instruction::OperandList new_operands;
625   bool modified = false;
626 
627   for (uint32_t i = 0; i < first_operand + 2; ++i) {
628     new_operands.emplace_back(inst->GetInOperand(i));
629   }
630   for (uint32_t i = first_operand + 2; i < inst->NumInOperands(); ++i) {
631     uint32_t member_idx = inst->GetSingleWordInOperand(i);
632     uint32_t new_member_idx = GetNewMemberIndex(type_id, member_idx);
633     if (new_member_idx == kRemovedMember) {
634       context()->KillInst(inst);
635       return true;
636     }
637 
638     if (member_idx != new_member_idx) {
639       modified = true;
640     }
641     new_operands.emplace_back(
642         Operand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {new_member_idx}}));
643 
644     Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
645     switch (type_inst->opcode()) {
646       case spv::Op::OpTypeStruct:
647         // The type will have already been rewritten, so use the new member
648         // index.
649         type_id = type_inst->GetSingleWordInOperand(new_member_idx);
650         break;
651       case spv::Op::OpTypeArray:
652       case spv::Op::OpTypeRuntimeArray:
653       case spv::Op::OpTypeVector:
654       case spv::Op::OpTypeMatrix:
655       case spv::Op::OpTypeCooperativeMatrixNV:
656       case spv::Op::OpTypeCooperativeMatrixKHR:
657         type_id = type_inst->GetSingleWordInOperand(0);
658         break;
659       default:
660         assert(false);
661     }
662   }
663 
664   if (!modified) {
665     return false;
666   }
667   inst->SetInOperands(std::move(new_operands));
668   context()->UpdateDefUse(inst);
669   return true;
670 }
671 
UpdateOpArrayLength(Instruction * inst)672 bool EliminateDeadMembersPass::UpdateOpArrayLength(Instruction* inst) {
673   uint32_t struct_id = inst->GetSingleWordInOperand(0);
674   Instruction* struct_inst = get_def_use_mgr()->GetDef(struct_id);
675   uint32_t pointer_type_id = struct_inst->type_id();
676   Instruction* pointer_type_inst = get_def_use_mgr()->GetDef(pointer_type_id);
677   uint32_t type_id = pointer_type_inst->GetSingleWordInOperand(1);
678 
679   uint32_t member_idx = inst->GetSingleWordInOperand(1);
680   uint32_t new_member_idx = GetNewMemberIndex(type_id, member_idx);
681   assert(new_member_idx != kRemovedMember);
682 
683   if (member_idx == new_member_idx) {
684     return false;
685   }
686 
687   inst->SetInOperand(1, {new_member_idx});
688   context()->UpdateDefUse(inst);
689   return true;
690 }
691 
MarkStructOperandsAsFullyUsed(const Instruction * inst)692 void EliminateDeadMembersPass::MarkStructOperandsAsFullyUsed(
693     const Instruction* inst) {
694   if (inst->type_id() != 0) {
695     MarkTypeAsFullyUsed(inst->type_id());
696   }
697 
698   inst->ForEachInId([this](const uint32_t* id) {
699     Instruction* instruction = get_def_use_mgr()->GetDef(*id);
700     if (instruction->type_id() != 0) {
701       MarkTypeAsFullyUsed(instruction->type_id());
702     }
703   });
704 }
705 }  // namespace opt
706 }  // namespace spvtools
707