• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "fold_spec_constant_op_and_composite_pass.h"
16 
17 #include <algorithm>
18 #include <initializer_list>
19 #include <tuple>
20 
21 #include "constants.h"
22 #include "make_unique.h"
23 
24 namespace spvtools {
25 namespace opt {
26 
27 namespace {
28 // Returns the single-word result from performing the given unary operation on
29 // the operand value which is passed in as a 32-bit word.
UnaryOperate(SpvOp opcode,uint32_t operand)30 uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
31   switch (opcode) {
32     // Arthimetics
33     case SpvOp::SpvOpSNegate:
34       return -static_cast<int32_t>(operand);
35     case SpvOp::SpvOpNot:
36       return ~operand;
37     case SpvOp::SpvOpLogicalNot:
38       return !static_cast<bool>(operand);
39     default:
40       assert(false &&
41              "Unsupported unary operation for OpSpecConstantOp instruction");
42       return 0u;
43   }
44 }
45 
46 // Returns the single-word result from performing the given binary operation on
47 // the operand values which are passed in as two 32-bit word.
BinaryOperate(SpvOp opcode,uint32_t a,uint32_t b)48 uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
49   switch (opcode) {
50     // Arthimetics
51     case SpvOp::SpvOpIAdd:
52       return a + b;
53     case SpvOp::SpvOpISub:
54       return a - b;
55     case SpvOp::SpvOpIMul:
56       return a * b;
57     case SpvOp::SpvOpUDiv:
58       assert(b != 0);
59       return a / b;
60     case SpvOp::SpvOpSDiv:
61       assert(b != 0u);
62       return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
63     case SpvOp::SpvOpSRem: {
64       // The sign of non-zero result comes from the first operand: a. This is
65       // guaranteed by C++11 rules for integer division operator. The division
66       // result is rounded toward zero, so the result of '%' has the sign of
67       // the first operand.
68       assert(b != 0u);
69       return static_cast<int32_t>(a) % static_cast<int32_t>(b);
70     }
71     case SpvOp::SpvOpSMod: {
72       // The sign of non-zero result comes from the second operand: b
73       assert(b != 0u);
74       int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
75       int32_t b_prim = static_cast<int32_t>(b);
76       return (rem + b_prim) % b_prim;
77     }
78     case SpvOp::SpvOpUMod:
79       assert(b != 0u);
80       return (a % b);
81 
82     // Shifting
83     case SpvOp::SpvOpShiftRightLogical: {
84       return a >> b;
85     }
86     case SpvOp::SpvOpShiftRightArithmetic:
87       return (static_cast<int32_t>(a)) >> b;
88     case SpvOp::SpvOpShiftLeftLogical:
89       return a << b;
90 
91     // Bitwise operations
92     case SpvOp::SpvOpBitwiseOr:
93       return a | b;
94     case SpvOp::SpvOpBitwiseAnd:
95       return a & b;
96     case SpvOp::SpvOpBitwiseXor:
97       return a ^ b;
98 
99     // Logical
100     case SpvOp::SpvOpLogicalEqual:
101       return (static_cast<bool>(a)) == (static_cast<bool>(b));
102     case SpvOp::SpvOpLogicalNotEqual:
103       return (static_cast<bool>(a)) != (static_cast<bool>(b));
104     case SpvOp::SpvOpLogicalOr:
105       return (static_cast<bool>(a)) || (static_cast<bool>(b));
106     case SpvOp::SpvOpLogicalAnd:
107       return (static_cast<bool>(a)) && (static_cast<bool>(b));
108 
109     // Comparison
110     case SpvOp::SpvOpIEqual:
111       return a == b;
112     case SpvOp::SpvOpINotEqual:
113       return a != b;
114     case SpvOp::SpvOpULessThan:
115       return a < b;
116     case SpvOp::SpvOpSLessThan:
117       return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
118     case SpvOp::SpvOpUGreaterThan:
119       return a > b;
120     case SpvOp::SpvOpSGreaterThan:
121       return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
122     case SpvOp::SpvOpULessThanEqual:
123       return a <= b;
124     case SpvOp::SpvOpSLessThanEqual:
125       return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
126     case SpvOp::SpvOpUGreaterThanEqual:
127       return a >= b;
128     case SpvOp::SpvOpSGreaterThanEqual:
129       return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
130     default:
131       assert(false &&
132              "Unsupported binary operation for OpSpecConstantOp instruction");
133       return 0u;
134   }
135 }
136 
137 // Returns the single-word result from performing the given ternary operation
138 // on the operand values which are passed in as three 32-bit word.
TernaryOperate(SpvOp opcode,uint32_t a,uint32_t b,uint32_t c)139 uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
140   switch (opcode) {
141     case SpvOp::SpvOpSelect:
142       return (static_cast<bool>(a)) ? b : c;
143     default:
144       assert(false &&
145              "Unsupported ternary operation for OpSpecConstantOp instruction");
146       return 0u;
147   }
148 }
149 
150 // Returns the single-word result from performing the given operation on the
151 // operand words. This only works with 32-bit operations and uses boolean
152 // convention that 0u is false, and anything else is boolean true.
153 // TODO(qining): Support operands other than 32-bit wide.
OperateWords(SpvOp opcode,const std::vector<uint32_t> & operand_words)154 uint32_t OperateWords(SpvOp opcode,
155                       const std::vector<uint32_t>& operand_words) {
156   switch (operand_words.size()) {
157     case 1:
158       return UnaryOperate(opcode, operand_words.front());
159     case 2:
160       return BinaryOperate(opcode, operand_words.front(), operand_words.back());
161     case 3:
162       return TernaryOperate(opcode, operand_words[0], operand_words[1],
163                             operand_words[2]);
164     default:
165       assert(false && "Invalid number of operands");
166       return 0;
167   }
168 }
169 
170 // Returns the result of performing an operation on scalar constant operands.
171 // This function extracts the operand values as 32 bit words and returns the
172 // result in 32 bit word. Scalar constants with longer than 32-bit width are
173 // not accepted in this function.
OperateScalars(SpvOp opcode,const std::vector<analysis::Constant * > & operands)174 uint32_t OperateScalars(SpvOp opcode,
175                         const std::vector<analysis::Constant*>& operands) {
176   std::vector<uint32_t> operand_values_in_raw_words;
177   for (analysis::Constant* operand : operands) {
178     if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
179       const auto& scalar_words = scalar->words();
180       assert(scalar_words.size() == 1 &&
181              "Scalar constants with longer than 32-bit width are not allowed "
182              "in OperateScalars()");
183       operand_values_in_raw_words.push_back(scalar_words.front());
184     } else if (operand->AsNullConstant()) {
185       operand_values_in_raw_words.push_back(0u);
186     } else {
187       assert(false &&
188              "OperateScalars() only accepts ScalarConst or NullConst type of "
189              "constant");
190     }
191   }
192   return OperateWords(opcode, operand_values_in_raw_words);
193 }
194 
195 // Returns the result of performing an operation over constant vectors. This
196 // function iterates through the given vector type constant operands and
197 // calculates the result for each element of the result vector to return.
198 // Vectors with longer than 32-bit scalar components are not accepted in this
199 // function.
OperateVectors(SpvOp opcode,uint32_t num_dims,const std::vector<analysis::Constant * > & operands)200 std::vector<uint32_t> OperateVectors(
201     SpvOp opcode, uint32_t num_dims,
202     const std::vector<analysis::Constant*>& operands) {
203   std::vector<uint32_t> result;
204   for (uint32_t d = 0; d < num_dims; d++) {
205     std::vector<uint32_t> operand_values_for_one_dimension;
206     for (analysis::Constant* operand : operands) {
207       if (analysis::VectorConstant* vector_operand =
208               operand->AsVectorConstant()) {
209         // Extract the raw value of the scalar component constants
210         // in 32-bit words here. The reason of not using OperateScalars() here
211         // is that we do not create temporary null constants as components
212         // when the vector operand is a NullConstant because Constant creation
213         // may need extra checks for the validity and that is not manageed in
214         // here.
215         if (const analysis::ScalarConstant* scalar_component =
216                 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
217           const auto& scalar_words = scalar_component->words();
218           assert(
219               scalar_words.size() == 1 &&
220               "Vector components with longer than 32-bit width are not allowed "
221               "in OperateVectors()");
222           operand_values_for_one_dimension.push_back(scalar_words.front());
223         } else if (operand->AsNullConstant()) {
224           operand_values_for_one_dimension.push_back(0u);
225         } else {
226           assert(false &&
227                  "VectorConst should only has ScalarConst or NullConst as "
228                  "components");
229         }
230       } else if (operand->AsNullConstant()) {
231         operand_values_for_one_dimension.push_back(0u);
232       } else {
233         assert(false &&
234                "OperateVectors() only accepts VectorConst or NullConst type of "
235                "constant");
236       }
237     }
238     result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
239   }
240   return result;
241 }
242 }  // anonymous namespace
243 
FoldSpecConstantOpAndCompositePass()244 FoldSpecConstantOpAndCompositePass::FoldSpecConstantOpAndCompositePass()
245     : max_id_(0),
246       module_(nullptr),
247       def_use_mgr_(nullptr),
248       type_mgr_(nullptr),
249       id_to_const_val_() {}
250 
Process(ir::Module * module)251 Pass::Status FoldSpecConstantOpAndCompositePass::Process(ir::Module* module) {
252   Initialize(module);
253   return ProcessImpl(module);
254 }
255 
Initialize(ir::Module * module)256 void FoldSpecConstantOpAndCompositePass::Initialize(ir::Module* module) {
257   type_mgr_.reset(new analysis::TypeManager(consumer(), *module));
258   def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module));
259   for (const auto& id_def : def_use_mgr_->id_to_defs()) {
260     max_id_ = std::max(max_id_, id_def.first);
261   }
262   module_ = module;
263 };
264 
ProcessImpl(ir::Module * module)265 Pass::Status FoldSpecConstantOpAndCompositePass::ProcessImpl(
266     ir::Module* module) {
267   bool modified = false;
268   // Traverse through all the constant defining instructions. For Normal
269   // Constants whose values are determined and do not depend on OpUndef
270   // instructions, records their values in two internal maps: id_to_const_val_
271   // and const_val_to_id_ so that we can use them to infer the value of Spec
272   // Constants later.
273   // For Spec Constants defined with OpSpecConstantComposite instructions, if
274   // all of their components are Normal Constants, they will be turned into
275   // Normal Constants too. For Spec Constants defined with OpSpecConstantOp
276   // instructions, we check if they only depends on Normal Constants and fold
277   // them when possible. The two maps for Normal Constants: id_to_const_val_
278   // and const_val_to_id_ will be updated along the traversal so that the new
279   // Normal Constants generated from folding can be used to fold following Spec
280   // Constants.
281   // This algorithm depends on the SSA property of SPIR-V when
282   // defining constants. The dependent constants must be defined before the
283   // dependee constants. So a dependent Spec Constant must be defined and
284   // will be processed before its dependee Spec Constant. When we encounter
285   // the dependee Spec Constants, all its dependent constants must have been
286   // processed and all its dependent Spec Constants should have been folded if
287   // possible.
288   for (ir::Module::inst_iterator inst_iter = module->types_values_begin();
289        // Need to re-evaluate the end iterator since we may modify the list of
290        // instructions in this section of the module as the process goes.
291        inst_iter != module->types_values_end(); ++inst_iter) {
292     ir::Instruction* inst = &*inst_iter;
293     // Collect constant values of normal constants and process the
294     // OpSpecConstantOp and OpSpecConstantComposite instructions if possible.
295     // The constant values will be stored in analysis::Constant instances.
296     // OpConstantSampler instruction is not collected here because it cannot be
297     // used in OpSpecConstant{Composite|Op} instructions.
298     // TODO(qining): If the constant or its type has decoration, we may need
299     // to skip it.
300     if (GetType(inst) && !GetType(inst)->decoration_empty()) continue;
301     switch (SpvOp opcode = inst->opcode()) {
302       // Records the values of Normal Constants.
303       case SpvOp::SpvOpConstantTrue:
304       case SpvOp::SpvOpConstantFalse:
305       case SpvOp::SpvOpConstant:
306       case SpvOp::SpvOpConstantNull:
307       case SpvOp::SpvOpConstantComposite:
308       case SpvOp::SpvOpSpecConstantComposite: {
309         // A Constant instance will be created if the given instruction is a
310         // Normal Constant whose value(s) are fixed. Note that for a composite
311         // Spec Constant defined with OpSpecConstantComposite instruction, if
312         // all of its components are Normal Constants already, the Spec
313         // Constant will be turned in to a Normal Constant. In that case, a
314         // Constant instance should also be created successfully and recorded
315         // in the id_to_const_val_ and const_val_to_id_ mapps.
316         if (auto const_value = CreateConstFromInst(inst)) {
317           // Need to replace the OpSpecConstantComposite instruction with a
318           // corresponding OpConstantComposite instruction.
319           if (opcode == SpvOp::SpvOpSpecConstantComposite) {
320             inst->SetOpcode(SpvOp::SpvOpConstantComposite);
321             modified = true;
322           }
323           const_val_to_id_[const_value.get()] = inst->result_id();
324           id_to_const_val_[inst->result_id()] = std::move(const_value);
325         }
326         break;
327       }
328       // For a Spec Constants defined with OpSpecConstantOp instruction, check
329       // if it only depends on Normal Constants. If so, the Spec Constant will
330       // be folded. The original Spec Constant defining instruction will be
331       // replaced by Normal Constant defining instructions, and the new Normal
332       // Constants will be added to id_to_const_val_ and const_val_to_id_ so
333       // that we can use the new Normal Constants when folding following Spec
334       // Constants.
335       case SpvOp::SpvOpSpecConstantOp:
336         modified |= ProcessOpSpecConstantOp(&inst_iter);
337         break;
338       default:
339         break;
340     }
341   }
342   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
343 }
344 
ProcessOpSpecConstantOp(ir::Module::inst_iterator * pos)345 bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
346     ir::Module::inst_iterator* pos) {
347   ir::Instruction* inst = &**pos;
348   ir::Instruction* folded_inst = nullptr;
349   assert(inst->GetInOperand(0).type ==
350              SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER &&
351          "The first in-operand of OpSpecContantOp instruction must be of "
352          "SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type");
353 
354   switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) {
355     case SpvOp::SpvOpCompositeExtract:
356       folded_inst = DoCompositeExtract(pos);
357       break;
358     case SpvOp::SpvOpVectorShuffle:
359       folded_inst = DoVectorShuffle(pos);
360       break;
361 
362     case SpvOp::SpvOpCompositeInsert:
363       // Current Glslang does not generate code with OpSpecConstantOp
364       // CompositeInsert instruction, so this is not implmented so far.
365       // TODO(qining): Implement CompositeInsert case.
366       return false;
367 
368     default:
369       // Component-wise operations.
370       folded_inst = DoComponentWiseOperation(pos);
371       break;
372   }
373   if (!folded_inst) return false;
374 
375   // Replace the original constant with the new folded constant, kill the
376   // original constant.
377   uint32_t new_id = folded_inst->result_id();
378   uint32_t old_id = inst->result_id();
379   def_use_mgr_->ReplaceAllUsesWith(old_id, new_id);
380   def_use_mgr_->KillDef(old_id);
381   return true;
382 }
383 
DoCompositeExtract(ir::Module::inst_iterator * pos)384 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
385     ir::Module::inst_iterator* pos) {
386   ir::Instruction* inst = &**pos;
387   assert(inst->NumInOperands() - 1 >= 2 &&
388          "OpSpecConstantOp CompositeExtract requires at least two non-type "
389          "non-opcode operands.");
390   assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID &&
391          "The vector operand must have a SPV_OPERAND_TYPE_ID type");
392   assert(
393       inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
394       "The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type");
395 
396   // Note that for OpSpecConstantOp, the second in-operand is the first id
397   // operand. The first in-operand is the spec opcode.
398   analysis::Constant* first_operand_const =
399       FindRecordedConst(inst->GetSingleWordInOperand(1));
400   if (!first_operand_const) return nullptr;
401 
402   const analysis::Constant* current_const = first_operand_const;
403   for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
404     uint32_t literal = inst->GetSingleWordInOperand(i);
405     if (const analysis::CompositeConstant* composite_const =
406             current_const->AsCompositeConstant()) {
407       // Case 1: current constant is a non-null composite type constant.
408       assert(literal < composite_const->GetComponents().size() &&
409              "Literal index out of bound of the composite constant");
410       current_const = composite_const->GetComponents().at(literal);
411     } else if (current_const->AsNullConstant()) {
412       // Case 2: current constant is a constant created with OpConstantNull.
413       // Because components of a NullConstant are always NullConstants, we can
414       // return early with a NullConstant in the result type.
415       return BuildInstructionAndAddToModule(CreateConst(GetType(inst), {}),
416                                             pos);
417     } else {
418       // Dereferencing a non-composite constant. Invalid case.
419       return nullptr;
420     }
421   }
422   return BuildInstructionAndAddToModule(current_const->Copy(), pos);
423 }
424 
DoVectorShuffle(ir::Module::inst_iterator * pos)425 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
426     ir::Module::inst_iterator* pos) {
427   ir::Instruction* inst = &**pos;
428   analysis::Vector* result_vec_type = GetType(inst)->AsVector();
429   assert(inst->NumInOperands() - 1 > 2 &&
430          "OpSpecConstantOp DoVectorShuffle instruction requires more than 2 "
431          "operands (2 vector ids and at least one literal operand");
432   assert(result_vec_type &&
433          "The result of VectorShuffle must be of type vector");
434 
435   // A temporary null constants that can be used as the components fo the
436   // result vector. This is needed when any one of the vector operands are null
437   // constant.
438   std::unique_ptr<analysis::Constant> null_component_constants;
439 
440   // Get a concatenated vector of scalar constants. The vector should be built
441   // with the components from the first and the second operand of VectorShuffle.
442   std::vector<const analysis::Constant*> concatenated_components;
443   // Note that for OpSpecConstantOp, the second in-operand is the first id
444   // operand. The first in-operand is the spec opcode.
445   for (uint32_t i : {1, 2}) {
446     assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_ID &&
447            "The vector operand must have a SPV_OPERAND_TYPE_ID type");
448     uint32_t operand_id = inst->GetSingleWordInOperand(i);
449     analysis::Constant* operand_const = FindRecordedConst(operand_id);
450     if (!operand_const) return nullptr;
451     const analysis::Type* operand_type = operand_const->type();
452     assert(operand_type->AsVector() &&
453            "The first two operand of VectorShuffle must be of vector type");
454     if (analysis::VectorConstant* vec_const =
455             operand_const->AsVectorConstant()) {
456       // case 1: current operand is a non-null vector constant.
457       concatenated_components.insert(concatenated_components.end(),
458                                      vec_const->GetComponents().begin(),
459                                      vec_const->GetComponents().end());
460     } else if (operand_const->AsNullConstant()) {
461       // case 2: current operand is a null vector constant. Create a temporary
462       // null scalar constant as the component.
463       if (!null_component_constants) {
464         const analysis::Type* component_type =
465             operand_type->AsVector()->element_type();
466         null_component_constants = CreateConst(component_type, {});
467       }
468       // Append the null scalar consts to the concatenated components
469       // vector.
470       concatenated_components.insert(concatenated_components.end(),
471                                      operand_type->AsVector()->element_count(),
472                                      null_component_constants.get());
473     } else {
474       // no other valid cases
475       return nullptr;
476     }
477   }
478   // Create null component constants if there are any. The component constants
479   // must be added to the module before the dependee composite constants to
480   // satisfy SSA def-use dominance.
481   if (null_component_constants) {
482     BuildInstructionAndAddToModule(std::move(null_component_constants), pos);
483   }
484   // Create the new vector constant with the selected components.
485   std::vector<const analysis::Constant*> selected_components;
486   for (uint32_t i = 3; i < inst->NumInOperands(); i++) {
487     assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
488            "The literal operand must of type SPV_OPERAND_TYPE_LITERAL_INTEGER");
489     uint32_t literal = inst->GetSingleWordInOperand(i);
490     assert(literal < concatenated_components.size() &&
491            "Literal index out of bound of the concatenated vector");
492     selected_components.push_back(concatenated_components[literal]);
493   }
494   auto new_vec_const = MakeUnique<analysis::VectorConstant>(
495       result_vec_type, selected_components);
496   return BuildInstructionAndAddToModule(std::move(new_vec_const), pos);
497 }
498 
499 namespace {
500 // A helper function to check the type for component wise operations. Returns
501 // true if the type:
502 //  1) is bool type;
503 //  2) is 32-bit int type;
504 //  3) is vector of bool type;
505 //  4) is vector of 32-bit integer type.
506 // Otherwise returns false.
IsValidTypeForComponentWiseOperation(const analysis::Type * type)507 bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
508   if (type->AsBool()) {
509     return true;
510   } else if (auto* it = type->AsInteger()) {
511     if (it->width() == 32) return true;
512   } else if (auto* vt = type->AsVector()) {
513     if (vt->element_type()->AsBool())
514       return true;
515     else if (auto* vit = vt->element_type()->AsInteger()) {
516       if (vit->width() == 32) return true;
517     }
518   }
519   return false;
520 }
521 }
522 
DoComponentWiseOperation(ir::Module::inst_iterator * pos)523 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
524     ir::Module::inst_iterator* pos) {
525   const ir::Instruction* inst = &**pos;
526   const analysis::Type* result_type = GetType(inst);
527   SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
528   // Check and collect operands.
529   std::vector<analysis::Constant*> operands;
530 
531   if (!std::all_of(inst->cbegin(), inst->cend(),
532                    [&operands, this](const ir::Operand& o) {
533                      // skip the operands that is not an id.
534                      if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID)
535                        return true;
536                      uint32_t id = o.words.front();
537                      if (analysis::Constant* c = FindRecordedConst(id)) {
538                        if (IsValidTypeForComponentWiseOperation(c->type())) {
539                          operands.push_back(c);
540                          return true;
541                        }
542                      }
543                      return false;
544                    }))
545     return nullptr;
546 
547   if (result_type->AsInteger() || result_type->AsBool()) {
548     // Scalar operation
549     uint32_t result_val = OperateScalars(spec_opcode, operands);
550     auto result_const = CreateConst(result_type, {result_val});
551     return BuildInstructionAndAddToModule(std::move(result_const), pos);
552   } else if (result_type->AsVector()) {
553     // Vector operation
554     const analysis::Type* element_type =
555         result_type->AsVector()->element_type();
556     uint32_t num_dims = result_type->AsVector()->element_count();
557     std::vector<uint32_t> result_vec =
558         OperateVectors(spec_opcode, num_dims, operands);
559     std::vector<const analysis::Constant*> result_vector_components;
560     for (uint32_t r : result_vec) {
561       if (auto rc = CreateConst(element_type, {r})) {
562         result_vector_components.push_back(rc.get());
563         if (!BuildInstructionAndAddToModule(std::move(rc), pos)) {
564           assert(false &&
565                  "Failed to build and insert constant declaring instruction "
566                  "for the given vector component constant");
567         }
568       } else {
569         assert(false && "Failed to create constants with 32-bit word");
570       }
571     }
572     auto new_vec_const = MakeUnique<analysis::VectorConstant>(
573         result_type->AsVector(), result_vector_components);
574     return BuildInstructionAndAddToModule(std::move(new_vec_const), pos);
575   } else {
576     // Cannot process invalid component wise operation. The result of component
577     // wise operation must be of integer or bool scalar or vector of
578     // integer/bool type.
579     return nullptr;
580   }
581 }
582 
583 ir::Instruction*
BuildInstructionAndAddToModule(std::unique_ptr<analysis::Constant> c,ir::Module::inst_iterator * pos)584 FoldSpecConstantOpAndCompositePass::BuildInstructionAndAddToModule(
585     std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos) {
586   analysis::Constant* new_const = c.get();
587   uint32_t new_id = ++max_id_;
588   module_->SetIdBound(new_id + 1);
589   const_val_to_id_[new_const] = new_id;
590   id_to_const_val_[new_id] = std::move(c);
591   auto new_inst = CreateInstruction(new_id, new_const);
592   if (!new_inst) return nullptr;
593   auto* new_inst_ptr = new_inst.get();
594   *pos = pos->InsertBefore(std::move(new_inst));
595   (*pos)++;
596   def_use_mgr_->AnalyzeInstDefUse(new_inst_ptr);
597   return new_inst_ptr;
598 }
599 
600 std::unique_ptr<analysis::Constant>
CreateConstFromInst(ir::Instruction * inst)601 FoldSpecConstantOpAndCompositePass::CreateConstFromInst(ir::Instruction* inst) {
602   std::vector<uint32_t> literal_words_or_ids;
603   std::unique_ptr<analysis::Constant> new_const;
604   // Collect the constant defining literals or component ids.
605   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
606     literal_words_or_ids.insert(literal_words_or_ids.end(),
607                                 inst->GetInOperand(i).words.begin(),
608                                 inst->GetInOperand(i).words.end());
609   }
610   switch (inst->opcode()) {
611     // OpConstant{True|Flase} have the value embedded in the opcode. So they
612     // are not handled by the for-loop above. Here we add the value explicitly.
613     case SpvOp::SpvOpConstantTrue:
614       literal_words_or_ids.push_back(true);
615       break;
616     case SpvOp::SpvOpConstantFalse:
617       literal_words_or_ids.push_back(false);
618       break;
619     case SpvOp::SpvOpConstantNull:
620     case SpvOp::SpvOpConstant:
621     case SpvOp::SpvOpConstantComposite:
622     case SpvOp::SpvOpSpecConstantComposite:
623       break;
624     default:
625       return nullptr;
626   }
627   return CreateConst(GetType(inst), literal_words_or_ids);
628 }
629 
FindRecordedConst(uint32_t id)630 analysis::Constant* FoldSpecConstantOpAndCompositePass::FindRecordedConst(
631     uint32_t id) {
632   auto iter = id_to_const_val_.find(id);
633   if (iter == id_to_const_val_.end()) {
634     return nullptr;
635   } else {
636     return iter->second.get();
637   }
638 }
639 
FindRecordedConst(const analysis::Constant * c)640 uint32_t FoldSpecConstantOpAndCompositePass::FindRecordedConst(
641     const analysis::Constant* c) {
642   auto iter = const_val_to_id_.find(c);
643   if (iter == const_val_to_id_.end()) {
644     return 0;
645   } else {
646     return iter->second;
647   }
648 }
649 
650 std::vector<const analysis::Constant*>
GetConstsFromIds(const std::vector<uint32_t> & ids)651 FoldSpecConstantOpAndCompositePass::GetConstsFromIds(
652     const std::vector<uint32_t>& ids) {
653   std::vector<const analysis::Constant*> constants;
654   for (uint32_t id : ids) {
655     if (analysis::Constant* c = FindRecordedConst(id)) {
656       constants.push_back(c);
657     } else {
658       return {};
659     }
660   }
661   return constants;
662 }
663 
664 std::unique_ptr<analysis::Constant>
CreateConst(const analysis::Type * type,const std::vector<uint32_t> & literal_words_or_ids)665 FoldSpecConstantOpAndCompositePass::CreateConst(
666     const analysis::Type* type,
667     const std::vector<uint32_t>& literal_words_or_ids) {
668   std::unique_ptr<analysis::Constant> new_const;
669   if (literal_words_or_ids.size() == 0) {
670     // Constant declared with OpConstantNull
671     return MakeUnique<analysis::NullConstant>(type);
672   } else if (auto* bt = type->AsBool()) {
673     assert(literal_words_or_ids.size() == 1 &&
674            "Bool constant should be declared with one operand");
675     return MakeUnique<analysis::BoolConstant>(bt, literal_words_or_ids.front());
676   } else if (auto* it = type->AsInteger()) {
677     return MakeUnique<analysis::IntConstant>(it, literal_words_or_ids);
678   } else if (auto* ft = type->AsFloat()) {
679     return MakeUnique<analysis::FloatConstant>(ft, literal_words_or_ids);
680   } else if (auto* vt = type->AsVector()) {
681     auto components = GetConstsFromIds(literal_words_or_ids);
682     if (components.empty()) return nullptr;
683     // All components of VectorConstant must be of type Bool, Integer or Float.
684     if (!std::all_of(components.begin(), components.end(),
685                      [](const analysis::Constant* c) {
686                        if (c->type()->AsBool() || c->type()->AsInteger() ||
687                            c->type()->AsFloat()) {
688                          return true;
689                        } else {
690                          return false;
691                        }
692                      }))
693       return nullptr;
694     // All components of VectorConstant must be in the same type.
695     const auto* component_type = components.front()->type();
696     if (!std::all_of(components.begin(), components.end(),
697                      [&component_type](const analysis::Constant* c) {
698                        if (c->type() == component_type) return true;
699                        return false;
700                      }))
701       return nullptr;
702     return MakeUnique<analysis::VectorConstant>(vt, components);
703   } else if (auto* st = type->AsStruct()) {
704     auto components = GetConstsFromIds(literal_words_or_ids);
705     if (components.empty()) return nullptr;
706     return MakeUnique<analysis::StructConstant>(st, components);
707   } else if (auto* at = type->AsArray()) {
708     auto components = GetConstsFromIds(literal_words_or_ids);
709     if (components.empty()) return nullptr;
710     return MakeUnique<analysis::ArrayConstant>(at, components);
711   } else {
712     return nullptr;
713   }
714 }
715 
BuildOperandsFromIds(const std::vector<uint32_t> & ids)716 std::vector<ir::Operand> BuildOperandsFromIds(
717     const std::vector<uint32_t>& ids) {
718   std::vector<ir::Operand> operands;
719   for (uint32_t id : ids) {
720     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
721                           std::initializer_list<uint32_t>{id});
722   }
723   return operands;
724 }
725 
726 std::unique_ptr<ir::Instruction>
CreateInstruction(uint32_t id,analysis::Constant * c)727 FoldSpecConstantOpAndCompositePass::CreateInstruction(uint32_t id,
728                                                       analysis::Constant* c) {
729   if (c->AsNullConstant()) {
730     return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantNull,
731                                        type_mgr_->GetId(c->type()), id,
732                                        std::initializer_list<ir::Operand>{});
733   } else if (analysis::BoolConstant* bc = c->AsBoolConstant()) {
734     return MakeUnique<ir::Instruction>(
735         bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
736         type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{});
737   } else if (analysis::IntConstant* ic = c->AsIntConstant()) {
738     return MakeUnique<ir::Instruction>(
739         SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
740         std::initializer_list<ir::Operand>{ir::Operand(
741             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
742             ic->words())});
743   } else if (analysis::FloatConstant* fc = c->AsFloatConstant()) {
744     return MakeUnique<ir::Instruction>(
745         SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
746         std::initializer_list<ir::Operand>{ir::Operand(
747             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
748             fc->words())});
749   } else if (analysis::CompositeConstant* cc = c->AsCompositeConstant()) {
750     return CreateCompositeInstruction(id, cc);
751   } else {
752     return nullptr;
753   }
754 }
755 
756 std::unique_ptr<ir::Instruction>
CreateCompositeInstruction(uint32_t result_id,analysis::CompositeConstant * cc)757 FoldSpecConstantOpAndCompositePass::CreateCompositeInstruction(
758     uint32_t result_id, analysis::CompositeConstant* cc) {
759   std::vector<ir::Operand> operands;
760   for (const analysis::Constant* component_const : cc->GetComponents()) {
761     uint32_t id = FindRecordedConst(component_const);
762     if (id == 0) {
763       // Cannot get the id of the component constant, while all components
764       // should have been added to the module prior to the composite constant.
765       // Cannot create OpConstantComposite instruction in this case.
766       return nullptr;
767     }
768     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
769                           std::initializer_list<uint32_t>{id});
770   }
771   return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantComposite,
772                                      type_mgr_->GetId(cc->type()), result_id,
773                                      std::move(operands));
774 }
775 
776 }  // namespace opt
777 }  // namespace spvtools
778