1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "fold_spec_constant_op_and_composite_pass.h"
16
17 #include <algorithm>
18 #include <initializer_list>
19 #include <tuple>
20
21 #include "constants.h"
22 #include "make_unique.h"
23
24 namespace spvtools {
25 namespace opt {
26
27 namespace {
28 // Returns the single-word result from performing the given unary operation on
29 // the operand value which is passed in as a 32-bit word.
UnaryOperate(SpvOp opcode,uint32_t operand)30 uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
31 switch (opcode) {
32 // Arthimetics
33 case SpvOp::SpvOpSNegate:
34 return -static_cast<int32_t>(operand);
35 case SpvOp::SpvOpNot:
36 return ~operand;
37 case SpvOp::SpvOpLogicalNot:
38 return !static_cast<bool>(operand);
39 default:
40 assert(false &&
41 "Unsupported unary operation for OpSpecConstantOp instruction");
42 return 0u;
43 }
44 }
45
46 // Returns the single-word result from performing the given binary operation on
47 // the operand values which are passed in as two 32-bit word.
BinaryOperate(SpvOp opcode,uint32_t a,uint32_t b)48 uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
49 switch (opcode) {
50 // Arthimetics
51 case SpvOp::SpvOpIAdd:
52 return a + b;
53 case SpvOp::SpvOpISub:
54 return a - b;
55 case SpvOp::SpvOpIMul:
56 return a * b;
57 case SpvOp::SpvOpUDiv:
58 assert(b != 0);
59 return a / b;
60 case SpvOp::SpvOpSDiv:
61 assert(b != 0u);
62 return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
63 case SpvOp::SpvOpSRem: {
64 // The sign of non-zero result comes from the first operand: a. This is
65 // guaranteed by C++11 rules for integer division operator. The division
66 // result is rounded toward zero, so the result of '%' has the sign of
67 // the first operand.
68 assert(b != 0u);
69 return static_cast<int32_t>(a) % static_cast<int32_t>(b);
70 }
71 case SpvOp::SpvOpSMod: {
72 // The sign of non-zero result comes from the second operand: b
73 assert(b != 0u);
74 int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
75 int32_t b_prim = static_cast<int32_t>(b);
76 return (rem + b_prim) % b_prim;
77 }
78 case SpvOp::SpvOpUMod:
79 assert(b != 0u);
80 return (a % b);
81
82 // Shifting
83 case SpvOp::SpvOpShiftRightLogical: {
84 return a >> b;
85 }
86 case SpvOp::SpvOpShiftRightArithmetic:
87 return (static_cast<int32_t>(a)) >> b;
88 case SpvOp::SpvOpShiftLeftLogical:
89 return a << b;
90
91 // Bitwise operations
92 case SpvOp::SpvOpBitwiseOr:
93 return a | b;
94 case SpvOp::SpvOpBitwiseAnd:
95 return a & b;
96 case SpvOp::SpvOpBitwiseXor:
97 return a ^ b;
98
99 // Logical
100 case SpvOp::SpvOpLogicalEqual:
101 return (static_cast<bool>(a)) == (static_cast<bool>(b));
102 case SpvOp::SpvOpLogicalNotEqual:
103 return (static_cast<bool>(a)) != (static_cast<bool>(b));
104 case SpvOp::SpvOpLogicalOr:
105 return (static_cast<bool>(a)) || (static_cast<bool>(b));
106 case SpvOp::SpvOpLogicalAnd:
107 return (static_cast<bool>(a)) && (static_cast<bool>(b));
108
109 // Comparison
110 case SpvOp::SpvOpIEqual:
111 return a == b;
112 case SpvOp::SpvOpINotEqual:
113 return a != b;
114 case SpvOp::SpvOpULessThan:
115 return a < b;
116 case SpvOp::SpvOpSLessThan:
117 return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
118 case SpvOp::SpvOpUGreaterThan:
119 return a > b;
120 case SpvOp::SpvOpSGreaterThan:
121 return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
122 case SpvOp::SpvOpULessThanEqual:
123 return a <= b;
124 case SpvOp::SpvOpSLessThanEqual:
125 return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
126 case SpvOp::SpvOpUGreaterThanEqual:
127 return a >= b;
128 case SpvOp::SpvOpSGreaterThanEqual:
129 return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
130 default:
131 assert(false &&
132 "Unsupported binary operation for OpSpecConstantOp instruction");
133 return 0u;
134 }
135 }
136
137 // Returns the single-word result from performing the given ternary operation
138 // on the operand values which are passed in as three 32-bit word.
TernaryOperate(SpvOp opcode,uint32_t a,uint32_t b,uint32_t c)139 uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
140 switch (opcode) {
141 case SpvOp::SpvOpSelect:
142 return (static_cast<bool>(a)) ? b : c;
143 default:
144 assert(false &&
145 "Unsupported ternary operation for OpSpecConstantOp instruction");
146 return 0u;
147 }
148 }
149
150 // Returns the single-word result from performing the given operation on the
151 // operand words. This only works with 32-bit operations and uses boolean
152 // convention that 0u is false, and anything else is boolean true.
153 // TODO(qining): Support operands other than 32-bit wide.
OperateWords(SpvOp opcode,const std::vector<uint32_t> & operand_words)154 uint32_t OperateWords(SpvOp opcode,
155 const std::vector<uint32_t>& operand_words) {
156 switch (operand_words.size()) {
157 case 1:
158 return UnaryOperate(opcode, operand_words.front());
159 case 2:
160 return BinaryOperate(opcode, operand_words.front(), operand_words.back());
161 case 3:
162 return TernaryOperate(opcode, operand_words[0], operand_words[1],
163 operand_words[2]);
164 default:
165 assert(false && "Invalid number of operands");
166 return 0;
167 }
168 }
169
170 // Returns the result of performing an operation on scalar constant operands.
171 // This function extracts the operand values as 32 bit words and returns the
172 // result in 32 bit word. Scalar constants with longer than 32-bit width are
173 // not accepted in this function.
OperateScalars(SpvOp opcode,const std::vector<analysis::Constant * > & operands)174 uint32_t OperateScalars(SpvOp opcode,
175 const std::vector<analysis::Constant*>& operands) {
176 std::vector<uint32_t> operand_values_in_raw_words;
177 for (analysis::Constant* operand : operands) {
178 if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
179 const auto& scalar_words = scalar->words();
180 assert(scalar_words.size() == 1 &&
181 "Scalar constants with longer than 32-bit width are not allowed "
182 "in OperateScalars()");
183 operand_values_in_raw_words.push_back(scalar_words.front());
184 } else if (operand->AsNullConstant()) {
185 operand_values_in_raw_words.push_back(0u);
186 } else {
187 assert(false &&
188 "OperateScalars() only accepts ScalarConst or NullConst type of "
189 "constant");
190 }
191 }
192 return OperateWords(opcode, operand_values_in_raw_words);
193 }
194
195 // Returns the result of performing an operation over constant vectors. This
196 // function iterates through the given vector type constant operands and
197 // calculates the result for each element of the result vector to return.
198 // Vectors with longer than 32-bit scalar components are not accepted in this
199 // function.
OperateVectors(SpvOp opcode,uint32_t num_dims,const std::vector<analysis::Constant * > & operands)200 std::vector<uint32_t> OperateVectors(
201 SpvOp opcode, uint32_t num_dims,
202 const std::vector<analysis::Constant*>& operands) {
203 std::vector<uint32_t> result;
204 for (uint32_t d = 0; d < num_dims; d++) {
205 std::vector<uint32_t> operand_values_for_one_dimension;
206 for (analysis::Constant* operand : operands) {
207 if (analysis::VectorConstant* vector_operand =
208 operand->AsVectorConstant()) {
209 // Extract the raw value of the scalar component constants
210 // in 32-bit words here. The reason of not using OperateScalars() here
211 // is that we do not create temporary null constants as components
212 // when the vector operand is a NullConstant because Constant creation
213 // may need extra checks for the validity and that is not manageed in
214 // here.
215 if (const analysis::ScalarConstant* scalar_component =
216 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
217 const auto& scalar_words = scalar_component->words();
218 assert(
219 scalar_words.size() == 1 &&
220 "Vector components with longer than 32-bit width are not allowed "
221 "in OperateVectors()");
222 operand_values_for_one_dimension.push_back(scalar_words.front());
223 } else if (operand->AsNullConstant()) {
224 operand_values_for_one_dimension.push_back(0u);
225 } else {
226 assert(false &&
227 "VectorConst should only has ScalarConst or NullConst as "
228 "components");
229 }
230 } else if (operand->AsNullConstant()) {
231 operand_values_for_one_dimension.push_back(0u);
232 } else {
233 assert(false &&
234 "OperateVectors() only accepts VectorConst or NullConst type of "
235 "constant");
236 }
237 }
238 result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
239 }
240 return result;
241 }
242 } // anonymous namespace
243
FoldSpecConstantOpAndCompositePass()244 FoldSpecConstantOpAndCompositePass::FoldSpecConstantOpAndCompositePass()
245 : max_id_(0),
246 module_(nullptr),
247 def_use_mgr_(nullptr),
248 type_mgr_(nullptr),
249 id_to_const_val_() {}
250
Process(ir::Module * module)251 Pass::Status FoldSpecConstantOpAndCompositePass::Process(ir::Module* module) {
252 Initialize(module);
253 return ProcessImpl(module);
254 }
255
Initialize(ir::Module * module)256 void FoldSpecConstantOpAndCompositePass::Initialize(ir::Module* module) {
257 type_mgr_.reset(new analysis::TypeManager(consumer(), *module));
258 def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module));
259 for (const auto& id_def : def_use_mgr_->id_to_defs()) {
260 max_id_ = std::max(max_id_, id_def.first);
261 }
262 module_ = module;
263 };
264
ProcessImpl(ir::Module * module)265 Pass::Status FoldSpecConstantOpAndCompositePass::ProcessImpl(
266 ir::Module* module) {
267 bool modified = false;
268 // Traverse through all the constant defining instructions. For Normal
269 // Constants whose values are determined and do not depend on OpUndef
270 // instructions, records their values in two internal maps: id_to_const_val_
271 // and const_val_to_id_ so that we can use them to infer the value of Spec
272 // Constants later.
273 // For Spec Constants defined with OpSpecConstantComposite instructions, if
274 // all of their components are Normal Constants, they will be turned into
275 // Normal Constants too. For Spec Constants defined with OpSpecConstantOp
276 // instructions, we check if they only depends on Normal Constants and fold
277 // them when possible. The two maps for Normal Constants: id_to_const_val_
278 // and const_val_to_id_ will be updated along the traversal so that the new
279 // Normal Constants generated from folding can be used to fold following Spec
280 // Constants.
281 // This algorithm depends on the SSA property of SPIR-V when
282 // defining constants. The dependent constants must be defined before the
283 // dependee constants. So a dependent Spec Constant must be defined and
284 // will be processed before its dependee Spec Constant. When we encounter
285 // the dependee Spec Constants, all its dependent constants must have been
286 // processed and all its dependent Spec Constants should have been folded if
287 // possible.
288 for (ir::Module::inst_iterator inst_iter = module->types_values_begin();
289 // Need to re-evaluate the end iterator since we may modify the list of
290 // instructions in this section of the module as the process goes.
291 inst_iter != module->types_values_end(); ++inst_iter) {
292 ir::Instruction* inst = &*inst_iter;
293 // Collect constant values of normal constants and process the
294 // OpSpecConstantOp and OpSpecConstantComposite instructions if possible.
295 // The constant values will be stored in analysis::Constant instances.
296 // OpConstantSampler instruction is not collected here because it cannot be
297 // used in OpSpecConstant{Composite|Op} instructions.
298 // TODO(qining): If the constant or its type has decoration, we may need
299 // to skip it.
300 if (GetType(inst) && !GetType(inst)->decoration_empty()) continue;
301 switch (SpvOp opcode = inst->opcode()) {
302 // Records the values of Normal Constants.
303 case SpvOp::SpvOpConstantTrue:
304 case SpvOp::SpvOpConstantFalse:
305 case SpvOp::SpvOpConstant:
306 case SpvOp::SpvOpConstantNull:
307 case SpvOp::SpvOpConstantComposite:
308 case SpvOp::SpvOpSpecConstantComposite: {
309 // A Constant instance will be created if the given instruction is a
310 // Normal Constant whose value(s) are fixed. Note that for a composite
311 // Spec Constant defined with OpSpecConstantComposite instruction, if
312 // all of its components are Normal Constants already, the Spec
313 // Constant will be turned in to a Normal Constant. In that case, a
314 // Constant instance should also be created successfully and recorded
315 // in the id_to_const_val_ and const_val_to_id_ mapps.
316 if (auto const_value = CreateConstFromInst(inst)) {
317 // Need to replace the OpSpecConstantComposite instruction with a
318 // corresponding OpConstantComposite instruction.
319 if (opcode == SpvOp::SpvOpSpecConstantComposite) {
320 inst->SetOpcode(SpvOp::SpvOpConstantComposite);
321 modified = true;
322 }
323 const_val_to_id_[const_value.get()] = inst->result_id();
324 id_to_const_val_[inst->result_id()] = std::move(const_value);
325 }
326 break;
327 }
328 // For a Spec Constants defined with OpSpecConstantOp instruction, check
329 // if it only depends on Normal Constants. If so, the Spec Constant will
330 // be folded. The original Spec Constant defining instruction will be
331 // replaced by Normal Constant defining instructions, and the new Normal
332 // Constants will be added to id_to_const_val_ and const_val_to_id_ so
333 // that we can use the new Normal Constants when folding following Spec
334 // Constants.
335 case SpvOp::SpvOpSpecConstantOp:
336 modified |= ProcessOpSpecConstantOp(&inst_iter);
337 break;
338 default:
339 break;
340 }
341 }
342 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
343 }
344
ProcessOpSpecConstantOp(ir::Module::inst_iterator * pos)345 bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
346 ir::Module::inst_iterator* pos) {
347 ir::Instruction* inst = &**pos;
348 ir::Instruction* folded_inst = nullptr;
349 assert(inst->GetInOperand(0).type ==
350 SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER &&
351 "The first in-operand of OpSpecContantOp instruction must be of "
352 "SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type");
353
354 switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) {
355 case SpvOp::SpvOpCompositeExtract:
356 folded_inst = DoCompositeExtract(pos);
357 break;
358 case SpvOp::SpvOpVectorShuffle:
359 folded_inst = DoVectorShuffle(pos);
360 break;
361
362 case SpvOp::SpvOpCompositeInsert:
363 // Current Glslang does not generate code with OpSpecConstantOp
364 // CompositeInsert instruction, so this is not implmented so far.
365 // TODO(qining): Implement CompositeInsert case.
366 return false;
367
368 default:
369 // Component-wise operations.
370 folded_inst = DoComponentWiseOperation(pos);
371 break;
372 }
373 if (!folded_inst) return false;
374
375 // Replace the original constant with the new folded constant, kill the
376 // original constant.
377 uint32_t new_id = folded_inst->result_id();
378 uint32_t old_id = inst->result_id();
379 def_use_mgr_->ReplaceAllUsesWith(old_id, new_id);
380 def_use_mgr_->KillDef(old_id);
381 return true;
382 }
383
DoCompositeExtract(ir::Module::inst_iterator * pos)384 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
385 ir::Module::inst_iterator* pos) {
386 ir::Instruction* inst = &**pos;
387 assert(inst->NumInOperands() - 1 >= 2 &&
388 "OpSpecConstantOp CompositeExtract requires at least two non-type "
389 "non-opcode operands.");
390 assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID &&
391 "The vector operand must have a SPV_OPERAND_TYPE_ID type");
392 assert(
393 inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
394 "The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type");
395
396 // Note that for OpSpecConstantOp, the second in-operand is the first id
397 // operand. The first in-operand is the spec opcode.
398 analysis::Constant* first_operand_const =
399 FindRecordedConst(inst->GetSingleWordInOperand(1));
400 if (!first_operand_const) return nullptr;
401
402 const analysis::Constant* current_const = first_operand_const;
403 for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
404 uint32_t literal = inst->GetSingleWordInOperand(i);
405 if (const analysis::CompositeConstant* composite_const =
406 current_const->AsCompositeConstant()) {
407 // Case 1: current constant is a non-null composite type constant.
408 assert(literal < composite_const->GetComponents().size() &&
409 "Literal index out of bound of the composite constant");
410 current_const = composite_const->GetComponents().at(literal);
411 } else if (current_const->AsNullConstant()) {
412 // Case 2: current constant is a constant created with OpConstantNull.
413 // Because components of a NullConstant are always NullConstants, we can
414 // return early with a NullConstant in the result type.
415 return BuildInstructionAndAddToModule(CreateConst(GetType(inst), {}),
416 pos);
417 } else {
418 // Dereferencing a non-composite constant. Invalid case.
419 return nullptr;
420 }
421 }
422 return BuildInstructionAndAddToModule(current_const->Copy(), pos);
423 }
424
DoVectorShuffle(ir::Module::inst_iterator * pos)425 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
426 ir::Module::inst_iterator* pos) {
427 ir::Instruction* inst = &**pos;
428 analysis::Vector* result_vec_type = GetType(inst)->AsVector();
429 assert(inst->NumInOperands() - 1 > 2 &&
430 "OpSpecConstantOp DoVectorShuffle instruction requires more than 2 "
431 "operands (2 vector ids and at least one literal operand");
432 assert(result_vec_type &&
433 "The result of VectorShuffle must be of type vector");
434
435 // A temporary null constants that can be used as the components fo the
436 // result vector. This is needed when any one of the vector operands are null
437 // constant.
438 std::unique_ptr<analysis::Constant> null_component_constants;
439
440 // Get a concatenated vector of scalar constants. The vector should be built
441 // with the components from the first and the second operand of VectorShuffle.
442 std::vector<const analysis::Constant*> concatenated_components;
443 // Note that for OpSpecConstantOp, the second in-operand is the first id
444 // operand. The first in-operand is the spec opcode.
445 for (uint32_t i : {1, 2}) {
446 assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_ID &&
447 "The vector operand must have a SPV_OPERAND_TYPE_ID type");
448 uint32_t operand_id = inst->GetSingleWordInOperand(i);
449 analysis::Constant* operand_const = FindRecordedConst(operand_id);
450 if (!operand_const) return nullptr;
451 const analysis::Type* operand_type = operand_const->type();
452 assert(operand_type->AsVector() &&
453 "The first two operand of VectorShuffle must be of vector type");
454 if (analysis::VectorConstant* vec_const =
455 operand_const->AsVectorConstant()) {
456 // case 1: current operand is a non-null vector constant.
457 concatenated_components.insert(concatenated_components.end(),
458 vec_const->GetComponents().begin(),
459 vec_const->GetComponents().end());
460 } else if (operand_const->AsNullConstant()) {
461 // case 2: current operand is a null vector constant. Create a temporary
462 // null scalar constant as the component.
463 if (!null_component_constants) {
464 const analysis::Type* component_type =
465 operand_type->AsVector()->element_type();
466 null_component_constants = CreateConst(component_type, {});
467 }
468 // Append the null scalar consts to the concatenated components
469 // vector.
470 concatenated_components.insert(concatenated_components.end(),
471 operand_type->AsVector()->element_count(),
472 null_component_constants.get());
473 } else {
474 // no other valid cases
475 return nullptr;
476 }
477 }
478 // Create null component constants if there are any. The component constants
479 // must be added to the module before the dependee composite constants to
480 // satisfy SSA def-use dominance.
481 if (null_component_constants) {
482 BuildInstructionAndAddToModule(std::move(null_component_constants), pos);
483 }
484 // Create the new vector constant with the selected components.
485 std::vector<const analysis::Constant*> selected_components;
486 for (uint32_t i = 3; i < inst->NumInOperands(); i++) {
487 assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
488 "The literal operand must of type SPV_OPERAND_TYPE_LITERAL_INTEGER");
489 uint32_t literal = inst->GetSingleWordInOperand(i);
490 assert(literal < concatenated_components.size() &&
491 "Literal index out of bound of the concatenated vector");
492 selected_components.push_back(concatenated_components[literal]);
493 }
494 auto new_vec_const = MakeUnique<analysis::VectorConstant>(
495 result_vec_type, selected_components);
496 return BuildInstructionAndAddToModule(std::move(new_vec_const), pos);
497 }
498
499 namespace {
500 // A helper function to check the type for component wise operations. Returns
501 // true if the type:
502 // 1) is bool type;
503 // 2) is 32-bit int type;
504 // 3) is vector of bool type;
505 // 4) is vector of 32-bit integer type.
506 // Otherwise returns false.
IsValidTypeForComponentWiseOperation(const analysis::Type * type)507 bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
508 if (type->AsBool()) {
509 return true;
510 } else if (auto* it = type->AsInteger()) {
511 if (it->width() == 32) return true;
512 } else if (auto* vt = type->AsVector()) {
513 if (vt->element_type()->AsBool())
514 return true;
515 else if (auto* vit = vt->element_type()->AsInteger()) {
516 if (vit->width() == 32) return true;
517 }
518 }
519 return false;
520 }
521 }
522
DoComponentWiseOperation(ir::Module::inst_iterator * pos)523 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
524 ir::Module::inst_iterator* pos) {
525 const ir::Instruction* inst = &**pos;
526 const analysis::Type* result_type = GetType(inst);
527 SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
528 // Check and collect operands.
529 std::vector<analysis::Constant*> operands;
530
531 if (!std::all_of(inst->cbegin(), inst->cend(),
532 [&operands, this](const ir::Operand& o) {
533 // skip the operands that is not an id.
534 if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID)
535 return true;
536 uint32_t id = o.words.front();
537 if (analysis::Constant* c = FindRecordedConst(id)) {
538 if (IsValidTypeForComponentWiseOperation(c->type())) {
539 operands.push_back(c);
540 return true;
541 }
542 }
543 return false;
544 }))
545 return nullptr;
546
547 if (result_type->AsInteger() || result_type->AsBool()) {
548 // Scalar operation
549 uint32_t result_val = OperateScalars(spec_opcode, operands);
550 auto result_const = CreateConst(result_type, {result_val});
551 return BuildInstructionAndAddToModule(std::move(result_const), pos);
552 } else if (result_type->AsVector()) {
553 // Vector operation
554 const analysis::Type* element_type =
555 result_type->AsVector()->element_type();
556 uint32_t num_dims = result_type->AsVector()->element_count();
557 std::vector<uint32_t> result_vec =
558 OperateVectors(spec_opcode, num_dims, operands);
559 std::vector<const analysis::Constant*> result_vector_components;
560 for (uint32_t r : result_vec) {
561 if (auto rc = CreateConst(element_type, {r})) {
562 result_vector_components.push_back(rc.get());
563 if (!BuildInstructionAndAddToModule(std::move(rc), pos)) {
564 assert(false &&
565 "Failed to build and insert constant declaring instruction "
566 "for the given vector component constant");
567 }
568 } else {
569 assert(false && "Failed to create constants with 32-bit word");
570 }
571 }
572 auto new_vec_const = MakeUnique<analysis::VectorConstant>(
573 result_type->AsVector(), result_vector_components);
574 return BuildInstructionAndAddToModule(std::move(new_vec_const), pos);
575 } else {
576 // Cannot process invalid component wise operation. The result of component
577 // wise operation must be of integer or bool scalar or vector of
578 // integer/bool type.
579 return nullptr;
580 }
581 }
582
583 ir::Instruction*
BuildInstructionAndAddToModule(std::unique_ptr<analysis::Constant> c,ir::Module::inst_iterator * pos)584 FoldSpecConstantOpAndCompositePass::BuildInstructionAndAddToModule(
585 std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos) {
586 analysis::Constant* new_const = c.get();
587 uint32_t new_id = ++max_id_;
588 module_->SetIdBound(new_id + 1);
589 const_val_to_id_[new_const] = new_id;
590 id_to_const_val_[new_id] = std::move(c);
591 auto new_inst = CreateInstruction(new_id, new_const);
592 if (!new_inst) return nullptr;
593 auto* new_inst_ptr = new_inst.get();
594 *pos = pos->InsertBefore(std::move(new_inst));
595 (*pos)++;
596 def_use_mgr_->AnalyzeInstDefUse(new_inst_ptr);
597 return new_inst_ptr;
598 }
599
600 std::unique_ptr<analysis::Constant>
CreateConstFromInst(ir::Instruction * inst)601 FoldSpecConstantOpAndCompositePass::CreateConstFromInst(ir::Instruction* inst) {
602 std::vector<uint32_t> literal_words_or_ids;
603 std::unique_ptr<analysis::Constant> new_const;
604 // Collect the constant defining literals or component ids.
605 for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
606 literal_words_or_ids.insert(literal_words_or_ids.end(),
607 inst->GetInOperand(i).words.begin(),
608 inst->GetInOperand(i).words.end());
609 }
610 switch (inst->opcode()) {
611 // OpConstant{True|Flase} have the value embedded in the opcode. So they
612 // are not handled by the for-loop above. Here we add the value explicitly.
613 case SpvOp::SpvOpConstantTrue:
614 literal_words_or_ids.push_back(true);
615 break;
616 case SpvOp::SpvOpConstantFalse:
617 literal_words_or_ids.push_back(false);
618 break;
619 case SpvOp::SpvOpConstantNull:
620 case SpvOp::SpvOpConstant:
621 case SpvOp::SpvOpConstantComposite:
622 case SpvOp::SpvOpSpecConstantComposite:
623 break;
624 default:
625 return nullptr;
626 }
627 return CreateConst(GetType(inst), literal_words_or_ids);
628 }
629
FindRecordedConst(uint32_t id)630 analysis::Constant* FoldSpecConstantOpAndCompositePass::FindRecordedConst(
631 uint32_t id) {
632 auto iter = id_to_const_val_.find(id);
633 if (iter == id_to_const_val_.end()) {
634 return nullptr;
635 } else {
636 return iter->second.get();
637 }
638 }
639
FindRecordedConst(const analysis::Constant * c)640 uint32_t FoldSpecConstantOpAndCompositePass::FindRecordedConst(
641 const analysis::Constant* c) {
642 auto iter = const_val_to_id_.find(c);
643 if (iter == const_val_to_id_.end()) {
644 return 0;
645 } else {
646 return iter->second;
647 }
648 }
649
650 std::vector<const analysis::Constant*>
GetConstsFromIds(const std::vector<uint32_t> & ids)651 FoldSpecConstantOpAndCompositePass::GetConstsFromIds(
652 const std::vector<uint32_t>& ids) {
653 std::vector<const analysis::Constant*> constants;
654 for (uint32_t id : ids) {
655 if (analysis::Constant* c = FindRecordedConst(id)) {
656 constants.push_back(c);
657 } else {
658 return {};
659 }
660 }
661 return constants;
662 }
663
664 std::unique_ptr<analysis::Constant>
CreateConst(const analysis::Type * type,const std::vector<uint32_t> & literal_words_or_ids)665 FoldSpecConstantOpAndCompositePass::CreateConst(
666 const analysis::Type* type,
667 const std::vector<uint32_t>& literal_words_or_ids) {
668 std::unique_ptr<analysis::Constant> new_const;
669 if (literal_words_or_ids.size() == 0) {
670 // Constant declared with OpConstantNull
671 return MakeUnique<analysis::NullConstant>(type);
672 } else if (auto* bt = type->AsBool()) {
673 assert(literal_words_or_ids.size() == 1 &&
674 "Bool constant should be declared with one operand");
675 return MakeUnique<analysis::BoolConstant>(bt, literal_words_or_ids.front());
676 } else if (auto* it = type->AsInteger()) {
677 return MakeUnique<analysis::IntConstant>(it, literal_words_or_ids);
678 } else if (auto* ft = type->AsFloat()) {
679 return MakeUnique<analysis::FloatConstant>(ft, literal_words_or_ids);
680 } else if (auto* vt = type->AsVector()) {
681 auto components = GetConstsFromIds(literal_words_or_ids);
682 if (components.empty()) return nullptr;
683 // All components of VectorConstant must be of type Bool, Integer or Float.
684 if (!std::all_of(components.begin(), components.end(),
685 [](const analysis::Constant* c) {
686 if (c->type()->AsBool() || c->type()->AsInteger() ||
687 c->type()->AsFloat()) {
688 return true;
689 } else {
690 return false;
691 }
692 }))
693 return nullptr;
694 // All components of VectorConstant must be in the same type.
695 const auto* component_type = components.front()->type();
696 if (!std::all_of(components.begin(), components.end(),
697 [&component_type](const analysis::Constant* c) {
698 if (c->type() == component_type) return true;
699 return false;
700 }))
701 return nullptr;
702 return MakeUnique<analysis::VectorConstant>(vt, components);
703 } else if (auto* st = type->AsStruct()) {
704 auto components = GetConstsFromIds(literal_words_or_ids);
705 if (components.empty()) return nullptr;
706 return MakeUnique<analysis::StructConstant>(st, components);
707 } else if (auto* at = type->AsArray()) {
708 auto components = GetConstsFromIds(literal_words_or_ids);
709 if (components.empty()) return nullptr;
710 return MakeUnique<analysis::ArrayConstant>(at, components);
711 } else {
712 return nullptr;
713 }
714 }
715
BuildOperandsFromIds(const std::vector<uint32_t> & ids)716 std::vector<ir::Operand> BuildOperandsFromIds(
717 const std::vector<uint32_t>& ids) {
718 std::vector<ir::Operand> operands;
719 for (uint32_t id : ids) {
720 operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
721 std::initializer_list<uint32_t>{id});
722 }
723 return operands;
724 }
725
726 std::unique_ptr<ir::Instruction>
CreateInstruction(uint32_t id,analysis::Constant * c)727 FoldSpecConstantOpAndCompositePass::CreateInstruction(uint32_t id,
728 analysis::Constant* c) {
729 if (c->AsNullConstant()) {
730 return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantNull,
731 type_mgr_->GetId(c->type()), id,
732 std::initializer_list<ir::Operand>{});
733 } else if (analysis::BoolConstant* bc = c->AsBoolConstant()) {
734 return MakeUnique<ir::Instruction>(
735 bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
736 type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{});
737 } else if (analysis::IntConstant* ic = c->AsIntConstant()) {
738 return MakeUnique<ir::Instruction>(
739 SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
740 std::initializer_list<ir::Operand>{ir::Operand(
741 spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
742 ic->words())});
743 } else if (analysis::FloatConstant* fc = c->AsFloatConstant()) {
744 return MakeUnique<ir::Instruction>(
745 SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
746 std::initializer_list<ir::Operand>{ir::Operand(
747 spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
748 fc->words())});
749 } else if (analysis::CompositeConstant* cc = c->AsCompositeConstant()) {
750 return CreateCompositeInstruction(id, cc);
751 } else {
752 return nullptr;
753 }
754 }
755
756 std::unique_ptr<ir::Instruction>
CreateCompositeInstruction(uint32_t result_id,analysis::CompositeConstant * cc)757 FoldSpecConstantOpAndCompositePass::CreateCompositeInstruction(
758 uint32_t result_id, analysis::CompositeConstant* cc) {
759 std::vector<ir::Operand> operands;
760 for (const analysis::Constant* component_const : cc->GetComponents()) {
761 uint32_t id = FindRecordedConst(component_const);
762 if (id == 0) {
763 // Cannot get the id of the component constant, while all components
764 // should have been added to the module prior to the composite constant.
765 // Cannot create OpConstantComposite instruction in this case.
766 return nullptr;
767 }
768 operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
769 std::initializer_list<uint32_t>{id});
770 }
771 return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantComposite,
772 type_mgr_->GetId(cc->type()), result_id,
773 std::move(operands));
774 }
775
776 } // namespace opt
777 } // namespace spvtools
778