• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/folding_rules.h"
16 
17 #include <limits>
18 #include <memory>
19 #include <utility>
20 
21 #include "ir_builder.h"
22 #include "source/latest_version_glsl_std_450_header.h"
23 #include "source/opt/ir_context.h"
24 
25 namespace spvtools {
26 namespace opt {
27 namespace {
28 
29 const uint32_t kExtractCompositeIdInIdx = 0;
30 const uint32_t kInsertObjectIdInIdx = 0;
31 const uint32_t kInsertCompositeIdInIdx = 1;
32 const uint32_t kExtInstSetIdInIdx = 0;
33 const uint32_t kExtInstInstructionInIdx = 1;
34 const uint32_t kFMixXIdInIdx = 2;
35 const uint32_t kFMixYIdInIdx = 3;
36 const uint32_t kFMixAIdInIdx = 4;
37 const uint32_t kStoreObjectInIdx = 1;
38 
39 // Some image instructions may contain an "image operands" argument.
40 // Returns the operand index for the "image operands".
41 // Returns -1 if the instruction does not have image operands.
ImageOperandsMaskInOperandIndex(Instruction * inst)42 int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
43   const auto opcode = inst->opcode();
44   switch (opcode) {
45     case SpvOpImageSampleImplicitLod:
46     case SpvOpImageSampleExplicitLod:
47     case SpvOpImageSampleProjImplicitLod:
48     case SpvOpImageSampleProjExplicitLod:
49     case SpvOpImageFetch:
50     case SpvOpImageRead:
51     case SpvOpImageSparseSampleImplicitLod:
52     case SpvOpImageSparseSampleExplicitLod:
53     case SpvOpImageSparseSampleProjImplicitLod:
54     case SpvOpImageSparseSampleProjExplicitLod:
55     case SpvOpImageSparseFetch:
56     case SpvOpImageSparseRead:
57       return inst->NumOperands() > 4 ? 2 : -1;
58     case SpvOpImageSampleDrefImplicitLod:
59     case SpvOpImageSampleDrefExplicitLod:
60     case SpvOpImageSampleProjDrefImplicitLod:
61     case SpvOpImageSampleProjDrefExplicitLod:
62     case SpvOpImageGather:
63     case SpvOpImageDrefGather:
64     case SpvOpImageSparseSampleDrefImplicitLod:
65     case SpvOpImageSparseSampleDrefExplicitLod:
66     case SpvOpImageSparseSampleProjDrefImplicitLod:
67     case SpvOpImageSparseSampleProjDrefExplicitLod:
68     case SpvOpImageSparseGather:
69     case SpvOpImageSparseDrefGather:
70       return inst->NumOperands() > 5 ? 3 : -1;
71     case SpvOpImageWrite:
72       return inst->NumOperands() > 3 ? 3 : -1;
73     default:
74       return -1;
75   }
76 }
77 
78 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)79 uint32_t ElementWidth(const analysis::Type* type) {
80   if (const analysis::Vector* vec_type = type->AsVector()) {
81     return ElementWidth(vec_type->element_type());
82   } else if (const analysis::Float* float_type = type->AsFloat()) {
83     return float_type->width();
84   } else {
85     assert(type->AsInteger());
86     return type->AsInteger()->width();
87   }
88 }
89 
90 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)91 bool HasFloatingPoint(const analysis::Type* type) {
92   if (type->AsFloat()) {
93     return true;
94   } else if (const analysis::Vector* vec_type = type->AsVector()) {
95     return vec_type->element_type()->AsFloat() != nullptr;
96   }
97 
98   return false;
99 }
100 
101 // Returns false if |val| is NaN, infinite or subnormal.
102 template <typename T>
IsValidResult(T val)103 bool IsValidResult(T val) {
104   int classified = std::fpclassify(val);
105   switch (classified) {
106     case FP_NAN:
107     case FP_INFINITE:
108     case FP_SUBNORMAL:
109       return false;
110     default:
111       return true;
112   }
113 }
114 
ConstInput(const std::vector<const analysis::Constant * > & constants)115 const analysis::Constant* ConstInput(
116     const std::vector<const analysis::Constant*>& constants) {
117   return constants[0] ? constants[0] : constants[1];
118 }
119 
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)120 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
121                            Instruction* inst) {
122   uint32_t in_op = c ? 1u : 0u;
123   return context->get_def_use_mgr()->GetDef(
124       inst->GetSingleWordInOperand(in_op));
125 }
126 
127 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
128 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)129 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
130                                      const analysis::Constant* c) {
131   assert(c);
132   assert(c->type()->AsFloat());
133   uint32_t width = c->type()->AsFloat()->width();
134   assert(width == 32 || width == 64);
135   std::vector<uint32_t> words;
136   if (width == 64) {
137     utils::FloatProxy<double> result(c->GetDouble() * -1.0);
138     words = result.GetWords();
139   } else {
140     utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
141     words = result.GetWords();
142   }
143 
144   const analysis::Constant* negated_const =
145       const_mgr->GetConstant(c->type(), std::move(words));
146   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
147 }
148 
ExtractInts(uint64_t val)149 std::vector<uint32_t> ExtractInts(uint64_t val) {
150   std::vector<uint32_t> words;
151   words.push_back(static_cast<uint32_t>(val));
152   words.push_back(static_cast<uint32_t>(val >> 32));
153   return words;
154 }
155 
156 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)157 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
158                                const analysis::Constant* c) {
159   assert(c);
160   assert(c->type()->AsInteger());
161   uint32_t width = c->type()->AsInteger()->width();
162   assert(width == 32 || width == 64);
163   std::vector<uint32_t> words;
164   if (width == 64) {
165     uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
166     words = ExtractInts(uval);
167   } else {
168     words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
169   }
170 
171   const analysis::Constant* negated_const =
172       const_mgr->GetConstant(c->type(), std::move(words));
173   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
174 }
175 
176 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)177 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
178                               const analysis::Constant* c) {
179   assert(const_mgr && c);
180   assert(c->type()->AsVector());
181   if (c->AsNullConstant()) {
182     // 0.0 vs -0.0 shouldn't matter.
183     return const_mgr->GetDefiningInstruction(c)->result_id();
184   } else {
185     const analysis::Type* component_type =
186         c->AsVectorConstant()->component_type();
187     std::vector<uint32_t> words;
188     for (auto& comp : c->AsVectorConstant()->GetComponents()) {
189       if (component_type->AsFloat()) {
190         words.push_back(NegateFloatingPointConstant(const_mgr, comp));
191       } else {
192         assert(component_type->AsInteger());
193         words.push_back(NegateIntegerConstant(const_mgr, comp));
194       }
195     }
196 
197     const analysis::Constant* negated_const =
198         const_mgr->GetConstant(c->type(), std::move(words));
199     return const_mgr->GetDefiningInstruction(negated_const)->result_id();
200   }
201 }
202 
203 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)204 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
205                         const analysis::Constant* c) {
206   if (c->type()->AsVector()) {
207     return NegateVectorConstant(const_mgr, c);
208   } else if (c->type()->AsFloat()) {
209     return NegateFloatingPointConstant(const_mgr, c);
210   } else {
211     assert(c->type()->AsInteger());
212     return NegateIntegerConstant(const_mgr, c);
213   }
214 }
215 
216 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
217 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)218 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
219                     const analysis::Constant* c) {
220   assert(const_mgr && c);
221   assert(c->type()->AsFloat());
222 
223   uint32_t width = c->type()->AsFloat()->width();
224   assert(width == 32 || width == 64);
225   std::vector<uint32_t> words;
226   if (width == 64) {
227     spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
228     if (!IsValidResult(result.getAsFloat())) return 0;
229     words = result.GetWords();
230   } else {
231     spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
232     if (!IsValidResult(result.getAsFloat())) return 0;
233     words = result.GetWords();
234   }
235 
236   const analysis::Constant* negated_const =
237       const_mgr->GetConstant(c->type(), std::move(words));
238   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
239 }
240 
241 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()242 FoldingRule ReciprocalFDiv() {
243   return [](IRContext* context, Instruction* inst,
244             const std::vector<const analysis::Constant*>& constants) {
245     assert(inst->opcode() == SpvOpFDiv);
246     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
247     const analysis::Type* type =
248         context->get_type_mgr()->GetType(inst->type_id());
249     if (!inst->IsFloatingPointFoldingAllowed()) return false;
250 
251     uint32_t width = ElementWidth(type);
252     if (width != 32 && width != 64) return false;
253 
254     if (constants[1] != nullptr) {
255       uint32_t id = 0;
256       if (const analysis::VectorConstant* vector_const =
257               constants[1]->AsVectorConstant()) {
258         std::vector<uint32_t> neg_ids;
259         for (auto& comp : vector_const->GetComponents()) {
260           id = Reciprocal(const_mgr, comp);
261           if (id == 0) return false;
262           neg_ids.push_back(id);
263         }
264         const analysis::Constant* negated_const =
265             const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
266         id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
267       } else if (constants[1]->AsFloatConstant()) {
268         id = Reciprocal(const_mgr, constants[1]);
269         if (id == 0) return false;
270       } else {
271         // Don't fold a null constant.
272         return false;
273       }
274       inst->SetOpcode(SpvOpFMul);
275       inst->SetInOperands(
276           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
277            {SPV_OPERAND_TYPE_ID, {id}}});
278       return true;
279     }
280 
281     return false;
282   };
283 }
284 
285 // Elides consecutive negate instructions.
MergeNegateArithmetic()286 FoldingRule MergeNegateArithmetic() {
287   return [](IRContext* context, Instruction* inst,
288             const std::vector<const analysis::Constant*>& constants) {
289     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
290     (void)constants;
291     const analysis::Type* type =
292         context->get_type_mgr()->GetType(inst->type_id());
293     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
294       return false;
295 
296     Instruction* op_inst =
297         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
298     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
299       return false;
300 
301     if (op_inst->opcode() == inst->opcode()) {
302       // Elide negates.
303       inst->SetOpcode(SpvOpCopyObject);
304       inst->SetInOperands(
305           {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
306       return true;
307     }
308 
309     return false;
310   };
311 }
312 
313 // Merges negate into a mul or div operation if that operation contains a
314 // constant operand.
315 // Cases:
316 // -(x * 2) = x * -2
317 // -(2 * x) = x * -2
318 // -(x / 2) = x / -2
319 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()320 FoldingRule MergeNegateMulDivArithmetic() {
321   return [](IRContext* context, Instruction* inst,
322             const std::vector<const analysis::Constant*>& constants) {
323     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
324     (void)constants;
325     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
326     const analysis::Type* type =
327         context->get_type_mgr()->GetType(inst->type_id());
328     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
329       return false;
330 
331     Instruction* op_inst =
332         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
333     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
334       return false;
335 
336     uint32_t width = ElementWidth(type);
337     if (width != 32 && width != 64) return false;
338 
339     SpvOp opcode = op_inst->opcode();
340     if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
341         opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
342       std::vector<const analysis::Constant*> op_constants =
343           const_mgr->GetOperandConstants(op_inst);
344       // Merge negate into mul or div if one operand is constant.
345       if (op_constants[0] || op_constants[1]) {
346         bool zero_is_variable = op_constants[0] == nullptr;
347         const analysis::Constant* c = ConstInput(op_constants);
348         uint32_t neg_id = NegateConstant(const_mgr, c);
349         uint32_t non_const_id = zero_is_variable
350                                     ? op_inst->GetSingleWordInOperand(0u)
351                                     : op_inst->GetSingleWordInOperand(1u);
352         // Change this instruction to a mul/div.
353         inst->SetOpcode(op_inst->opcode());
354         if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
355           uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
356           uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
357           inst->SetInOperands(
358               {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
359         } else {
360           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
361                                {SPV_OPERAND_TYPE_ID, {neg_id}}});
362         }
363         return true;
364       }
365     }
366 
367     return false;
368   };
369 }
370 
371 // Merges negate into a add or sub operation if that operation contains a
372 // constant operand.
373 // Cases:
374 // -(x + 2) = -2 - x
375 // -(2 + x) = -2 - x
376 // -(x - 2) = 2 - x
377 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()378 FoldingRule MergeNegateAddSubArithmetic() {
379   return [](IRContext* context, Instruction* inst,
380             const std::vector<const analysis::Constant*>& constants) {
381     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
382     (void)constants;
383     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
384     const analysis::Type* type =
385         context->get_type_mgr()->GetType(inst->type_id());
386     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
387       return false;
388 
389     Instruction* op_inst =
390         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
391     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
392       return false;
393 
394     uint32_t width = ElementWidth(type);
395     if (width != 32 && width != 64) return false;
396 
397     if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
398         op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
399       std::vector<const analysis::Constant*> op_constants =
400           const_mgr->GetOperandConstants(op_inst);
401       if (op_constants[0] || op_constants[1]) {
402         bool zero_is_variable = op_constants[0] == nullptr;
403         bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
404                       (op_inst->opcode() == SpvOpIAdd);
405         bool swap_operands = !is_add || zero_is_variable;
406         bool negate_const = is_add;
407         const analysis::Constant* c = ConstInput(op_constants);
408         uint32_t const_id = 0;
409         if (negate_const) {
410           const_id = NegateConstant(const_mgr, c);
411         } else {
412           const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
413                                       : op_inst->GetSingleWordInOperand(0u);
414         }
415 
416         // Swap operands if necessary and make the instruction a subtraction.
417         uint32_t op0 =
418             zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
419         uint32_t op1 =
420             zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
421         if (swap_operands) std::swap(op0, op1);
422         inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
423         inst->SetInOperands(
424             {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
425         return true;
426       }
427     }
428 
429     return false;
430   };
431 }
432 
433 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)434 bool HasZero(const analysis::Constant* c) {
435   if (c->AsNullConstant()) {
436     return true;
437   }
438   if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
439     for (auto& comp : vec_const->GetComponents())
440       if (HasZero(comp)) return true;
441   } else {
442     assert(c->AsScalarConstant());
443     return c->AsScalarConstant()->IsZero();
444   }
445 
446   return false;
447 }
448 
449 // Performs |input1| |opcode| |input2| and returns the merged constant result
450 // id. Returns 0 if the result is not a valid value. The input types must be
451 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)452 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
453                                        SpvOp opcode,
454                                        const analysis::Constant* input1,
455                                        const analysis::Constant* input2) {
456   const analysis::Type* type = input1->type();
457   assert(type->AsFloat());
458   uint32_t width = type->AsFloat()->width();
459   assert(width == 32 || width == 64);
460   std::vector<uint32_t> words;
461 #define FOLD_OP(op)                                                          \
462   if (width == 64) {                                                         \
463     utils::FloatProxy<double> val =                                          \
464         input1->GetDouble() op input2->GetDouble();                          \
465     double dval = val.getAsFloat();                                          \
466     if (!IsValidResult(dval)) return 0;                                      \
467     words = val.GetWords();                                                  \
468   } else {                                                                   \
469     utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
470     float fval = val.getAsFloat();                                           \
471     if (!IsValidResult(fval)) return 0;                                      \
472     words = val.GetWords();                                                  \
473   }
474   switch (opcode) {
475     case SpvOpFMul:
476       FOLD_OP(*);
477       break;
478     case SpvOpFDiv:
479       if (HasZero(input2)) return 0;
480       FOLD_OP(/);
481       break;
482     case SpvOpFAdd:
483       FOLD_OP(+);
484       break;
485     case SpvOpFSub:
486       FOLD_OP(-);
487       break;
488     default:
489       assert(false && "Unexpected operation");
490       break;
491   }
492 #undef FOLD_OP
493   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
494   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
495 }
496 
497 // Performs |input1| |opcode| |input2| and returns the merged constant result
498 // id. Returns 0 if the result is not a valid value. The input types must be
499 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)500 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
501                                  SpvOp opcode, const analysis::Constant* input1,
502                                  const analysis::Constant* input2) {
503   assert(input1->type()->AsInteger());
504   const analysis::Integer* type = input1->type()->AsInteger();
505   uint32_t width = type->AsInteger()->width();
506   assert(width == 32 || width == 64);
507   std::vector<uint32_t> words;
508 #define FOLD_OP(op)                                        \
509   if (width == 64) {                                       \
510     if (type->IsSigned()) {                                \
511       int64_t val = input1->GetS64() op input2->GetS64();  \
512       words = ExtractInts(static_cast<uint64_t>(val));     \
513     } else {                                               \
514       uint64_t val = input1->GetU64() op input2->GetU64(); \
515       words = ExtractInts(val);                            \
516     }                                                      \
517   } else {                                                 \
518     if (type->IsSigned()) {                                \
519       int32_t val = input1->GetS32() op input2->GetS32();  \
520       words.push_back(static_cast<uint32_t>(val));         \
521     } else {                                               \
522       uint32_t val = input1->GetU32() op input2->GetU32(); \
523       words.push_back(val);                                \
524     }                                                      \
525   }
526   switch (opcode) {
527     case SpvOpIMul:
528       FOLD_OP(*);
529       break;
530     case SpvOpSDiv:
531     case SpvOpUDiv:
532       assert(false && "Should not merge integer division");
533       break;
534     case SpvOpIAdd:
535       FOLD_OP(+);
536       break;
537     case SpvOpISub:
538       FOLD_OP(-);
539       break;
540     default:
541       assert(false && "Unexpected operation");
542       break;
543   }
544 #undef FOLD_OP
545   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
546   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
547 }
548 
549 // Performs |input1| |opcode| |input2| and returns the merged constant result
550 // id. Returns 0 if the result is not a valid value. The input types must be
551 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)552 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
553                           const analysis::Constant* input1,
554                           const analysis::Constant* input2) {
555   assert(input1 && input2);
556   const analysis::Type* type = input1->type();
557   std::vector<uint32_t> words;
558   if (const analysis::Vector* vector_type = type->AsVector()) {
559     const analysis::Type* ele_type = vector_type->element_type();
560     for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
561       uint32_t id = 0;
562 
563       const analysis::Constant* input1_comp = nullptr;
564       if (const analysis::VectorConstant* input1_vector =
565               input1->AsVectorConstant()) {
566         input1_comp = input1_vector->GetComponents()[i];
567       } else {
568         assert(input1->AsNullConstant());
569         input1_comp = const_mgr->GetConstant(ele_type, {});
570       }
571 
572       const analysis::Constant* input2_comp = nullptr;
573       if (const analysis::VectorConstant* input2_vector =
574               input2->AsVectorConstant()) {
575         input2_comp = input2_vector->GetComponents()[i];
576       } else {
577         assert(input2->AsNullConstant());
578         input2_comp = const_mgr->GetConstant(ele_type, {});
579       }
580 
581       if (ele_type->AsFloat()) {
582         id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
583                                            input2_comp);
584       } else {
585         assert(ele_type->AsInteger());
586         id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
587                                      input2_comp);
588       }
589       if (id == 0) return 0;
590       words.push_back(id);
591     }
592     const analysis::Constant* merged_const =
593         const_mgr->GetConstant(type, words);
594     return const_mgr->GetDefiningInstruction(merged_const)->result_id();
595   } else if (type->AsFloat()) {
596     return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
597   } else {
598     assert(type->AsInteger());
599     return PerformIntegerOperation(const_mgr, opcode, input1, input2);
600   }
601 }
602 
603 // Merges consecutive multiplies where each contains one constant operand.
604 // Cases:
605 // 2 * (x * 2) = x * 4
606 // 2 * (2 * x) = x * 4
607 // (x * 2) * 2 = x * 4
608 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()609 FoldingRule MergeMulMulArithmetic() {
610   return [](IRContext* context, Instruction* inst,
611             const std::vector<const analysis::Constant*>& constants) {
612     assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
613     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
614     const analysis::Type* type =
615         context->get_type_mgr()->GetType(inst->type_id());
616     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
617       return false;
618 
619     uint32_t width = ElementWidth(type);
620     if (width != 32 && width != 64) return false;
621 
622     // Determine the constant input and the variable input in |inst|.
623     const analysis::Constant* const_input1 = ConstInput(constants);
624     if (!const_input1) return false;
625     Instruction* other_inst = NonConstInput(context, constants[0], inst);
626     if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
627       return false;
628 
629     if (other_inst->opcode() == inst->opcode()) {
630       std::vector<const analysis::Constant*> other_constants =
631           const_mgr->GetOperandConstants(other_inst);
632       const analysis::Constant* const_input2 = ConstInput(other_constants);
633       if (!const_input2) return false;
634 
635       bool other_first_is_variable = other_constants[0] == nullptr;
636       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
637                                             const_input1, const_input2);
638       if (merged_id == 0) return false;
639 
640       uint32_t non_const_id = other_first_is_variable
641                                   ? other_inst->GetSingleWordInOperand(0u)
642                                   : other_inst->GetSingleWordInOperand(1u);
643       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
644                            {SPV_OPERAND_TYPE_ID, {merged_id}}});
645       return true;
646     }
647 
648     return false;
649   };
650 }
651 
652 // Merges divides into subsequent multiplies if each instruction contains one
653 // constant operand. Does not support integer operations.
654 // Cases:
655 // 2 * (x / 2) = x * 1
656 // 2 * (2 / x) = 4 / x
657 // (x / 2) * 2 = x * 1
658 // (2 / x) * 2 = 4 / x
659 // (y / x) * x = y
660 // x * (y / x) = y
MergeMulDivArithmetic()661 FoldingRule MergeMulDivArithmetic() {
662   return [](IRContext* context, Instruction* inst,
663             const std::vector<const analysis::Constant*>& constants) {
664     assert(inst->opcode() == SpvOpFMul);
665     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
666     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
667 
668     const analysis::Type* type =
669         context->get_type_mgr()->GetType(inst->type_id());
670     if (!inst->IsFloatingPointFoldingAllowed()) return false;
671 
672     uint32_t width = ElementWidth(type);
673     if (width != 32 && width != 64) return false;
674 
675     for (uint32_t i = 0; i < 2; i++) {
676       uint32_t op_id = inst->GetSingleWordInOperand(i);
677       Instruction* op_inst = def_use_mgr->GetDef(op_id);
678       if (op_inst->opcode() == SpvOpFDiv) {
679         if (op_inst->GetSingleWordInOperand(1) ==
680             inst->GetSingleWordInOperand(1 - i)) {
681           inst->SetOpcode(SpvOpCopyObject);
682           inst->SetInOperands(
683               {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
684           return true;
685         }
686       }
687     }
688 
689     const analysis::Constant* const_input1 = ConstInput(constants);
690     if (!const_input1) return false;
691     Instruction* other_inst = NonConstInput(context, constants[0], inst);
692     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
693 
694     if (other_inst->opcode() == SpvOpFDiv) {
695       std::vector<const analysis::Constant*> other_constants =
696           const_mgr->GetOperandConstants(other_inst);
697       const analysis::Constant* const_input2 = ConstInput(other_constants);
698       if (!const_input2 || HasZero(const_input2)) return false;
699 
700       bool other_first_is_variable = other_constants[0] == nullptr;
701       // If the variable value is the second operand of the divide, multiply
702       // the constants together. Otherwise divide the constants.
703       uint32_t merged_id = PerformOperation(
704           const_mgr,
705           other_first_is_variable ? other_inst->opcode() : inst->opcode(),
706           const_input1, const_input2);
707       if (merged_id == 0) return false;
708 
709       uint32_t non_const_id = other_first_is_variable
710                                   ? other_inst->GetSingleWordInOperand(0u)
711                                   : other_inst->GetSingleWordInOperand(1u);
712 
713       // If the variable value is on the second operand of the div, then this
714       // operation is a div. Otherwise it should be a multiply.
715       inst->SetOpcode(other_first_is_variable ? inst->opcode()
716                                               : other_inst->opcode());
717       if (other_first_is_variable) {
718         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
719                              {SPV_OPERAND_TYPE_ID, {merged_id}}});
720       } else {
721         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
722                              {SPV_OPERAND_TYPE_ID, {non_const_id}}});
723       }
724       return true;
725     }
726 
727     return false;
728   };
729 }
730 
731 // Merges multiply of constant and negation.
732 // Cases:
733 // (-x) * 2 = x * -2
734 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()735 FoldingRule MergeMulNegateArithmetic() {
736   return [](IRContext* context, Instruction* inst,
737             const std::vector<const analysis::Constant*>& constants) {
738     assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
739     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
740     const analysis::Type* type =
741         context->get_type_mgr()->GetType(inst->type_id());
742     bool uses_float = HasFloatingPoint(type);
743     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
744 
745     uint32_t width = ElementWidth(type);
746     if (width != 32 && width != 64) return false;
747 
748     const analysis::Constant* const_input1 = ConstInput(constants);
749     if (!const_input1) return false;
750     Instruction* other_inst = NonConstInput(context, constants[0], inst);
751     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
752       return false;
753 
754     if (other_inst->opcode() == SpvOpFNegate ||
755         other_inst->opcode() == SpvOpSNegate) {
756       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
757 
758       inst->SetInOperands(
759           {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
760            {SPV_OPERAND_TYPE_ID, {neg_id}}});
761       return true;
762     }
763 
764     return false;
765   };
766 }
767 
768 // Merges consecutive divides if each instruction contains one constant operand.
769 // Does not support integer division.
770 // Cases:
771 // 2 / (x / 2) = 4 / x
772 // 4 / (2 / x) = 2 * x
773 // (4 / x) / 2 = 2 / x
774 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()775 FoldingRule MergeDivDivArithmetic() {
776   return [](IRContext* context, Instruction* inst,
777             const std::vector<const analysis::Constant*>& constants) {
778     assert(inst->opcode() == SpvOpFDiv);
779     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
780     const analysis::Type* type =
781         context->get_type_mgr()->GetType(inst->type_id());
782     if (!inst->IsFloatingPointFoldingAllowed()) return false;
783 
784     uint32_t width = ElementWidth(type);
785     if (width != 32 && width != 64) return false;
786 
787     const analysis::Constant* const_input1 = ConstInput(constants);
788     if (!const_input1 || HasZero(const_input1)) return false;
789     Instruction* other_inst = NonConstInput(context, constants[0], inst);
790     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
791 
792     bool first_is_variable = constants[0] == nullptr;
793     if (other_inst->opcode() == inst->opcode()) {
794       std::vector<const analysis::Constant*> other_constants =
795           const_mgr->GetOperandConstants(other_inst);
796       const analysis::Constant* const_input2 = ConstInput(other_constants);
797       if (!const_input2 || HasZero(const_input2)) return false;
798 
799       bool other_first_is_variable = other_constants[0] == nullptr;
800 
801       SpvOp merge_op = inst->opcode();
802       if (other_first_is_variable) {
803         // Constants magnify.
804         merge_op = SpvOpFMul;
805       }
806 
807       // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
808       // because it is commutative.
809       if (first_is_variable) std::swap(const_input1, const_input2);
810       uint32_t merged_id =
811           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
812       if (merged_id == 0) return false;
813 
814       uint32_t non_const_id = other_first_is_variable
815                                   ? other_inst->GetSingleWordInOperand(0u)
816                                   : other_inst->GetSingleWordInOperand(1u);
817 
818       SpvOp op = inst->opcode();
819       if (!first_is_variable && !other_first_is_variable) {
820         // Effectively div of 1/x, so change to multiply.
821         op = SpvOpFMul;
822       }
823 
824       uint32_t op1 = merged_id;
825       uint32_t op2 = non_const_id;
826       if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
827       inst->SetOpcode(op);
828       inst->SetInOperands(
829           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
830       return true;
831     }
832 
833     return false;
834   };
835 }
836 
837 // Fold multiplies succeeded by divides where each instruction contains a
838 // constant operand. Does not support integer divide.
839 // Cases:
840 // 4 / (x * 2) = 2 / x
841 // 4 / (2 * x) = 2 / x
842 // (x * 4) / 2 = x * 2
843 // (4 * x) / 2 = x * 2
844 // (x * y) / x = y
845 // (y * x) / x = y
MergeDivMulArithmetic()846 FoldingRule MergeDivMulArithmetic() {
847   return [](IRContext* context, Instruction* inst,
848             const std::vector<const analysis::Constant*>& constants) {
849     assert(inst->opcode() == SpvOpFDiv);
850     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
851     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
852 
853     const analysis::Type* type =
854         context->get_type_mgr()->GetType(inst->type_id());
855     if (!inst->IsFloatingPointFoldingAllowed()) return false;
856 
857     uint32_t width = ElementWidth(type);
858     if (width != 32 && width != 64) return false;
859 
860     uint32_t op_id = inst->GetSingleWordInOperand(0);
861     Instruction* op_inst = def_use_mgr->GetDef(op_id);
862 
863     if (op_inst->opcode() == SpvOpFMul) {
864       for (uint32_t i = 0; i < 2; i++) {
865         if (op_inst->GetSingleWordInOperand(i) ==
866             inst->GetSingleWordInOperand(1)) {
867           inst->SetOpcode(SpvOpCopyObject);
868           inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
869                                 {op_inst->GetSingleWordInOperand(1 - i)}}});
870           return true;
871         }
872       }
873     }
874 
875     const analysis::Constant* const_input1 = ConstInput(constants);
876     if (!const_input1 || HasZero(const_input1)) return false;
877     Instruction* other_inst = NonConstInput(context, constants[0], inst);
878     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
879 
880     bool first_is_variable = constants[0] == nullptr;
881     if (other_inst->opcode() == SpvOpFMul) {
882       std::vector<const analysis::Constant*> other_constants =
883           const_mgr->GetOperandConstants(other_inst);
884       const analysis::Constant* const_input2 = ConstInput(other_constants);
885       if (!const_input2) return false;
886 
887       bool other_first_is_variable = other_constants[0] == nullptr;
888 
889       // This is an x / (*) case. Swap the inputs.
890       if (first_is_variable) std::swap(const_input1, const_input2);
891       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
892                                             const_input1, const_input2);
893       if (merged_id == 0) return false;
894 
895       uint32_t non_const_id = other_first_is_variable
896                                   ? other_inst->GetSingleWordInOperand(0u)
897                                   : other_inst->GetSingleWordInOperand(1u);
898 
899       uint32_t op1 = merged_id;
900       uint32_t op2 = non_const_id;
901       if (first_is_variable) std::swap(op1, op2);
902 
903       // Convert to multiply
904       if (first_is_variable) inst->SetOpcode(other_inst->opcode());
905       inst->SetInOperands(
906           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
907       return true;
908     }
909 
910     return false;
911   };
912 }
913 
914 // Fold divides of a constant and a negation.
915 // Cases:
916 // (-x) / 2 = x / -2
917 // 2 / (-x) = 2 / -x
MergeDivNegateArithmetic()918 FoldingRule MergeDivNegateArithmetic() {
919   return [](IRContext* context, Instruction* inst,
920             const std::vector<const analysis::Constant*>& constants) {
921     assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
922            inst->opcode() == SpvOpUDiv);
923     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
924     const analysis::Type* type =
925         context->get_type_mgr()->GetType(inst->type_id());
926     bool uses_float = HasFloatingPoint(type);
927     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
928 
929     uint32_t width = ElementWidth(type);
930     if (width != 32 && width != 64) return false;
931 
932     const analysis::Constant* const_input1 = ConstInput(constants);
933     if (!const_input1) return false;
934     Instruction* other_inst = NonConstInput(context, constants[0], inst);
935     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
936       return false;
937 
938     bool first_is_variable = constants[0] == nullptr;
939     if (other_inst->opcode() == SpvOpFNegate ||
940         other_inst->opcode() == SpvOpSNegate) {
941       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
942 
943       if (first_is_variable) {
944         inst->SetInOperands(
945             {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
946              {SPV_OPERAND_TYPE_ID, {neg_id}}});
947       } else {
948         inst->SetInOperands(
949             {{SPV_OPERAND_TYPE_ID, {neg_id}},
950              {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
951       }
952       return true;
953     }
954 
955     return false;
956   };
957 }
958 
959 // Folds addition of a constant and a negation.
960 // Cases:
961 // (-x) + 2 = 2 - x
962 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()963 FoldingRule MergeAddNegateArithmetic() {
964   return [](IRContext* context, Instruction* inst,
965             const std::vector<const analysis::Constant*>& constants) {
966     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
967     const analysis::Type* type =
968         context->get_type_mgr()->GetType(inst->type_id());
969     bool uses_float = HasFloatingPoint(type);
970     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
971 
972     const analysis::Constant* const_input1 = ConstInput(constants);
973     if (!const_input1) return false;
974     Instruction* other_inst = NonConstInput(context, constants[0], inst);
975     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
976       return false;
977 
978     if (other_inst->opcode() == SpvOpSNegate ||
979         other_inst->opcode() == SpvOpFNegate) {
980       inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
981       uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
982                                        : inst->GetSingleWordInOperand(1u);
983       inst->SetInOperands(
984           {{SPV_OPERAND_TYPE_ID, {const_id}},
985            {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
986       return true;
987     }
988     return false;
989   };
990 }
991 
992 // Folds subtraction of a constant and a negation.
993 // Cases:
994 // (-x) - 2 = -2 - x
995 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()996 FoldingRule MergeSubNegateArithmetic() {
997   return [](IRContext* context, Instruction* inst,
998             const std::vector<const analysis::Constant*>& constants) {
999     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1000     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1001     const analysis::Type* type =
1002         context->get_type_mgr()->GetType(inst->type_id());
1003     bool uses_float = HasFloatingPoint(type);
1004     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1005 
1006     uint32_t width = ElementWidth(type);
1007     if (width != 32 && width != 64) return false;
1008 
1009     const analysis::Constant* const_input1 = ConstInput(constants);
1010     if (!const_input1) return false;
1011     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1012     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1013       return false;
1014 
1015     if (other_inst->opcode() == SpvOpSNegate ||
1016         other_inst->opcode() == SpvOpFNegate) {
1017       uint32_t op1 = 0;
1018       uint32_t op2 = 0;
1019       SpvOp opcode = inst->opcode();
1020       if (constants[0] != nullptr) {
1021         op1 = other_inst->GetSingleWordInOperand(0u);
1022         op2 = inst->GetSingleWordInOperand(0u);
1023         opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
1024       } else {
1025         op1 = NegateConstant(const_mgr, const_input1);
1026         op2 = other_inst->GetSingleWordInOperand(0u);
1027       }
1028 
1029       inst->SetOpcode(opcode);
1030       inst->SetInOperands(
1031           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1032       return true;
1033     }
1034     return false;
1035   };
1036 }
1037 
1038 // Folds addition of an addition where each operation has a constant operand.
1039 // Cases:
1040 // (x + 2) + 2 = x + 4
1041 // (2 + x) + 2 = x + 4
1042 // 2 + (x + 2) = x + 4
1043 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1044 FoldingRule MergeAddAddArithmetic() {
1045   return [](IRContext* context, Instruction* inst,
1046             const std::vector<const analysis::Constant*>& constants) {
1047     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1048     const analysis::Type* type =
1049         context->get_type_mgr()->GetType(inst->type_id());
1050     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1051     bool uses_float = HasFloatingPoint(type);
1052     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1053 
1054     uint32_t width = ElementWidth(type);
1055     if (width != 32 && width != 64) return false;
1056 
1057     const analysis::Constant* const_input1 = ConstInput(constants);
1058     if (!const_input1) return false;
1059     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1060     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1061       return false;
1062 
1063     if (other_inst->opcode() == SpvOpFAdd ||
1064         other_inst->opcode() == SpvOpIAdd) {
1065       std::vector<const analysis::Constant*> other_constants =
1066           const_mgr->GetOperandConstants(other_inst);
1067       const analysis::Constant* const_input2 = ConstInput(other_constants);
1068       if (!const_input2) return false;
1069 
1070       Instruction* non_const_input =
1071           NonConstInput(context, other_constants[0], other_inst);
1072       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1073                                             const_input1, const_input2);
1074       if (merged_id == 0) return false;
1075 
1076       inst->SetInOperands(
1077           {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1078            {SPV_OPERAND_TYPE_ID, {merged_id}}});
1079       return true;
1080     }
1081     return false;
1082   };
1083 }
1084 
1085 // Folds addition of a subtraction where each operation has a constant operand.
1086 // Cases:
1087 // (x - 2) + 2 = x + 0
1088 // (2 - x) + 2 = 4 - x
1089 // 2 + (x - 2) = x + 0
1090 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1091 FoldingRule MergeAddSubArithmetic() {
1092   return [](IRContext* context, Instruction* inst,
1093             const std::vector<const analysis::Constant*>& constants) {
1094     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1095     const analysis::Type* type =
1096         context->get_type_mgr()->GetType(inst->type_id());
1097     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1098     bool uses_float = HasFloatingPoint(type);
1099     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1100 
1101     uint32_t width = ElementWidth(type);
1102     if (width != 32 && width != 64) return false;
1103 
1104     const analysis::Constant* const_input1 = ConstInput(constants);
1105     if (!const_input1) return false;
1106     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1107     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1108       return false;
1109 
1110     if (other_inst->opcode() == SpvOpFSub ||
1111         other_inst->opcode() == SpvOpISub) {
1112       std::vector<const analysis::Constant*> other_constants =
1113           const_mgr->GetOperandConstants(other_inst);
1114       const analysis::Constant* const_input2 = ConstInput(other_constants);
1115       if (!const_input2) return false;
1116 
1117       bool first_is_variable = other_constants[0] == nullptr;
1118       SpvOp op = inst->opcode();
1119       uint32_t op1 = 0;
1120       uint32_t op2 = 0;
1121       if (first_is_variable) {
1122         // Subtract constants. Non-constant operand is first.
1123         op1 = other_inst->GetSingleWordInOperand(0u);
1124         op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1125                                const_input2);
1126       } else {
1127         // Add constants. Constant operand is first. Change the opcode.
1128         op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1129                                const_input2);
1130         op2 = other_inst->GetSingleWordInOperand(1u);
1131         op = other_inst->opcode();
1132       }
1133       if (op1 == 0 || op2 == 0) return false;
1134 
1135       inst->SetOpcode(op);
1136       inst->SetInOperands(
1137           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1138       return true;
1139     }
1140     return false;
1141   };
1142 }
1143 
1144 // Folds subtraction of an addition where each operand has a constant operand.
1145 // Cases:
1146 // (x + 2) - 2 = x + 0
1147 // (2 + x) - 2 = x + 0
1148 // 2 - (x + 2) = 0 - x
1149 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1150 FoldingRule MergeSubAddArithmetic() {
1151   return [](IRContext* context, Instruction* inst,
1152             const std::vector<const analysis::Constant*>& constants) {
1153     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1154     const analysis::Type* type =
1155         context->get_type_mgr()->GetType(inst->type_id());
1156     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1157     bool uses_float = HasFloatingPoint(type);
1158     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1159 
1160     uint32_t width = ElementWidth(type);
1161     if (width != 32 && width != 64) return false;
1162 
1163     const analysis::Constant* const_input1 = ConstInput(constants);
1164     if (!const_input1) return false;
1165     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1166     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1167       return false;
1168 
1169     if (other_inst->opcode() == SpvOpFAdd ||
1170         other_inst->opcode() == SpvOpIAdd) {
1171       std::vector<const analysis::Constant*> other_constants =
1172           const_mgr->GetOperandConstants(other_inst);
1173       const analysis::Constant* const_input2 = ConstInput(other_constants);
1174       if (!const_input2) return false;
1175 
1176       Instruction* non_const_input =
1177           NonConstInput(context, other_constants[0], other_inst);
1178 
1179       // If the first operand of the sub is not a constant, swap the constants
1180       // so the subtraction has the correct operands.
1181       if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1182       // Subtract the constants.
1183       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1184                                             const_input1, const_input2);
1185       SpvOp op = inst->opcode();
1186       uint32_t op1 = 0;
1187       uint32_t op2 = 0;
1188       if (constants[0] == nullptr) {
1189         // Non-constant operand is first. Change the opcode.
1190         op1 = non_const_input->result_id();
1191         op2 = merged_id;
1192         op = other_inst->opcode();
1193       } else {
1194         // Constant operand is first.
1195         op1 = merged_id;
1196         op2 = non_const_input->result_id();
1197       }
1198       if (op1 == 0 || op2 == 0) return false;
1199 
1200       inst->SetOpcode(op);
1201       inst->SetInOperands(
1202           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1203       return true;
1204     }
1205     return false;
1206   };
1207 }
1208 
1209 // Folds subtraction of a subtraction where each operand has a constant operand.
1210 // Cases:
1211 // (x - 2) - 2 = x - 4
1212 // (2 - x) - 2 = 0 - x
1213 // 2 - (x - 2) = 4 - x
1214 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1215 FoldingRule MergeSubSubArithmetic() {
1216   return [](IRContext* context, Instruction* inst,
1217             const std::vector<const analysis::Constant*>& constants) {
1218     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1219     const analysis::Type* type =
1220         context->get_type_mgr()->GetType(inst->type_id());
1221     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1222     bool uses_float = HasFloatingPoint(type);
1223     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1224 
1225     uint32_t width = ElementWidth(type);
1226     if (width != 32 && width != 64) return false;
1227 
1228     const analysis::Constant* const_input1 = ConstInput(constants);
1229     if (!const_input1) return false;
1230     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1231     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1232       return false;
1233 
1234     if (other_inst->opcode() == SpvOpFSub ||
1235         other_inst->opcode() == SpvOpISub) {
1236       std::vector<const analysis::Constant*> other_constants =
1237           const_mgr->GetOperandConstants(other_inst);
1238       const analysis::Constant* const_input2 = ConstInput(other_constants);
1239       if (!const_input2) return false;
1240 
1241       Instruction* non_const_input =
1242           NonConstInput(context, other_constants[0], other_inst);
1243 
1244       // Merge the constants.
1245       uint32_t merged_id = 0;
1246       SpvOp merge_op = inst->opcode();
1247       if (other_constants[0] == nullptr) {
1248         merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1249       } else if (constants[0] == nullptr) {
1250         std::swap(const_input1, const_input2);
1251       }
1252       merged_id =
1253           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1254       if (merged_id == 0) return false;
1255 
1256       SpvOp op = inst->opcode();
1257       if (constants[0] != nullptr && other_constants[0] != nullptr) {
1258         // Change the operation.
1259         op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1260       }
1261 
1262       uint32_t op1 = 0;
1263       uint32_t op2 = 0;
1264       if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1265         op1 = merged_id;
1266         op2 = non_const_input->result_id();
1267       } else {
1268         op1 = non_const_input->result_id();
1269         op2 = merged_id;
1270       }
1271 
1272       inst->SetOpcode(op);
1273       inst->SetInOperands(
1274           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1275       return true;
1276     }
1277     return false;
1278   };
1279 }
1280 
1281 // Helper function for MergeGenericAddSubArithmetic. If |addend| and
1282 // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
MergeGenericAddendSub(uint32_t addend,uint32_t sub,Instruction * inst)1283 bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1284   IRContext* context = inst->context();
1285   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1286   Instruction* sub_inst = def_use_mgr->GetDef(sub);
1287   if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
1288     return false;
1289   if (sub_inst->opcode() == SpvOpFSub &&
1290       !sub_inst->IsFloatingPointFoldingAllowed())
1291     return false;
1292   if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1293   inst->SetOpcode(SpvOpCopyObject);
1294   inst->SetInOperands(
1295       {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1296   context->UpdateDefUse(inst);
1297   return true;
1298 }
1299 
1300 // Folds addition of a subtraction where the subtrahend is equal to the
1301 // other addend. Return a copy of the minuend. Accepts generic (const and
1302 // non-const) operands.
1303 // Cases:
1304 // (a - b) + b = a
1305 // b + (a - b) = a
MergeGenericAddSubArithmetic()1306 FoldingRule MergeGenericAddSubArithmetic() {
1307   return [](IRContext* context, Instruction* inst,
1308             const std::vector<const analysis::Constant*>&) {
1309     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1310     const analysis::Type* type =
1311         context->get_type_mgr()->GetType(inst->type_id());
1312     bool uses_float = HasFloatingPoint(type);
1313     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1314 
1315     uint32_t width = ElementWidth(type);
1316     if (width != 32 && width != 64) return false;
1317 
1318     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1319     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1320     if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1321     return MergeGenericAddendSub(add_op1, add_op0, inst);
1322   };
1323 }
1324 
1325 // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1326 // generate |factor0_0| * (|factor0_1| + |factor1_1|).
FactorAddMulsOpnds(uint32_t factor0_0,uint32_t factor0_1,uint32_t factor1_0,uint32_t factor1_1,Instruction * inst)1327 bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1328                         uint32_t factor1_0, uint32_t factor1_1,
1329                         Instruction* inst) {
1330   IRContext* context = inst->context();
1331   if (factor0_0 != factor1_0) return false;
1332   InstructionBuilder ir_builder(
1333       context, inst,
1334       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1335   Instruction* new_add_inst = ir_builder.AddBinaryOp(
1336       inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1337   inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
1338   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1339                        {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1340   context->UpdateDefUse(inst);
1341   return true;
1342 }
1343 
1344 // Perform the following factoring identity, handling all operand order
1345 // combinations: (a * b) + (a * c) = a * (b + c)
FactorAddMuls()1346 FoldingRule FactorAddMuls() {
1347   return [](IRContext* context, Instruction* inst,
1348             const std::vector<const analysis::Constant*>&) {
1349     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1350     const analysis::Type* type =
1351         context->get_type_mgr()->GetType(inst->type_id());
1352     bool uses_float = HasFloatingPoint(type);
1353     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1354 
1355     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1356     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1357     Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1358     if (add_op0_inst->opcode() != SpvOpFMul &&
1359         add_op0_inst->opcode() != SpvOpIMul)
1360       return false;
1361     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1362     Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1363     if (add_op1_inst->opcode() != SpvOpFMul &&
1364         add_op1_inst->opcode() != SpvOpIMul)
1365       return false;
1366 
1367     // Only perform this optimization if both of the muls only have one use.
1368     // Otherwise this is a deoptimization in size and performance.
1369     if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1370     if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1371 
1372     if (add_op0_inst->opcode() == SpvOpFMul &&
1373         (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1374          !add_op1_inst->IsFloatingPointFoldingAllowed()))
1375       return false;
1376 
1377     for (int i = 0; i < 2; i++) {
1378       for (int j = 0; j < 2; j++) {
1379         // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1380         if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1381                                add_op0_inst->GetSingleWordInOperand(1 - i),
1382                                add_op1_inst->GetSingleWordInOperand(j),
1383                                add_op1_inst->GetSingleWordInOperand(1 - j),
1384                                inst))
1385           return true;
1386       }
1387     }
1388     return false;
1389   };
1390 }
1391 
IntMultipleBy1()1392 FoldingRule IntMultipleBy1() {
1393   return [](IRContext*, Instruction* inst,
1394             const std::vector<const analysis::Constant*>& constants) {
1395     assert(inst->opcode() == SpvOpIMul && "Wrong opcode.  Should be OpIMul.");
1396     for (uint32_t i = 0; i < 2; i++) {
1397       if (constants[i] == nullptr) {
1398         continue;
1399       }
1400       const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1401       if (int_constant) {
1402         uint32_t width = ElementWidth(int_constant->type());
1403         if (width != 32 && width != 64) return false;
1404         bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1405                                     : int_constant->GetU64BitValue() == 1ull;
1406         if (is_one) {
1407           inst->SetOpcode(SpvOpCopyObject);
1408           inst->SetInOperands(
1409               {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1410           return true;
1411         }
1412       }
1413     }
1414     return false;
1415   };
1416 }
1417 
CompositeConstructFeedingExtract()1418 FoldingRule CompositeConstructFeedingExtract() {
1419   return [](IRContext* context, Instruction* inst,
1420             const std::vector<const analysis::Constant*>&) {
1421     // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1422     // then we can simply use the appropriate element in the construction.
1423     assert(inst->opcode() == SpvOpCompositeExtract &&
1424            "Wrong opcode.  Should be OpCompositeExtract.");
1425     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1426     analysis::TypeManager* type_mgr = context->get_type_mgr();
1427 
1428     // If there are no index operands, then this rule cannot do anything.
1429     if (inst->NumInOperands() <= 1) {
1430       return false;
1431     }
1432 
1433     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1434     Instruction* cinst = def_use_mgr->GetDef(cid);
1435 
1436     if (cinst->opcode() != SpvOpCompositeConstruct) {
1437       return false;
1438     }
1439 
1440     std::vector<Operand> operands;
1441     analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
1442     if (composite_type->AsVector() == nullptr) {
1443       // Get the element being extracted from the OpCompositeConstruct
1444       // Since it is not a vector, it is simple to extract the single element.
1445       uint32_t element_index = inst->GetSingleWordInOperand(1);
1446       uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
1447       operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1448 
1449       // Add the remaining indices for extraction.
1450       for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1451         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1452                             {inst->GetSingleWordInOperand(i)}});
1453       }
1454 
1455     } else {
1456       // With vectors we have to handle the case where it is concatenating
1457       // vectors.
1458       assert(inst->NumInOperands() == 2 &&
1459              "Expecting a vector of scalar values.");
1460 
1461       uint32_t element_index = inst->GetSingleWordInOperand(1);
1462       for (uint32_t construct_index = 0;
1463            construct_index < cinst->NumInOperands(); ++construct_index) {
1464         uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
1465         Instruction* element_def = def_use_mgr->GetDef(element_id);
1466         analysis::Vector* element_type =
1467             type_mgr->GetType(element_def->type_id())->AsVector();
1468         if (element_type) {
1469           uint32_t vector_size = element_type->element_count();
1470           if (vector_size < element_index) {
1471             // The element we want comes after this vector.
1472             element_index -= vector_size;
1473           } else {
1474             // We want an element of this vector.
1475             operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1476             operands.push_back(
1477                 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
1478             break;
1479           }
1480         } else {
1481           if (element_index == 0) {
1482             // This is a scalar, and we this is the element we are extracting.
1483             operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1484             break;
1485           } else {
1486             // Skip over this scalar value.
1487             --element_index;
1488           }
1489         }
1490       }
1491     }
1492 
1493     // If there were no extra indices, then we have the final object.  No need
1494     // to extract even more.
1495     if (operands.size() == 1) {
1496       inst->SetOpcode(SpvOpCopyObject);
1497     }
1498 
1499     inst->SetInOperands(std::move(operands));
1500     return true;
1501   };
1502 }
1503 
1504 // If the OpCompositeConstruct is simply putting back together elements that
1505 // where extracted from the same source, we can simply reuse the source.
1506 //
1507 // This is a common code pattern because of the way that scalar replacement
1508 // works.
CompositeExtractFeedingConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1509 bool CompositeExtractFeedingConstruct(
1510     IRContext* context, Instruction* inst,
1511     const std::vector<const analysis::Constant*>&) {
1512   assert(inst->opcode() == SpvOpCompositeConstruct &&
1513          "Wrong opcode.  Should be OpCompositeConstruct.");
1514   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1515   uint32_t original_id = 0;
1516 
1517   if (inst->NumInOperands() == 0) {
1518     // The struct being constructed has no members.
1519     return false;
1520   }
1521 
1522   // Check each element to make sure they are:
1523   // - extractions
1524   // - extracting the same position they are inserting
1525   // - all extract from the same id.
1526   for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1527     const uint32_t element_id = inst->GetSingleWordInOperand(i);
1528     Instruction* element_inst = def_use_mgr->GetDef(element_id);
1529 
1530     if (element_inst->opcode() != SpvOpCompositeExtract) {
1531       return false;
1532     }
1533 
1534     if (element_inst->NumInOperands() != 2) {
1535       return false;
1536     }
1537 
1538     if (element_inst->GetSingleWordInOperand(1) != i) {
1539       return false;
1540     }
1541 
1542     if (i == 0) {
1543       original_id =
1544           element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1545     } else if (original_id !=
1546                element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1547       return false;
1548     }
1549   }
1550 
1551   // The last check it to see that the object being extracted from is the
1552   // correct type.
1553   Instruction* original_inst = def_use_mgr->GetDef(original_id);
1554   if (original_inst->type_id() != inst->type_id()) {
1555     return false;
1556   }
1557 
1558   // Simplify by using the original object.
1559   inst->SetOpcode(SpvOpCopyObject);
1560   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1561   return true;
1562 }
1563 
InsertFeedingExtract()1564 FoldingRule InsertFeedingExtract() {
1565   return [](IRContext* context, Instruction* inst,
1566             const std::vector<const analysis::Constant*>&) {
1567     assert(inst->opcode() == SpvOpCompositeExtract &&
1568            "Wrong opcode.  Should be OpCompositeExtract.");
1569     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1570     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1571     Instruction* cinst = def_use_mgr->GetDef(cid);
1572 
1573     if (cinst->opcode() != SpvOpCompositeInsert) {
1574       return false;
1575     }
1576 
1577     // Find the first position where the list of insert and extract indicies
1578     // differ, if at all.
1579     uint32_t i;
1580     for (i = 1; i < inst->NumInOperands(); ++i) {
1581       if (i + 1 >= cinst->NumInOperands()) {
1582         break;
1583       }
1584 
1585       if (inst->GetSingleWordInOperand(i) !=
1586           cinst->GetSingleWordInOperand(i + 1)) {
1587         break;
1588       }
1589     }
1590 
1591     // We are extracting the element that was inserted.
1592     if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1593       inst->SetOpcode(SpvOpCopyObject);
1594       inst->SetInOperands(
1595           {{SPV_OPERAND_TYPE_ID,
1596             {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1597       return true;
1598     }
1599 
1600     // Extracting the value that was inserted along with values for the base
1601     // composite.  Cannot do anything.
1602     if (i == inst->NumInOperands()) {
1603       return false;
1604     }
1605 
1606     // Extracting an element of the value that was inserted.  Extract from
1607     // that value directly.
1608     if (i + 1 == cinst->NumInOperands()) {
1609       std::vector<Operand> operands;
1610       operands.push_back(
1611           {SPV_OPERAND_TYPE_ID,
1612            {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1613       for (; i < inst->NumInOperands(); ++i) {
1614         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1615                             {inst->GetSingleWordInOperand(i)}});
1616       }
1617       inst->SetInOperands(std::move(operands));
1618       return true;
1619     }
1620 
1621     // Extracting a value that is disjoint from the element being inserted.
1622     // Rewrite the extract to use the composite input to the insert.
1623     std::vector<Operand> operands;
1624     operands.push_back(
1625         {SPV_OPERAND_TYPE_ID,
1626          {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1627     for (i = 1; i < inst->NumInOperands(); ++i) {
1628       operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1629                           {inst->GetSingleWordInOperand(i)}});
1630     }
1631     inst->SetInOperands(std::move(operands));
1632     return true;
1633   };
1634 }
1635 
1636 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1637 // operands of the VectorShuffle.  We just need to adjust the index in the
1638 // extract instruction.
VectorShuffleFeedingExtract()1639 FoldingRule VectorShuffleFeedingExtract() {
1640   return [](IRContext* context, Instruction* inst,
1641             const std::vector<const analysis::Constant*>&) {
1642     assert(inst->opcode() == SpvOpCompositeExtract &&
1643            "Wrong opcode.  Should be OpCompositeExtract.");
1644     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1645     analysis::TypeManager* type_mgr = context->get_type_mgr();
1646     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1647     Instruction* cinst = def_use_mgr->GetDef(cid);
1648 
1649     if (cinst->opcode() != SpvOpVectorShuffle) {
1650       return false;
1651     }
1652 
1653     // Find the size of the first vector operand of the VectorShuffle
1654     Instruction* first_input =
1655         def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1656     analysis::Type* first_input_type =
1657         type_mgr->GetType(first_input->type_id());
1658     assert(first_input_type->AsVector() &&
1659            "Input to vector shuffle should be vectors.");
1660     uint32_t first_input_size = first_input_type->AsVector()->element_count();
1661 
1662     // Get index of the element the vector shuffle is placing in the position
1663     // being extracted.
1664     uint32_t new_index =
1665         cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1666 
1667     // Extracting an undefined value so fold this extract into an undef.
1668     const uint32_t undef_literal_value = 0xffffffff;
1669     if (new_index == undef_literal_value) {
1670       inst->SetOpcode(SpvOpUndef);
1671       inst->SetInOperands({});
1672       return true;
1673     }
1674 
1675     // Get the id of the of the vector the elemtent comes from, and update the
1676     // index if needed.
1677     uint32_t new_vector = 0;
1678     if (new_index < first_input_size) {
1679       new_vector = cinst->GetSingleWordInOperand(0);
1680     } else {
1681       new_vector = cinst->GetSingleWordInOperand(1);
1682       new_index -= first_input_size;
1683     }
1684 
1685     // Update the extract instruction.
1686     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1687     inst->SetInOperand(1, {new_index});
1688     return true;
1689   };
1690 }
1691 
1692 // When an FMix with is feeding an Extract that extracts an element whose
1693 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1694 // operands of the FMix.
FMixFeedingExtract()1695 FoldingRule FMixFeedingExtract() {
1696   return [](IRContext* context, Instruction* inst,
1697             const std::vector<const analysis::Constant*>&) {
1698     assert(inst->opcode() == SpvOpCompositeExtract &&
1699            "Wrong opcode.  Should be OpCompositeExtract.");
1700     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1701     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1702 
1703     uint32_t composite_id =
1704         inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1705     Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1706 
1707     if (composite_inst->opcode() != SpvOpExtInst) {
1708       return false;
1709     }
1710 
1711     uint32_t inst_set_id =
1712         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1713 
1714     if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1715             inst_set_id ||
1716         composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1717             GLSLstd450FMix) {
1718       return false;
1719     }
1720 
1721     // Get the |a| for the FMix instruction.
1722     uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1723     std::unique_ptr<Instruction> a(inst->Clone(context));
1724     a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1725     context->get_instruction_folder().FoldInstruction(a.get());
1726 
1727     if (a->opcode() != SpvOpCopyObject) {
1728       return false;
1729     }
1730 
1731     const analysis::Constant* a_const =
1732         const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1733 
1734     if (!a_const) {
1735       return false;
1736     }
1737 
1738     bool use_x = false;
1739 
1740     assert(a_const->type()->AsFloat());
1741     double element_value = a_const->GetValueAsDouble();
1742     if (element_value == 0.0) {
1743       use_x = true;
1744     } else if (element_value == 1.0) {
1745       use_x = false;
1746     } else {
1747       return false;
1748     }
1749 
1750     // Get the id of the of the vector the element comes from.
1751     uint32_t new_vector = 0;
1752     if (use_x) {
1753       new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1754     } else {
1755       new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1756     }
1757 
1758     // Update the extract instruction.
1759     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1760     return true;
1761   };
1762 }
1763 
RedundantPhi()1764 FoldingRule RedundantPhi() {
1765   // An OpPhi instruction where all values are the same or the result of the phi
1766   // itself, can be replaced by the value itself.
1767   return [](IRContext*, Instruction* inst,
1768             const std::vector<const analysis::Constant*>&) {
1769     assert(inst->opcode() == SpvOpPhi && "Wrong opcode.  Should be OpPhi.");
1770 
1771     uint32_t incoming_value = 0;
1772 
1773     for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
1774       uint32_t op_id = inst->GetSingleWordInOperand(i);
1775       if (op_id == inst->result_id()) {
1776         continue;
1777       }
1778 
1779       if (incoming_value == 0) {
1780         incoming_value = op_id;
1781       } else if (op_id != incoming_value) {
1782         // Found two possible value.  Can't simplify.
1783         return false;
1784       }
1785     }
1786 
1787     if (incoming_value == 0) {
1788       // Code looks invalid.  Don't do anything.
1789       return false;
1790     }
1791 
1792     // We have a single incoming value.  Simplify using that value.
1793     inst->SetOpcode(SpvOpCopyObject);
1794     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
1795     return true;
1796   };
1797 }
1798 
RedundantSelect()1799 FoldingRule RedundantSelect() {
1800   // An OpSelect instruction where both values are the same or the condition is
1801   // constant can be replaced by one of the values
1802   return [](IRContext*, Instruction* inst,
1803             const std::vector<const analysis::Constant*>& constants) {
1804     assert(inst->opcode() == SpvOpSelect &&
1805            "Wrong opcode.  Should be OpSelect.");
1806     assert(inst->NumInOperands() == 3);
1807     assert(constants.size() == 3);
1808 
1809     uint32_t true_id = inst->GetSingleWordInOperand(1);
1810     uint32_t false_id = inst->GetSingleWordInOperand(2);
1811 
1812     if (true_id == false_id) {
1813       // Both results are the same, condition doesn't matter
1814       inst->SetOpcode(SpvOpCopyObject);
1815       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1816       return true;
1817     } else if (constants[0]) {
1818       const analysis::Type* type = constants[0]->type();
1819       if (type->AsBool()) {
1820         // Scalar constant value, select the corresponding value.
1821         inst->SetOpcode(SpvOpCopyObject);
1822         if (constants[0]->AsNullConstant() ||
1823             !constants[0]->AsBoolConstant()->value()) {
1824           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1825         } else {
1826           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1827         }
1828         return true;
1829       } else {
1830         assert(type->AsVector());
1831         if (constants[0]->AsNullConstant()) {
1832           // All values come from false id.
1833           inst->SetOpcode(SpvOpCopyObject);
1834           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1835           return true;
1836         } else {
1837           // Convert to a vector shuffle.
1838           std::vector<Operand> ops;
1839           ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
1840           ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
1841           const analysis::VectorConstant* vector_const =
1842               constants[0]->AsVectorConstant();
1843           uint32_t size =
1844               static_cast<uint32_t>(vector_const->GetComponents().size());
1845           for (uint32_t i = 0; i != size; ++i) {
1846             const analysis::Constant* component =
1847                 vector_const->GetComponents()[i];
1848             if (component->AsNullConstant() ||
1849                 !component->AsBoolConstant()->value()) {
1850               // Selecting from the false vector which is the second input
1851               // vector to the shuffle. Offset the index by |size|.
1852               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
1853             } else {
1854               // Selecting from true vector which is the first input vector to
1855               // the shuffle.
1856               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
1857             }
1858           }
1859 
1860           inst->SetOpcode(SpvOpVectorShuffle);
1861           inst->SetInOperands(std::move(ops));
1862           return true;
1863         }
1864       }
1865     }
1866 
1867     return false;
1868   };
1869 }
1870 
1871 enum class FloatConstantKind { Unknown, Zero, One };
1872 
getFloatConstantKind(const analysis::Constant * constant)1873 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
1874   if (constant == nullptr) {
1875     return FloatConstantKind::Unknown;
1876   }
1877 
1878   assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
1879 
1880   if (constant->AsNullConstant()) {
1881     return FloatConstantKind::Zero;
1882   } else if (const analysis::VectorConstant* vc =
1883                  constant->AsVectorConstant()) {
1884     const std::vector<const analysis::Constant*>& components =
1885         vc->GetComponents();
1886     assert(!components.empty());
1887 
1888     FloatConstantKind kind = getFloatConstantKind(components[0]);
1889 
1890     for (size_t i = 1; i < components.size(); ++i) {
1891       if (getFloatConstantKind(components[i]) != kind) {
1892         return FloatConstantKind::Unknown;
1893       }
1894     }
1895 
1896     return kind;
1897   } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
1898     if (fc->IsZero()) return FloatConstantKind::Zero;
1899 
1900     uint32_t width = fc->type()->AsFloat()->width();
1901     if (width != 32 && width != 64) return FloatConstantKind::Unknown;
1902 
1903     double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
1904 
1905     if (value == 0.0) {
1906       return FloatConstantKind::Zero;
1907     } else if (value == 1.0) {
1908       return FloatConstantKind::One;
1909     } else {
1910       return FloatConstantKind::Unknown;
1911     }
1912   } else {
1913     return FloatConstantKind::Unknown;
1914   }
1915 }
1916 
RedundantFAdd()1917 FoldingRule RedundantFAdd() {
1918   return [](IRContext*, Instruction* inst,
1919             const std::vector<const analysis::Constant*>& constants) {
1920     assert(inst->opcode() == SpvOpFAdd && "Wrong opcode.  Should be OpFAdd.");
1921     assert(constants.size() == 2);
1922 
1923     if (!inst->IsFloatingPointFoldingAllowed()) {
1924       return false;
1925     }
1926 
1927     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1928     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1929 
1930     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
1931       inst->SetOpcode(SpvOpCopyObject);
1932       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1933                             {inst->GetSingleWordInOperand(
1934                                 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
1935       return true;
1936     }
1937 
1938     return false;
1939   };
1940 }
1941 
RedundantFSub()1942 FoldingRule RedundantFSub() {
1943   return [](IRContext*, Instruction* inst,
1944             const std::vector<const analysis::Constant*>& constants) {
1945     assert(inst->opcode() == SpvOpFSub && "Wrong opcode.  Should be OpFSub.");
1946     assert(constants.size() == 2);
1947 
1948     if (!inst->IsFloatingPointFoldingAllowed()) {
1949       return false;
1950     }
1951 
1952     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1953     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1954 
1955     if (kind0 == FloatConstantKind::Zero) {
1956       inst->SetOpcode(SpvOpFNegate);
1957       inst->SetInOperands(
1958           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
1959       return true;
1960     }
1961 
1962     if (kind1 == FloatConstantKind::Zero) {
1963       inst->SetOpcode(SpvOpCopyObject);
1964       inst->SetInOperands(
1965           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
1966       return true;
1967     }
1968 
1969     return false;
1970   };
1971 }
1972 
RedundantFMul()1973 FoldingRule RedundantFMul() {
1974   return [](IRContext*, Instruction* inst,
1975             const std::vector<const analysis::Constant*>& constants) {
1976     assert(inst->opcode() == SpvOpFMul && "Wrong opcode.  Should be OpFMul.");
1977     assert(constants.size() == 2);
1978 
1979     if (!inst->IsFloatingPointFoldingAllowed()) {
1980       return false;
1981     }
1982 
1983     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1984     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1985 
1986     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
1987       inst->SetOpcode(SpvOpCopyObject);
1988       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1989                             {inst->GetSingleWordInOperand(
1990                                 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
1991       return true;
1992     }
1993 
1994     if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
1995       inst->SetOpcode(SpvOpCopyObject);
1996       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1997                             {inst->GetSingleWordInOperand(
1998                                 kind0 == FloatConstantKind::One ? 1 : 0)}}});
1999       return true;
2000     }
2001 
2002     return false;
2003   };
2004 }
2005 
RedundantFDiv()2006 FoldingRule RedundantFDiv() {
2007   return [](IRContext*, Instruction* inst,
2008             const std::vector<const analysis::Constant*>& constants) {
2009     assert(inst->opcode() == SpvOpFDiv && "Wrong opcode.  Should be OpFDiv.");
2010     assert(constants.size() == 2);
2011 
2012     if (!inst->IsFloatingPointFoldingAllowed()) {
2013       return false;
2014     }
2015 
2016     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2017     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2018 
2019     if (kind0 == FloatConstantKind::Zero) {
2020       inst->SetOpcode(SpvOpCopyObject);
2021       inst->SetInOperands(
2022           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2023       return true;
2024     }
2025 
2026     if (kind1 == FloatConstantKind::One) {
2027       inst->SetOpcode(SpvOpCopyObject);
2028       inst->SetInOperands(
2029           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2030       return true;
2031     }
2032 
2033     return false;
2034   };
2035 }
2036 
RedundantFMix()2037 FoldingRule RedundantFMix() {
2038   return [](IRContext* context, Instruction* inst,
2039             const std::vector<const analysis::Constant*>& constants) {
2040     assert(inst->opcode() == SpvOpExtInst &&
2041            "Wrong opcode.  Should be OpExtInst.");
2042 
2043     if (!inst->IsFloatingPointFoldingAllowed()) {
2044       return false;
2045     }
2046 
2047     uint32_t instSetId =
2048         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2049 
2050     if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2051         inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2052             GLSLstd450FMix) {
2053       assert(constants.size() == 5);
2054 
2055       FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2056 
2057       if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2058         inst->SetOpcode(SpvOpCopyObject);
2059         inst->SetInOperands(
2060             {{SPV_OPERAND_TYPE_ID,
2061               {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2062                                                 ? kFMixXIdInIdx
2063                                                 : kFMixYIdInIdx)}}});
2064         return true;
2065       }
2066     }
2067 
2068     return false;
2069   };
2070 }
2071 
2072 // This rule handles addition of zero for integers.
RedundantIAdd()2073 FoldingRule RedundantIAdd() {
2074   return [](IRContext* context, Instruction* inst,
2075             const std::vector<const analysis::Constant*>& constants) {
2076     assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
2077 
2078     uint32_t operand = std::numeric_limits<uint32_t>::max();
2079     const analysis::Type* operand_type = nullptr;
2080     if (constants[0] && constants[0]->IsZero()) {
2081       operand = inst->GetSingleWordInOperand(1);
2082       operand_type = constants[0]->type();
2083     } else if (constants[1] && constants[1]->IsZero()) {
2084       operand = inst->GetSingleWordInOperand(0);
2085       operand_type = constants[1]->type();
2086     }
2087 
2088     if (operand != std::numeric_limits<uint32_t>::max()) {
2089       const analysis::Type* inst_type =
2090           context->get_type_mgr()->GetType(inst->type_id());
2091       if (inst_type->IsSame(operand_type)) {
2092         inst->SetOpcode(SpvOpCopyObject);
2093       } else {
2094         inst->SetOpcode(SpvOpBitcast);
2095       }
2096       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2097       return true;
2098     }
2099     return false;
2100   };
2101 }
2102 
2103 // This rule look for a dot with a constant vector containing a single 1 and
2104 // the rest 0s.  This is the same as doing an extract.
DotProductDoingExtract()2105 FoldingRule DotProductDoingExtract() {
2106   return [](IRContext* context, Instruction* inst,
2107             const std::vector<const analysis::Constant*>& constants) {
2108     assert(inst->opcode() == SpvOpDot && "Wrong opcode.  Should be OpDot.");
2109 
2110     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2111 
2112     if (!inst->IsFloatingPointFoldingAllowed()) {
2113       return false;
2114     }
2115 
2116     for (int i = 0; i < 2; ++i) {
2117       if (!constants[i]) {
2118         continue;
2119       }
2120 
2121       const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2122       assert(vector_type && "Inputs to OpDot must be vectors.");
2123       const analysis::Float* element_type =
2124           vector_type->element_type()->AsFloat();
2125       assert(element_type && "Inputs to OpDot must be vectors of floats.");
2126       uint32_t element_width = element_type->width();
2127       if (element_width != 32 && element_width != 64) {
2128         return false;
2129       }
2130 
2131       std::vector<const analysis::Constant*> components;
2132       components = constants[i]->GetVectorComponents(const_mgr);
2133 
2134       const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2135 
2136       uint32_t component_with_one = kNotFound;
2137       bool all_others_zero = true;
2138       for (uint32_t j = 0; j < components.size(); ++j) {
2139         const analysis::Constant* element = components[j];
2140         double value =
2141             (element_width == 32 ? element->GetFloat() : element->GetDouble());
2142         if (value == 0.0) {
2143           continue;
2144         } else if (value == 1.0) {
2145           if (component_with_one == kNotFound) {
2146             component_with_one = j;
2147           } else {
2148             component_with_one = kNotFound;
2149             break;
2150           }
2151         } else {
2152           all_others_zero = false;
2153           break;
2154         }
2155       }
2156 
2157       if (!all_others_zero || component_with_one == kNotFound) {
2158         continue;
2159       }
2160 
2161       std::vector<Operand> operands;
2162       operands.push_back(
2163           {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2164       operands.push_back(
2165           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2166 
2167       inst->SetOpcode(SpvOpCompositeExtract);
2168       inst->SetInOperands(std::move(operands));
2169       return true;
2170     }
2171     return false;
2172   };
2173 }
2174 
2175 // If we are storing an undef, then we can remove the store.
2176 //
2177 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2178 // is complicated.  Waiting to see if it is needed.
StoringUndef()2179 FoldingRule StoringUndef() {
2180   return [](IRContext* context, Instruction* inst,
2181             const std::vector<const analysis::Constant*>&) {
2182     assert(inst->opcode() == SpvOpStore && "Wrong opcode.  Should be OpStore.");
2183 
2184     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2185 
2186     // If this is a volatile store, the store cannot be removed.
2187     if (inst->NumInOperands() == 3) {
2188       if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
2189         return false;
2190       }
2191     }
2192 
2193     uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2194     Instruction* object_inst = def_use_mgr->GetDef(object_id);
2195     if (object_inst->opcode() == SpvOpUndef) {
2196       inst->ToNop();
2197       return true;
2198     }
2199     return false;
2200   };
2201 }
2202 
VectorShuffleFeedingShuffle()2203 FoldingRule VectorShuffleFeedingShuffle() {
2204   return [](IRContext* context, Instruction* inst,
2205             const std::vector<const analysis::Constant*>&) {
2206     assert(inst->opcode() == SpvOpVectorShuffle &&
2207            "Wrong opcode.  Should be OpVectorShuffle.");
2208 
2209     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2210     analysis::TypeManager* type_mgr = context->get_type_mgr();
2211 
2212     Instruction* feeding_shuffle_inst =
2213         def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2214     analysis::Vector* op0_type =
2215         type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2216     uint32_t op0_length = op0_type->element_count();
2217 
2218     bool feeder_is_op0 = true;
2219     if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2220       feeding_shuffle_inst =
2221           def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2222       feeder_is_op0 = false;
2223     }
2224 
2225     if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2226       return false;
2227     }
2228 
2229     Instruction* feeder2 =
2230         def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2231     analysis::Vector* feeder_op0_type =
2232         type_mgr->GetType(feeder2->type_id())->AsVector();
2233     uint32_t feeder_op0_length = feeder_op0_type->element_count();
2234 
2235     uint32_t new_feeder_id = 0;
2236     std::vector<Operand> new_operands;
2237     new_operands.resize(
2238         2, {SPV_OPERAND_TYPE_ID, {0}});  // Place holders for vector operands.
2239     const uint32_t undef_literal = 0xffffffff;
2240     for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2241       uint32_t component_index = inst->GetSingleWordInOperand(op);
2242 
2243       // Do not interpret the undefined value literal as coming from operand 1.
2244       if (component_index != undef_literal &&
2245           feeder_is_op0 == (component_index < op0_length)) {
2246         // This component comes from the feeding_shuffle_inst.  Update
2247         // |component_index| to be the index into the operand of the feeder.
2248 
2249         // Adjust component_index to get the index into the operands of the
2250         // feeding_shuffle_inst.
2251         if (component_index >= op0_length) {
2252           component_index -= op0_length;
2253         }
2254         component_index =
2255             feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2256 
2257         // Check if we are using a component from the first or second operand of
2258         // the feeding instruction.
2259         if (component_index < feeder_op0_length) {
2260           if (new_feeder_id == 0) {
2261             // First time through, save the id of the operand the element comes
2262             // from.
2263             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2264           } else if (new_feeder_id !=
2265                      feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2266             // We need both elements of the feeding_shuffle_inst, so we cannot
2267             // fold.
2268             return false;
2269           }
2270         } else {
2271           if (new_feeder_id == 0) {
2272             // First time through, save the id of the operand the element comes
2273             // from.
2274             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2275           } else if (new_feeder_id !=
2276                      feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2277             // We need both elements of the feeding_shuffle_inst, so we cannot
2278             // fold.
2279             return false;
2280           }
2281           component_index -= feeder_op0_length;
2282         }
2283 
2284         if (!feeder_is_op0) {
2285           component_index += op0_length;
2286         }
2287       }
2288       new_operands.push_back(
2289           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2290     }
2291 
2292     if (new_feeder_id == 0) {
2293       analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2294       const analysis::Type* type =
2295           type_mgr->GetType(feeding_shuffle_inst->type_id());
2296       const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2297       new_feeder_id =
2298           const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2299     }
2300 
2301     if (feeder_is_op0) {
2302       // If the size of the first vector operand changed then the indices
2303       // referring to the second operand need to be adjusted.
2304       Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2305       analysis::Type* new_feeder_type =
2306           type_mgr->GetType(new_feeder_inst->type_id());
2307       uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2308       int32_t adjustment = op0_length - new_op0_size;
2309 
2310       if (adjustment != 0) {
2311         for (uint32_t i = 2; i < new_operands.size(); i++) {
2312           if (inst->GetSingleWordInOperand(i) >= op0_length) {
2313             new_operands[i].words[0] -= adjustment;
2314           }
2315         }
2316       }
2317 
2318       new_operands[0].words[0] = new_feeder_id;
2319       new_operands[1] = inst->GetInOperand(1);
2320     } else {
2321       new_operands[1].words[0] = new_feeder_id;
2322       new_operands[0] = inst->GetInOperand(0);
2323     }
2324 
2325     inst->SetInOperands(std::move(new_operands));
2326     return true;
2327   };
2328 }
2329 
2330 // Removes duplicate ids from the interface list of an OpEntryPoint
2331 // instruction.
RemoveRedundantOperands()2332 FoldingRule RemoveRedundantOperands() {
2333   return [](IRContext*, Instruction* inst,
2334             const std::vector<const analysis::Constant*>&) {
2335     assert(inst->opcode() == SpvOpEntryPoint &&
2336            "Wrong opcode.  Should be OpEntryPoint.");
2337     bool has_redundant_operand = false;
2338     std::unordered_set<uint32_t> seen_operands;
2339     std::vector<Operand> new_operands;
2340 
2341     new_operands.emplace_back(inst->GetOperand(0));
2342     new_operands.emplace_back(inst->GetOperand(1));
2343     new_operands.emplace_back(inst->GetOperand(2));
2344     for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2345       if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2346         new_operands.emplace_back(inst->GetOperand(i));
2347       } else {
2348         has_redundant_operand = true;
2349       }
2350     }
2351 
2352     if (!has_redundant_operand) {
2353       return false;
2354     }
2355 
2356     inst->SetInOperands(std::move(new_operands));
2357     return true;
2358   };
2359 }
2360 
2361 // If an image instruction's operand is a constant, updates the image operand
2362 // flag from Offset to ConstOffset.
UpdateImageOperands()2363 FoldingRule UpdateImageOperands() {
2364   return [](IRContext*, Instruction* inst,
2365             const std::vector<const analysis::Constant*>& constants) {
2366     const auto opcode = inst->opcode();
2367     (void)opcode;
2368     assert((opcode == SpvOpImageSampleImplicitLod ||
2369             opcode == SpvOpImageSampleExplicitLod ||
2370             opcode == SpvOpImageSampleDrefImplicitLod ||
2371             opcode == SpvOpImageSampleDrefExplicitLod ||
2372             opcode == SpvOpImageSampleProjImplicitLod ||
2373             opcode == SpvOpImageSampleProjExplicitLod ||
2374             opcode == SpvOpImageSampleProjDrefImplicitLod ||
2375             opcode == SpvOpImageSampleProjDrefExplicitLod ||
2376             opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
2377             opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
2378             opcode == SpvOpImageWrite ||
2379             opcode == SpvOpImageSparseSampleImplicitLod ||
2380             opcode == SpvOpImageSparseSampleExplicitLod ||
2381             opcode == SpvOpImageSparseSampleDrefImplicitLod ||
2382             opcode == SpvOpImageSparseSampleDrefExplicitLod ||
2383             opcode == SpvOpImageSparseSampleProjImplicitLod ||
2384             opcode == SpvOpImageSparseSampleProjExplicitLod ||
2385             opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
2386             opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
2387             opcode == SpvOpImageSparseFetch ||
2388             opcode == SpvOpImageSparseGather ||
2389             opcode == SpvOpImageSparseDrefGather ||
2390             opcode == SpvOpImageSparseRead) &&
2391            "Wrong opcode.  Should be an image instruction.");
2392 
2393     int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2394     if (operand_index >= 0) {
2395       auto image_operands = inst->GetSingleWordInOperand(operand_index);
2396       if (image_operands & SpvImageOperandsOffsetMask) {
2397         uint32_t offset_operand_index = operand_index + 1;
2398         if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
2399         if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
2400         if (image_operands & SpvImageOperandsGradMask)
2401           offset_operand_index += 2;
2402         assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
2403                "Offset and ConstOffset may not be used together");
2404         if (offset_operand_index < inst->NumOperands()) {
2405           if (constants[offset_operand_index]) {
2406             image_operands = image_operands | SpvImageOperandsConstOffsetMask;
2407             image_operands = image_operands & ~SpvImageOperandsOffsetMask;
2408             inst->SetInOperand(operand_index, {image_operands});
2409             return true;
2410           }
2411         }
2412       }
2413     }
2414 
2415     return false;
2416   };
2417 }
2418 
2419 }  // namespace
2420 
AddFoldingRules()2421 void FoldingRules::AddFoldingRules() {
2422   // Add all folding rules to the list for the opcodes to which they apply.
2423   // Note that the order in which rules are added to the list matters. If a rule
2424   // applies to the instruction, the rest of the rules will not be attempted.
2425   // Take that into consideration.
2426   rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
2427 
2428   rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
2429   rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
2430   rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2431   rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
2432 
2433   rules_[SpvOpDot].push_back(DotProductDoingExtract());
2434 
2435   rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
2436 
2437   rules_[SpvOpFAdd].push_back(RedundantFAdd());
2438   rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
2439   rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
2440   rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
2441   rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
2442   rules_[SpvOpFAdd].push_back(FactorAddMuls());
2443 
2444   rules_[SpvOpFDiv].push_back(RedundantFDiv());
2445   rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
2446   rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
2447   rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
2448   rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
2449 
2450   rules_[SpvOpFMul].push_back(RedundantFMul());
2451   rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
2452   rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
2453   rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
2454 
2455   rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
2456   rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
2457   rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
2458 
2459   rules_[SpvOpFSub].push_back(RedundantFSub());
2460   rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
2461   rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
2462   rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
2463 
2464   rules_[SpvOpIAdd].push_back(RedundantIAdd());
2465   rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
2466   rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
2467   rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
2468   rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
2469   rules_[SpvOpIAdd].push_back(FactorAddMuls());
2470 
2471   rules_[SpvOpIMul].push_back(IntMultipleBy1());
2472   rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
2473   rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
2474 
2475   rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
2476   rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
2477   rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
2478 
2479   rules_[SpvOpPhi].push_back(RedundantPhi());
2480 
2481   rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
2482 
2483   rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
2484   rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
2485   rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
2486 
2487   rules_[SpvOpSelect].push_back(RedundantSelect());
2488 
2489   rules_[SpvOpStore].push_back(StoringUndef());
2490 
2491   rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
2492 
2493   rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2494 
2495   rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
2496   rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
2497   rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
2498   rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
2499   rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
2500   rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
2501   rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
2502   rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
2503   rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
2504   rules_[SpvOpImageGather].push_back(UpdateImageOperands());
2505   rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
2506   rules_[SpvOpImageRead].push_back(UpdateImageOperands());
2507   rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
2508   rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
2509   rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
2510   rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
2511       UpdateImageOperands());
2512   rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
2513       UpdateImageOperands());
2514   rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
2515       UpdateImageOperands());
2516   rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
2517       UpdateImageOperands());
2518   rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
2519       UpdateImageOperands());
2520   rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
2521       UpdateImageOperands());
2522   rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
2523   rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
2524   rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
2525   rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
2526 
2527   FeatureManager* feature_manager = context_->get_feature_mgr();
2528   // Add rules for GLSLstd450
2529   uint32_t ext_inst_glslstd450_id =
2530       feature_manager->GetExtInstImportId_GLSLstd450();
2531   if (ext_inst_glslstd450_id != 0) {
2532     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
2533         RedundantFMix());
2534   }
2535 }
2536 }  // namespace opt
2537 }  // namespace spvtools
2538