1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/fold_spec_constant_op_and_composite_pass.h"
16
17 #include <algorithm>
18 #include <tuple>
19
20 #include "source/opt/constants.h"
21 #include "source/util/make_unique.h"
22
23 namespace spvtools {
24 namespace opt {
25
Process()26 Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
27 bool modified = false;
28 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
29 // Traverse through all the constant defining instructions. For Normal
30 // Constants whose values are determined and do not depend on OpUndef
31 // instructions, records their values in two internal maps: id_to_const_val_
32 // and const_val_to_id_ so that we can use them to infer the value of Spec
33 // Constants later.
34 // For Spec Constants defined with OpSpecConstantComposite instructions, if
35 // all of their components are Normal Constants, they will be turned into
36 // Normal Constants too. For Spec Constants defined with OpSpecConstantOp
37 // instructions, we check if they only depends on Normal Constants and fold
38 // them when possible. The two maps for Normal Constants: id_to_const_val_
39 // and const_val_to_id_ will be updated along the traversal so that the new
40 // Normal Constants generated from folding can be used to fold following Spec
41 // Constants.
42 // This algorithm depends on the SSA property of SPIR-V when
43 // defining constants. The dependent constants must be defined before the
44 // dependee constants. So a dependent Spec Constant must be defined and
45 // will be processed before its dependee Spec Constant. When we encounter
46 // the dependee Spec Constants, all its dependent constants must have been
47 // processed and all its dependent Spec Constants should have been folded if
48 // possible.
49 Module::inst_iterator next_inst = context()->types_values_begin();
50 for (Module::inst_iterator inst_iter = next_inst;
51 // Need to re-evaluate the end iterator since we may modify the list of
52 // instructions in this section of the module as the process goes.
53 inst_iter != context()->types_values_end(); inst_iter = next_inst) {
54 ++next_inst;
55 Instruction* inst = &*inst_iter;
56 // Collect constant values of normal constants and process the
57 // OpSpecConstantOp and OpSpecConstantComposite instructions if possible.
58 // The constant values will be stored in analysis::Constant instances.
59 // OpConstantSampler instruction is not collected here because it cannot be
60 // used in OpSpecConstant{Composite|Op} instructions.
61 // TODO(qining): If the constant or its type has decoration, we may need
62 // to skip it.
63 if (const_mgr->GetType(inst) &&
64 !const_mgr->GetType(inst)->decoration_empty())
65 continue;
66 switch (spv::Op opcode = inst->opcode()) {
67 // Records the values of Normal Constants.
68 case spv::Op::OpConstantTrue:
69 case spv::Op::OpConstantFalse:
70 case spv::Op::OpConstant:
71 case spv::Op::OpConstantNull:
72 case spv::Op::OpConstantComposite:
73 case spv::Op::OpSpecConstantComposite: {
74 // A Constant instance will be created if the given instruction is a
75 // Normal Constant whose value(s) are fixed. Note that for a composite
76 // Spec Constant defined with OpSpecConstantComposite instruction, if
77 // all of its components are Normal Constants already, the Spec
78 // Constant will be turned in to a Normal Constant. In that case, a
79 // Constant instance should also be created successfully and recorded
80 // in the id_to_const_val_ and const_val_to_id_ mapps.
81 if (auto const_value = const_mgr->GetConstantFromInst(inst)) {
82 // Need to replace the OpSpecConstantComposite instruction with a
83 // corresponding OpConstantComposite instruction.
84 if (opcode == spv::Op::OpSpecConstantComposite) {
85 inst->SetOpcode(spv::Op::OpConstantComposite);
86 modified = true;
87 }
88 const_mgr->MapConstantToInst(const_value, inst);
89 }
90 break;
91 }
92 // For a Spec Constants defined with OpSpecConstantOp instruction, check
93 // if it only depends on Normal Constants. If so, the Spec Constant will
94 // be folded. The original Spec Constant defining instruction will be
95 // replaced by Normal Constant defining instructions, and the new Normal
96 // Constants will be added to id_to_const_val_ and const_val_to_id_ so
97 // that we can use the new Normal Constants when folding following Spec
98 // Constants.
99 case spv::Op::OpSpecConstantOp:
100 modified |= ProcessOpSpecConstantOp(&inst_iter);
101 break;
102 default:
103 break;
104 }
105 }
106 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
107 }
108
ProcessOpSpecConstantOp(Module::inst_iterator * pos)109 bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
110 Module::inst_iterator* pos) {
111 Instruction* inst = &**pos;
112 Instruction* folded_inst = nullptr;
113 assert(inst->GetInOperand(0).type ==
114 SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER &&
115 "The first in-operand of OpSpecConstantOp instruction must be of "
116 "SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type");
117
118 switch (static_cast<spv::Op>(inst->GetSingleWordInOperand(0))) {
119 case spv::Op::OpCompositeExtract:
120 case spv::Op::OpVectorShuffle:
121 case spv::Op::OpCompositeInsert:
122 case spv::Op::OpQuantizeToF16:
123 folded_inst = FoldWithInstructionFolder(pos);
124 break;
125 default:
126 // TODO: This should use the instruction folder as well, but some folding
127 // rules are missing.
128
129 // Component-wise operations.
130 folded_inst = DoComponentWiseOperation(pos);
131 break;
132 }
133 if (!folded_inst) return false;
134
135 // Replace the original constant with the new folded constant, kill the
136 // original constant.
137 uint32_t new_id = folded_inst->result_id();
138 uint32_t old_id = inst->result_id();
139 context()->ReplaceAllUsesWith(old_id, new_id);
140 context()->KillDef(old_id);
141 return true;
142 }
143
FoldWithInstructionFolder(Module::inst_iterator * inst_iter_ptr)144 Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
145 Module::inst_iterator* inst_iter_ptr) {
146 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
147 // If one of operands to the instruction is not a
148 // constant, then we cannot fold this spec constant.
149 for (uint32_t i = 1; i < (*inst_iter_ptr)->NumInOperands(); i++) {
150 const Operand& operand = (*inst_iter_ptr)->GetInOperand(i);
151 if (operand.type != SPV_OPERAND_TYPE_ID &&
152 operand.type != SPV_OPERAND_TYPE_OPTIONAL_ID) {
153 continue;
154 }
155 uint32_t id = operand.words[0];
156 if (const_mgr->FindDeclaredConstant(id) == nullptr) {
157 return nullptr;
158 }
159 }
160
161 // All of the operands are constant. Construct a regular version of the
162 // instruction and pass it to the instruction folder.
163 std::unique_ptr<Instruction> inst((*inst_iter_ptr)->Clone(context()));
164 inst->SetOpcode(
165 static_cast<spv::Op>((*inst_iter_ptr)->GetSingleWordInOperand(0)));
166 inst->RemoveOperand(2);
167
168 // We want the current instruction to be replaced by an |OpConstant*|
169 // instruction in the same position. We need to keep track of which constants
170 // the instruction folder creates, so we can move them into the correct place.
171 auto last_type_value_iter = (context()->types_values_end());
172 --last_type_value_iter;
173 Instruction* last_type_value = &*last_type_value_iter;
174
175 auto identity_map = [](uint32_t id) { return id; };
176 Instruction* new_const_inst =
177 context()->get_instruction_folder().FoldInstructionToConstant(
178 inst.get(), identity_map);
179
180 // new_const_inst == null indicates we cannot fold this spec constant
181 if (!new_const_inst) return nullptr;
182
183 // Get the instruction before |pos| to insert after. |pos| cannot be the
184 // first instruction in the list because its type has to come first.
185 Instruction* insert_pos = (*inst_iter_ptr)->PreviousNode();
186 assert(insert_pos != nullptr &&
187 "pos is the first instruction in the types and values.");
188 bool need_to_clone = true;
189 for (Instruction* i = last_type_value->NextNode(); i != nullptr;
190 i = last_type_value->NextNode()) {
191 if (i == new_const_inst) {
192 need_to_clone = false;
193 }
194 i->InsertAfter(insert_pos);
195 insert_pos = insert_pos->NextNode();
196 }
197
198 if (need_to_clone) {
199 new_const_inst = new_const_inst->Clone(context());
200 new_const_inst->SetResultId(TakeNextId());
201 new_const_inst->InsertAfter(insert_pos);
202 get_def_use_mgr()->AnalyzeInstDefUse(new_const_inst);
203 }
204 const_mgr->MapInst(new_const_inst);
205 return new_const_inst;
206 }
207
208 namespace {
209 // A helper function to check the type for component wise operations. Returns
210 // true if the type:
211 // 1) is bool type;
212 // 2) is 32-bit int type;
213 // 3) is vector of bool type;
214 // 4) is vector of 32-bit integer type.
215 // Otherwise returns false.
IsValidTypeForComponentWiseOperation(const analysis::Type * type)216 bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
217 if (type->AsBool()) {
218 return true;
219 } else if (auto* it = type->AsInteger()) {
220 if (it->width() == 32) return true;
221 } else if (auto* vt = type->AsVector()) {
222 if (vt->element_type()->AsBool()) {
223 return true;
224 } else if (auto* vit = vt->element_type()->AsInteger()) {
225 if (vit->width() == 32) return true;
226 }
227 }
228 return false;
229 }
230
231 // Encodes the integer |value| of in a word vector format appropriate for
232 // representing this value as a operands for a constant definition. Performs
233 // zero-extension/sign-extension/truncation when needed, based on the signess of
234 // the given target type.
235 //
236 // Note: type |type| argument must be either Integer or Bool.
EncodeIntegerAsWords(const analysis::Type & type,uint32_t value)237 utils::SmallVector<uint32_t, 2> EncodeIntegerAsWords(const analysis::Type& type,
238 uint32_t value) {
239 const uint32_t all_ones = ~0;
240 uint32_t bit_width = 0;
241 uint32_t pad_value = 0;
242 bool result_type_signed = false;
243 if (auto* int_ty = type.AsInteger()) {
244 bit_width = int_ty->width();
245 result_type_signed = int_ty->IsSigned();
246 if (result_type_signed && static_cast<int32_t>(value) < 0) {
247 pad_value = all_ones;
248 }
249 } else if (type.AsBool()) {
250 bit_width = 1;
251 } else {
252 assert(false && "type must be Integer or Bool");
253 }
254
255 assert(bit_width > 0);
256 uint32_t first_word = value;
257 const uint32_t bits_per_word = 32;
258
259 // Truncate first_word if the |type| has width less than uint32.
260 if (bit_width < bits_per_word) {
261 const uint32_t num_high_bits_to_mask = bits_per_word - bit_width;
262 const bool is_negative_after_truncation =
263 result_type_signed &&
264 utils::IsBitAtPositionSet(first_word, bit_width - 1);
265
266 if (is_negative_after_truncation) {
267 // Truncate and sign-extend |first_word|. No padding words will be
268 // added and |pad_value| can be left as-is.
269 first_word = utils::SetHighBits(first_word, num_high_bits_to_mask);
270 } else {
271 first_word = utils::ClearHighBits(first_word, num_high_bits_to_mask);
272 }
273 }
274
275 utils::SmallVector<uint32_t, 2> words = {first_word};
276 for (uint32_t current_bit = bits_per_word; current_bit < bit_width;
277 current_bit += bits_per_word) {
278 words.push_back(pad_value);
279 }
280
281 return words;
282 }
283 } // namespace
284
DoComponentWiseOperation(Module::inst_iterator * pos)285 Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
286 Module::inst_iterator* pos) {
287 const Instruction* inst = &**pos;
288 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
289 const analysis::Type* result_type = const_mgr->GetType(inst);
290 spv::Op spec_opcode = static_cast<spv::Op>(inst->GetSingleWordInOperand(0));
291 // Check and collect operands.
292 std::vector<const analysis::Constant*> operands;
293
294 if (!std::all_of(
295 inst->cbegin(), inst->cend(), [&operands, this](const Operand& o) {
296 // skip the operands that is not an id.
297 if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID) return true;
298 uint32_t id = o.words.front();
299 if (auto c =
300 context()->get_constant_mgr()->FindDeclaredConstant(id)) {
301 if (IsValidTypeForComponentWiseOperation(c->type())) {
302 operands.push_back(c);
303 return true;
304 }
305 }
306 return false;
307 }))
308 return nullptr;
309
310 if (result_type->AsInteger() || result_type->AsBool()) {
311 // Scalar operation
312 const uint32_t result_val =
313 context()->get_instruction_folder().FoldScalars(spec_opcode, operands);
314 auto result_const = const_mgr->GetConstant(
315 result_type, EncodeIntegerAsWords(*result_type, result_val));
316 return const_mgr->BuildInstructionAndAddToModule(result_const, pos);
317 } else if (result_type->AsVector()) {
318 // Vector operation
319 const analysis::Type* element_type =
320 result_type->AsVector()->element_type();
321 uint32_t num_dims = result_type->AsVector()->element_count();
322 std::vector<uint32_t> result_vec =
323 context()->get_instruction_folder().FoldVectors(spec_opcode, num_dims,
324 operands);
325 std::vector<const analysis::Constant*> result_vector_components;
326 for (const uint32_t r : result_vec) {
327 if (auto rc = const_mgr->GetConstant(
328 element_type, EncodeIntegerAsWords(*element_type, r))) {
329 result_vector_components.push_back(rc);
330 if (!const_mgr->BuildInstructionAndAddToModule(rc, pos)) {
331 assert(false &&
332 "Failed to build and insert constant declaring instruction "
333 "for the given vector component constant");
334 }
335 } else {
336 assert(false && "Failed to create constants with 32-bit word");
337 }
338 }
339 auto new_vec_const = MakeUnique<analysis::VectorConstant>(
340 result_type->AsVector(), result_vector_components);
341 auto reg_vec_const = const_mgr->RegisterConstant(std::move(new_vec_const));
342 return const_mgr->BuildInstructionAndAddToModule(reg_vec_const, pos);
343 } else {
344 // Cannot process invalid component wise operation. The result of component
345 // wise operation must be of integer or bool scalar or vector of
346 // integer/bool type.
347 return nullptr;
348 }
349 }
350
351 } // namespace opt
352 } // namespace spvtools
353