• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/folding_rules.h"
16 
17 #include <limits>
18 #include <memory>
19 #include <utility>
20 
21 #include "source/latest_version_glsl_std_450_header.h"
22 #include "source/opt/ir_context.h"
23 
24 namespace spvtools {
25 namespace opt {
26 namespace {
27 
28 const uint32_t kExtractCompositeIdInIdx = 0;
29 const uint32_t kInsertObjectIdInIdx = 0;
30 const uint32_t kInsertCompositeIdInIdx = 1;
31 const uint32_t kExtInstSetIdInIdx = 0;
32 const uint32_t kExtInstInstructionInIdx = 1;
33 const uint32_t kFMixXIdInIdx = 2;
34 const uint32_t kFMixYIdInIdx = 3;
35 const uint32_t kFMixAIdInIdx = 4;
36 const uint32_t kStoreObjectInIdx = 1;
37 
38 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)39 uint32_t ElementWidth(const analysis::Type* type) {
40   if (const analysis::Vector* vec_type = type->AsVector()) {
41     return ElementWidth(vec_type->element_type());
42   } else if (const analysis::Float* float_type = type->AsFloat()) {
43     return float_type->width();
44   } else {
45     assert(type->AsInteger());
46     return type->AsInteger()->width();
47   }
48 }
49 
50 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)51 bool HasFloatingPoint(const analysis::Type* type) {
52   if (type->AsFloat()) {
53     return true;
54   } else if (const analysis::Vector* vec_type = type->AsVector()) {
55     return vec_type->element_type()->AsFloat() != nullptr;
56   }
57 
58   return false;
59 }
60 
61 // Returns false if |val| is NaN, infinite or subnormal.
62 template <typename T>
IsValidResult(T val)63 bool IsValidResult(T val) {
64   int classified = std::fpclassify(val);
65   switch (classified) {
66     case FP_NAN:
67     case FP_INFINITE:
68     case FP_SUBNORMAL:
69       return false;
70     default:
71       return true;
72   }
73 }
74 
ConstInput(const std::vector<const analysis::Constant * > & constants)75 const analysis::Constant* ConstInput(
76     const std::vector<const analysis::Constant*>& constants) {
77   return constants[0] ? constants[0] : constants[1];
78 }
79 
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)80 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
81                            Instruction* inst) {
82   uint32_t in_op = c ? 1u : 0u;
83   return context->get_def_use_mgr()->GetDef(
84       inst->GetSingleWordInOperand(in_op));
85 }
86 
87 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
88 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)89 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
90                                      const analysis::Constant* c) {
91   assert(c);
92   assert(c->type()->AsFloat());
93   uint32_t width = c->type()->AsFloat()->width();
94   assert(width == 32 || width == 64);
95   std::vector<uint32_t> words;
96   if (width == 64) {
97     utils::FloatProxy<double> result(c->GetDouble() * -1.0);
98     words = result.GetWords();
99   } else {
100     utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
101     words = result.GetWords();
102   }
103 
104   const analysis::Constant* negated_const =
105       const_mgr->GetConstant(c->type(), std::move(words));
106   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
107 }
108 
ExtractInts(uint64_t val)109 std::vector<uint32_t> ExtractInts(uint64_t val) {
110   std::vector<uint32_t> words;
111   words.push_back(static_cast<uint32_t>(val));
112   words.push_back(static_cast<uint32_t>(val >> 32));
113   return words;
114 }
115 
116 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)117 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
118                                const analysis::Constant* c) {
119   assert(c);
120   assert(c->type()->AsInteger());
121   uint32_t width = c->type()->AsInteger()->width();
122   assert(width == 32 || width == 64);
123   std::vector<uint32_t> words;
124   if (width == 64) {
125     uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
126     words = ExtractInts(uval);
127   } else {
128     words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
129   }
130 
131   const analysis::Constant* negated_const =
132       const_mgr->GetConstant(c->type(), std::move(words));
133   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
134 }
135 
136 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)137 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
138                               const analysis::Constant* c) {
139   assert(const_mgr && c);
140   assert(c->type()->AsVector());
141   if (c->AsNullConstant()) {
142     // 0.0 vs -0.0 shouldn't matter.
143     return const_mgr->GetDefiningInstruction(c)->result_id();
144   } else {
145     const analysis::Type* component_type =
146         c->AsVectorConstant()->component_type();
147     std::vector<uint32_t> words;
148     for (auto& comp : c->AsVectorConstant()->GetComponents()) {
149       if (component_type->AsFloat()) {
150         words.push_back(NegateFloatingPointConstant(const_mgr, comp));
151       } else {
152         assert(component_type->AsInteger());
153         words.push_back(NegateIntegerConstant(const_mgr, comp));
154       }
155     }
156 
157     const analysis::Constant* negated_const =
158         const_mgr->GetConstant(c->type(), std::move(words));
159     return const_mgr->GetDefiningInstruction(negated_const)->result_id();
160   }
161 }
162 
163 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)164 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
165                         const analysis::Constant* c) {
166   if (c->type()->AsVector()) {
167     return NegateVectorConstant(const_mgr, c);
168   } else if (c->type()->AsFloat()) {
169     return NegateFloatingPointConstant(const_mgr, c);
170   } else {
171     assert(c->type()->AsInteger());
172     return NegateIntegerConstant(const_mgr, c);
173   }
174 }
175 
176 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
177 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)178 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
179                     const analysis::Constant* c) {
180   assert(const_mgr && c);
181   assert(c->type()->AsFloat());
182 
183   uint32_t width = c->type()->AsFloat()->width();
184   assert(width == 32 || width == 64);
185   std::vector<uint32_t> words;
186   if (width == 64) {
187     spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
188     if (!IsValidResult(result.getAsFloat())) return 0;
189     words = result.GetWords();
190   } else {
191     spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
192     if (!IsValidResult(result.getAsFloat())) return 0;
193     words = result.GetWords();
194   }
195 
196   const analysis::Constant* negated_const =
197       const_mgr->GetConstant(c->type(), std::move(words));
198   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
199 }
200 
201 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()202 FoldingRule ReciprocalFDiv() {
203   return [](IRContext* context, Instruction* inst,
204             const std::vector<const analysis::Constant*>& constants) {
205     assert(inst->opcode() == SpvOpFDiv);
206     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
207     const analysis::Type* type =
208         context->get_type_mgr()->GetType(inst->type_id());
209     if (!inst->IsFloatingPointFoldingAllowed()) return false;
210 
211     uint32_t width = ElementWidth(type);
212     if (width != 32 && width != 64) return false;
213 
214     if (constants[1] != nullptr) {
215       uint32_t id = 0;
216       if (const analysis::VectorConstant* vector_const =
217               constants[1]->AsVectorConstant()) {
218         std::vector<uint32_t> neg_ids;
219         for (auto& comp : vector_const->GetComponents()) {
220           id = Reciprocal(const_mgr, comp);
221           if (id == 0) return false;
222           neg_ids.push_back(id);
223         }
224         const analysis::Constant* negated_const =
225             const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
226         id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
227       } else if (constants[1]->AsFloatConstant()) {
228         id = Reciprocal(const_mgr, constants[1]);
229         if (id == 0) return false;
230       } else {
231         // Don't fold a null constant.
232         return false;
233       }
234       inst->SetOpcode(SpvOpFMul);
235       inst->SetInOperands(
236           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
237            {SPV_OPERAND_TYPE_ID, {id}}});
238       return true;
239     }
240 
241     return false;
242   };
243 }
244 
245 // Elides consecutive negate instructions.
MergeNegateArithmetic()246 FoldingRule MergeNegateArithmetic() {
247   return [](IRContext* context, Instruction* inst,
248             const std::vector<const analysis::Constant*>& constants) {
249     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
250     (void)constants;
251     const analysis::Type* type =
252         context->get_type_mgr()->GetType(inst->type_id());
253     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
254       return false;
255 
256     Instruction* op_inst =
257         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
258     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
259       return false;
260 
261     if (op_inst->opcode() == inst->opcode()) {
262       // Elide negates.
263       inst->SetOpcode(SpvOpCopyObject);
264       inst->SetInOperands(
265           {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
266       return true;
267     }
268 
269     return false;
270   };
271 }
272 
273 // Merges negate into a mul or div operation if that operation contains a
274 // constant operand.
275 // Cases:
276 // -(x * 2) = x * -2
277 // -(2 * x) = x * -2
278 // -(x / 2) = x / -2
279 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()280 FoldingRule MergeNegateMulDivArithmetic() {
281   return [](IRContext* context, Instruction* inst,
282             const std::vector<const analysis::Constant*>& constants) {
283     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
284     (void)constants;
285     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
286     const analysis::Type* type =
287         context->get_type_mgr()->GetType(inst->type_id());
288     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
289       return false;
290 
291     Instruction* op_inst =
292         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
293     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
294       return false;
295 
296     uint32_t width = ElementWidth(type);
297     if (width != 32 && width != 64) return false;
298 
299     SpvOp opcode = op_inst->opcode();
300     if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
301         opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
302       std::vector<const analysis::Constant*> op_constants =
303           const_mgr->GetOperandConstants(op_inst);
304       // Merge negate into mul or div if one operand is constant.
305       if (op_constants[0] || op_constants[1]) {
306         bool zero_is_variable = op_constants[0] == nullptr;
307         const analysis::Constant* c = ConstInput(op_constants);
308         uint32_t neg_id = NegateConstant(const_mgr, c);
309         uint32_t non_const_id = zero_is_variable
310                                     ? op_inst->GetSingleWordInOperand(0u)
311                                     : op_inst->GetSingleWordInOperand(1u);
312         // Change this instruction to a mul/div.
313         inst->SetOpcode(op_inst->opcode());
314         if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
315           uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
316           uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
317           inst->SetInOperands(
318               {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
319         } else {
320           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
321                                {SPV_OPERAND_TYPE_ID, {neg_id}}});
322         }
323         return true;
324       }
325     }
326 
327     return false;
328   };
329 }
330 
331 // Merges negate into a add or sub operation if that operation contains a
332 // constant operand.
333 // Cases:
334 // -(x + 2) = -2 - x
335 // -(2 + x) = -2 - x
336 // -(x - 2) = 2 - x
337 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()338 FoldingRule MergeNegateAddSubArithmetic() {
339   return [](IRContext* context, Instruction* inst,
340             const std::vector<const analysis::Constant*>& constants) {
341     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
342     (void)constants;
343     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
344     const analysis::Type* type =
345         context->get_type_mgr()->GetType(inst->type_id());
346     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
347       return false;
348 
349     Instruction* op_inst =
350         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
351     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
352       return false;
353 
354     uint32_t width = ElementWidth(type);
355     if (width != 32 && width != 64) return false;
356 
357     if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
358         op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
359       std::vector<const analysis::Constant*> op_constants =
360           const_mgr->GetOperandConstants(op_inst);
361       if (op_constants[0] || op_constants[1]) {
362         bool zero_is_variable = op_constants[0] == nullptr;
363         bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
364                       (op_inst->opcode() == SpvOpIAdd);
365         bool swap_operands = !is_add || zero_is_variable;
366         bool negate_const = is_add;
367         const analysis::Constant* c = ConstInput(op_constants);
368         uint32_t const_id = 0;
369         if (negate_const) {
370           const_id = NegateConstant(const_mgr, c);
371         } else {
372           const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
373                                       : op_inst->GetSingleWordInOperand(0u);
374         }
375 
376         // Swap operands if necessary and make the instruction a subtraction.
377         uint32_t op0 =
378             zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
379         uint32_t op1 =
380             zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
381         if (swap_operands) std::swap(op0, op1);
382         inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
383         inst->SetInOperands(
384             {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
385         return true;
386       }
387     }
388 
389     return false;
390   };
391 }
392 
393 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)394 bool HasZero(const analysis::Constant* c) {
395   if (c->AsNullConstant()) {
396     return true;
397   }
398   if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
399     for (auto& comp : vec_const->GetComponents())
400       if (HasZero(comp)) return true;
401   } else {
402     assert(c->AsScalarConstant());
403     return c->AsScalarConstant()->IsZero();
404   }
405 
406   return false;
407 }
408 
409 // Performs |input1| |opcode| |input2| and returns the merged constant result
410 // id. Returns 0 if the result is not a valid value. The input types must be
411 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)412 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
413                                        SpvOp opcode,
414                                        const analysis::Constant* input1,
415                                        const analysis::Constant* input2) {
416   const analysis::Type* type = input1->type();
417   assert(type->AsFloat());
418   uint32_t width = type->AsFloat()->width();
419   assert(width == 32 || width == 64);
420   std::vector<uint32_t> words;
421 #define FOLD_OP(op)                                                          \
422   if (width == 64) {                                                         \
423     utils::FloatProxy<double> val =                                          \
424         input1->GetDouble() op input2->GetDouble();                          \
425     double dval = val.getAsFloat();                                          \
426     if (!IsValidResult(dval)) return 0;                                      \
427     words = val.GetWords();                                                  \
428   } else {                                                                   \
429     utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
430     float fval = val.getAsFloat();                                           \
431     if (!IsValidResult(fval)) return 0;                                      \
432     words = val.GetWords();                                                  \
433   }
434   switch (opcode) {
435     case SpvOpFMul:
436       FOLD_OP(*);
437       break;
438     case SpvOpFDiv:
439       if (HasZero(input2)) return 0;
440       FOLD_OP(/);
441       break;
442     case SpvOpFAdd:
443       FOLD_OP(+);
444       break;
445     case SpvOpFSub:
446       FOLD_OP(-);
447       break;
448     default:
449       assert(false && "Unexpected operation");
450       break;
451   }
452 #undef FOLD_OP
453   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
454   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
455 }
456 
457 // Performs |input1| |opcode| |input2| and returns the merged constant result
458 // id. Returns 0 if the result is not a valid value. The input types must be
459 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)460 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
461                                  SpvOp opcode, const analysis::Constant* input1,
462                                  const analysis::Constant* input2) {
463   assert(input1->type()->AsInteger());
464   const analysis::Integer* type = input1->type()->AsInteger();
465   uint32_t width = type->AsInteger()->width();
466   assert(width == 32 || width == 64);
467   std::vector<uint32_t> words;
468 #define FOLD_OP(op)                                        \
469   if (width == 64) {                                       \
470     if (type->IsSigned()) {                                \
471       int64_t val = input1->GetS64() op input2->GetS64();  \
472       words = ExtractInts(static_cast<uint64_t>(val));     \
473     } else {                                               \
474       uint64_t val = input1->GetU64() op input2->GetU64(); \
475       words = ExtractInts(val);                            \
476     }                                                      \
477   } else {                                                 \
478     if (type->IsSigned()) {                                \
479       int32_t val = input1->GetS32() op input2->GetS32();  \
480       words.push_back(static_cast<uint32_t>(val));         \
481     } else {                                               \
482       uint32_t val = input1->GetU32() op input2->GetU32(); \
483       words.push_back(val);                                \
484     }                                                      \
485   }
486   switch (opcode) {
487     case SpvOpIMul:
488       FOLD_OP(*);
489       break;
490     case SpvOpSDiv:
491     case SpvOpUDiv:
492       assert(false && "Should not merge integer division");
493       break;
494     case SpvOpIAdd:
495       FOLD_OP(+);
496       break;
497     case SpvOpISub:
498       FOLD_OP(-);
499       break;
500     default:
501       assert(false && "Unexpected operation");
502       break;
503   }
504 #undef FOLD_OP
505   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
506   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
507 }
508 
509 // Performs |input1| |opcode| |input2| and returns the merged constant result
510 // id. Returns 0 if the result is not a valid value. The input types must be
511 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)512 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
513                           const analysis::Constant* input1,
514                           const analysis::Constant* input2) {
515   assert(input1 && input2);
516   assert(input1->type() == input2->type());
517   const analysis::Type* type = input1->type();
518   std::vector<uint32_t> words;
519   if (const analysis::Vector* vector_type = type->AsVector()) {
520     const analysis::Type* ele_type = vector_type->element_type();
521     for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
522       uint32_t id = 0;
523 
524       const analysis::Constant* input1_comp = nullptr;
525       if (const analysis::VectorConstant* input1_vector =
526               input1->AsVectorConstant()) {
527         input1_comp = input1_vector->GetComponents()[i];
528       } else {
529         assert(input1->AsNullConstant());
530         input1_comp = const_mgr->GetConstant(ele_type, {});
531       }
532 
533       const analysis::Constant* input2_comp = nullptr;
534       if (const analysis::VectorConstant* input2_vector =
535               input2->AsVectorConstant()) {
536         input2_comp = input2_vector->GetComponents()[i];
537       } else {
538         assert(input2->AsNullConstant());
539         input2_comp = const_mgr->GetConstant(ele_type, {});
540       }
541 
542       if (ele_type->AsFloat()) {
543         id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
544                                            input2_comp);
545       } else {
546         assert(ele_type->AsInteger());
547         id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
548                                      input2_comp);
549       }
550       if (id == 0) return 0;
551       words.push_back(id);
552     }
553     const analysis::Constant* merged_const =
554         const_mgr->GetConstant(type, words);
555     return const_mgr->GetDefiningInstruction(merged_const)->result_id();
556   } else if (type->AsFloat()) {
557     return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
558   } else {
559     assert(type->AsInteger());
560     return PerformIntegerOperation(const_mgr, opcode, input1, input2);
561   }
562 }
563 
564 // Merges consecutive multiplies where each contains one constant operand.
565 // Cases:
566 // 2 * (x * 2) = x * 4
567 // 2 * (2 * x) = x * 4
568 // (x * 2) * 2 = x * 4
569 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()570 FoldingRule MergeMulMulArithmetic() {
571   return [](IRContext* context, Instruction* inst,
572             const std::vector<const analysis::Constant*>& constants) {
573     assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
574     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
575     const analysis::Type* type =
576         context->get_type_mgr()->GetType(inst->type_id());
577     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
578       return false;
579 
580     uint32_t width = ElementWidth(type);
581     if (width != 32 && width != 64) return false;
582 
583     // Determine the constant input and the variable input in |inst|.
584     const analysis::Constant* const_input1 = ConstInput(constants);
585     if (!const_input1) return false;
586     Instruction* other_inst = NonConstInput(context, constants[0], inst);
587     if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
588       return false;
589 
590     if (other_inst->opcode() == inst->opcode()) {
591       std::vector<const analysis::Constant*> other_constants =
592           const_mgr->GetOperandConstants(other_inst);
593       const analysis::Constant* const_input2 = ConstInput(other_constants);
594       if (!const_input2) return false;
595 
596       bool other_first_is_variable = other_constants[0] == nullptr;
597       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
598                                             const_input1, const_input2);
599       if (merged_id == 0) return false;
600 
601       uint32_t non_const_id = other_first_is_variable
602                                   ? other_inst->GetSingleWordInOperand(0u)
603                                   : other_inst->GetSingleWordInOperand(1u);
604       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
605                            {SPV_OPERAND_TYPE_ID, {merged_id}}});
606       return true;
607     }
608 
609     return false;
610   };
611 }
612 
613 // Merges divides into subsequent multiplies if each instruction contains one
614 // constant operand. Does not support integer operations.
615 // Cases:
616 // 2 * (x / 2) = x * 1
617 // 2 * (2 / x) = 4 / x
618 // (x / 2) * 2 = x * 1
619 // (2 / x) * 2 = 4 / x
620 // (y / x) * x = y
621 // x * (y / x) = y
MergeMulDivArithmetic()622 FoldingRule MergeMulDivArithmetic() {
623   return [](IRContext* context, Instruction* inst,
624             const std::vector<const analysis::Constant*>& constants) {
625     assert(inst->opcode() == SpvOpFMul);
626     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
627     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
628 
629     const analysis::Type* type =
630         context->get_type_mgr()->GetType(inst->type_id());
631     if (!inst->IsFloatingPointFoldingAllowed()) return false;
632 
633     uint32_t width = ElementWidth(type);
634     if (width != 32 && width != 64) return false;
635 
636     for (uint32_t i = 0; i < 2; i++) {
637       uint32_t op_id = inst->GetSingleWordInOperand(i);
638       Instruction* op_inst = def_use_mgr->GetDef(op_id);
639       if (op_inst->opcode() == SpvOpFDiv) {
640         if (op_inst->GetSingleWordInOperand(1) ==
641             inst->GetSingleWordInOperand(1 - i)) {
642           inst->SetOpcode(SpvOpCopyObject);
643           inst->SetInOperands(
644               {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
645           return true;
646         }
647       }
648     }
649 
650     const analysis::Constant* const_input1 = ConstInput(constants);
651     if (!const_input1) return false;
652     Instruction* other_inst = NonConstInput(context, constants[0], inst);
653     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
654 
655     if (other_inst->opcode() == SpvOpFDiv) {
656       std::vector<const analysis::Constant*> other_constants =
657           const_mgr->GetOperandConstants(other_inst);
658       const analysis::Constant* const_input2 = ConstInput(other_constants);
659       if (!const_input2 || HasZero(const_input2)) return false;
660 
661       bool other_first_is_variable = other_constants[0] == nullptr;
662       // If the variable value is the second operand of the divide, multiply
663       // the constants together. Otherwise divide the constants.
664       uint32_t merged_id = PerformOperation(
665           const_mgr,
666           other_first_is_variable ? other_inst->opcode() : inst->opcode(),
667           const_input1, const_input2);
668       if (merged_id == 0) return false;
669 
670       uint32_t non_const_id = other_first_is_variable
671                                   ? other_inst->GetSingleWordInOperand(0u)
672                                   : other_inst->GetSingleWordInOperand(1u);
673 
674       // If the variable value is on the second operand of the div, then this
675       // operation is a div. Otherwise it should be a multiply.
676       inst->SetOpcode(other_first_is_variable ? inst->opcode()
677                                               : other_inst->opcode());
678       if (other_first_is_variable) {
679         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
680                              {SPV_OPERAND_TYPE_ID, {merged_id}}});
681       } else {
682         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
683                              {SPV_OPERAND_TYPE_ID, {non_const_id}}});
684       }
685       return true;
686     }
687 
688     return false;
689   };
690 }
691 
692 // Merges multiply of constant and negation.
693 // Cases:
694 // (-x) * 2 = x * -2
695 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()696 FoldingRule MergeMulNegateArithmetic() {
697   return [](IRContext* context, Instruction* inst,
698             const std::vector<const analysis::Constant*>& constants) {
699     assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
700     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
701     const analysis::Type* type =
702         context->get_type_mgr()->GetType(inst->type_id());
703     bool uses_float = HasFloatingPoint(type);
704     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
705 
706     uint32_t width = ElementWidth(type);
707     if (width != 32 && width != 64) return false;
708 
709     const analysis::Constant* const_input1 = ConstInput(constants);
710     if (!const_input1) return false;
711     Instruction* other_inst = NonConstInput(context, constants[0], inst);
712     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
713       return false;
714 
715     if (other_inst->opcode() == SpvOpFNegate ||
716         other_inst->opcode() == SpvOpSNegate) {
717       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
718 
719       inst->SetInOperands(
720           {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
721            {SPV_OPERAND_TYPE_ID, {neg_id}}});
722       return true;
723     }
724 
725     return false;
726   };
727 }
728 
729 // Merges consecutive divides if each instruction contains one constant operand.
730 // Does not support integer division.
731 // Cases:
732 // 2 / (x / 2) = 4 / x
733 // 4 / (2 / x) = 2 * x
734 // (4 / x) / 2 = 2 / x
735 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()736 FoldingRule MergeDivDivArithmetic() {
737   return [](IRContext* context, Instruction* inst,
738             const std::vector<const analysis::Constant*>& constants) {
739     assert(inst->opcode() == SpvOpFDiv);
740     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
741     const analysis::Type* type =
742         context->get_type_mgr()->GetType(inst->type_id());
743     if (!inst->IsFloatingPointFoldingAllowed()) return false;
744 
745     uint32_t width = ElementWidth(type);
746     if (width != 32 && width != 64) return false;
747 
748     const analysis::Constant* const_input1 = ConstInput(constants);
749     if (!const_input1 || HasZero(const_input1)) return false;
750     Instruction* other_inst = NonConstInput(context, constants[0], inst);
751     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
752 
753     bool first_is_variable = constants[0] == nullptr;
754     if (other_inst->opcode() == inst->opcode()) {
755       std::vector<const analysis::Constant*> other_constants =
756           const_mgr->GetOperandConstants(other_inst);
757       const analysis::Constant* const_input2 = ConstInput(other_constants);
758       if (!const_input2 || HasZero(const_input2)) return false;
759 
760       bool other_first_is_variable = other_constants[0] == nullptr;
761 
762       SpvOp merge_op = inst->opcode();
763       if (other_first_is_variable) {
764         // Constants magnify.
765         merge_op = SpvOpFMul;
766       }
767 
768       // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
769       // because it is commutative.
770       if (first_is_variable) std::swap(const_input1, const_input2);
771       uint32_t merged_id =
772           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
773       if (merged_id == 0) return false;
774 
775       uint32_t non_const_id = other_first_is_variable
776                                   ? other_inst->GetSingleWordInOperand(0u)
777                                   : other_inst->GetSingleWordInOperand(1u);
778 
779       SpvOp op = inst->opcode();
780       if (!first_is_variable && !other_first_is_variable) {
781         // Effectively div of 1/x, so change to multiply.
782         op = SpvOpFMul;
783       }
784 
785       uint32_t op1 = merged_id;
786       uint32_t op2 = non_const_id;
787       if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
788       inst->SetOpcode(op);
789       inst->SetInOperands(
790           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
791       return true;
792     }
793 
794     return false;
795   };
796 }
797 
798 // Fold multiplies succeeded by divides where each instruction contains a
799 // constant operand. Does not support integer divide.
800 // Cases:
801 // 4 / (x * 2) = 2 / x
802 // 4 / (2 * x) = 2 / x
803 // (x * 4) / 2 = x * 2
804 // (4 * x) / 2 = x * 2
805 // (x * y) / x = y
806 // (y * x) / x = y
MergeDivMulArithmetic()807 FoldingRule MergeDivMulArithmetic() {
808   return [](IRContext* context, Instruction* inst,
809             const std::vector<const analysis::Constant*>& constants) {
810     assert(inst->opcode() == SpvOpFDiv);
811     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
812     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
813 
814     const analysis::Type* type =
815         context->get_type_mgr()->GetType(inst->type_id());
816     if (!inst->IsFloatingPointFoldingAllowed()) return false;
817 
818     uint32_t width = ElementWidth(type);
819     if (width != 32 && width != 64) return false;
820 
821     uint32_t op_id = inst->GetSingleWordInOperand(0);
822     Instruction* op_inst = def_use_mgr->GetDef(op_id);
823 
824     if (op_inst->opcode() == SpvOpFMul) {
825       for (uint32_t i = 0; i < 2; i++) {
826         if (op_inst->GetSingleWordInOperand(i) ==
827             inst->GetSingleWordInOperand(1)) {
828           inst->SetOpcode(SpvOpCopyObject);
829           inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
830                                 {op_inst->GetSingleWordInOperand(1 - i)}}});
831           return true;
832         }
833       }
834     }
835 
836     const analysis::Constant* const_input1 = ConstInput(constants);
837     if (!const_input1 || HasZero(const_input1)) return false;
838     Instruction* other_inst = NonConstInput(context, constants[0], inst);
839     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
840 
841     bool first_is_variable = constants[0] == nullptr;
842     if (other_inst->opcode() == SpvOpFMul) {
843       std::vector<const analysis::Constant*> other_constants =
844           const_mgr->GetOperandConstants(other_inst);
845       const analysis::Constant* const_input2 = ConstInput(other_constants);
846       if (!const_input2) return false;
847 
848       bool other_first_is_variable = other_constants[0] == nullptr;
849 
850       // This is an x / (*) case. Swap the inputs.
851       if (first_is_variable) std::swap(const_input1, const_input2);
852       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
853                                             const_input1, const_input2);
854       if (merged_id == 0) return false;
855 
856       uint32_t non_const_id = other_first_is_variable
857                                   ? other_inst->GetSingleWordInOperand(0u)
858                                   : other_inst->GetSingleWordInOperand(1u);
859 
860       uint32_t op1 = merged_id;
861       uint32_t op2 = non_const_id;
862       if (first_is_variable) std::swap(op1, op2);
863 
864       // Convert to multiply
865       if (first_is_variable) inst->SetOpcode(other_inst->opcode());
866       inst->SetInOperands(
867           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
868       return true;
869     }
870 
871     return false;
872   };
873 }
874 
875 // Fold divides of a constant and a negation.
876 // Cases:
877 // (-x) / 2 = x / -2
878 // 2 / (-x) = 2 / -x
MergeDivNegateArithmetic()879 FoldingRule MergeDivNegateArithmetic() {
880   return [](IRContext* context, Instruction* inst,
881             const std::vector<const analysis::Constant*>& constants) {
882     assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
883            inst->opcode() == SpvOpUDiv);
884     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
885     const analysis::Type* type =
886         context->get_type_mgr()->GetType(inst->type_id());
887     bool uses_float = HasFloatingPoint(type);
888     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
889 
890     uint32_t width = ElementWidth(type);
891     if (width != 32 && width != 64) return false;
892 
893     const analysis::Constant* const_input1 = ConstInput(constants);
894     if (!const_input1) return false;
895     Instruction* other_inst = NonConstInput(context, constants[0], inst);
896     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
897       return false;
898 
899     bool first_is_variable = constants[0] == nullptr;
900     if (other_inst->opcode() == SpvOpFNegate ||
901         other_inst->opcode() == SpvOpSNegate) {
902       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
903 
904       if (first_is_variable) {
905         inst->SetInOperands(
906             {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
907              {SPV_OPERAND_TYPE_ID, {neg_id}}});
908       } else {
909         inst->SetInOperands(
910             {{SPV_OPERAND_TYPE_ID, {neg_id}},
911              {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
912       }
913       return true;
914     }
915 
916     return false;
917   };
918 }
919 
920 // Folds addition of a constant and a negation.
921 // Cases:
922 // (-x) + 2 = 2 - x
923 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()924 FoldingRule MergeAddNegateArithmetic() {
925   return [](IRContext* context, Instruction* inst,
926             const std::vector<const analysis::Constant*>& constants) {
927     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
928     const analysis::Type* type =
929         context->get_type_mgr()->GetType(inst->type_id());
930     bool uses_float = HasFloatingPoint(type);
931     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
932 
933     const analysis::Constant* const_input1 = ConstInput(constants);
934     if (!const_input1) return false;
935     Instruction* other_inst = NonConstInput(context, constants[0], inst);
936     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
937       return false;
938 
939     if (other_inst->opcode() == SpvOpSNegate ||
940         other_inst->opcode() == SpvOpFNegate) {
941       inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
942       uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
943                                        : inst->GetSingleWordInOperand(1u);
944       inst->SetInOperands(
945           {{SPV_OPERAND_TYPE_ID, {const_id}},
946            {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
947       return true;
948     }
949     return false;
950   };
951 }
952 
953 // Folds subtraction of a constant and a negation.
954 // Cases:
955 // (-x) - 2 = -2 - x
956 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()957 FoldingRule MergeSubNegateArithmetic() {
958   return [](IRContext* context, Instruction* inst,
959             const std::vector<const analysis::Constant*>& constants) {
960     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
961     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
962     const analysis::Type* type =
963         context->get_type_mgr()->GetType(inst->type_id());
964     bool uses_float = HasFloatingPoint(type);
965     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
966 
967     uint32_t width = ElementWidth(type);
968     if (width != 32 && width != 64) return false;
969 
970     const analysis::Constant* const_input1 = ConstInput(constants);
971     if (!const_input1) return false;
972     Instruction* other_inst = NonConstInput(context, constants[0], inst);
973     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
974       return false;
975 
976     if (other_inst->opcode() == SpvOpSNegate ||
977         other_inst->opcode() == SpvOpFNegate) {
978       uint32_t op1 = 0;
979       uint32_t op2 = 0;
980       SpvOp opcode = inst->opcode();
981       if (constants[0] != nullptr) {
982         op1 = other_inst->GetSingleWordInOperand(0u);
983         op2 = inst->GetSingleWordInOperand(0u);
984         opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
985       } else {
986         op1 = NegateConstant(const_mgr, const_input1);
987         op2 = other_inst->GetSingleWordInOperand(0u);
988       }
989 
990       inst->SetOpcode(opcode);
991       inst->SetInOperands(
992           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
993       return true;
994     }
995     return false;
996   };
997 }
998 
999 // Folds addition of an addition where each operation has a constant operand.
1000 // Cases:
1001 // (x + 2) + 2 = x + 4
1002 // (2 + x) + 2 = x + 4
1003 // 2 + (x + 2) = x + 4
1004 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1005 FoldingRule MergeAddAddArithmetic() {
1006   return [](IRContext* context, Instruction* inst,
1007             const std::vector<const analysis::Constant*>& constants) {
1008     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1009     const analysis::Type* type =
1010         context->get_type_mgr()->GetType(inst->type_id());
1011     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1012     bool uses_float = HasFloatingPoint(type);
1013     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1014 
1015     uint32_t width = ElementWidth(type);
1016     if (width != 32 && width != 64) return false;
1017 
1018     const analysis::Constant* const_input1 = ConstInput(constants);
1019     if (!const_input1) return false;
1020     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1021     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1022       return false;
1023 
1024     if (other_inst->opcode() == SpvOpFAdd ||
1025         other_inst->opcode() == SpvOpIAdd) {
1026       std::vector<const analysis::Constant*> other_constants =
1027           const_mgr->GetOperandConstants(other_inst);
1028       const analysis::Constant* const_input2 = ConstInput(other_constants);
1029       if (!const_input2) return false;
1030 
1031       Instruction* non_const_input =
1032           NonConstInput(context, other_constants[0], other_inst);
1033       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1034                                             const_input1, const_input2);
1035       if (merged_id == 0) return false;
1036 
1037       inst->SetInOperands(
1038           {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1039            {SPV_OPERAND_TYPE_ID, {merged_id}}});
1040       return true;
1041     }
1042     return false;
1043   };
1044 }
1045 
1046 // Folds addition of a subtraction where each operation has a constant operand.
1047 // Cases:
1048 // (x - 2) + 2 = x + 0
1049 // (2 - x) + 2 = 4 - x
1050 // 2 + (x - 2) = x + 0
1051 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1052 FoldingRule MergeAddSubArithmetic() {
1053   return [](IRContext* context, Instruction* inst,
1054             const std::vector<const analysis::Constant*>& constants) {
1055     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1056     const analysis::Type* type =
1057         context->get_type_mgr()->GetType(inst->type_id());
1058     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1059     bool uses_float = HasFloatingPoint(type);
1060     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1061 
1062     uint32_t width = ElementWidth(type);
1063     if (width != 32 && width != 64) return false;
1064 
1065     const analysis::Constant* const_input1 = ConstInput(constants);
1066     if (!const_input1) return false;
1067     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1068     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1069       return false;
1070 
1071     if (other_inst->opcode() == SpvOpFSub ||
1072         other_inst->opcode() == SpvOpISub) {
1073       std::vector<const analysis::Constant*> other_constants =
1074           const_mgr->GetOperandConstants(other_inst);
1075       const analysis::Constant* const_input2 = ConstInput(other_constants);
1076       if (!const_input2) return false;
1077 
1078       bool first_is_variable = other_constants[0] == nullptr;
1079       SpvOp op = inst->opcode();
1080       uint32_t op1 = 0;
1081       uint32_t op2 = 0;
1082       if (first_is_variable) {
1083         // Subtract constants. Non-constant operand is first.
1084         op1 = other_inst->GetSingleWordInOperand(0u);
1085         op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1086                                const_input2);
1087       } else {
1088         // Add constants. Constant operand is first. Change the opcode.
1089         op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1090                                const_input2);
1091         op2 = other_inst->GetSingleWordInOperand(1u);
1092         op = other_inst->opcode();
1093       }
1094       if (op1 == 0 || op2 == 0) return false;
1095 
1096       inst->SetOpcode(op);
1097       inst->SetInOperands(
1098           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1099       return true;
1100     }
1101     return false;
1102   };
1103 }
1104 
1105 // Folds subtraction of an addition where each operand has a constant operand.
1106 // Cases:
1107 // (x + 2) - 2 = x + 0
1108 // (2 + x) - 2 = x + 0
1109 // 2 - (x + 2) = 0 - x
1110 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1111 FoldingRule MergeSubAddArithmetic() {
1112   return [](IRContext* context, Instruction* inst,
1113             const std::vector<const analysis::Constant*>& constants) {
1114     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1115     const analysis::Type* type =
1116         context->get_type_mgr()->GetType(inst->type_id());
1117     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1118     bool uses_float = HasFloatingPoint(type);
1119     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1120 
1121     uint32_t width = ElementWidth(type);
1122     if (width != 32 && width != 64) return false;
1123 
1124     const analysis::Constant* const_input1 = ConstInput(constants);
1125     if (!const_input1) return false;
1126     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1127     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1128       return false;
1129 
1130     if (other_inst->opcode() == SpvOpFAdd ||
1131         other_inst->opcode() == SpvOpIAdd) {
1132       std::vector<const analysis::Constant*> other_constants =
1133           const_mgr->GetOperandConstants(other_inst);
1134       const analysis::Constant* const_input2 = ConstInput(other_constants);
1135       if (!const_input2) return false;
1136 
1137       Instruction* non_const_input =
1138           NonConstInput(context, other_constants[0], other_inst);
1139 
1140       // If the first operand of the sub is not a constant, swap the constants
1141       // so the subtraction has the correct operands.
1142       if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1143       // Subtract the constants.
1144       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1145                                             const_input1, const_input2);
1146       SpvOp op = inst->opcode();
1147       uint32_t op1 = 0;
1148       uint32_t op2 = 0;
1149       if (constants[0] == nullptr) {
1150         // Non-constant operand is first. Change the opcode.
1151         op1 = non_const_input->result_id();
1152         op2 = merged_id;
1153         op = other_inst->opcode();
1154       } else {
1155         // Constant operand is first.
1156         op1 = merged_id;
1157         op2 = non_const_input->result_id();
1158       }
1159       if (op1 == 0 || op2 == 0) return false;
1160 
1161       inst->SetOpcode(op);
1162       inst->SetInOperands(
1163           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1164       return true;
1165     }
1166     return false;
1167   };
1168 }
1169 
1170 // Folds subtraction of a subtraction where each operand has a constant operand.
1171 // Cases:
1172 // (x - 2) - 2 = x - 4
1173 // (2 - x) - 2 = 0 - x
1174 // 2 - (x - 2) = 4 - x
1175 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1176 FoldingRule MergeSubSubArithmetic() {
1177   return [](IRContext* context, Instruction* inst,
1178             const std::vector<const analysis::Constant*>& constants) {
1179     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1180     const analysis::Type* type =
1181         context->get_type_mgr()->GetType(inst->type_id());
1182     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1183     bool uses_float = HasFloatingPoint(type);
1184     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1185 
1186     uint32_t width = ElementWidth(type);
1187     if (width != 32 && width != 64) return false;
1188 
1189     const analysis::Constant* const_input1 = ConstInput(constants);
1190     if (!const_input1) return false;
1191     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1192     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1193       return false;
1194 
1195     if (other_inst->opcode() == SpvOpFSub ||
1196         other_inst->opcode() == SpvOpISub) {
1197       std::vector<const analysis::Constant*> other_constants =
1198           const_mgr->GetOperandConstants(other_inst);
1199       const analysis::Constant* const_input2 = ConstInput(other_constants);
1200       if (!const_input2) return false;
1201 
1202       Instruction* non_const_input =
1203           NonConstInput(context, other_constants[0], other_inst);
1204 
1205       // Merge the constants.
1206       uint32_t merged_id = 0;
1207       SpvOp merge_op = inst->opcode();
1208       if (other_constants[0] == nullptr) {
1209         merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1210       } else if (constants[0] == nullptr) {
1211         std::swap(const_input1, const_input2);
1212       }
1213       merged_id =
1214           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1215       if (merged_id == 0) return false;
1216 
1217       SpvOp op = inst->opcode();
1218       if (constants[0] != nullptr && other_constants[0] != nullptr) {
1219         // Change the operation.
1220         op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1221       }
1222 
1223       uint32_t op1 = 0;
1224       uint32_t op2 = 0;
1225       if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1226         op1 = merged_id;
1227         op2 = non_const_input->result_id();
1228       } else {
1229         op1 = non_const_input->result_id();
1230         op2 = merged_id;
1231       }
1232 
1233       inst->SetOpcode(op);
1234       inst->SetInOperands(
1235           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1236       return true;
1237     }
1238     return false;
1239   };
1240 }
1241 
IntMultipleBy1()1242 FoldingRule IntMultipleBy1() {
1243   return [](IRContext*, Instruction* inst,
1244             const std::vector<const analysis::Constant*>& constants) {
1245     assert(inst->opcode() == SpvOpIMul && "Wrong opcode.  Should be OpIMul.");
1246     for (uint32_t i = 0; i < 2; i++) {
1247       if (constants[i] == nullptr) {
1248         continue;
1249       }
1250       const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1251       if (int_constant) {
1252         uint32_t width = ElementWidth(int_constant->type());
1253         if (width != 32 && width != 64) return false;
1254         bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1255                                     : int_constant->GetU64BitValue() == 1ull;
1256         if (is_one) {
1257           inst->SetOpcode(SpvOpCopyObject);
1258           inst->SetInOperands(
1259               {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1260           return true;
1261         }
1262       }
1263     }
1264     return false;
1265   };
1266 }
1267 
CompositeConstructFeedingExtract()1268 FoldingRule CompositeConstructFeedingExtract() {
1269   return [](IRContext* context, Instruction* inst,
1270             const std::vector<const analysis::Constant*>&) {
1271     // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1272     // then we can simply use the appropriate element in the construction.
1273     assert(inst->opcode() == SpvOpCompositeExtract &&
1274            "Wrong opcode.  Should be OpCompositeExtract.");
1275     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1276     analysis::TypeManager* type_mgr = context->get_type_mgr();
1277     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1278     Instruction* cinst = def_use_mgr->GetDef(cid);
1279 
1280     if (cinst->opcode() != SpvOpCompositeConstruct) {
1281       return false;
1282     }
1283 
1284     std::vector<Operand> operands;
1285     analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
1286     if (composite_type->AsVector() == nullptr) {
1287       // Get the element being extracted from the OpCompositeConstruct
1288       // Since it is not a vector, it is simple to extract the single element.
1289       uint32_t element_index = inst->GetSingleWordInOperand(1);
1290       uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
1291       operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1292 
1293       // Add the remaining indices for extraction.
1294       for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1295         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1296                             {inst->GetSingleWordInOperand(i)}});
1297       }
1298 
1299     } else {
1300       // With vectors we have to handle the case where it is concatenating
1301       // vectors.
1302       assert(inst->NumInOperands() == 2 &&
1303              "Expecting a vector of scalar values.");
1304 
1305       uint32_t element_index = inst->GetSingleWordInOperand(1);
1306       for (uint32_t construct_index = 0;
1307            construct_index < cinst->NumInOperands(); ++construct_index) {
1308         uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
1309         Instruction* element_def = def_use_mgr->GetDef(element_id);
1310         analysis::Vector* element_type =
1311             type_mgr->GetType(element_def->type_id())->AsVector();
1312         if (element_type) {
1313           uint32_t vector_size = element_type->element_count();
1314           if (vector_size < element_index) {
1315             // The element we want comes after this vector.
1316             element_index -= vector_size;
1317           } else {
1318             // We want an element of this vector.
1319             operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1320             operands.push_back(
1321                 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
1322             break;
1323           }
1324         } else {
1325           if (element_index == 0) {
1326             // This is a scalar, and we this is the element we are extracting.
1327             operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1328             break;
1329           } else {
1330             // Skip over this scalar value.
1331             --element_index;
1332           }
1333         }
1334       }
1335     }
1336 
1337     // If there were no extra indices, then we have the final object.  No need
1338     // to extract even more.
1339     if (operands.size() == 1) {
1340       inst->SetOpcode(SpvOpCopyObject);
1341     }
1342 
1343     inst->SetInOperands(std::move(operands));
1344     return true;
1345   };
1346 }
1347 
CompositeExtractFeedingConstruct()1348 FoldingRule CompositeExtractFeedingConstruct() {
1349   // If the OpCompositeConstruct is simply putting back together elements that
1350   // where extracted from the same souce, we can simlpy reuse the source.
1351   //
1352   // This is a common code pattern because of the way that scalar replacement
1353   // works.
1354   return [](IRContext* context, Instruction* inst,
1355             const std::vector<const analysis::Constant*>&) {
1356     assert(inst->opcode() == SpvOpCompositeConstruct &&
1357            "Wrong opcode.  Should be OpCompositeConstruct.");
1358     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1359     uint32_t original_id = 0;
1360 
1361     // Check each element to make sure they are:
1362     // - extractions
1363     // - extracting the same position they are inserting
1364     // - all extract from the same id.
1365     for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1366       uint32_t element_id = inst->GetSingleWordInOperand(i);
1367       Instruction* element_inst = def_use_mgr->GetDef(element_id);
1368 
1369       if (element_inst->opcode() != SpvOpCompositeExtract) {
1370         return false;
1371       }
1372 
1373       if (element_inst->NumInOperands() != 2) {
1374         return false;
1375       }
1376 
1377       if (element_inst->GetSingleWordInOperand(1) != i) {
1378         return false;
1379       }
1380 
1381       if (i == 0) {
1382         original_id =
1383             element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1384       } else if (original_id != element_inst->GetSingleWordInOperand(
1385                                     kExtractCompositeIdInIdx)) {
1386         return false;
1387       }
1388     }
1389 
1390     // The last check it to see that the object being extracted from is the
1391     // correct type.
1392     Instruction* original_inst = def_use_mgr->GetDef(original_id);
1393     if (original_inst->type_id() != inst->type_id()) {
1394       return false;
1395     }
1396 
1397     // Simplify by using the original object.
1398     inst->SetOpcode(SpvOpCopyObject);
1399     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1400     return true;
1401   };
1402 }
1403 
InsertFeedingExtract()1404 FoldingRule InsertFeedingExtract() {
1405   return [](IRContext* context, Instruction* inst,
1406             const std::vector<const analysis::Constant*>&) {
1407     assert(inst->opcode() == SpvOpCompositeExtract &&
1408            "Wrong opcode.  Should be OpCompositeExtract.");
1409     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1410     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1411     Instruction* cinst = def_use_mgr->GetDef(cid);
1412 
1413     if (cinst->opcode() != SpvOpCompositeInsert) {
1414       return false;
1415     }
1416 
1417     // Find the first position where the list of insert and extract indicies
1418     // differ, if at all.
1419     uint32_t i;
1420     for (i = 1; i < inst->NumInOperands(); ++i) {
1421       if (i + 1 >= cinst->NumInOperands()) {
1422         break;
1423       }
1424 
1425       if (inst->GetSingleWordInOperand(i) !=
1426           cinst->GetSingleWordInOperand(i + 1)) {
1427         break;
1428       }
1429     }
1430 
1431     // We are extracting the element that was inserted.
1432     if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1433       inst->SetOpcode(SpvOpCopyObject);
1434       inst->SetInOperands(
1435           {{SPV_OPERAND_TYPE_ID,
1436             {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1437       return true;
1438     }
1439 
1440     // Extracting the value that was inserted along with values for the base
1441     // composite.  Cannot do anything.
1442     if (i == inst->NumInOperands()) {
1443       return false;
1444     }
1445 
1446     // Extracting an element of the value that was inserted.  Extract from
1447     // that value directly.
1448     if (i + 1 == cinst->NumInOperands()) {
1449       std::vector<Operand> operands;
1450       operands.push_back(
1451           {SPV_OPERAND_TYPE_ID,
1452            {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1453       for (; i < inst->NumInOperands(); ++i) {
1454         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1455                             {inst->GetSingleWordInOperand(i)}});
1456       }
1457       inst->SetInOperands(std::move(operands));
1458       return true;
1459     }
1460 
1461     // Extracting a value that is disjoint from the element being inserted.
1462     // Rewrite the extract to use the composite input to the insert.
1463     std::vector<Operand> operands;
1464     operands.push_back(
1465         {SPV_OPERAND_TYPE_ID,
1466          {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1467     for (i = 1; i < inst->NumInOperands(); ++i) {
1468       operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1469                           {inst->GetSingleWordInOperand(i)}});
1470     }
1471     inst->SetInOperands(std::move(operands));
1472     return true;
1473   };
1474 }
1475 
1476 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1477 // operands of the VectorShuffle.  We just need to adjust the index in the
1478 // extract instruction.
VectorShuffleFeedingExtract()1479 FoldingRule VectorShuffleFeedingExtract() {
1480   return [](IRContext* context, Instruction* inst,
1481             const std::vector<const analysis::Constant*>&) {
1482     assert(inst->opcode() == SpvOpCompositeExtract &&
1483            "Wrong opcode.  Should be OpCompositeExtract.");
1484     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1485     analysis::TypeManager* type_mgr = context->get_type_mgr();
1486     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1487     Instruction* cinst = def_use_mgr->GetDef(cid);
1488 
1489     if (cinst->opcode() != SpvOpVectorShuffle) {
1490       return false;
1491     }
1492 
1493     // Find the size of the first vector operand of the VectorShuffle
1494     Instruction* first_input =
1495         def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1496     analysis::Type* first_input_type =
1497         type_mgr->GetType(first_input->type_id());
1498     assert(first_input_type->AsVector() &&
1499            "Input to vector shuffle should be vectors.");
1500     uint32_t first_input_size = first_input_type->AsVector()->element_count();
1501 
1502     // Get index of the element the vector shuffle is placing in the position
1503     // being extracted.
1504     uint32_t new_index =
1505         cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1506 
1507     // Extracting an undefined value so fold this extract into an undef.
1508     const uint32_t undef_literal_value = 0xffffffff;
1509     if (new_index == undef_literal_value) {
1510       inst->SetOpcode(SpvOpUndef);
1511       inst->SetInOperands({});
1512       return true;
1513     }
1514 
1515     // Get the id of the of the vector the elemtent comes from, and update the
1516     // index if needed.
1517     uint32_t new_vector = 0;
1518     if (new_index < first_input_size) {
1519       new_vector = cinst->GetSingleWordInOperand(0);
1520     } else {
1521       new_vector = cinst->GetSingleWordInOperand(1);
1522       new_index -= first_input_size;
1523     }
1524 
1525     // Update the extract instruction.
1526     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1527     inst->SetInOperand(1, {new_index});
1528     return true;
1529   };
1530 }
1531 
1532 // When an FMix with is feeding an Extract that extracts an element whose
1533 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1534 // operands of the FMix.
FMixFeedingExtract()1535 FoldingRule FMixFeedingExtract() {
1536   return [](IRContext* context, Instruction* inst,
1537             const std::vector<const analysis::Constant*>&) {
1538     assert(inst->opcode() == SpvOpCompositeExtract &&
1539            "Wrong opcode.  Should be OpCompositeExtract.");
1540     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1541     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1542 
1543     uint32_t composite_id =
1544         inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1545     Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1546 
1547     if (composite_inst->opcode() != SpvOpExtInst) {
1548       return false;
1549     }
1550 
1551     uint32_t inst_set_id =
1552         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1553 
1554     if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1555             inst_set_id ||
1556         composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1557             GLSLstd450FMix) {
1558       return false;
1559     }
1560 
1561     // Get the |a| for the FMix instruction.
1562     uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1563     std::unique_ptr<Instruction> a(inst->Clone(context));
1564     a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1565     context->get_instruction_folder().FoldInstruction(a.get());
1566 
1567     if (a->opcode() != SpvOpCopyObject) {
1568       return false;
1569     }
1570 
1571     const analysis::Constant* a_const =
1572         const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1573 
1574     if (!a_const) {
1575       return false;
1576     }
1577 
1578     bool use_x = false;
1579 
1580     assert(a_const->type()->AsFloat());
1581     double element_value = a_const->GetValueAsDouble();
1582     if (element_value == 0.0) {
1583       use_x = true;
1584     } else if (element_value == 1.0) {
1585       use_x = false;
1586     } else {
1587       return false;
1588     }
1589 
1590     // Get the id of the of the vector the element comes from.
1591     uint32_t new_vector = 0;
1592     if (use_x) {
1593       new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1594     } else {
1595       new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1596     }
1597 
1598     // Update the extract instruction.
1599     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1600     return true;
1601   };
1602 }
1603 
RedundantPhi()1604 FoldingRule RedundantPhi() {
1605   // An OpPhi instruction where all values are the same or the result of the phi
1606   // itself, can be replaced by the value itself.
1607   return [](IRContext*, Instruction* inst,
1608             const std::vector<const analysis::Constant*>&) {
1609     assert(inst->opcode() == SpvOpPhi && "Wrong opcode.  Should be OpPhi.");
1610 
1611     uint32_t incoming_value = 0;
1612 
1613     for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
1614       uint32_t op_id = inst->GetSingleWordInOperand(i);
1615       if (op_id == inst->result_id()) {
1616         continue;
1617       }
1618 
1619       if (incoming_value == 0) {
1620         incoming_value = op_id;
1621       } else if (op_id != incoming_value) {
1622         // Found two possible value.  Can't simplify.
1623         return false;
1624       }
1625     }
1626 
1627     if (incoming_value == 0) {
1628       // Code looks invalid.  Don't do anything.
1629       return false;
1630     }
1631 
1632     // We have a single incoming value.  Simplify using that value.
1633     inst->SetOpcode(SpvOpCopyObject);
1634     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
1635     return true;
1636   };
1637 }
1638 
RedundantSelect()1639 FoldingRule RedundantSelect() {
1640   // An OpSelect instruction where both values are the same or the condition is
1641   // constant can be replaced by one of the values
1642   return [](IRContext*, Instruction* inst,
1643             const std::vector<const analysis::Constant*>& constants) {
1644     assert(inst->opcode() == SpvOpSelect &&
1645            "Wrong opcode.  Should be OpSelect.");
1646     assert(inst->NumInOperands() == 3);
1647     assert(constants.size() == 3);
1648 
1649     uint32_t true_id = inst->GetSingleWordInOperand(1);
1650     uint32_t false_id = inst->GetSingleWordInOperand(2);
1651 
1652     if (true_id == false_id) {
1653       // Both results are the same, condition doesn't matter
1654       inst->SetOpcode(SpvOpCopyObject);
1655       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1656       return true;
1657     } else if (constants[0]) {
1658       const analysis::Type* type = constants[0]->type();
1659       if (type->AsBool()) {
1660         // Scalar constant value, select the corresponding value.
1661         inst->SetOpcode(SpvOpCopyObject);
1662         if (constants[0]->AsNullConstant() ||
1663             !constants[0]->AsBoolConstant()->value()) {
1664           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1665         } else {
1666           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1667         }
1668         return true;
1669       } else {
1670         assert(type->AsVector());
1671         if (constants[0]->AsNullConstant()) {
1672           // All values come from false id.
1673           inst->SetOpcode(SpvOpCopyObject);
1674           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1675           return true;
1676         } else {
1677           // Convert to a vector shuffle.
1678           std::vector<Operand> ops;
1679           ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
1680           ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
1681           const analysis::VectorConstant* vector_const =
1682               constants[0]->AsVectorConstant();
1683           uint32_t size =
1684               static_cast<uint32_t>(vector_const->GetComponents().size());
1685           for (uint32_t i = 0; i != size; ++i) {
1686             const analysis::Constant* component =
1687                 vector_const->GetComponents()[i];
1688             if (component->AsNullConstant() ||
1689                 !component->AsBoolConstant()->value()) {
1690               // Selecting from the false vector which is the second input
1691               // vector to the shuffle. Offset the index by |size|.
1692               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
1693             } else {
1694               // Selecting from true vector which is the first input vector to
1695               // the shuffle.
1696               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
1697             }
1698           }
1699 
1700           inst->SetOpcode(SpvOpVectorShuffle);
1701           inst->SetInOperands(std::move(ops));
1702           return true;
1703         }
1704       }
1705     }
1706 
1707     return false;
1708   };
1709 }
1710 
1711 enum class FloatConstantKind { Unknown, Zero, One };
1712 
getFloatConstantKind(const analysis::Constant * constant)1713 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
1714   if (constant == nullptr) {
1715     return FloatConstantKind::Unknown;
1716   }
1717 
1718   assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
1719 
1720   if (constant->AsNullConstant()) {
1721     return FloatConstantKind::Zero;
1722   } else if (const analysis::VectorConstant* vc =
1723                  constant->AsVectorConstant()) {
1724     const std::vector<const analysis::Constant*>& components =
1725         vc->GetComponents();
1726     assert(!components.empty());
1727 
1728     FloatConstantKind kind = getFloatConstantKind(components[0]);
1729 
1730     for (size_t i = 1; i < components.size(); ++i) {
1731       if (getFloatConstantKind(components[i]) != kind) {
1732         return FloatConstantKind::Unknown;
1733       }
1734     }
1735 
1736     return kind;
1737   } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
1738     if (fc->IsZero()) return FloatConstantKind::Zero;
1739 
1740     uint32_t width = fc->type()->AsFloat()->width();
1741     if (width != 32 && width != 64) return FloatConstantKind::Unknown;
1742 
1743     double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
1744 
1745     if (value == 0.0) {
1746       return FloatConstantKind::Zero;
1747     } else if (value == 1.0) {
1748       return FloatConstantKind::One;
1749     } else {
1750       return FloatConstantKind::Unknown;
1751     }
1752   } else {
1753     return FloatConstantKind::Unknown;
1754   }
1755 }
1756 
RedundantFAdd()1757 FoldingRule RedundantFAdd() {
1758   return [](IRContext*, Instruction* inst,
1759             const std::vector<const analysis::Constant*>& constants) {
1760     assert(inst->opcode() == SpvOpFAdd && "Wrong opcode.  Should be OpFAdd.");
1761     assert(constants.size() == 2);
1762 
1763     if (!inst->IsFloatingPointFoldingAllowed()) {
1764       return false;
1765     }
1766 
1767     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1768     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1769 
1770     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
1771       inst->SetOpcode(SpvOpCopyObject);
1772       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1773                             {inst->GetSingleWordInOperand(
1774                                 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
1775       return true;
1776     }
1777 
1778     return false;
1779   };
1780 }
1781 
RedundantFSub()1782 FoldingRule RedundantFSub() {
1783   return [](IRContext*, Instruction* inst,
1784             const std::vector<const analysis::Constant*>& constants) {
1785     assert(inst->opcode() == SpvOpFSub && "Wrong opcode.  Should be OpFSub.");
1786     assert(constants.size() == 2);
1787 
1788     if (!inst->IsFloatingPointFoldingAllowed()) {
1789       return false;
1790     }
1791 
1792     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1793     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1794 
1795     if (kind0 == FloatConstantKind::Zero) {
1796       inst->SetOpcode(SpvOpFNegate);
1797       inst->SetInOperands(
1798           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
1799       return true;
1800     }
1801 
1802     if (kind1 == FloatConstantKind::Zero) {
1803       inst->SetOpcode(SpvOpCopyObject);
1804       inst->SetInOperands(
1805           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
1806       return true;
1807     }
1808 
1809     return false;
1810   };
1811 }
1812 
RedundantFMul()1813 FoldingRule RedundantFMul() {
1814   return [](IRContext*, Instruction* inst,
1815             const std::vector<const analysis::Constant*>& constants) {
1816     assert(inst->opcode() == SpvOpFMul && "Wrong opcode.  Should be OpFMul.");
1817     assert(constants.size() == 2);
1818 
1819     if (!inst->IsFloatingPointFoldingAllowed()) {
1820       return false;
1821     }
1822 
1823     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1824     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1825 
1826     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
1827       inst->SetOpcode(SpvOpCopyObject);
1828       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1829                             {inst->GetSingleWordInOperand(
1830                                 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
1831       return true;
1832     }
1833 
1834     if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
1835       inst->SetOpcode(SpvOpCopyObject);
1836       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1837                             {inst->GetSingleWordInOperand(
1838                                 kind0 == FloatConstantKind::One ? 1 : 0)}}});
1839       return true;
1840     }
1841 
1842     return false;
1843   };
1844 }
1845 
RedundantFDiv()1846 FoldingRule RedundantFDiv() {
1847   return [](IRContext*, Instruction* inst,
1848             const std::vector<const analysis::Constant*>& constants) {
1849     assert(inst->opcode() == SpvOpFDiv && "Wrong opcode.  Should be OpFDiv.");
1850     assert(constants.size() == 2);
1851 
1852     if (!inst->IsFloatingPointFoldingAllowed()) {
1853       return false;
1854     }
1855 
1856     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1857     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1858 
1859     if (kind0 == FloatConstantKind::Zero) {
1860       inst->SetOpcode(SpvOpCopyObject);
1861       inst->SetInOperands(
1862           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
1863       return true;
1864     }
1865 
1866     if (kind1 == FloatConstantKind::One) {
1867       inst->SetOpcode(SpvOpCopyObject);
1868       inst->SetInOperands(
1869           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
1870       return true;
1871     }
1872 
1873     return false;
1874   };
1875 }
1876 
RedundantFMix()1877 FoldingRule RedundantFMix() {
1878   return [](IRContext* context, Instruction* inst,
1879             const std::vector<const analysis::Constant*>& constants) {
1880     assert(inst->opcode() == SpvOpExtInst &&
1881            "Wrong opcode.  Should be OpExtInst.");
1882 
1883     if (!inst->IsFloatingPointFoldingAllowed()) {
1884       return false;
1885     }
1886 
1887     uint32_t instSetId =
1888         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1889 
1890     if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
1891         inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
1892             GLSLstd450FMix) {
1893       assert(constants.size() == 5);
1894 
1895       FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
1896 
1897       if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
1898         inst->SetOpcode(SpvOpCopyObject);
1899         inst->SetInOperands(
1900             {{SPV_OPERAND_TYPE_ID,
1901               {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
1902                                                 ? kFMixXIdInIdx
1903                                                 : kFMixYIdInIdx)}}});
1904         return true;
1905       }
1906     }
1907 
1908     return false;
1909   };
1910 }
1911 
1912 // This rule handles addition of zero for integers.
RedundantIAdd()1913 FoldingRule RedundantIAdd() {
1914   return [](IRContext* context, Instruction* inst,
1915             const std::vector<const analysis::Constant*>& constants) {
1916     assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
1917 
1918     uint32_t operand = std::numeric_limits<uint32_t>::max();
1919     const analysis::Type* operand_type = nullptr;
1920     if (constants[0] && constants[0]->IsZero()) {
1921       operand = inst->GetSingleWordInOperand(1);
1922       operand_type = constants[0]->type();
1923     } else if (constants[1] && constants[1]->IsZero()) {
1924       operand = inst->GetSingleWordInOperand(0);
1925       operand_type = constants[1]->type();
1926     }
1927 
1928     if (operand != std::numeric_limits<uint32_t>::max()) {
1929       const analysis::Type* inst_type =
1930           context->get_type_mgr()->GetType(inst->type_id());
1931       if (inst_type->IsSame(operand_type)) {
1932         inst->SetOpcode(SpvOpCopyObject);
1933       } else {
1934         inst->SetOpcode(SpvOpBitcast);
1935       }
1936       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
1937       return true;
1938     }
1939     return false;
1940   };
1941 }
1942 
1943 // This rule look for a dot with a constant vector containing a single 1 and
1944 // the rest 0s.  This is the same as doing an extract.
DotProductDoingExtract()1945 FoldingRule DotProductDoingExtract() {
1946   return [](IRContext* context, Instruction* inst,
1947             const std::vector<const analysis::Constant*>& constants) {
1948     assert(inst->opcode() == SpvOpDot && "Wrong opcode.  Should be OpDot.");
1949 
1950     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1951 
1952     if (!inst->IsFloatingPointFoldingAllowed()) {
1953       return false;
1954     }
1955 
1956     for (int i = 0; i < 2; ++i) {
1957       if (!constants[i]) {
1958         continue;
1959       }
1960 
1961       const analysis::Vector* vector_type = constants[i]->type()->AsVector();
1962       assert(vector_type && "Inputs to OpDot must be vectors.");
1963       const analysis::Float* element_type =
1964           vector_type->element_type()->AsFloat();
1965       assert(element_type && "Inputs to OpDot must be vectors of floats.");
1966       uint32_t element_width = element_type->width();
1967       if (element_width != 32 && element_width != 64) {
1968         return false;
1969       }
1970 
1971       std::vector<const analysis::Constant*> components;
1972       components = constants[i]->GetVectorComponents(const_mgr);
1973 
1974       const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
1975 
1976       uint32_t component_with_one = kNotFound;
1977       bool all_others_zero = true;
1978       for (uint32_t j = 0; j < components.size(); ++j) {
1979         const analysis::Constant* element = components[j];
1980         double value =
1981             (element_width == 32 ? element->GetFloat() : element->GetDouble());
1982         if (value == 0.0) {
1983           continue;
1984         } else if (value == 1.0) {
1985           if (component_with_one == kNotFound) {
1986             component_with_one = j;
1987           } else {
1988             component_with_one = kNotFound;
1989             break;
1990           }
1991         } else {
1992           all_others_zero = false;
1993           break;
1994         }
1995       }
1996 
1997       if (!all_others_zero || component_with_one == kNotFound) {
1998         continue;
1999       }
2000 
2001       std::vector<Operand> operands;
2002       operands.push_back(
2003           {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2004       operands.push_back(
2005           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2006 
2007       inst->SetOpcode(SpvOpCompositeExtract);
2008       inst->SetInOperands(std::move(operands));
2009       return true;
2010     }
2011     return false;
2012   };
2013 }
2014 
2015 // If we are storing an undef, then we can remove the store.
2016 //
2017 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2018 // is complicated.  Waiting to see if it is needed.
StoringUndef()2019 FoldingRule StoringUndef() {
2020   return [](IRContext* context, Instruction* inst,
2021             const std::vector<const analysis::Constant*>&) {
2022     assert(inst->opcode() == SpvOpStore && "Wrong opcode.  Should be OpStore.");
2023 
2024     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2025 
2026     // If this is a volatile store, the store cannot be removed.
2027     if (inst->NumInOperands() == 3) {
2028       if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
2029         return false;
2030       }
2031     }
2032 
2033     uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2034     Instruction* object_inst = def_use_mgr->GetDef(object_id);
2035     if (object_inst->opcode() == SpvOpUndef) {
2036       inst->ToNop();
2037       return true;
2038     }
2039     return false;
2040   };
2041 }
2042 
VectorShuffleFeedingShuffle()2043 FoldingRule VectorShuffleFeedingShuffle() {
2044   return [](IRContext* context, Instruction* inst,
2045             const std::vector<const analysis::Constant*>&) {
2046     assert(inst->opcode() == SpvOpVectorShuffle &&
2047            "Wrong opcode.  Should be OpVectorShuffle.");
2048 
2049     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2050     analysis::TypeManager* type_mgr = context->get_type_mgr();
2051 
2052     Instruction* feeding_shuffle_inst =
2053         def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2054     analysis::Vector* op0_type =
2055         type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2056     uint32_t op0_length = op0_type->element_count();
2057 
2058     bool feeder_is_op0 = true;
2059     if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2060       feeding_shuffle_inst =
2061           def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2062       feeder_is_op0 = false;
2063     }
2064 
2065     if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2066       return false;
2067     }
2068 
2069     Instruction* feeder2 =
2070         def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2071     analysis::Vector* feeder_op0_type =
2072         type_mgr->GetType(feeder2->type_id())->AsVector();
2073     uint32_t feeder_op0_length = feeder_op0_type->element_count();
2074 
2075     uint32_t new_feeder_id = 0;
2076     std::vector<Operand> new_operands;
2077     new_operands.resize(
2078         2, {SPV_OPERAND_TYPE_ID, {0}});  // Place holders for vector operands.
2079     const uint32_t undef_literal = 0xffffffff;
2080     for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2081       uint32_t component_index = inst->GetSingleWordInOperand(op);
2082 
2083       // Do not interpret the undefined value literal as coming from operand 1.
2084       if (component_index != undef_literal &&
2085           feeder_is_op0 == (component_index < op0_length)) {
2086         // This component comes from the feeding_shuffle_inst.  Update
2087         // |component_index| to be the index into the operand of the feeder.
2088 
2089         // Adjust component_index to get the index into the operands of the
2090         // feeding_shuffle_inst.
2091         if (component_index >= op0_length) {
2092           component_index -= op0_length;
2093         }
2094         component_index =
2095             feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2096 
2097         // Check if we are using a component from the first or second operand of
2098         // the feeding instruction.
2099         if (component_index < feeder_op0_length) {
2100           if (new_feeder_id == 0) {
2101             // First time through, save the id of the operand the element comes
2102             // from.
2103             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2104           } else if (new_feeder_id !=
2105                      feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2106             // We need both elements of the feeding_shuffle_inst, so we cannot
2107             // fold.
2108             return false;
2109           }
2110         } else {
2111           if (new_feeder_id == 0) {
2112             // First time through, save the id of the operand the element comes
2113             // from.
2114             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2115           } else if (new_feeder_id !=
2116                      feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2117             // We need both elements of the feeding_shuffle_inst, so we cannot
2118             // fold.
2119             return false;
2120           }
2121           component_index -= feeder_op0_length;
2122         }
2123 
2124         if (!feeder_is_op0) {
2125           component_index += op0_length;
2126         }
2127       }
2128       new_operands.push_back(
2129           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2130     }
2131 
2132     if (new_feeder_id == 0) {
2133       analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2134       const analysis::Type* type =
2135           type_mgr->GetType(feeding_shuffle_inst->type_id());
2136       const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2137       new_feeder_id =
2138           const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2139     }
2140 
2141     if (feeder_is_op0) {
2142       // If the size of the first vector operand changed then the indices
2143       // referring to the second operand need to be adjusted.
2144       Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2145       analysis::Type* new_feeder_type =
2146           type_mgr->GetType(new_feeder_inst->type_id());
2147       uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2148       int32_t adjustment = op0_length - new_op0_size;
2149 
2150       if (adjustment != 0) {
2151         for (uint32_t i = 2; i < new_operands.size(); i++) {
2152           if (inst->GetSingleWordInOperand(i) >= op0_length) {
2153             new_operands[i].words[0] -= adjustment;
2154           }
2155         }
2156       }
2157 
2158       new_operands[0].words[0] = new_feeder_id;
2159       new_operands[1] = inst->GetInOperand(1);
2160     } else {
2161       new_operands[1].words[0] = new_feeder_id;
2162       new_operands[0] = inst->GetInOperand(0);
2163     }
2164 
2165     inst->SetInOperands(std::move(new_operands));
2166     return true;
2167   };
2168 }
2169 
2170 }  // namespace
2171 
FoldingRules()2172 FoldingRules::FoldingRules() {
2173   // Add all folding rules to the list for the opcodes to which they apply.
2174   // Note that the order in which rules are added to the list matters. If a rule
2175   // applies to the instruction, the rest of the rules will not be attempted.
2176   // Take that into consideration.
2177   rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
2178 
2179   rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
2180   rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
2181   rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2182   rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
2183 
2184   rules_[SpvOpDot].push_back(DotProductDoingExtract());
2185 
2186   rules_[SpvOpExtInst].push_back(RedundantFMix());
2187 
2188   rules_[SpvOpFAdd].push_back(RedundantFAdd());
2189   rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
2190   rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
2191   rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
2192 
2193   rules_[SpvOpFDiv].push_back(RedundantFDiv());
2194   rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
2195   rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
2196   rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
2197   rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
2198 
2199   rules_[SpvOpFMul].push_back(RedundantFMul());
2200   rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
2201   rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
2202   rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
2203 
2204   rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
2205   rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
2206   rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
2207 
2208   rules_[SpvOpFSub].push_back(RedundantFSub());
2209   rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
2210   rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
2211   rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
2212 
2213   rules_[SpvOpIAdd].push_back(RedundantIAdd());
2214   rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
2215   rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
2216   rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
2217 
2218   rules_[SpvOpIMul].push_back(IntMultipleBy1());
2219   rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
2220   rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
2221 
2222   rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
2223   rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
2224   rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
2225 
2226   rules_[SpvOpPhi].push_back(RedundantPhi());
2227 
2228   rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
2229 
2230   rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
2231   rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
2232   rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
2233 
2234   rules_[SpvOpSelect].push_back(RedundantSelect());
2235 
2236   rules_[SpvOpStore].push_back(StoringUndef());
2237 
2238   rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
2239 
2240   rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2241 }
2242 }  // namespace opt
2243 }  // namespace spvtools
2244