• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/fold.h"
16 
17 #include <cassert>
18 #include <cstdint>
19 #include <vector>
20 
21 #include "source/opt/const_folding_rules.h"
22 #include "source/opt/def_use_manager.h"
23 #include "source/opt/folding_rules.h"
24 #include "source/opt/ir_context.h"
25 
26 namespace spvtools {
27 namespace opt {
28 namespace {
29 
30 #ifndef INT32_MIN
31 #define INT32_MIN (-2147483648)
32 #endif
33 
34 #ifndef INT32_MAX
35 #define INT32_MAX 2147483647
36 #endif
37 
38 #ifndef UINT32_MAX
39 #define UINT32_MAX 0xffffffff /* 4294967295U */
40 #endif
41 
42 }  // namespace
43 
UnaryOperate(spv::Op opcode,uint32_t operand) const44 uint32_t InstructionFolder::UnaryOperate(spv::Op opcode,
45                                          uint32_t operand) const {
46   switch (opcode) {
47     // Arthimetics
48     case spv::Op::OpSNegate: {
49       int32_t s_operand = static_cast<int32_t>(operand);
50       if (s_operand == std::numeric_limits<int32_t>::min()) {
51         return s_operand;
52       }
53       return -s_operand;
54     }
55     case spv::Op::OpNot:
56       return ~operand;
57     case spv::Op::OpLogicalNot:
58       return !static_cast<bool>(operand);
59     case spv::Op::OpUConvert:
60       return operand;
61     case spv::Op::OpSConvert:
62       return operand;
63     default:
64       assert(false &&
65              "Unsupported unary operation for OpSpecConstantOp instruction");
66       return 0u;
67   }
68 }
69 
BinaryOperate(spv::Op opcode,uint32_t a,uint32_t b) const70 uint32_t InstructionFolder::BinaryOperate(spv::Op opcode, uint32_t a,
71                                           uint32_t b) const {
72   switch (opcode) {
73     // Arthimetics
74     case spv::Op::OpIAdd:
75       return a + b;
76     case spv::Op::OpISub:
77       return a - b;
78     case spv::Op::OpIMul:
79       return a * b;
80     case spv::Op::OpUDiv:
81       if (b != 0) {
82         return a / b;
83       } else {
84         // Dividing by 0 is undefined, so we will just pick 0.
85         return 0;
86       }
87     case spv::Op::OpSDiv:
88       if (b != 0u) {
89         return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
90       } else {
91         // Dividing by 0 is undefined, so we will just pick 0.
92         return 0;
93       }
94     case spv::Op::OpSRem: {
95       // The sign of non-zero result comes from the first operand: a. This is
96       // guaranteed by C++11 rules for integer division operator. The division
97       // result is rounded toward zero, so the result of '%' has the sign of
98       // the first operand.
99       if (b != 0u) {
100         return static_cast<int32_t>(a) % static_cast<int32_t>(b);
101       } else {
102         // Remainder when dividing with 0 is undefined, so we will just pick 0.
103         return 0;
104       }
105     }
106     case spv::Op::OpSMod: {
107       // The sign of non-zero result comes from the second operand: b
108       if (b != 0u) {
109         int32_t rem = BinaryOperate(spv::Op::OpSRem, a, b);
110         int32_t b_prim = static_cast<int32_t>(b);
111         return (rem + b_prim) % b_prim;
112       } else {
113         // Mod with 0 is undefined, so we will just pick 0.
114         return 0;
115       }
116     }
117     case spv::Op::OpUMod:
118       if (b != 0u) {
119         return (a % b);
120       } else {
121         // Mod with 0 is undefined, so we will just pick 0.
122         return 0;
123       }
124 
125     // Shifting
126     case spv::Op::OpShiftRightLogical:
127       if (b >= 32) {
128         // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.
129         // When |b| == 32, doing the shift in C++ in undefined, but the result
130         // will be 0, so just return that value.
131         return 0;
132       }
133       return a >> b;
134     case spv::Op::OpShiftRightArithmetic:
135       if (b > 32) {
136         // This is undefined behaviour.  Choose 0 for consistency.
137         return 0;
138       }
139       if (b == 32) {
140         // Doing the shift in C++ is undefined, but the result is defined in the
141         // spir-v spec.  Find that value another way.
142         if (static_cast<int32_t>(a) >= 0) {
143           return 0;
144         } else {
145           return static_cast<uint32_t>(-1);
146         }
147       }
148       return (static_cast<int32_t>(a)) >> b;
149     case spv::Op::OpShiftLeftLogical:
150       if (b >= 32) {
151         // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.
152         // When |b| == 32, doing the shift in C++ in undefined, but the result
153         // will be 0, so just return that value.
154         return 0;
155       }
156       return a << b;
157 
158     // Bitwise operations
159     case spv::Op::OpBitwiseOr:
160       return a | b;
161     case spv::Op::OpBitwiseAnd:
162       return a & b;
163     case spv::Op::OpBitwiseXor:
164       return a ^ b;
165 
166     // Logical
167     case spv::Op::OpLogicalEqual:
168       return (static_cast<bool>(a)) == (static_cast<bool>(b));
169     case spv::Op::OpLogicalNotEqual:
170       return (static_cast<bool>(a)) != (static_cast<bool>(b));
171     case spv::Op::OpLogicalOr:
172       return (static_cast<bool>(a)) || (static_cast<bool>(b));
173     case spv::Op::OpLogicalAnd:
174       return (static_cast<bool>(a)) && (static_cast<bool>(b));
175 
176     // Comparison
177     case spv::Op::OpIEqual:
178       return a == b;
179     case spv::Op::OpINotEqual:
180       return a != b;
181     case spv::Op::OpULessThan:
182       return a < b;
183     case spv::Op::OpSLessThan:
184       return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
185     case spv::Op::OpUGreaterThan:
186       return a > b;
187     case spv::Op::OpSGreaterThan:
188       return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
189     case spv::Op::OpULessThanEqual:
190       return a <= b;
191     case spv::Op::OpSLessThanEqual:
192       return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
193     case spv::Op::OpUGreaterThanEqual:
194       return a >= b;
195     case spv::Op::OpSGreaterThanEqual:
196       return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
197     default:
198       assert(false &&
199              "Unsupported binary operation for OpSpecConstantOp instruction");
200       return 0u;
201   }
202 }
203 
TernaryOperate(spv::Op opcode,uint32_t a,uint32_t b,uint32_t c) const204 uint32_t InstructionFolder::TernaryOperate(spv::Op opcode, uint32_t a,
205                                            uint32_t b, uint32_t c) const {
206   switch (opcode) {
207     case spv::Op::OpSelect:
208       return (static_cast<bool>(a)) ? b : c;
209     default:
210       assert(false &&
211              "Unsupported ternary operation for OpSpecConstantOp instruction");
212       return 0u;
213   }
214 }
215 
OperateWords(spv::Op opcode,const std::vector<uint32_t> & operand_words) const216 uint32_t InstructionFolder::OperateWords(
217     spv::Op opcode, const std::vector<uint32_t>& operand_words) const {
218   switch (operand_words.size()) {
219     case 1:
220       return UnaryOperate(opcode, operand_words.front());
221     case 2:
222       return BinaryOperate(opcode, operand_words.front(), operand_words.back());
223     case 3:
224       return TernaryOperate(opcode, operand_words[0], operand_words[1],
225                             operand_words[2]);
226     default:
227       assert(false && "Invalid number of operands");
228       return 0;
229   }
230 }
231 
FoldInstructionInternal(Instruction * inst) const232 bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
233   auto identity_map = [](uint32_t id) { return id; };
234   Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
235   if (folded_inst != nullptr) {
236     inst->SetOpcode(spv::Op::OpCopyObject);
237     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
238     return true;
239   }
240 
241   analysis::ConstantManager* const_manager = context_->get_constant_mgr();
242   std::vector<const analysis::Constant*> constants =
243       const_manager->GetOperandConstants(inst);
244 
245   for (const FoldingRule& rule :
246        GetFoldingRules().GetRulesForInstruction(inst)) {
247     if (rule(context_, inst, constants)) {
248       return true;
249     }
250   }
251   return false;
252 }
253 
254 // Returns the result of performing an operation on scalar constant operands.
255 // This function extracts the operand values as 32 bit words and returns the
256 // result in 32 bit word. Scalar constants with longer than 32-bit width are
257 // not accepted in this function.
FoldScalars(spv::Op opcode,const std::vector<const analysis::Constant * > & operands) const258 uint32_t InstructionFolder::FoldScalars(
259     spv::Op opcode,
260     const std::vector<const analysis::Constant*>& operands) const {
261   assert(IsFoldableOpcode(opcode) &&
262          "Unhandled instruction opcode in FoldScalars");
263   std::vector<uint32_t> operand_values_in_raw_words;
264   for (const auto& operand : operands) {
265     if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
266       const auto& scalar_words = scalar->words();
267       assert(scalar_words.size() == 1 &&
268              "Scalar constants with longer than 32-bit width are not allowed "
269              "in FoldScalars()");
270       operand_values_in_raw_words.push_back(scalar_words.front());
271     } else if (operand->AsNullConstant()) {
272       operand_values_in_raw_words.push_back(0u);
273     } else {
274       assert(false &&
275              "FoldScalars() only accepts ScalarConst or NullConst type of "
276              "constant");
277     }
278   }
279   return OperateWords(opcode, operand_values_in_raw_words);
280 }
281 
FoldBinaryIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const282 bool InstructionFolder::FoldBinaryIntegerOpToConstant(
283     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
284     uint32_t* result) const {
285   spv::Op opcode = inst->opcode();
286   analysis::ConstantManager* const_manger = context_->get_constant_mgr();
287 
288   uint32_t ids[2];
289   const analysis::IntConstant* constants[2];
290   for (uint32_t i = 0; i < 2; i++) {
291     const Operand* operand = &inst->GetInOperand(i);
292     if (operand->type != SPV_OPERAND_TYPE_ID) {
293       return false;
294     }
295     ids[i] = id_map(operand->words[0]);
296     const analysis::Constant* constant =
297         const_manger->FindDeclaredConstant(ids[i]);
298     constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
299   }
300 
301   switch (opcode) {
302     // Arthimetics
303     case spv::Op::OpIMul:
304       for (uint32_t i = 0; i < 2; i++) {
305         if (constants[i] != nullptr && constants[i]->IsZero()) {
306           *result = 0;
307           return true;
308         }
309       }
310       break;
311     case spv::Op::OpUDiv:
312     case spv::Op::OpSDiv:
313     case spv::Op::OpSRem:
314     case spv::Op::OpSMod:
315     case spv::Op::OpUMod:
316       // This changes undefined behaviour (ie divide by 0) into a 0.
317       for (uint32_t i = 0; i < 2; i++) {
318         if (constants[i] != nullptr && constants[i]->IsZero()) {
319           *result = 0;
320           return true;
321         }
322       }
323       break;
324 
325     // Shifting
326     case spv::Op::OpShiftRightLogical:
327     case spv::Op::OpShiftLeftLogical:
328       if (constants[1] != nullptr) {
329         // When shifting by a value larger than the size of the result, the
330         // result is undefined.  We are setting the undefined behaviour to a
331         // result of 0.  If the shift amount is the same as the size of the
332         // result, then the result is defined, and it 0.
333         uint32_t shift_amount = constants[1]->GetU32BitValue();
334         if (shift_amount >= 32) {
335           *result = 0;
336           return true;
337         }
338       }
339       break;
340 
341     // Bitwise operations
342     case spv::Op::OpBitwiseOr:
343       for (uint32_t i = 0; i < 2; i++) {
344         if (constants[i] != nullptr) {
345           // TODO: Change the mask against a value based on the bit width of the
346           // instruction result type.  This way we can handle say 16-bit values
347           // as well.
348           uint32_t mask = constants[i]->GetU32BitValue();
349           if (mask == 0xFFFFFFFF) {
350             *result = 0xFFFFFFFF;
351             return true;
352           }
353         }
354       }
355       break;
356     case spv::Op::OpBitwiseAnd:
357       for (uint32_t i = 0; i < 2; i++) {
358         if (constants[i] != nullptr) {
359           if (constants[i]->IsZero()) {
360             *result = 0;
361             return true;
362           }
363         }
364       }
365       break;
366 
367     // Comparison
368     case spv::Op::OpULessThan:
369       if (constants[0] != nullptr &&
370           constants[0]->GetU32BitValue() == UINT32_MAX) {
371         *result = false;
372         return true;
373       }
374       if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
375         *result = false;
376         return true;
377       }
378       break;
379     case spv::Op::OpSLessThan:
380       if (constants[0] != nullptr &&
381           constants[0]->GetS32BitValue() == INT32_MAX) {
382         *result = false;
383         return true;
384       }
385       if (constants[1] != nullptr &&
386           constants[1]->GetS32BitValue() == INT32_MIN) {
387         *result = false;
388         return true;
389       }
390       break;
391     case spv::Op::OpUGreaterThan:
392       if (constants[0] != nullptr && constants[0]->IsZero()) {
393         *result = false;
394         return true;
395       }
396       if (constants[1] != nullptr &&
397           constants[1]->GetU32BitValue() == UINT32_MAX) {
398         *result = false;
399         return true;
400       }
401       break;
402     case spv::Op::OpSGreaterThan:
403       if (constants[0] != nullptr &&
404           constants[0]->GetS32BitValue() == INT32_MIN) {
405         *result = false;
406         return true;
407       }
408       if (constants[1] != nullptr &&
409           constants[1]->GetS32BitValue() == INT32_MAX) {
410         *result = false;
411         return true;
412       }
413       break;
414     case spv::Op::OpULessThanEqual:
415       if (constants[0] != nullptr && constants[0]->IsZero()) {
416         *result = true;
417         return true;
418       }
419       if (constants[1] != nullptr &&
420           constants[1]->GetU32BitValue() == UINT32_MAX) {
421         *result = true;
422         return true;
423       }
424       break;
425     case spv::Op::OpSLessThanEqual:
426       if (constants[0] != nullptr &&
427           constants[0]->GetS32BitValue() == INT32_MIN) {
428         *result = true;
429         return true;
430       }
431       if (constants[1] != nullptr &&
432           constants[1]->GetS32BitValue() == INT32_MAX) {
433         *result = true;
434         return true;
435       }
436       break;
437     case spv::Op::OpUGreaterThanEqual:
438       if (constants[0] != nullptr &&
439           constants[0]->GetU32BitValue() == UINT32_MAX) {
440         *result = true;
441         return true;
442       }
443       if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
444         *result = true;
445         return true;
446       }
447       break;
448     case spv::Op::OpSGreaterThanEqual:
449       if (constants[0] != nullptr &&
450           constants[0]->GetS32BitValue() == INT32_MAX) {
451         *result = true;
452         return true;
453       }
454       if (constants[1] != nullptr &&
455           constants[1]->GetS32BitValue() == INT32_MIN) {
456         *result = true;
457         return true;
458       }
459       break;
460     default:
461       break;
462   }
463   return false;
464 }
465 
FoldBinaryBooleanOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const466 bool InstructionFolder::FoldBinaryBooleanOpToConstant(
467     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
468     uint32_t* result) const {
469   spv::Op opcode = inst->opcode();
470   analysis::ConstantManager* const_manger = context_->get_constant_mgr();
471 
472   uint32_t ids[2];
473   const analysis::BoolConstant* constants[2];
474   for (uint32_t i = 0; i < 2; i++) {
475     const Operand* operand = &inst->GetInOperand(i);
476     if (operand->type != SPV_OPERAND_TYPE_ID) {
477       return false;
478     }
479     ids[i] = id_map(operand->words[0]);
480     const analysis::Constant* constant =
481         const_manger->FindDeclaredConstant(ids[i]);
482     constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
483   }
484 
485   switch (opcode) {
486     // Logical
487     case spv::Op::OpLogicalOr:
488       for (uint32_t i = 0; i < 2; i++) {
489         if (constants[i] != nullptr) {
490           if (constants[i]->value()) {
491             *result = true;
492             return true;
493           }
494         }
495       }
496       break;
497     case spv::Op::OpLogicalAnd:
498       for (uint32_t i = 0; i < 2; i++) {
499         if (constants[i] != nullptr) {
500           if (!constants[i]->value()) {
501             *result = false;
502             return true;
503           }
504         }
505       }
506       break;
507 
508     default:
509       break;
510   }
511   return false;
512 }
513 
FoldIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const514 bool InstructionFolder::FoldIntegerOpToConstant(
515     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
516     uint32_t* result) const {
517   assert(IsFoldableOpcode(inst->opcode()) &&
518          "Unhandled instruction opcode in FoldScalars");
519   switch (inst->NumInOperands()) {
520     case 2:
521       return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
522              FoldBinaryBooleanOpToConstant(inst, id_map, result);
523     default:
524       return false;
525   }
526 }
527 
FoldVectors(spv::Op opcode,uint32_t num_dims,const std::vector<const analysis::Constant * > & operands) const528 std::vector<uint32_t> InstructionFolder::FoldVectors(
529     spv::Op opcode, uint32_t num_dims,
530     const std::vector<const analysis::Constant*>& operands) const {
531   assert(IsFoldableOpcode(opcode) &&
532          "Unhandled instruction opcode in FoldVectors");
533   std::vector<uint32_t> result;
534   for (uint32_t d = 0; d < num_dims; d++) {
535     std::vector<uint32_t> operand_values_for_one_dimension;
536     for (const auto& operand : operands) {
537       if (const analysis::VectorConstant* vector_operand =
538               operand->AsVectorConstant()) {
539         // Extract the raw value of the scalar component constants
540         // in 32-bit words here. The reason of not using FoldScalars() here
541         // is that we do not create temporary null constants as components
542         // when the vector operand is a NullConstant because Constant creation
543         // may need extra checks for the validity and that is not managed in
544         // here.
545         if (const analysis::ScalarConstant* scalar_component =
546                 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
547           const auto& scalar_words = scalar_component->words();
548           assert(
549               scalar_words.size() == 1 &&
550               "Vector components with longer than 32-bit width are not allowed "
551               "in FoldVectors()");
552           operand_values_for_one_dimension.push_back(scalar_words.front());
553         } else if (operand->AsNullConstant()) {
554           operand_values_for_one_dimension.push_back(0u);
555         } else {
556           assert(false &&
557                  "VectorConst should only has ScalarConst or NullConst as "
558                  "components");
559         }
560       } else if (operand->AsNullConstant()) {
561         operand_values_for_one_dimension.push_back(0u);
562       } else {
563         assert(false &&
564                "FoldVectors() only accepts VectorConst or NullConst type of "
565                "constant");
566       }
567     }
568     result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
569   }
570   return result;
571 }
572 
IsFoldableOpcode(spv::Op opcode) const573 bool InstructionFolder::IsFoldableOpcode(spv::Op opcode) const {
574   // NOTE: Extend to more opcodes as new cases are handled in the folder
575   // functions.
576   switch (opcode) {
577     case spv::Op::OpBitwiseAnd:
578     case spv::Op::OpBitwiseOr:
579     case spv::Op::OpBitwiseXor:
580     case spv::Op::OpIAdd:
581     case spv::Op::OpIEqual:
582     case spv::Op::OpIMul:
583     case spv::Op::OpINotEqual:
584     case spv::Op::OpISub:
585     case spv::Op::OpLogicalAnd:
586     case spv::Op::OpLogicalEqual:
587     case spv::Op::OpLogicalNot:
588     case spv::Op::OpLogicalNotEqual:
589     case spv::Op::OpLogicalOr:
590     case spv::Op::OpNot:
591     case spv::Op::OpSDiv:
592     case spv::Op::OpSelect:
593     case spv::Op::OpSGreaterThan:
594     case spv::Op::OpSGreaterThanEqual:
595     case spv::Op::OpShiftLeftLogical:
596     case spv::Op::OpShiftRightArithmetic:
597     case spv::Op::OpShiftRightLogical:
598     case spv::Op::OpSLessThan:
599     case spv::Op::OpSLessThanEqual:
600     case spv::Op::OpSMod:
601     case spv::Op::OpSNegate:
602     case spv::Op::OpSRem:
603     case spv::Op::OpSConvert:
604     case spv::Op::OpUConvert:
605     case spv::Op::OpUDiv:
606     case spv::Op::OpUGreaterThan:
607     case spv::Op::OpUGreaterThanEqual:
608     case spv::Op::OpULessThan:
609     case spv::Op::OpULessThanEqual:
610     case spv::Op::OpUMod:
611       return true;
612     default:
613       return false;
614   }
615 }
616 
IsFoldableConstant(const analysis::Constant * cst) const617 bool InstructionFolder::IsFoldableConstant(
618     const analysis::Constant* cst) const {
619   // Currently supported constants are 32-bit values or null constants.
620   if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
621     return scalar->words().size() == 1;
622   else
623     return cst->AsNullConstant() != nullptr;
624 }
625 
FoldInstructionToConstant(Instruction * inst,std::function<uint32_t (uint32_t)> id_map) const626 Instruction* InstructionFolder::FoldInstructionToConstant(
627     Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
628   analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
629 
630   if (!inst->IsFoldableByFoldScalar() && !inst->IsFoldableByFoldVector() &&
631       !GetConstantFoldingRules().HasFoldingRule(inst)) {
632     return nullptr;
633   }
634   // Collect the values of the constant parameters.
635   std::vector<const analysis::Constant*> constants;
636   bool missing_constants = false;
637   inst->ForEachInId([&constants, &missing_constants, const_mgr,
638                      &id_map](uint32_t* op_id) {
639     uint32_t id = id_map(*op_id);
640     const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
641     if (!const_op) {
642       constants.push_back(nullptr);
643       missing_constants = true;
644     } else {
645       constants.push_back(const_op);
646     }
647   });
648 
649   const analysis::Constant* folded_const = nullptr;
650   for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) {
651     folded_const = rule(context_, inst, constants);
652     if (folded_const != nullptr) {
653       Instruction* const_inst =
654           const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
655       if (const_inst == nullptr) {
656         return nullptr;
657       }
658       assert(const_inst->type_id() == inst->type_id());
659       // May be a new instruction that needs to be analysed.
660       context_->UpdateDefUse(const_inst);
661       return const_inst;
662     }
663   }
664 
665   bool successful = false;
666 
667   // If all parameters are constant, fold the instruction to a constant.
668   if (inst->IsFoldableByFoldScalar()) {
669     uint32_t result_val = 0;
670 
671     if (!missing_constants) {
672       result_val = FoldScalars(inst->opcode(), constants);
673       successful = true;
674     }
675 
676     if (!successful) {
677       successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
678     }
679 
680     if (successful) {
681       const analysis::Constant* result_const =
682           const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
683       Instruction* folded_inst =
684           const_mgr->GetDefiningInstruction(result_const, inst->type_id());
685       return folded_inst;
686     }
687   } else if (inst->IsFoldableByFoldVector()) {
688     std::vector<uint32_t> result_val;
689 
690     if (!missing_constants) {
691       if (Instruction* inst_type =
692               context_->get_def_use_mgr()->GetDef(inst->type_id())) {
693         result_val = FoldVectors(
694             inst->opcode(), inst_type->GetSingleWordInOperand(1), constants);
695         successful = true;
696       }
697     }
698 
699     if (successful) {
700       const analysis::Constant* result_const =
701           const_mgr->GetNumericVectorConstantWithWords(
702               const_mgr->GetType(inst)->AsVector(), result_val);
703       Instruction* folded_inst =
704           const_mgr->GetDefiningInstruction(result_const, inst->type_id());
705       return folded_inst;
706     }
707   }
708 
709   return nullptr;
710 }
711 
IsFoldableType(Instruction * type_inst) const712 bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
713   return IsFoldableScalarType(type_inst) || IsFoldableVectorType(type_inst);
714 }
715 
IsFoldableScalarType(Instruction * type_inst) const716 bool InstructionFolder::IsFoldableScalarType(Instruction* type_inst) const {
717   // Support 32-bit integers.
718   if (type_inst->opcode() == spv::Op::OpTypeInt) {
719     return type_inst->GetSingleWordInOperand(0) == 32;
720   }
721   // Support booleans.
722   if (type_inst->opcode() == spv::Op::OpTypeBool) {
723     return true;
724   }
725   // Nothing else yet.
726   return false;
727 }
728 
IsFoldableVectorType(Instruction * type_inst) const729 bool InstructionFolder::IsFoldableVectorType(Instruction* type_inst) const {
730   // Support vectors with foldable components
731   if (type_inst->opcode() == spv::Op::OpTypeVector) {
732     uint32_t component_type_id = type_inst->GetSingleWordInOperand(0);
733     Instruction* def_component_type =
734         context_->get_def_use_mgr()->GetDef(component_type_id);
735     return def_component_type != nullptr &&
736            IsFoldableScalarType(def_component_type);
737   }
738   // Nothing else yet.
739   return false;
740 }
741 
FoldInstruction(Instruction * inst) const742 bool InstructionFolder::FoldInstruction(Instruction* inst) const {
743   bool modified = false;
744   Instruction* folded_inst(inst);
745   while (folded_inst->opcode() != spv::Op::OpCopyObject &&
746          FoldInstructionInternal(&*folded_inst)) {
747     modified = true;
748   }
749   return modified;
750 }
751 
752 }  // namespace opt
753 }  // namespace spvtools
754