• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/combine_access_chains.h"
16 
17 #include <utility>
18 
19 #include "source/opt/constants.h"
20 #include "source/opt/ir_builder.h"
21 #include "source/opt/ir_context.h"
22 
23 namespace spvtools {
24 namespace opt {
25 
Process()26 Pass::Status CombineAccessChains::Process() {
27   bool modified = false;
28 
29   for (auto& function : *get_module()) {
30     modified |= ProcessFunction(function);
31   }
32 
33   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
34 }
35 
ProcessFunction(Function & function)36 bool CombineAccessChains::ProcessFunction(Function& function) {
37   if (function.IsDeclaration()) {
38     return false;
39   }
40 
41   bool modified = false;
42 
43   cfg()->ForEachBlockInReversePostOrder(
44       function.entry().get(), [&modified, this](BasicBlock* block) {
45         block->ForEachInst([&modified, this](Instruction* inst) {
46           switch (inst->opcode()) {
47             case SpvOpAccessChain:
48             case SpvOpInBoundsAccessChain:
49             case SpvOpPtrAccessChain:
50             case SpvOpInBoundsPtrAccessChain:
51               modified |= CombineAccessChain(inst);
52               break;
53             default:
54               break;
55           }
56         });
57       });
58 
59   return modified;
60 }
61 
GetConstantValue(const analysis::Constant * constant_inst)62 uint32_t CombineAccessChains::GetConstantValue(
63     const analysis::Constant* constant_inst) {
64   if (constant_inst->type()->AsInteger()->width() <= 32) {
65     if (constant_inst->type()->AsInteger()->IsSigned()) {
66       return static_cast<uint32_t>(constant_inst->GetS32());
67     } else {
68       return constant_inst->GetU32();
69     }
70   } else {
71     assert(false);
72     return 0u;
73   }
74 }
75 
GetArrayStride(const Instruction * inst)76 uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) {
77   uint32_t array_stride = 0;
78   context()->get_decoration_mgr()->WhileEachDecoration(
79       inst->type_id(), SpvDecorationArrayStride,
80       [&array_stride](const Instruction& decoration) {
81         assert(decoration.opcode() != SpvOpDecorateId);
82         if (decoration.opcode() == SpvOpDecorate) {
83           array_stride = decoration.GetSingleWordInOperand(1);
84         } else {
85           array_stride = decoration.GetSingleWordInOperand(2);
86         }
87         return false;
88       });
89   return array_stride;
90 }
91 
GetIndexedType(Instruction * inst)92 const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) {
93   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
94   analysis::TypeManager* type_mgr = context()->get_type_mgr();
95 
96   Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
97   const analysis::Type* type = type_mgr->GetType(base_ptr->type_id());
98   assert(type->AsPointer());
99   type = type->AsPointer()->pointee_type();
100   std::vector<uint32_t> element_indices;
101   uint32_t starting_index = 1;
102   if (IsPtrAccessChain(inst->opcode())) {
103     // Skip the first index of OpPtrAccessChain as it does not affect type
104     // resolution.
105     starting_index = 2;
106   }
107   for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
108     Instruction* index_inst =
109         def_use_mgr->GetDef(inst->GetSingleWordInOperand(i));
110     const analysis::Constant* index_constant =
111         context()->get_constant_mgr()->GetConstantFromInst(index_inst);
112     if (index_constant) {
113       uint32_t index_value = GetConstantValue(index_constant);
114       element_indices.push_back(index_value);
115     } else {
116       // This index must not matter to resolve the type in valid SPIR-V.
117       element_indices.push_back(0);
118     }
119   }
120   type = type_mgr->GetMemberType(type, element_indices);
121   return type;
122 }
123 
CombineIndices(Instruction * ptr_input,Instruction * inst,std::vector<Operand> * new_operands)124 bool CombineAccessChains::CombineIndices(Instruction* ptr_input,
125                                          Instruction* inst,
126                                          std::vector<Operand>* new_operands) {
127   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
128   analysis::ConstantManager* constant_mgr = context()->get_constant_mgr();
129 
130   Instruction* last_index_inst = def_use_mgr->GetDef(
131       ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1));
132   const analysis::Constant* last_index_constant =
133       constant_mgr->GetConstantFromInst(last_index_inst);
134 
135   Instruction* element_inst =
136       def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
137   const analysis::Constant* element_constant =
138       constant_mgr->GetConstantFromInst(element_inst);
139 
140   // Combine the last index of the AccessChain (|ptr_inst|) with the element
141   // operand of the PtrAccessChain (|inst|).
142   const bool combining_element_operands =
143       IsPtrAccessChain(inst->opcode()) &&
144       IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2;
145   uint32_t new_value_id = 0;
146   const analysis::Type* type = GetIndexedType(ptr_input);
147   if (last_index_constant && element_constant) {
148     // Combine the constants.
149     uint32_t new_value = GetConstantValue(last_index_constant) +
150                          GetConstantValue(element_constant);
151     const analysis::Constant* new_value_constant =
152         constant_mgr->GetConstant(last_index_constant->type(), {new_value});
153     Instruction* new_value_inst =
154         constant_mgr->GetDefiningInstruction(new_value_constant);
155     new_value_id = new_value_inst->result_id();
156   } else if (!type->AsStruct() || combining_element_operands) {
157     // Generate an addition of the two indices.
158     InstructionBuilder builder(
159         context(), inst,
160         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
161     Instruction* addition = builder.AddIAdd(last_index_inst->type_id(),
162                                             last_index_inst->result_id(),
163                                             element_inst->result_id());
164     new_value_id = addition->result_id();
165   } else {
166     // Indexing into structs must be constant, so bail out here.
167     return false;
168   }
169   new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}});
170   return true;
171 }
172 
CreateNewInputOperands(Instruction * ptr_input,Instruction * inst,std::vector<Operand> * new_operands)173 bool CombineAccessChains::CreateNewInputOperands(
174     Instruction* ptr_input, Instruction* inst,
175     std::vector<Operand>* new_operands) {
176   // Start by copying all the input operands of the feeder access chain.
177   for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) {
178     new_operands->push_back(ptr_input->GetInOperand(i));
179   }
180 
181   // Deal with the last index of the feeder access chain.
182   if (IsPtrAccessChain(inst->opcode())) {
183     // The last index of the feeder should be combined with the element operand
184     // of |inst|.
185     if (!CombineIndices(ptr_input, inst, new_operands)) return false;
186   } else {
187     // The indices aren't being combined so now add the last index operand of
188     // |ptr_input|.
189     new_operands->push_back(
190         ptr_input->GetInOperand(ptr_input->NumInOperands() - 1));
191   }
192 
193   // Copy the remaining index operands.
194   uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1;
195   for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
196     new_operands->push_back(inst->GetInOperand(i));
197   }
198 
199   return true;
200 }
201 
CombineAccessChain(Instruction * inst)202 bool CombineAccessChains::CombineAccessChain(Instruction* inst) {
203   assert((inst->opcode() == SpvOpPtrAccessChain ||
204           inst->opcode() == SpvOpAccessChain ||
205           inst->opcode() == SpvOpInBoundsAccessChain ||
206           inst->opcode() == SpvOpInBoundsPtrAccessChain) &&
207          "Wrong opcode. Expected an access chain.");
208 
209   Instruction* ptr_input =
210       context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0));
211   if (ptr_input->opcode() != SpvOpAccessChain &&
212       ptr_input->opcode() != SpvOpInBoundsAccessChain &&
213       ptr_input->opcode() != SpvOpPtrAccessChain &&
214       ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) {
215     return false;
216   }
217 
218   if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false;
219 
220   // Handles the following cases:
221   // 1. |ptr_input| is an index-less access chain. Replace the pointer
222   //    in |inst| with |ptr_input|'s pointer.
223   // 2. |inst| is a index-less access chain. Change |inst| to an
224   //    OpCopyObject.
225   // 3. |inst| is not a pointer access chain.
226   //    |inst|'s indices are appended to |ptr_input|'s indices.
227   // 4. |ptr_input| is not pointer access chain.
228   //    |inst| is a pointer access chain.
229   //    |inst|'s element operand is combined with the last index in
230   //    |ptr_input| to form a new operand.
231   // 5. |ptr_input| is a pointer access chain.
232   //    Like the above scenario, |inst|'s element operand is combined
233   //    with |ptr_input|'s last index. This results is either a
234   //    combined element operand or combined regular index.
235 
236   // TODO(alan-baker): Support this properly. Requires analyzing the
237   // size/alignment of the type and converting the stride into an element
238   // index.
239   uint32_t array_stride = GetArrayStride(ptr_input);
240   if (array_stride != 0) return false;
241 
242   if (ptr_input->NumInOperands() == 1) {
243     // The input is effectively a no-op.
244     inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)});
245     context()->AnalyzeUses(inst);
246   } else if (inst->NumInOperands() == 1) {
247     // |inst| is a no-op, change it to a copy. Instruction simplification will
248     // clean it up.
249     inst->SetOpcode(SpvOpCopyObject);
250   } else {
251     std::vector<Operand> new_operands;
252     if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false;
253 
254     // Update the instruction.
255     inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode()));
256     inst->SetInOperands(std::move(new_operands));
257     context()->AnalyzeUses(inst);
258   }
259   return true;
260 }
261 
UpdateOpcode(SpvOp base_opcode,SpvOp input_opcode)262 SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) {
263   auto IsInBounds = [](SpvOp opcode) {
264     return opcode == SpvOpInBoundsPtrAccessChain ||
265            opcode == SpvOpInBoundsAccessChain;
266   };
267 
268   if (input_opcode == SpvOpInBoundsPtrAccessChain) {
269     if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain;
270   } else if (input_opcode == SpvOpInBoundsAccessChain) {
271     if (!IsInBounds(base_opcode)) return SpvOpAccessChain;
272   }
273 
274   return input_opcode;
275 }
276 
IsPtrAccessChain(SpvOp opcode)277 bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) {
278   return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain;
279 }
280 
Has64BitIndices(Instruction * inst)281 bool CombineAccessChains::Has64BitIndices(Instruction* inst) {
282   for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
283     Instruction* index_inst =
284         context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i));
285     const analysis::Type* index_type =
286         context()->get_type_mgr()->GetType(index_inst->type_id());
287     if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32)
288       return true;
289   }
290   return false;
291 }
292 
293 }  // namespace opt
294 }  // namespace spvtools
295