• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/fold.h"
16 
17 #include <cassert>
18 #include <cstdint>
19 #include <vector>
20 
21 #include "source/opt/const_folding_rules.h"
22 #include "source/opt/def_use_manager.h"
23 #include "source/opt/folding_rules.h"
24 #include "source/opt/ir_context.h"
25 
26 namespace spvtools {
27 namespace opt {
28 namespace {
29 
30 #ifndef INT32_MIN
31 #define INT32_MIN (-2147483648)
32 #endif
33 
34 #ifndef INT32_MAX
35 #define INT32_MAX 2147483647
36 #endif
37 
38 #ifndef UINT32_MAX
39 #define UINT32_MAX 0xffffffff /* 4294967295U */
40 #endif
41 
42 }  // namespace
43 
UnaryOperate(spv::Op opcode,uint32_t operand) const44 uint32_t InstructionFolder::UnaryOperate(spv::Op opcode,
45                                          uint32_t operand) const {
46   switch (opcode) {
47     // Arthimetics
48     case spv::Op::OpSNegate: {
49       int32_t s_operand = static_cast<int32_t>(operand);
50       if (s_operand == std::numeric_limits<int32_t>::min()) {
51         return s_operand;
52       }
53       return -s_operand;
54     }
55     case spv::Op::OpNot:
56       return ~operand;
57     case spv::Op::OpLogicalNot:
58       return !static_cast<bool>(operand);
59     case spv::Op::OpUConvert:
60       return operand;
61     case spv::Op::OpSConvert:
62       return operand;
63     default:
64       assert(false &&
65              "Unsupported unary operation for OpSpecConstantOp instruction");
66       return 0u;
67   }
68 }
69 
BinaryOperate(spv::Op opcode,uint32_t a,uint32_t b) const70 uint32_t InstructionFolder::BinaryOperate(spv::Op opcode, uint32_t a,
71                                           uint32_t b) const {
72   switch (opcode) {
73     // Shifting
74     case spv::Op::OpShiftRightLogical:
75       if (b >= 32) {
76         // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.
77         // When |b| == 32, doing the shift in C++ in undefined, but the result
78         // will be 0, so just return that value.
79         return 0;
80       }
81       return a >> b;
82     case spv::Op::OpShiftRightArithmetic:
83       if (b > 32) {
84         // This is undefined behaviour.  Choose 0 for consistency.
85         return 0;
86       }
87       if (b == 32) {
88         // Doing the shift in C++ is undefined, but the result is defined in the
89         // spir-v spec.  Find that value another way.
90         if (static_cast<int32_t>(a) >= 0) {
91           return 0;
92         } else {
93           return static_cast<uint32_t>(-1);
94         }
95       }
96       return (static_cast<int32_t>(a)) >> b;
97     case spv::Op::OpShiftLeftLogical:
98       if (b >= 32) {
99         // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.
100         // When |b| == 32, doing the shift in C++ in undefined, but the result
101         // will be 0, so just return that value.
102         return 0;
103       }
104       return a << b;
105 
106     // Bitwise operations
107     case spv::Op::OpBitwiseOr:
108       return a | b;
109     case spv::Op::OpBitwiseAnd:
110       return a & b;
111     case spv::Op::OpBitwiseXor:
112       return a ^ b;
113 
114     // Logical
115     case spv::Op::OpLogicalEqual:
116       return (static_cast<bool>(a)) == (static_cast<bool>(b));
117     case spv::Op::OpLogicalNotEqual:
118       return (static_cast<bool>(a)) != (static_cast<bool>(b));
119     case spv::Op::OpLogicalOr:
120       return (static_cast<bool>(a)) || (static_cast<bool>(b));
121     case spv::Op::OpLogicalAnd:
122       return (static_cast<bool>(a)) && (static_cast<bool>(b));
123 
124     // Comparison
125     case spv::Op::OpIEqual:
126       return a == b;
127     case spv::Op::OpINotEqual:
128       return a != b;
129     case spv::Op::OpULessThan:
130       return a < b;
131     case spv::Op::OpSLessThan:
132       return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
133     case spv::Op::OpUGreaterThan:
134       return a > b;
135     case spv::Op::OpSGreaterThan:
136       return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
137     case spv::Op::OpULessThanEqual:
138       return a <= b;
139     case spv::Op::OpSLessThanEqual:
140       return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
141     case spv::Op::OpUGreaterThanEqual:
142       return a >= b;
143     case spv::Op::OpSGreaterThanEqual:
144       return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
145     default:
146       assert(false &&
147              "Unsupported binary operation for OpSpecConstantOp instruction");
148       return 0u;
149   }
150 }
151 
TernaryOperate(spv::Op opcode,uint32_t a,uint32_t b,uint32_t c) const152 uint32_t InstructionFolder::TernaryOperate(spv::Op opcode, uint32_t a,
153                                            uint32_t b, uint32_t c) const {
154   switch (opcode) {
155     case spv::Op::OpSelect:
156       return (static_cast<bool>(a)) ? b : c;
157     default:
158       assert(false &&
159              "Unsupported ternary operation for OpSpecConstantOp instruction");
160       return 0u;
161   }
162 }
163 
OperateWords(spv::Op opcode,const std::vector<uint32_t> & operand_words) const164 uint32_t InstructionFolder::OperateWords(
165     spv::Op opcode, const std::vector<uint32_t>& operand_words) const {
166   switch (operand_words.size()) {
167     case 1:
168       return UnaryOperate(opcode, operand_words.front());
169     case 2:
170       return BinaryOperate(opcode, operand_words.front(), operand_words.back());
171     case 3:
172       return TernaryOperate(opcode, operand_words[0], operand_words[1],
173                             operand_words[2]);
174     default:
175       assert(false && "Invalid number of operands");
176       return 0;
177   }
178 }
179 
FoldInstructionInternal(Instruction * inst) const180 bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
181   auto identity_map = [](uint32_t id) { return id; };
182   Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
183   if (folded_inst != nullptr) {
184     inst->SetOpcode(spv::Op::OpCopyObject);
185     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
186     return true;
187   }
188 
189   analysis::ConstantManager* const_manager = context_->get_constant_mgr();
190   std::vector<const analysis::Constant*> constants =
191       const_manager->GetOperandConstants(inst);
192 
193   for (const FoldingRule& rule :
194        GetFoldingRules().GetRulesForInstruction(inst)) {
195     if (rule(context_, inst, constants)) {
196       return true;
197     }
198   }
199   return false;
200 }
201 
202 // Returns the result of performing an operation on scalar constant operands.
203 // This function extracts the operand values as 32 bit words and returns the
204 // result in 32 bit word. Scalar constants with longer than 32-bit width are
205 // not accepted in this function.
FoldScalars(spv::Op opcode,const std::vector<const analysis::Constant * > & operands) const206 uint32_t InstructionFolder::FoldScalars(
207     spv::Op opcode,
208     const std::vector<const analysis::Constant*>& operands) const {
209   assert(IsFoldableOpcode(opcode) &&
210          "Unhandled instruction opcode in FoldScalars");
211   std::vector<uint32_t> operand_values_in_raw_words;
212   for (const auto& operand : operands) {
213     if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
214       const auto& scalar_words = scalar->words();
215       assert(scalar_words.size() == 1 &&
216              "Scalar constants with longer than 32-bit width are not allowed "
217              "in FoldScalars()");
218       operand_values_in_raw_words.push_back(scalar_words.front());
219     } else if (operand->AsNullConstant()) {
220       operand_values_in_raw_words.push_back(0u);
221     } else {
222       assert(false &&
223              "FoldScalars() only accepts ScalarConst or NullConst type of "
224              "constant");
225     }
226   }
227   return OperateWords(opcode, operand_values_in_raw_words);
228 }
229 
FoldBinaryIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const230 bool InstructionFolder::FoldBinaryIntegerOpToConstant(
231     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
232     uint32_t* result) const {
233   spv::Op opcode = inst->opcode();
234   analysis::ConstantManager* const_manger = context_->get_constant_mgr();
235 
236   uint32_t ids[2];
237   const analysis::IntConstant* constants[2];
238   for (uint32_t i = 0; i < 2; i++) {
239     const Operand* operand = &inst->GetInOperand(i);
240     if (operand->type != SPV_OPERAND_TYPE_ID) {
241       return false;
242     }
243     ids[i] = id_map(operand->words[0]);
244     const analysis::Constant* constant =
245         const_manger->FindDeclaredConstant(ids[i]);
246     constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
247   }
248 
249   switch (opcode) {
250     // Arthimetics
251     case spv::Op::OpIMul:
252       for (uint32_t i = 0; i < 2; i++) {
253         if (constants[i] != nullptr && constants[i]->IsZero()) {
254           *result = 0;
255           return true;
256         }
257       }
258       break;
259     case spv::Op::OpUDiv:
260     case spv::Op::OpSDiv:
261     case spv::Op::OpSRem:
262     case spv::Op::OpSMod:
263     case spv::Op::OpUMod:
264       // This changes undefined behaviour (ie divide by 0) into a 0.
265       for (uint32_t i = 0; i < 2; i++) {
266         if (constants[i] != nullptr && constants[i]->IsZero()) {
267           *result = 0;
268           return true;
269         }
270       }
271       break;
272 
273     // Shifting
274     case spv::Op::OpShiftRightLogical:
275     case spv::Op::OpShiftLeftLogical:
276       if (constants[1] != nullptr) {
277         // When shifting by a value larger than the size of the result, the
278         // result is undefined.  We are setting the undefined behaviour to a
279         // result of 0.  If the shift amount is the same as the size of the
280         // result, then the result is defined, and it 0.
281         uint32_t shift_amount = constants[1]->GetU32BitValue();
282         if (shift_amount >= 32) {
283           *result = 0;
284           return true;
285         }
286       }
287       break;
288 
289     // Bitwise operations
290     case spv::Op::OpBitwiseOr:
291       for (uint32_t i = 0; i < 2; i++) {
292         if (constants[i] != nullptr) {
293           // TODO: Change the mask against a value based on the bit width of the
294           // instruction result type.  This way we can handle say 16-bit values
295           // as well.
296           uint32_t mask = constants[i]->GetU32BitValue();
297           if (mask == 0xFFFFFFFF) {
298             *result = 0xFFFFFFFF;
299             return true;
300           }
301         }
302       }
303       break;
304     case spv::Op::OpBitwiseAnd:
305       for (uint32_t i = 0; i < 2; i++) {
306         if (constants[i] != nullptr) {
307           if (constants[i]->IsZero()) {
308             *result = 0;
309             return true;
310           }
311         }
312       }
313       break;
314 
315     // Comparison
316     case spv::Op::OpULessThan:
317       if (constants[0] != nullptr &&
318           constants[0]->GetU32BitValue() == UINT32_MAX) {
319         *result = false;
320         return true;
321       }
322       if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
323         *result = false;
324         return true;
325       }
326       break;
327     case spv::Op::OpSLessThan:
328       if (constants[0] != nullptr &&
329           constants[0]->GetS32BitValue() == INT32_MAX) {
330         *result = false;
331         return true;
332       }
333       if (constants[1] != nullptr &&
334           constants[1]->GetS32BitValue() == INT32_MIN) {
335         *result = false;
336         return true;
337       }
338       break;
339     case spv::Op::OpUGreaterThan:
340       if (constants[0] != nullptr && constants[0]->IsZero()) {
341         *result = false;
342         return true;
343       }
344       if (constants[1] != nullptr &&
345           constants[1]->GetU32BitValue() == UINT32_MAX) {
346         *result = false;
347         return true;
348       }
349       break;
350     case spv::Op::OpSGreaterThan:
351       if (constants[0] != nullptr &&
352           constants[0]->GetS32BitValue() == INT32_MIN) {
353         *result = false;
354         return true;
355       }
356       if (constants[1] != nullptr &&
357           constants[1]->GetS32BitValue() == INT32_MAX) {
358         *result = false;
359         return true;
360       }
361       break;
362     case spv::Op::OpULessThanEqual:
363       if (constants[0] != nullptr && constants[0]->IsZero()) {
364         *result = true;
365         return true;
366       }
367       if (constants[1] != nullptr &&
368           constants[1]->GetU32BitValue() == UINT32_MAX) {
369         *result = true;
370         return true;
371       }
372       break;
373     case spv::Op::OpSLessThanEqual:
374       if (constants[0] != nullptr &&
375           constants[0]->GetS32BitValue() == INT32_MIN) {
376         *result = true;
377         return true;
378       }
379       if (constants[1] != nullptr &&
380           constants[1]->GetS32BitValue() == INT32_MAX) {
381         *result = true;
382         return true;
383       }
384       break;
385     case spv::Op::OpUGreaterThanEqual:
386       if (constants[0] != nullptr &&
387           constants[0]->GetU32BitValue() == UINT32_MAX) {
388         *result = true;
389         return true;
390       }
391       if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
392         *result = true;
393         return true;
394       }
395       break;
396     case spv::Op::OpSGreaterThanEqual:
397       if (constants[0] != nullptr &&
398           constants[0]->GetS32BitValue() == INT32_MAX) {
399         *result = true;
400         return true;
401       }
402       if (constants[1] != nullptr &&
403           constants[1]->GetS32BitValue() == INT32_MIN) {
404         *result = true;
405         return true;
406       }
407       break;
408     default:
409       break;
410   }
411   return false;
412 }
413 
FoldBinaryBooleanOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const414 bool InstructionFolder::FoldBinaryBooleanOpToConstant(
415     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
416     uint32_t* result) const {
417   spv::Op opcode = inst->opcode();
418   analysis::ConstantManager* const_manger = context_->get_constant_mgr();
419 
420   uint32_t ids[2];
421   const analysis::BoolConstant* constants[2];
422   for (uint32_t i = 0; i < 2; i++) {
423     const Operand* operand = &inst->GetInOperand(i);
424     if (operand->type != SPV_OPERAND_TYPE_ID) {
425       return false;
426     }
427     ids[i] = id_map(operand->words[0]);
428     const analysis::Constant* constant =
429         const_manger->FindDeclaredConstant(ids[i]);
430     constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
431   }
432 
433   switch (opcode) {
434     // Logical
435     case spv::Op::OpLogicalOr:
436       for (uint32_t i = 0; i < 2; i++) {
437         if (constants[i] != nullptr) {
438           if (constants[i]->value()) {
439             *result = true;
440             return true;
441           }
442         }
443       }
444       break;
445     case spv::Op::OpLogicalAnd:
446       for (uint32_t i = 0; i < 2; i++) {
447         if (constants[i] != nullptr) {
448           if (!constants[i]->value()) {
449             *result = false;
450             return true;
451           }
452         }
453       }
454       break;
455 
456     default:
457       break;
458   }
459   return false;
460 }
461 
FoldIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const462 bool InstructionFolder::FoldIntegerOpToConstant(
463     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
464     uint32_t* result) const {
465   assert(IsFoldableOpcode(inst->opcode()) &&
466          "Unhandled instruction opcode in FoldScalars");
467   switch (inst->NumInOperands()) {
468     case 2:
469       return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
470              FoldBinaryBooleanOpToConstant(inst, id_map, result);
471     default:
472       return false;
473   }
474 }
475 
FoldVectors(spv::Op opcode,uint32_t num_dims,const std::vector<const analysis::Constant * > & operands) const476 std::vector<uint32_t> InstructionFolder::FoldVectors(
477     spv::Op opcode, uint32_t num_dims,
478     const std::vector<const analysis::Constant*>& operands) const {
479   assert(IsFoldableOpcode(opcode) &&
480          "Unhandled instruction opcode in FoldVectors");
481   std::vector<uint32_t> result;
482   for (uint32_t d = 0; d < num_dims; d++) {
483     std::vector<uint32_t> operand_values_for_one_dimension;
484     for (const auto& operand : operands) {
485       if (const analysis::VectorConstant* vector_operand =
486               operand->AsVectorConstant()) {
487         // Extract the raw value of the scalar component constants
488         // in 32-bit words here. The reason of not using FoldScalars() here
489         // is that we do not create temporary null constants as components
490         // when the vector operand is a NullConstant because Constant creation
491         // may need extra checks for the validity and that is not managed in
492         // here.
493         if (const analysis::ScalarConstant* scalar_component =
494                 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
495           const auto& scalar_words = scalar_component->words();
496           assert(
497               scalar_words.size() == 1 &&
498               "Vector components with longer than 32-bit width are not allowed "
499               "in FoldVectors()");
500           operand_values_for_one_dimension.push_back(scalar_words.front());
501         } else if (operand->AsNullConstant()) {
502           operand_values_for_one_dimension.push_back(0u);
503         } else {
504           assert(false &&
505                  "VectorConst should only has ScalarConst or NullConst as "
506                  "components");
507         }
508       } else if (operand->AsNullConstant()) {
509         operand_values_for_one_dimension.push_back(0u);
510       } else {
511         assert(false &&
512                "FoldVectors() only accepts VectorConst or NullConst type of "
513                "constant");
514       }
515     }
516     result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
517   }
518   return result;
519 }
520 
IsFoldableOpcode(spv::Op opcode) const521 bool InstructionFolder::IsFoldableOpcode(spv::Op opcode) const {
522   // NOTE: Extend to more opcodes as new cases are handled in the folder
523   // functions.
524   switch (opcode) {
525     case spv::Op::OpBitwiseAnd:
526     case spv::Op::OpBitwiseOr:
527     case spv::Op::OpBitwiseXor:
528     case spv::Op::OpIAdd:
529     case spv::Op::OpIEqual:
530     case spv::Op::OpIMul:
531     case spv::Op::OpINotEqual:
532     case spv::Op::OpISub:
533     case spv::Op::OpLogicalAnd:
534     case spv::Op::OpLogicalEqual:
535     case spv::Op::OpLogicalNot:
536     case spv::Op::OpLogicalNotEqual:
537     case spv::Op::OpLogicalOr:
538     case spv::Op::OpNot:
539     case spv::Op::OpSDiv:
540     case spv::Op::OpSelect:
541     case spv::Op::OpSGreaterThan:
542     case spv::Op::OpSGreaterThanEqual:
543     case spv::Op::OpShiftLeftLogical:
544     case spv::Op::OpShiftRightArithmetic:
545     case spv::Op::OpShiftRightLogical:
546     case spv::Op::OpSLessThan:
547     case spv::Op::OpSLessThanEqual:
548     case spv::Op::OpSMod:
549     case spv::Op::OpSNegate:
550     case spv::Op::OpSRem:
551     case spv::Op::OpSConvert:
552     case spv::Op::OpUConvert:
553     case spv::Op::OpUDiv:
554     case spv::Op::OpUGreaterThan:
555     case spv::Op::OpUGreaterThanEqual:
556     case spv::Op::OpULessThan:
557     case spv::Op::OpULessThanEqual:
558     case spv::Op::OpUMod:
559       return true;
560     default:
561       return false;
562   }
563 }
564 
IsFoldableConstant(const analysis::Constant * cst) const565 bool InstructionFolder::IsFoldableConstant(
566     const analysis::Constant* cst) const {
567   // Currently supported constants are 32-bit values or null constants.
568   if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
569     return scalar->words().size() == 1;
570   else
571     return cst->AsNullConstant() != nullptr;
572 }
573 
FoldInstructionToConstant(Instruction * inst,std::function<uint32_t (uint32_t)> id_map) const574 Instruction* InstructionFolder::FoldInstructionToConstant(
575     Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
576   analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
577 
578   if (!inst->IsFoldableByFoldScalar() && !inst->IsFoldableByFoldVector() &&
579       !GetConstantFoldingRules().HasFoldingRule(inst)) {
580     return nullptr;
581   }
582   // Collect the values of the constant parameters.
583   std::vector<const analysis::Constant*> constants;
584   bool missing_constants = false;
585   inst->ForEachInId([&constants, &missing_constants, const_mgr,
586                      &id_map](uint32_t* op_id) {
587     uint32_t id = id_map(*op_id);
588     const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
589     if (!const_op) {
590       constants.push_back(nullptr);
591       missing_constants = true;
592     } else {
593       constants.push_back(const_op);
594     }
595   });
596 
597   const analysis::Constant* folded_const = nullptr;
598   for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) {
599     folded_const = rule(context_, inst, constants);
600     if (folded_const != nullptr) {
601       Instruction* const_inst =
602           const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
603       if (const_inst == nullptr) {
604         return nullptr;
605       }
606       assert(const_inst->type_id() == inst->type_id());
607       // May be a new instruction that needs to be analysed.
608       context_->UpdateDefUse(const_inst);
609       return const_inst;
610     }
611   }
612 
613   bool successful = false;
614 
615   // If all parameters are constant, fold the instruction to a constant.
616   if (inst->IsFoldableByFoldScalar()) {
617     uint32_t result_val = 0;
618 
619     if (!missing_constants) {
620       result_val = FoldScalars(inst->opcode(), constants);
621       successful = true;
622     }
623 
624     if (!successful) {
625       successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
626     }
627 
628     if (successful) {
629       const analysis::Constant* result_const =
630           const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
631       Instruction* folded_inst =
632           const_mgr->GetDefiningInstruction(result_const, inst->type_id());
633       return folded_inst;
634     }
635   } else if (inst->IsFoldableByFoldVector()) {
636     std::vector<uint32_t> result_val;
637 
638     if (!missing_constants) {
639       if (Instruction* inst_type =
640               context_->get_def_use_mgr()->GetDef(inst->type_id())) {
641         result_val = FoldVectors(
642             inst->opcode(), inst_type->GetSingleWordInOperand(1), constants);
643         successful = true;
644       }
645     }
646 
647     if (successful) {
648       const analysis::Constant* result_const =
649           const_mgr->GetNumericVectorConstantWithWords(
650               const_mgr->GetType(inst)->AsVector(), result_val);
651       Instruction* folded_inst =
652           const_mgr->GetDefiningInstruction(result_const, inst->type_id());
653       return folded_inst;
654     }
655   }
656 
657   return nullptr;
658 }
659 
IsFoldableType(Instruction * type_inst) const660 bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
661   return IsFoldableScalarType(type_inst) || IsFoldableVectorType(type_inst);
662 }
663 
IsFoldableScalarType(Instruction * type_inst) const664 bool InstructionFolder::IsFoldableScalarType(Instruction* type_inst) const {
665   // Support 32-bit integers.
666   if (type_inst->opcode() == spv::Op::OpTypeInt) {
667     return type_inst->GetSingleWordInOperand(0) == 32;
668   }
669   // Support booleans.
670   if (type_inst->opcode() == spv::Op::OpTypeBool) {
671     return true;
672   }
673   // Nothing else yet.
674   return false;
675 }
676 
IsFoldableVectorType(Instruction * type_inst) const677 bool InstructionFolder::IsFoldableVectorType(Instruction* type_inst) const {
678   // Support vectors with foldable components
679   if (type_inst->opcode() == spv::Op::OpTypeVector) {
680     uint32_t component_type_id = type_inst->GetSingleWordInOperand(0);
681     Instruction* def_component_type =
682         context_->get_def_use_mgr()->GetDef(component_type_id);
683     return def_component_type != nullptr &&
684            IsFoldableScalarType(def_component_type);
685   }
686   // Nothing else yet.
687   return false;
688 }
689 
FoldInstruction(Instruction * inst) const690 bool InstructionFolder::FoldInstruction(Instruction* inst) const {
691   bool modified = false;
692   Instruction* folded_inst(inst);
693   while (folded_inst->opcode() != spv::Op::OpCopyObject &&
694          FoldInstructionInternal(&*folded_inst)) {
695     modified = true;
696   }
697   return modified;
698 }
699 
700 }  // namespace opt
701 }  // namespace spvtools
702