• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/folding_rules.h"
16 
17 #include <limits>
18 #include <memory>
19 #include <utility>
20 
21 #include "ir_builder.h"
22 #include "source/latest_version_glsl_std_450_header.h"
23 #include "source/opt/ir_context.h"
24 
25 namespace spvtools {
26 namespace opt {
27 namespace {
28 
29 const uint32_t kExtractCompositeIdInIdx = 0;
30 const uint32_t kInsertObjectIdInIdx = 0;
31 const uint32_t kInsertCompositeIdInIdx = 1;
32 const uint32_t kExtInstSetIdInIdx = 0;
33 const uint32_t kExtInstInstructionInIdx = 1;
34 const uint32_t kFMixXIdInIdx = 2;
35 const uint32_t kFMixYIdInIdx = 3;
36 const uint32_t kFMixAIdInIdx = 4;
37 const uint32_t kStoreObjectInIdx = 1;
38 
39 // Some image instructions may contain an "image operands" argument.
40 // Returns the operand index for the "image operands".
41 // Returns -1 if the instruction does not have image operands.
ImageOperandsMaskInOperandIndex(Instruction * inst)42 int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
43   const auto opcode = inst->opcode();
44   switch (opcode) {
45     case SpvOpImageSampleImplicitLod:
46     case SpvOpImageSampleExplicitLod:
47     case SpvOpImageSampleProjImplicitLod:
48     case SpvOpImageSampleProjExplicitLod:
49     case SpvOpImageFetch:
50     case SpvOpImageRead:
51     case SpvOpImageSparseSampleImplicitLod:
52     case SpvOpImageSparseSampleExplicitLod:
53     case SpvOpImageSparseSampleProjImplicitLod:
54     case SpvOpImageSparseSampleProjExplicitLod:
55     case SpvOpImageSparseFetch:
56     case SpvOpImageSparseRead:
57       return inst->NumOperands() > 4 ? 2 : -1;
58     case SpvOpImageSampleDrefImplicitLod:
59     case SpvOpImageSampleDrefExplicitLod:
60     case SpvOpImageSampleProjDrefImplicitLod:
61     case SpvOpImageSampleProjDrefExplicitLod:
62     case SpvOpImageGather:
63     case SpvOpImageDrefGather:
64     case SpvOpImageSparseSampleDrefImplicitLod:
65     case SpvOpImageSparseSampleDrefExplicitLod:
66     case SpvOpImageSparseSampleProjDrefImplicitLod:
67     case SpvOpImageSparseSampleProjDrefExplicitLod:
68     case SpvOpImageSparseGather:
69     case SpvOpImageSparseDrefGather:
70       return inst->NumOperands() > 5 ? 3 : -1;
71     case SpvOpImageWrite:
72       return inst->NumOperands() > 3 ? 3 : -1;
73     default:
74       return -1;
75   }
76 }
77 
78 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)79 uint32_t ElementWidth(const analysis::Type* type) {
80   if (const analysis::Vector* vec_type = type->AsVector()) {
81     return ElementWidth(vec_type->element_type());
82   } else if (const analysis::Float* float_type = type->AsFloat()) {
83     return float_type->width();
84   } else {
85     assert(type->AsInteger());
86     return type->AsInteger()->width();
87   }
88 }
89 
90 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)91 bool HasFloatingPoint(const analysis::Type* type) {
92   if (type->AsFloat()) {
93     return true;
94   } else if (const analysis::Vector* vec_type = type->AsVector()) {
95     return vec_type->element_type()->AsFloat() != nullptr;
96   }
97 
98   return false;
99 }
100 
101 // Returns false if |val| is NaN, infinite or subnormal.
102 template <typename T>
IsValidResult(T val)103 bool IsValidResult(T val) {
104   int classified = std::fpclassify(val);
105   switch (classified) {
106     case FP_NAN:
107     case FP_INFINITE:
108     case FP_SUBNORMAL:
109       return false;
110     default:
111       return true;
112   }
113 }
114 
ConstInput(const std::vector<const analysis::Constant * > & constants)115 const analysis::Constant* ConstInput(
116     const std::vector<const analysis::Constant*>& constants) {
117   return constants[0] ? constants[0] : constants[1];
118 }
119 
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)120 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
121                            Instruction* inst) {
122   uint32_t in_op = c ? 1u : 0u;
123   return context->get_def_use_mgr()->GetDef(
124       inst->GetSingleWordInOperand(in_op));
125 }
126 
ExtractInts(uint64_t val)127 std::vector<uint32_t> ExtractInts(uint64_t val) {
128   std::vector<uint32_t> words;
129   words.push_back(static_cast<uint32_t>(val));
130   words.push_back(static_cast<uint32_t>(val >> 32));
131   return words;
132 }
133 
GetWordsFromScalarIntConstant(const analysis::IntConstant * c)134 std::vector<uint32_t> GetWordsFromScalarIntConstant(
135     const analysis::IntConstant* c) {
136   assert(c != nullptr);
137   uint32_t width = c->type()->AsInteger()->width();
138   assert(width == 32 || width == 64);
139   if (width == 64) {
140     uint64_t uval = static_cast<uint64_t>(c->GetU64());
141     return ExtractInts(uval);
142   }
143   return {c->GetU32()};
144 }
145 
GetWordsFromScalarFloatConstant(const analysis::FloatConstant * c)146 std::vector<uint32_t> GetWordsFromScalarFloatConstant(
147     const analysis::FloatConstant* c) {
148   assert(c != nullptr);
149   uint32_t width = c->type()->AsFloat()->width();
150   assert(width == 32 || width == 64);
151   if (width == 64) {
152     utils::FloatProxy<double> result(c->GetDouble());
153     return result.GetWords();
154   }
155   utils::FloatProxy<float> result(c->GetFloat());
156   return result.GetWords();
157 }
158 
GetWordsFromNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)159 std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
160     analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
161   if (const auto* float_constant = c->AsFloatConstant()) {
162     return GetWordsFromScalarFloatConstant(float_constant);
163   } else if (const auto* int_constant = c->AsIntConstant()) {
164     return GetWordsFromScalarIntConstant(int_constant);
165   } else if (const auto* vec_constant = c->AsVectorConstant()) {
166     std::vector<uint32_t> words;
167     for (const auto* comp : vec_constant->GetComponents()) {
168       auto comp_in_words =
169           GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
170       words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
171     }
172     return words;
173   }
174   return {};
175 }
176 
ConvertWordsToNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const std::vector<uint32_t> & words,const analysis::Type * type)177 const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
178     analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
179     const analysis::Type* type) {
180   if (type->AsInteger() || type->AsFloat())
181     return const_mgr->GetConstant(type, words);
182   if (const auto* vec_type = type->AsVector())
183     return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
184   return nullptr;
185 }
186 
187 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
188 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)189 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
190                                      const analysis::Constant* c) {
191   assert(c);
192   assert(c->type()->AsFloat());
193   uint32_t width = c->type()->AsFloat()->width();
194   assert(width == 32 || width == 64);
195   std::vector<uint32_t> words;
196   if (width == 64) {
197     utils::FloatProxy<double> result(c->GetDouble() * -1.0);
198     words = result.GetWords();
199   } else {
200     utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
201     words = result.GetWords();
202   }
203 
204   const analysis::Constant* negated_const =
205       const_mgr->GetConstant(c->type(), std::move(words));
206   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
207 }
208 
209 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)210 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
211                                const analysis::Constant* c) {
212   assert(c);
213   assert(c->type()->AsInteger());
214   uint32_t width = c->type()->AsInteger()->width();
215   assert(width == 32 || width == 64);
216   std::vector<uint32_t> words;
217   if (width == 64) {
218     uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
219     words = ExtractInts(uval);
220   } else {
221     words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
222   }
223 
224   const analysis::Constant* negated_const =
225       const_mgr->GetConstant(c->type(), std::move(words));
226   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
227 }
228 
229 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)230 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
231                               const analysis::Constant* c) {
232   assert(const_mgr && c);
233   assert(c->type()->AsVector());
234   if (c->AsNullConstant()) {
235     // 0.0 vs -0.0 shouldn't matter.
236     return const_mgr->GetDefiningInstruction(c)->result_id();
237   } else {
238     const analysis::Type* component_type =
239         c->AsVectorConstant()->component_type();
240     std::vector<uint32_t> words;
241     for (auto& comp : c->AsVectorConstant()->GetComponents()) {
242       if (component_type->AsFloat()) {
243         words.push_back(NegateFloatingPointConstant(const_mgr, comp));
244       } else {
245         assert(component_type->AsInteger());
246         words.push_back(NegateIntegerConstant(const_mgr, comp));
247       }
248     }
249 
250     const analysis::Constant* negated_const =
251         const_mgr->GetConstant(c->type(), std::move(words));
252     return const_mgr->GetDefiningInstruction(negated_const)->result_id();
253   }
254 }
255 
256 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)257 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
258                         const analysis::Constant* c) {
259   if (c->type()->AsVector()) {
260     return NegateVectorConstant(const_mgr, c);
261   } else if (c->type()->AsFloat()) {
262     return NegateFloatingPointConstant(const_mgr, c);
263   } else {
264     assert(c->type()->AsInteger());
265     return NegateIntegerConstant(const_mgr, c);
266   }
267 }
268 
269 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
270 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)271 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
272                     const analysis::Constant* c) {
273   assert(const_mgr && c);
274   assert(c->type()->AsFloat());
275 
276   uint32_t width = c->type()->AsFloat()->width();
277   assert(width == 32 || width == 64);
278   std::vector<uint32_t> words;
279   if (width == 64) {
280     spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
281     if (!IsValidResult(result.getAsFloat())) return 0;
282     words = result.GetWords();
283   } else {
284     spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
285     if (!IsValidResult(result.getAsFloat())) return 0;
286     words = result.GetWords();
287   }
288 
289   const analysis::Constant* negated_const =
290       const_mgr->GetConstant(c->type(), std::move(words));
291   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
292 }
293 
294 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()295 FoldingRule ReciprocalFDiv() {
296   return [](IRContext* context, Instruction* inst,
297             const std::vector<const analysis::Constant*>& constants) {
298     assert(inst->opcode() == SpvOpFDiv);
299     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
300     const analysis::Type* type =
301         context->get_type_mgr()->GetType(inst->type_id());
302     if (!inst->IsFloatingPointFoldingAllowed()) return false;
303 
304     uint32_t width = ElementWidth(type);
305     if (width != 32 && width != 64) return false;
306 
307     if (constants[1] != nullptr) {
308       uint32_t id = 0;
309       if (const analysis::VectorConstant* vector_const =
310               constants[1]->AsVectorConstant()) {
311         std::vector<uint32_t> neg_ids;
312         for (auto& comp : vector_const->GetComponents()) {
313           id = Reciprocal(const_mgr, comp);
314           if (id == 0) return false;
315           neg_ids.push_back(id);
316         }
317         const analysis::Constant* negated_const =
318             const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
319         id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
320       } else if (constants[1]->AsFloatConstant()) {
321         id = Reciprocal(const_mgr, constants[1]);
322         if (id == 0) return false;
323       } else {
324         // Don't fold a null constant.
325         return false;
326       }
327       inst->SetOpcode(SpvOpFMul);
328       inst->SetInOperands(
329           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
330            {SPV_OPERAND_TYPE_ID, {id}}});
331       return true;
332     }
333 
334     return false;
335   };
336 }
337 
338 // Elides consecutive negate instructions.
MergeNegateArithmetic()339 FoldingRule MergeNegateArithmetic() {
340   return [](IRContext* context, Instruction* inst,
341             const std::vector<const analysis::Constant*>& constants) {
342     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
343     (void)constants;
344     const analysis::Type* type =
345         context->get_type_mgr()->GetType(inst->type_id());
346     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
347       return false;
348 
349     Instruction* op_inst =
350         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
351     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
352       return false;
353 
354     if (op_inst->opcode() == inst->opcode()) {
355       // Elide negates.
356       inst->SetOpcode(SpvOpCopyObject);
357       inst->SetInOperands(
358           {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
359       return true;
360     }
361 
362     return false;
363   };
364 }
365 
366 // Merges negate into a mul or div operation if that operation contains a
367 // constant operand.
368 // Cases:
369 // -(x * 2) = x * -2
370 // -(2 * x) = x * -2
371 // -(x / 2) = x / -2
372 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()373 FoldingRule MergeNegateMulDivArithmetic() {
374   return [](IRContext* context, Instruction* inst,
375             const std::vector<const analysis::Constant*>& constants) {
376     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
377     (void)constants;
378     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
379     const analysis::Type* type =
380         context->get_type_mgr()->GetType(inst->type_id());
381     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
382       return false;
383 
384     Instruction* op_inst =
385         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
386     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
387       return false;
388 
389     uint32_t width = ElementWidth(type);
390     if (width != 32 && width != 64) return false;
391 
392     SpvOp opcode = op_inst->opcode();
393     if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
394         opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
395       std::vector<const analysis::Constant*> op_constants =
396           const_mgr->GetOperandConstants(op_inst);
397       // Merge negate into mul or div if one operand is constant.
398       if (op_constants[0] || op_constants[1]) {
399         bool zero_is_variable = op_constants[0] == nullptr;
400         const analysis::Constant* c = ConstInput(op_constants);
401         uint32_t neg_id = NegateConstant(const_mgr, c);
402         uint32_t non_const_id = zero_is_variable
403                                     ? op_inst->GetSingleWordInOperand(0u)
404                                     : op_inst->GetSingleWordInOperand(1u);
405         // Change this instruction to a mul/div.
406         inst->SetOpcode(op_inst->opcode());
407         if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
408           uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
409           uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
410           inst->SetInOperands(
411               {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
412         } else {
413           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
414                                {SPV_OPERAND_TYPE_ID, {neg_id}}});
415         }
416         return true;
417       }
418     }
419 
420     return false;
421   };
422 }
423 
424 // Merges negate into a add or sub operation if that operation contains a
425 // constant operand.
426 // Cases:
427 // -(x + 2) = -2 - x
428 // -(2 + x) = -2 - x
429 // -(x - 2) = 2 - x
430 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()431 FoldingRule MergeNegateAddSubArithmetic() {
432   return [](IRContext* context, Instruction* inst,
433             const std::vector<const analysis::Constant*>& constants) {
434     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
435     (void)constants;
436     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
437     const analysis::Type* type =
438         context->get_type_mgr()->GetType(inst->type_id());
439     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
440       return false;
441 
442     Instruction* op_inst =
443         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
444     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
445       return false;
446 
447     uint32_t width = ElementWidth(type);
448     if (width != 32 && width != 64) return false;
449 
450     if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
451         op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
452       std::vector<const analysis::Constant*> op_constants =
453           const_mgr->GetOperandConstants(op_inst);
454       if (op_constants[0] || op_constants[1]) {
455         bool zero_is_variable = op_constants[0] == nullptr;
456         bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
457                       (op_inst->opcode() == SpvOpIAdd);
458         bool swap_operands = !is_add || zero_is_variable;
459         bool negate_const = is_add;
460         const analysis::Constant* c = ConstInput(op_constants);
461         uint32_t const_id = 0;
462         if (negate_const) {
463           const_id = NegateConstant(const_mgr, c);
464         } else {
465           const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
466                                       : op_inst->GetSingleWordInOperand(0u);
467         }
468 
469         // Swap operands if necessary and make the instruction a subtraction.
470         uint32_t op0 =
471             zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
472         uint32_t op1 =
473             zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
474         if (swap_operands) std::swap(op0, op1);
475         inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
476         inst->SetInOperands(
477             {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
478         return true;
479       }
480     }
481 
482     return false;
483   };
484 }
485 
486 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)487 bool HasZero(const analysis::Constant* c) {
488   if (c->AsNullConstant()) {
489     return true;
490   }
491   if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
492     for (auto& comp : vec_const->GetComponents())
493       if (HasZero(comp)) return true;
494   } else {
495     assert(c->AsScalarConstant());
496     return c->AsScalarConstant()->IsZero();
497   }
498 
499   return false;
500 }
501 
502 // Performs |input1| |opcode| |input2| and returns the merged constant result
503 // id. Returns 0 if the result is not a valid value. The input types must be
504 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)505 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
506                                        SpvOp opcode,
507                                        const analysis::Constant* input1,
508                                        const analysis::Constant* input2) {
509   const analysis::Type* type = input1->type();
510   assert(type->AsFloat());
511   uint32_t width = type->AsFloat()->width();
512   assert(width == 32 || width == 64);
513   std::vector<uint32_t> words;
514 #define FOLD_OP(op)                                                          \
515   if (width == 64) {                                                         \
516     utils::FloatProxy<double> val =                                          \
517         input1->GetDouble() op input2->GetDouble();                          \
518     double dval = val.getAsFloat();                                          \
519     if (!IsValidResult(dval)) return 0;                                      \
520     words = val.GetWords();                                                  \
521   } else {                                                                   \
522     utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
523     float fval = val.getAsFloat();                                           \
524     if (!IsValidResult(fval)) return 0;                                      \
525     words = val.GetWords();                                                  \
526   } static_assert(true, "require extra semicolon")
527   switch (opcode) {
528     case SpvOpFMul:
529       FOLD_OP(*);
530       break;
531     case SpvOpFDiv:
532       if (HasZero(input2)) return 0;
533       FOLD_OP(/);
534       break;
535     case SpvOpFAdd:
536       FOLD_OP(+);
537       break;
538     case SpvOpFSub:
539       FOLD_OP(-);
540       break;
541     default:
542       assert(false && "Unexpected operation");
543       break;
544   }
545 #undef FOLD_OP
546   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
547   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
548 }
549 
550 // Performs |input1| |opcode| |input2| and returns the merged constant result
551 // id. Returns 0 if the result is not a valid value. The input types must be
552 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)553 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
554                                  SpvOp opcode, const analysis::Constant* input1,
555                                  const analysis::Constant* input2) {
556   assert(input1->type()->AsInteger());
557   const analysis::Integer* type = input1->type()->AsInteger();
558   uint32_t width = type->AsInteger()->width();
559   assert(width == 32 || width == 64);
560   std::vector<uint32_t> words;
561 #define FOLD_OP(op)                                        \
562   if (width == 64) {                                       \
563     if (type->IsSigned()) {                                \
564       int64_t val = input1->GetS64() op input2->GetS64();  \
565       words = ExtractInts(static_cast<uint64_t>(val));     \
566     } else {                                               \
567       uint64_t val = input1->GetU64() op input2->GetU64(); \
568       words = ExtractInts(val);                            \
569     }                                                      \
570   } else {                                                 \
571     if (type->IsSigned()) {                                \
572       int32_t val = input1->GetS32() op input2->GetS32();  \
573       words.push_back(static_cast<uint32_t>(val));         \
574     } else {                                               \
575       uint32_t val = input1->GetU32() op input2->GetU32(); \
576       words.push_back(val);                                \
577     }                                                      \
578   } static_assert(true, "require extra semicalon")
579   switch (opcode) {
580     case SpvOpIMul:
581       FOLD_OP(*);
582       break;
583     case SpvOpSDiv:
584     case SpvOpUDiv:
585       assert(false && "Should not merge integer division");
586       break;
587     case SpvOpIAdd:
588       FOLD_OP(+);
589       break;
590     case SpvOpISub:
591       FOLD_OP(-);
592       break;
593     default:
594       assert(false && "Unexpected operation");
595       break;
596   }
597 #undef FOLD_OP
598   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
599   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
600 }
601 
602 // Performs |input1| |opcode| |input2| and returns the merged constant result
603 // id. Returns 0 if the result is not a valid value. The input types must be
604 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)605 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
606                           const analysis::Constant* input1,
607                           const analysis::Constant* input2) {
608   assert(input1 && input2);
609   const analysis::Type* type = input1->type();
610   std::vector<uint32_t> words;
611   if (const analysis::Vector* vector_type = type->AsVector()) {
612     const analysis::Type* ele_type = vector_type->element_type();
613     for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
614       uint32_t id = 0;
615 
616       const analysis::Constant* input1_comp = nullptr;
617       if (const analysis::VectorConstant* input1_vector =
618               input1->AsVectorConstant()) {
619         input1_comp = input1_vector->GetComponents()[i];
620       } else {
621         assert(input1->AsNullConstant());
622         input1_comp = const_mgr->GetConstant(ele_type, {});
623       }
624 
625       const analysis::Constant* input2_comp = nullptr;
626       if (const analysis::VectorConstant* input2_vector =
627               input2->AsVectorConstant()) {
628         input2_comp = input2_vector->GetComponents()[i];
629       } else {
630         assert(input2->AsNullConstant());
631         input2_comp = const_mgr->GetConstant(ele_type, {});
632       }
633 
634       if (ele_type->AsFloat()) {
635         id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
636                                            input2_comp);
637       } else {
638         assert(ele_type->AsInteger());
639         id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
640                                      input2_comp);
641       }
642       if (id == 0) return 0;
643       words.push_back(id);
644     }
645     const analysis::Constant* merged_const =
646         const_mgr->GetConstant(type, words);
647     return const_mgr->GetDefiningInstruction(merged_const)->result_id();
648   } else if (type->AsFloat()) {
649     return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
650   } else {
651     assert(type->AsInteger());
652     return PerformIntegerOperation(const_mgr, opcode, input1, input2);
653   }
654 }
655 
656 // Merges consecutive multiplies where each contains one constant operand.
657 // Cases:
658 // 2 * (x * 2) = x * 4
659 // 2 * (2 * x) = x * 4
660 // (x * 2) * 2 = x * 4
661 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()662 FoldingRule MergeMulMulArithmetic() {
663   return [](IRContext* context, Instruction* inst,
664             const std::vector<const analysis::Constant*>& constants) {
665     assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
666     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
667     const analysis::Type* type =
668         context->get_type_mgr()->GetType(inst->type_id());
669     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
670       return false;
671 
672     uint32_t width = ElementWidth(type);
673     if (width != 32 && width != 64) return false;
674 
675     // Determine the constant input and the variable input in |inst|.
676     const analysis::Constant* const_input1 = ConstInput(constants);
677     if (!const_input1) return false;
678     Instruction* other_inst = NonConstInput(context, constants[0], inst);
679     if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
680       return false;
681 
682     if (other_inst->opcode() == inst->opcode()) {
683       std::vector<const analysis::Constant*> other_constants =
684           const_mgr->GetOperandConstants(other_inst);
685       const analysis::Constant* const_input2 = ConstInput(other_constants);
686       if (!const_input2) return false;
687 
688       bool other_first_is_variable = other_constants[0] == nullptr;
689       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
690                                             const_input1, const_input2);
691       if (merged_id == 0) return false;
692 
693       uint32_t non_const_id = other_first_is_variable
694                                   ? other_inst->GetSingleWordInOperand(0u)
695                                   : other_inst->GetSingleWordInOperand(1u);
696       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
697                            {SPV_OPERAND_TYPE_ID, {merged_id}}});
698       return true;
699     }
700 
701     return false;
702   };
703 }
704 
705 // Merges divides into subsequent multiplies if each instruction contains one
706 // constant operand. Does not support integer operations.
707 // Cases:
708 // 2 * (x / 2) = x * 1
709 // 2 * (2 / x) = 4 / x
710 // (x / 2) * 2 = x * 1
711 // (2 / x) * 2 = 4 / x
712 // (y / x) * x = y
713 // x * (y / x) = y
MergeMulDivArithmetic()714 FoldingRule MergeMulDivArithmetic() {
715   return [](IRContext* context, Instruction* inst,
716             const std::vector<const analysis::Constant*>& constants) {
717     assert(inst->opcode() == SpvOpFMul);
718     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
719     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
720 
721     const analysis::Type* type =
722         context->get_type_mgr()->GetType(inst->type_id());
723     if (!inst->IsFloatingPointFoldingAllowed()) return false;
724 
725     uint32_t width = ElementWidth(type);
726     if (width != 32 && width != 64) return false;
727 
728     for (uint32_t i = 0; i < 2; i++) {
729       uint32_t op_id = inst->GetSingleWordInOperand(i);
730       Instruction* op_inst = def_use_mgr->GetDef(op_id);
731       if (op_inst->opcode() == SpvOpFDiv) {
732         if (op_inst->GetSingleWordInOperand(1) ==
733             inst->GetSingleWordInOperand(1 - i)) {
734           inst->SetOpcode(SpvOpCopyObject);
735           inst->SetInOperands(
736               {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
737           return true;
738         }
739       }
740     }
741 
742     const analysis::Constant* const_input1 = ConstInput(constants);
743     if (!const_input1) return false;
744     Instruction* other_inst = NonConstInput(context, constants[0], inst);
745     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
746 
747     if (other_inst->opcode() == SpvOpFDiv) {
748       std::vector<const analysis::Constant*> other_constants =
749           const_mgr->GetOperandConstants(other_inst);
750       const analysis::Constant* const_input2 = ConstInput(other_constants);
751       if (!const_input2 || HasZero(const_input2)) return false;
752 
753       bool other_first_is_variable = other_constants[0] == nullptr;
754       // If the variable value is the second operand of the divide, multiply
755       // the constants together. Otherwise divide the constants.
756       uint32_t merged_id = PerformOperation(
757           const_mgr,
758           other_first_is_variable ? other_inst->opcode() : inst->opcode(),
759           const_input1, const_input2);
760       if (merged_id == 0) return false;
761 
762       uint32_t non_const_id = other_first_is_variable
763                                   ? other_inst->GetSingleWordInOperand(0u)
764                                   : other_inst->GetSingleWordInOperand(1u);
765 
766       // If the variable value is on the second operand of the div, then this
767       // operation is a div. Otherwise it should be a multiply.
768       inst->SetOpcode(other_first_is_variable ? inst->opcode()
769                                               : other_inst->opcode());
770       if (other_first_is_variable) {
771         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
772                              {SPV_OPERAND_TYPE_ID, {merged_id}}});
773       } else {
774         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
775                              {SPV_OPERAND_TYPE_ID, {non_const_id}}});
776       }
777       return true;
778     }
779 
780     return false;
781   };
782 }
783 
784 // Merges multiply of constant and negation.
785 // Cases:
786 // (-x) * 2 = x * -2
787 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()788 FoldingRule MergeMulNegateArithmetic() {
789   return [](IRContext* context, Instruction* inst,
790             const std::vector<const analysis::Constant*>& constants) {
791     assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
792     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
793     const analysis::Type* type =
794         context->get_type_mgr()->GetType(inst->type_id());
795     bool uses_float = HasFloatingPoint(type);
796     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
797 
798     uint32_t width = ElementWidth(type);
799     if (width != 32 && width != 64) return false;
800 
801     const analysis::Constant* const_input1 = ConstInput(constants);
802     if (!const_input1) return false;
803     Instruction* other_inst = NonConstInput(context, constants[0], inst);
804     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
805       return false;
806 
807     if (other_inst->opcode() == SpvOpFNegate ||
808         other_inst->opcode() == SpvOpSNegate) {
809       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
810 
811       inst->SetInOperands(
812           {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
813            {SPV_OPERAND_TYPE_ID, {neg_id}}});
814       return true;
815     }
816 
817     return false;
818   };
819 }
820 
821 // Merges consecutive divides if each instruction contains one constant operand.
822 // Does not support integer division.
823 // Cases:
824 // 2 / (x / 2) = 4 / x
825 // 4 / (2 / x) = 2 * x
826 // (4 / x) / 2 = 2 / x
827 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()828 FoldingRule MergeDivDivArithmetic() {
829   return [](IRContext* context, Instruction* inst,
830             const std::vector<const analysis::Constant*>& constants) {
831     assert(inst->opcode() == SpvOpFDiv);
832     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
833     const analysis::Type* type =
834         context->get_type_mgr()->GetType(inst->type_id());
835     if (!inst->IsFloatingPointFoldingAllowed()) return false;
836 
837     uint32_t width = ElementWidth(type);
838     if (width != 32 && width != 64) return false;
839 
840     const analysis::Constant* const_input1 = ConstInput(constants);
841     if (!const_input1 || HasZero(const_input1)) return false;
842     Instruction* other_inst = NonConstInput(context, constants[0], inst);
843     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
844 
845     bool first_is_variable = constants[0] == nullptr;
846     if (other_inst->opcode() == inst->opcode()) {
847       std::vector<const analysis::Constant*> other_constants =
848           const_mgr->GetOperandConstants(other_inst);
849       const analysis::Constant* const_input2 = ConstInput(other_constants);
850       if (!const_input2 || HasZero(const_input2)) return false;
851 
852       bool other_first_is_variable = other_constants[0] == nullptr;
853 
854       SpvOp merge_op = inst->opcode();
855       if (other_first_is_variable) {
856         // Constants magnify.
857         merge_op = SpvOpFMul;
858       }
859 
860       // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
861       // because it is commutative.
862       if (first_is_variable) std::swap(const_input1, const_input2);
863       uint32_t merged_id =
864           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
865       if (merged_id == 0) return false;
866 
867       uint32_t non_const_id = other_first_is_variable
868                                   ? other_inst->GetSingleWordInOperand(0u)
869                                   : other_inst->GetSingleWordInOperand(1u);
870 
871       SpvOp op = inst->opcode();
872       if (!first_is_variable && !other_first_is_variable) {
873         // Effectively div of 1/x, so change to multiply.
874         op = SpvOpFMul;
875       }
876 
877       uint32_t op1 = merged_id;
878       uint32_t op2 = non_const_id;
879       if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
880       inst->SetOpcode(op);
881       inst->SetInOperands(
882           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
883       return true;
884     }
885 
886     return false;
887   };
888 }
889 
890 // Fold multiplies succeeded by divides where each instruction contains a
891 // constant operand. Does not support integer divide.
892 // Cases:
893 // 4 / (x * 2) = 2 / x
894 // 4 / (2 * x) = 2 / x
895 // (x * 4) / 2 = x * 2
896 // (4 * x) / 2 = x * 2
897 // (x * y) / x = y
898 // (y * x) / x = y
MergeDivMulArithmetic()899 FoldingRule MergeDivMulArithmetic() {
900   return [](IRContext* context, Instruction* inst,
901             const std::vector<const analysis::Constant*>& constants) {
902     assert(inst->opcode() == SpvOpFDiv);
903     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
904     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
905 
906     const analysis::Type* type =
907         context->get_type_mgr()->GetType(inst->type_id());
908     if (!inst->IsFloatingPointFoldingAllowed()) return false;
909 
910     uint32_t width = ElementWidth(type);
911     if (width != 32 && width != 64) return false;
912 
913     uint32_t op_id = inst->GetSingleWordInOperand(0);
914     Instruction* op_inst = def_use_mgr->GetDef(op_id);
915 
916     if (op_inst->opcode() == SpvOpFMul) {
917       for (uint32_t i = 0; i < 2; i++) {
918         if (op_inst->GetSingleWordInOperand(i) ==
919             inst->GetSingleWordInOperand(1)) {
920           inst->SetOpcode(SpvOpCopyObject);
921           inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
922                                 {op_inst->GetSingleWordInOperand(1 - i)}}});
923           return true;
924         }
925       }
926     }
927 
928     const analysis::Constant* const_input1 = ConstInput(constants);
929     if (!const_input1 || HasZero(const_input1)) return false;
930     Instruction* other_inst = NonConstInput(context, constants[0], inst);
931     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
932 
933     bool first_is_variable = constants[0] == nullptr;
934     if (other_inst->opcode() == SpvOpFMul) {
935       std::vector<const analysis::Constant*> other_constants =
936           const_mgr->GetOperandConstants(other_inst);
937       const analysis::Constant* const_input2 = ConstInput(other_constants);
938       if (!const_input2) return false;
939 
940       bool other_first_is_variable = other_constants[0] == nullptr;
941 
942       // This is an x / (*) case. Swap the inputs.
943       if (first_is_variable) std::swap(const_input1, const_input2);
944       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
945                                             const_input1, const_input2);
946       if (merged_id == 0) return false;
947 
948       uint32_t non_const_id = other_first_is_variable
949                                   ? other_inst->GetSingleWordInOperand(0u)
950                                   : other_inst->GetSingleWordInOperand(1u);
951 
952       uint32_t op1 = merged_id;
953       uint32_t op2 = non_const_id;
954       if (first_is_variable) std::swap(op1, op2);
955 
956       // Convert to multiply
957       if (first_is_variable) inst->SetOpcode(other_inst->opcode());
958       inst->SetInOperands(
959           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
960       return true;
961     }
962 
963     return false;
964   };
965 }
966 
967 // Fold divides of a constant and a negation.
968 // Cases:
969 // (-x) / 2 = x / -2
970 // 2 / (-x) = 2 / -x
MergeDivNegateArithmetic()971 FoldingRule MergeDivNegateArithmetic() {
972   return [](IRContext* context, Instruction* inst,
973             const std::vector<const analysis::Constant*>& constants) {
974     assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
975            inst->opcode() == SpvOpUDiv);
976     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
977     const analysis::Type* type =
978         context->get_type_mgr()->GetType(inst->type_id());
979     bool uses_float = HasFloatingPoint(type);
980     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
981 
982     uint32_t width = ElementWidth(type);
983     if (width != 32 && width != 64) return false;
984 
985     const analysis::Constant* const_input1 = ConstInput(constants);
986     if (!const_input1) return false;
987     Instruction* other_inst = NonConstInput(context, constants[0], inst);
988     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
989       return false;
990 
991     bool first_is_variable = constants[0] == nullptr;
992     if (other_inst->opcode() == SpvOpFNegate ||
993         other_inst->opcode() == SpvOpSNegate) {
994       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
995 
996       if (first_is_variable) {
997         inst->SetInOperands(
998             {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
999              {SPV_OPERAND_TYPE_ID, {neg_id}}});
1000       } else {
1001         inst->SetInOperands(
1002             {{SPV_OPERAND_TYPE_ID, {neg_id}},
1003              {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1004       }
1005       return true;
1006     }
1007 
1008     return false;
1009   };
1010 }
1011 
1012 // Folds addition of a constant and a negation.
1013 // Cases:
1014 // (-x) + 2 = 2 - x
1015 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()1016 FoldingRule MergeAddNegateArithmetic() {
1017   return [](IRContext* context, Instruction* inst,
1018             const std::vector<const analysis::Constant*>& constants) {
1019     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1020     const analysis::Type* type =
1021         context->get_type_mgr()->GetType(inst->type_id());
1022     bool uses_float = HasFloatingPoint(type);
1023     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1024 
1025     const analysis::Constant* const_input1 = ConstInput(constants);
1026     if (!const_input1) return false;
1027     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1028     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1029       return false;
1030 
1031     if (other_inst->opcode() == SpvOpSNegate ||
1032         other_inst->opcode() == SpvOpFNegate) {
1033       inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
1034       uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
1035                                        : inst->GetSingleWordInOperand(1u);
1036       inst->SetInOperands(
1037           {{SPV_OPERAND_TYPE_ID, {const_id}},
1038            {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1039       return true;
1040     }
1041     return false;
1042   };
1043 }
1044 
1045 // Folds subtraction of a constant and a negation.
1046 // Cases:
1047 // (-x) - 2 = -2 - x
1048 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()1049 FoldingRule MergeSubNegateArithmetic() {
1050   return [](IRContext* context, Instruction* inst,
1051             const std::vector<const analysis::Constant*>& constants) {
1052     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1053     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1054     const analysis::Type* type =
1055         context->get_type_mgr()->GetType(inst->type_id());
1056     bool uses_float = HasFloatingPoint(type);
1057     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1058 
1059     uint32_t width = ElementWidth(type);
1060     if (width != 32 && width != 64) return false;
1061 
1062     const analysis::Constant* const_input1 = ConstInput(constants);
1063     if (!const_input1) return false;
1064     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1065     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1066       return false;
1067 
1068     if (other_inst->opcode() == SpvOpSNegate ||
1069         other_inst->opcode() == SpvOpFNegate) {
1070       uint32_t op1 = 0;
1071       uint32_t op2 = 0;
1072       SpvOp opcode = inst->opcode();
1073       if (constants[0] != nullptr) {
1074         op1 = other_inst->GetSingleWordInOperand(0u);
1075         op2 = inst->GetSingleWordInOperand(0u);
1076         opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
1077       } else {
1078         op1 = NegateConstant(const_mgr, const_input1);
1079         op2 = other_inst->GetSingleWordInOperand(0u);
1080       }
1081 
1082       inst->SetOpcode(opcode);
1083       inst->SetInOperands(
1084           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1085       return true;
1086     }
1087     return false;
1088   };
1089 }
1090 
1091 // Folds addition of an addition where each operation has a constant operand.
1092 // Cases:
1093 // (x + 2) + 2 = x + 4
1094 // (2 + x) + 2 = x + 4
1095 // 2 + (x + 2) = x + 4
1096 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1097 FoldingRule MergeAddAddArithmetic() {
1098   return [](IRContext* context, Instruction* inst,
1099             const std::vector<const analysis::Constant*>& constants) {
1100     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1101     const analysis::Type* type =
1102         context->get_type_mgr()->GetType(inst->type_id());
1103     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1104     bool uses_float = HasFloatingPoint(type);
1105     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1106 
1107     uint32_t width = ElementWidth(type);
1108     if (width != 32 && width != 64) return false;
1109 
1110     const analysis::Constant* const_input1 = ConstInput(constants);
1111     if (!const_input1) return false;
1112     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1113     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1114       return false;
1115 
1116     if (other_inst->opcode() == SpvOpFAdd ||
1117         other_inst->opcode() == SpvOpIAdd) {
1118       std::vector<const analysis::Constant*> other_constants =
1119           const_mgr->GetOperandConstants(other_inst);
1120       const analysis::Constant* const_input2 = ConstInput(other_constants);
1121       if (!const_input2) return false;
1122 
1123       Instruction* non_const_input =
1124           NonConstInput(context, other_constants[0], other_inst);
1125       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1126                                             const_input1, const_input2);
1127       if (merged_id == 0) return false;
1128 
1129       inst->SetInOperands(
1130           {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1131            {SPV_OPERAND_TYPE_ID, {merged_id}}});
1132       return true;
1133     }
1134     return false;
1135   };
1136 }
1137 
1138 // Folds addition of a subtraction where each operation has a constant operand.
1139 // Cases:
1140 // (x - 2) + 2 = x + 0
1141 // (2 - x) + 2 = 4 - x
1142 // 2 + (x - 2) = x + 0
1143 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1144 FoldingRule MergeAddSubArithmetic() {
1145   return [](IRContext* context, Instruction* inst,
1146             const std::vector<const analysis::Constant*>& constants) {
1147     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1148     const analysis::Type* type =
1149         context->get_type_mgr()->GetType(inst->type_id());
1150     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1151     bool uses_float = HasFloatingPoint(type);
1152     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1153 
1154     uint32_t width = ElementWidth(type);
1155     if (width != 32 && width != 64) return false;
1156 
1157     const analysis::Constant* const_input1 = ConstInput(constants);
1158     if (!const_input1) return false;
1159     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1160     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1161       return false;
1162 
1163     if (other_inst->opcode() == SpvOpFSub ||
1164         other_inst->opcode() == SpvOpISub) {
1165       std::vector<const analysis::Constant*> other_constants =
1166           const_mgr->GetOperandConstants(other_inst);
1167       const analysis::Constant* const_input2 = ConstInput(other_constants);
1168       if (!const_input2) return false;
1169 
1170       bool first_is_variable = other_constants[0] == nullptr;
1171       SpvOp op = inst->opcode();
1172       uint32_t op1 = 0;
1173       uint32_t op2 = 0;
1174       if (first_is_variable) {
1175         // Subtract constants. Non-constant operand is first.
1176         op1 = other_inst->GetSingleWordInOperand(0u);
1177         op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1178                                const_input2);
1179       } else {
1180         // Add constants. Constant operand is first. Change the opcode.
1181         op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1182                                const_input2);
1183         op2 = other_inst->GetSingleWordInOperand(1u);
1184         op = other_inst->opcode();
1185       }
1186       if (op1 == 0 || op2 == 0) return false;
1187 
1188       inst->SetOpcode(op);
1189       inst->SetInOperands(
1190           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1191       return true;
1192     }
1193     return false;
1194   };
1195 }
1196 
1197 // Folds subtraction of an addition where each operand has a constant operand.
1198 // Cases:
1199 // (x + 2) - 2 = x + 0
1200 // (2 + x) - 2 = x + 0
1201 // 2 - (x + 2) = 0 - x
1202 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1203 FoldingRule MergeSubAddArithmetic() {
1204   return [](IRContext* context, Instruction* inst,
1205             const std::vector<const analysis::Constant*>& constants) {
1206     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1207     const analysis::Type* type =
1208         context->get_type_mgr()->GetType(inst->type_id());
1209     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1210     bool uses_float = HasFloatingPoint(type);
1211     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1212 
1213     uint32_t width = ElementWidth(type);
1214     if (width != 32 && width != 64) return false;
1215 
1216     const analysis::Constant* const_input1 = ConstInput(constants);
1217     if (!const_input1) return false;
1218     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1219     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1220       return false;
1221 
1222     if (other_inst->opcode() == SpvOpFAdd ||
1223         other_inst->opcode() == SpvOpIAdd) {
1224       std::vector<const analysis::Constant*> other_constants =
1225           const_mgr->GetOperandConstants(other_inst);
1226       const analysis::Constant* const_input2 = ConstInput(other_constants);
1227       if (!const_input2) return false;
1228 
1229       Instruction* non_const_input =
1230           NonConstInput(context, other_constants[0], other_inst);
1231 
1232       // If the first operand of the sub is not a constant, swap the constants
1233       // so the subtraction has the correct operands.
1234       if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1235       // Subtract the constants.
1236       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1237                                             const_input1, const_input2);
1238       SpvOp op = inst->opcode();
1239       uint32_t op1 = 0;
1240       uint32_t op2 = 0;
1241       if (constants[0] == nullptr) {
1242         // Non-constant operand is first. Change the opcode.
1243         op1 = non_const_input->result_id();
1244         op2 = merged_id;
1245         op = other_inst->opcode();
1246       } else {
1247         // Constant operand is first.
1248         op1 = merged_id;
1249         op2 = non_const_input->result_id();
1250       }
1251       if (op1 == 0 || op2 == 0) return false;
1252 
1253       inst->SetOpcode(op);
1254       inst->SetInOperands(
1255           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1256       return true;
1257     }
1258     return false;
1259   };
1260 }
1261 
1262 // Folds subtraction of a subtraction where each operand has a constant operand.
1263 // Cases:
1264 // (x - 2) - 2 = x - 4
1265 // (2 - x) - 2 = 0 - x
1266 // 2 - (x - 2) = 4 - x
1267 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1268 FoldingRule MergeSubSubArithmetic() {
1269   return [](IRContext* context, Instruction* inst,
1270             const std::vector<const analysis::Constant*>& constants) {
1271     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1272     const analysis::Type* type =
1273         context->get_type_mgr()->GetType(inst->type_id());
1274     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1275     bool uses_float = HasFloatingPoint(type);
1276     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1277 
1278     uint32_t width = ElementWidth(type);
1279     if (width != 32 && width != 64) return false;
1280 
1281     const analysis::Constant* const_input1 = ConstInput(constants);
1282     if (!const_input1) return false;
1283     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1284     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1285       return false;
1286 
1287     if (other_inst->opcode() == SpvOpFSub ||
1288         other_inst->opcode() == SpvOpISub) {
1289       std::vector<const analysis::Constant*> other_constants =
1290           const_mgr->GetOperandConstants(other_inst);
1291       const analysis::Constant* const_input2 = ConstInput(other_constants);
1292       if (!const_input2) return false;
1293 
1294       Instruction* non_const_input =
1295           NonConstInput(context, other_constants[0], other_inst);
1296 
1297       // Merge the constants.
1298       uint32_t merged_id = 0;
1299       SpvOp merge_op = inst->opcode();
1300       if (other_constants[0] == nullptr) {
1301         merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1302       } else if (constants[0] == nullptr) {
1303         std::swap(const_input1, const_input2);
1304       }
1305       merged_id =
1306           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1307       if (merged_id == 0) return false;
1308 
1309       SpvOp op = inst->opcode();
1310       if (constants[0] != nullptr && other_constants[0] != nullptr) {
1311         // Change the operation.
1312         op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1313       }
1314 
1315       uint32_t op1 = 0;
1316       uint32_t op2 = 0;
1317       if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1318         op1 = merged_id;
1319         op2 = non_const_input->result_id();
1320       } else {
1321         op1 = non_const_input->result_id();
1322         op2 = merged_id;
1323       }
1324 
1325       inst->SetOpcode(op);
1326       inst->SetInOperands(
1327           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1328       return true;
1329     }
1330     return false;
1331   };
1332 }
1333 
1334 // Helper function for MergeGenericAddSubArithmetic. If |addend| and
1335 // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
MergeGenericAddendSub(uint32_t addend,uint32_t sub,Instruction * inst)1336 bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1337   IRContext* context = inst->context();
1338   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1339   Instruction* sub_inst = def_use_mgr->GetDef(sub);
1340   if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
1341     return false;
1342   if (sub_inst->opcode() == SpvOpFSub &&
1343       !sub_inst->IsFloatingPointFoldingAllowed())
1344     return false;
1345   if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1346   inst->SetOpcode(SpvOpCopyObject);
1347   inst->SetInOperands(
1348       {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1349   context->UpdateDefUse(inst);
1350   return true;
1351 }
1352 
1353 // Folds addition of a subtraction where the subtrahend is equal to the
1354 // other addend. Return a copy of the minuend. Accepts generic (const and
1355 // non-const) operands.
1356 // Cases:
1357 // (a - b) + b = a
1358 // b + (a - b) = a
MergeGenericAddSubArithmetic()1359 FoldingRule MergeGenericAddSubArithmetic() {
1360   return [](IRContext* context, Instruction* inst,
1361             const std::vector<const analysis::Constant*>&) {
1362     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1363     const analysis::Type* type =
1364         context->get_type_mgr()->GetType(inst->type_id());
1365     bool uses_float = HasFloatingPoint(type);
1366     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1367 
1368     uint32_t width = ElementWidth(type);
1369     if (width != 32 && width != 64) return false;
1370 
1371     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1372     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1373     if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1374     return MergeGenericAddendSub(add_op1, add_op0, inst);
1375   };
1376 }
1377 
1378 // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1379 // generate |factor0_0| * (|factor0_1| + |factor1_1|).
FactorAddMulsOpnds(uint32_t factor0_0,uint32_t factor0_1,uint32_t factor1_0,uint32_t factor1_1,Instruction * inst)1380 bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1381                         uint32_t factor1_0, uint32_t factor1_1,
1382                         Instruction* inst) {
1383   IRContext* context = inst->context();
1384   if (factor0_0 != factor1_0) return false;
1385   InstructionBuilder ir_builder(
1386       context, inst,
1387       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1388   Instruction* new_add_inst = ir_builder.AddBinaryOp(
1389       inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1390   inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
1391   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1392                        {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1393   context->UpdateDefUse(inst);
1394   return true;
1395 }
1396 
1397 // Perform the following factoring identity, handling all operand order
1398 // combinations: (a * b) + (a * c) = a * (b + c)
FactorAddMuls()1399 FoldingRule FactorAddMuls() {
1400   return [](IRContext* context, Instruction* inst,
1401             const std::vector<const analysis::Constant*>&) {
1402     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1403     const analysis::Type* type =
1404         context->get_type_mgr()->GetType(inst->type_id());
1405     bool uses_float = HasFloatingPoint(type);
1406     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1407 
1408     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1409     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1410     Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1411     if (add_op0_inst->opcode() != SpvOpFMul &&
1412         add_op0_inst->opcode() != SpvOpIMul)
1413       return false;
1414     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1415     Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1416     if (add_op1_inst->opcode() != SpvOpFMul &&
1417         add_op1_inst->opcode() != SpvOpIMul)
1418       return false;
1419 
1420     // Only perform this optimization if both of the muls only have one use.
1421     // Otherwise this is a deoptimization in size and performance.
1422     if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1423     if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1424 
1425     if (add_op0_inst->opcode() == SpvOpFMul &&
1426         (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1427          !add_op1_inst->IsFloatingPointFoldingAllowed()))
1428       return false;
1429 
1430     for (int i = 0; i < 2; i++) {
1431       for (int j = 0; j < 2; j++) {
1432         // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1433         if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1434                                add_op0_inst->GetSingleWordInOperand(1 - i),
1435                                add_op1_inst->GetSingleWordInOperand(j),
1436                                add_op1_inst->GetSingleWordInOperand(1 - j),
1437                                inst))
1438           return true;
1439       }
1440     }
1441     return false;
1442   };
1443 }
1444 
IntMultipleBy1()1445 FoldingRule IntMultipleBy1() {
1446   return [](IRContext*, Instruction* inst,
1447             const std::vector<const analysis::Constant*>& constants) {
1448     assert(inst->opcode() == SpvOpIMul && "Wrong opcode.  Should be OpIMul.");
1449     for (uint32_t i = 0; i < 2; i++) {
1450       if (constants[i] == nullptr) {
1451         continue;
1452       }
1453       const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1454       if (int_constant) {
1455         uint32_t width = ElementWidth(int_constant->type());
1456         if (width != 32 && width != 64) return false;
1457         bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1458                                     : int_constant->GetU64BitValue() == 1ull;
1459         if (is_one) {
1460           inst->SetOpcode(SpvOpCopyObject);
1461           inst->SetInOperands(
1462               {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1463           return true;
1464         }
1465       }
1466     }
1467     return false;
1468   };
1469 }
1470 
CompositeConstructFeedingExtract()1471 FoldingRule CompositeConstructFeedingExtract() {
1472   return [](IRContext* context, Instruction* inst,
1473             const std::vector<const analysis::Constant*>&) {
1474     // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1475     // then we can simply use the appropriate element in the construction.
1476     assert(inst->opcode() == SpvOpCompositeExtract &&
1477            "Wrong opcode.  Should be OpCompositeExtract.");
1478     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1479     analysis::TypeManager* type_mgr = context->get_type_mgr();
1480 
1481     // If there are no index operands, then this rule cannot do anything.
1482     if (inst->NumInOperands() <= 1) {
1483       return false;
1484     }
1485 
1486     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1487     Instruction* cinst = def_use_mgr->GetDef(cid);
1488 
1489     if (cinst->opcode() != SpvOpCompositeConstruct) {
1490       return false;
1491     }
1492 
1493     std::vector<Operand> operands;
1494     analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
1495     if (composite_type->AsVector() == nullptr) {
1496       // Get the element being extracted from the OpCompositeConstruct
1497       // Since it is not a vector, it is simple to extract the single element.
1498       uint32_t element_index = inst->GetSingleWordInOperand(1);
1499       uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
1500       operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1501 
1502       // Add the remaining indices for extraction.
1503       for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1504         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1505                             {inst->GetSingleWordInOperand(i)}});
1506       }
1507 
1508     } else {
1509       // With vectors we have to handle the case where it is concatenating
1510       // vectors.
1511       assert(inst->NumInOperands() == 2 &&
1512              "Expecting a vector of scalar values.");
1513 
1514       uint32_t element_index = inst->GetSingleWordInOperand(1);
1515       for (uint32_t construct_index = 0;
1516            construct_index < cinst->NumInOperands(); ++construct_index) {
1517         uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
1518         Instruction* element_def = def_use_mgr->GetDef(element_id);
1519         analysis::Vector* element_type =
1520             type_mgr->GetType(element_def->type_id())->AsVector();
1521         if (element_type) {
1522           uint32_t vector_size = element_type->element_count();
1523           if (vector_size <= element_index) {
1524             // The element we want comes after this vector.
1525             element_index -= vector_size;
1526           } else {
1527             // We want an element of this vector.
1528             operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1529             operands.push_back(
1530                 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
1531             break;
1532           }
1533         } else {
1534           if (element_index == 0) {
1535             // This is a scalar, and we this is the element we are extracting.
1536             operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1537             break;
1538           } else {
1539             // Skip over this scalar value.
1540             --element_index;
1541           }
1542         }
1543       }
1544     }
1545 
1546     // If there were no extra indices, then we have the final object.  No need
1547     // to extract even more.
1548     if (operands.size() == 1) {
1549       inst->SetOpcode(SpvOpCopyObject);
1550     }
1551 
1552     inst->SetInOperands(std::move(operands));
1553     return true;
1554   };
1555 }
1556 
1557 // If the OpCompositeConstruct is simply putting back together elements that
1558 // where extracted from the same source, we can simply reuse the source.
1559 //
1560 // This is a common code pattern because of the way that scalar replacement
1561 // works.
CompositeExtractFeedingConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1562 bool CompositeExtractFeedingConstruct(
1563     IRContext* context, Instruction* inst,
1564     const std::vector<const analysis::Constant*>&) {
1565   assert(inst->opcode() == SpvOpCompositeConstruct &&
1566          "Wrong opcode.  Should be OpCompositeConstruct.");
1567   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1568   uint32_t original_id = 0;
1569 
1570   if (inst->NumInOperands() == 0) {
1571     // The struct being constructed has no members.
1572     return false;
1573   }
1574 
1575   // Check each element to make sure they are:
1576   // - extractions
1577   // - extracting the same position they are inserting
1578   // - all extract from the same id.
1579   for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1580     const uint32_t element_id = inst->GetSingleWordInOperand(i);
1581     Instruction* element_inst = def_use_mgr->GetDef(element_id);
1582 
1583     if (element_inst->opcode() != SpvOpCompositeExtract) {
1584       return false;
1585     }
1586 
1587     if (element_inst->NumInOperands() != 2) {
1588       return false;
1589     }
1590 
1591     if (element_inst->GetSingleWordInOperand(1) != i) {
1592       return false;
1593     }
1594 
1595     if (i == 0) {
1596       original_id =
1597           element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1598     } else if (original_id !=
1599                element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1600       return false;
1601     }
1602   }
1603 
1604   // The last check it to see that the object being extracted from is the
1605   // correct type.
1606   Instruction* original_inst = def_use_mgr->GetDef(original_id);
1607   if (original_inst->type_id() != inst->type_id()) {
1608     return false;
1609   }
1610 
1611   // Simplify by using the original object.
1612   inst->SetOpcode(SpvOpCopyObject);
1613   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1614   return true;
1615 }
1616 
InsertFeedingExtract()1617 FoldingRule InsertFeedingExtract() {
1618   return [](IRContext* context, Instruction* inst,
1619             const std::vector<const analysis::Constant*>&) {
1620     assert(inst->opcode() == SpvOpCompositeExtract &&
1621            "Wrong opcode.  Should be OpCompositeExtract.");
1622     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1623     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1624     Instruction* cinst = def_use_mgr->GetDef(cid);
1625 
1626     if (cinst->opcode() != SpvOpCompositeInsert) {
1627       return false;
1628     }
1629 
1630     // Find the first position where the list of insert and extract indicies
1631     // differ, if at all.
1632     uint32_t i;
1633     for (i = 1; i < inst->NumInOperands(); ++i) {
1634       if (i + 1 >= cinst->NumInOperands()) {
1635         break;
1636       }
1637 
1638       if (inst->GetSingleWordInOperand(i) !=
1639           cinst->GetSingleWordInOperand(i + 1)) {
1640         break;
1641       }
1642     }
1643 
1644     // We are extracting the element that was inserted.
1645     if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1646       inst->SetOpcode(SpvOpCopyObject);
1647       inst->SetInOperands(
1648           {{SPV_OPERAND_TYPE_ID,
1649             {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1650       return true;
1651     }
1652 
1653     // Extracting the value that was inserted along with values for the base
1654     // composite.  Cannot do anything.
1655     if (i == inst->NumInOperands()) {
1656       return false;
1657     }
1658 
1659     // Extracting an element of the value that was inserted.  Extract from
1660     // that value directly.
1661     if (i + 1 == cinst->NumInOperands()) {
1662       std::vector<Operand> operands;
1663       operands.push_back(
1664           {SPV_OPERAND_TYPE_ID,
1665            {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1666       for (; i < inst->NumInOperands(); ++i) {
1667         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1668                             {inst->GetSingleWordInOperand(i)}});
1669       }
1670       inst->SetInOperands(std::move(operands));
1671       return true;
1672     }
1673 
1674     // Extracting a value that is disjoint from the element being inserted.
1675     // Rewrite the extract to use the composite input to the insert.
1676     std::vector<Operand> operands;
1677     operands.push_back(
1678         {SPV_OPERAND_TYPE_ID,
1679          {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1680     for (i = 1; i < inst->NumInOperands(); ++i) {
1681       operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1682                           {inst->GetSingleWordInOperand(i)}});
1683     }
1684     inst->SetInOperands(std::move(operands));
1685     return true;
1686   };
1687 }
1688 
1689 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1690 // operands of the VectorShuffle.  We just need to adjust the index in the
1691 // extract instruction.
VectorShuffleFeedingExtract()1692 FoldingRule VectorShuffleFeedingExtract() {
1693   return [](IRContext* context, Instruction* inst,
1694             const std::vector<const analysis::Constant*>&) {
1695     assert(inst->opcode() == SpvOpCompositeExtract &&
1696            "Wrong opcode.  Should be OpCompositeExtract.");
1697     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1698     analysis::TypeManager* type_mgr = context->get_type_mgr();
1699     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1700     Instruction* cinst = def_use_mgr->GetDef(cid);
1701 
1702     if (cinst->opcode() != SpvOpVectorShuffle) {
1703       return false;
1704     }
1705 
1706     // Find the size of the first vector operand of the VectorShuffle
1707     Instruction* first_input =
1708         def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1709     analysis::Type* first_input_type =
1710         type_mgr->GetType(first_input->type_id());
1711     assert(first_input_type->AsVector() &&
1712            "Input to vector shuffle should be vectors.");
1713     uint32_t first_input_size = first_input_type->AsVector()->element_count();
1714 
1715     // Get index of the element the vector shuffle is placing in the position
1716     // being extracted.
1717     uint32_t new_index =
1718         cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1719 
1720     // Extracting an undefined value so fold this extract into an undef.
1721     const uint32_t undef_literal_value = 0xffffffff;
1722     if (new_index == undef_literal_value) {
1723       inst->SetOpcode(SpvOpUndef);
1724       inst->SetInOperands({});
1725       return true;
1726     }
1727 
1728     // Get the id of the of the vector the elemtent comes from, and update the
1729     // index if needed.
1730     uint32_t new_vector = 0;
1731     if (new_index < first_input_size) {
1732       new_vector = cinst->GetSingleWordInOperand(0);
1733     } else {
1734       new_vector = cinst->GetSingleWordInOperand(1);
1735       new_index -= first_input_size;
1736     }
1737 
1738     // Update the extract instruction.
1739     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1740     inst->SetInOperand(1, {new_index});
1741     return true;
1742   };
1743 }
1744 
1745 // When an FMix with is feeding an Extract that extracts an element whose
1746 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1747 // operands of the FMix.
FMixFeedingExtract()1748 FoldingRule FMixFeedingExtract() {
1749   return [](IRContext* context, Instruction* inst,
1750             const std::vector<const analysis::Constant*>&) {
1751     assert(inst->opcode() == SpvOpCompositeExtract &&
1752            "Wrong opcode.  Should be OpCompositeExtract.");
1753     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1754     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1755 
1756     uint32_t composite_id =
1757         inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1758     Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1759 
1760     if (composite_inst->opcode() != SpvOpExtInst) {
1761       return false;
1762     }
1763 
1764     uint32_t inst_set_id =
1765         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1766 
1767     if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1768             inst_set_id ||
1769         composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1770             GLSLstd450FMix) {
1771       return false;
1772     }
1773 
1774     // Get the |a| for the FMix instruction.
1775     uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1776     std::unique_ptr<Instruction> a(inst->Clone(context));
1777     a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1778     context->get_instruction_folder().FoldInstruction(a.get());
1779 
1780     if (a->opcode() != SpvOpCopyObject) {
1781       return false;
1782     }
1783 
1784     const analysis::Constant* a_const =
1785         const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1786 
1787     if (!a_const) {
1788       return false;
1789     }
1790 
1791     bool use_x = false;
1792 
1793     assert(a_const->type()->AsFloat());
1794     double element_value = a_const->GetValueAsDouble();
1795     if (element_value == 0.0) {
1796       use_x = true;
1797     } else if (element_value == 1.0) {
1798       use_x = false;
1799     } else {
1800       return false;
1801     }
1802 
1803     // Get the id of the of the vector the element comes from.
1804     uint32_t new_vector = 0;
1805     if (use_x) {
1806       new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1807     } else {
1808       new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1809     }
1810 
1811     // Update the extract instruction.
1812     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1813     return true;
1814   };
1815 }
1816 
RedundantPhi()1817 FoldingRule RedundantPhi() {
1818   // An OpPhi instruction where all values are the same or the result of the phi
1819   // itself, can be replaced by the value itself.
1820   return [](IRContext*, Instruction* inst,
1821             const std::vector<const analysis::Constant*>&) {
1822     assert(inst->opcode() == SpvOpPhi && "Wrong opcode.  Should be OpPhi.");
1823 
1824     uint32_t incoming_value = 0;
1825 
1826     for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
1827       uint32_t op_id = inst->GetSingleWordInOperand(i);
1828       if (op_id == inst->result_id()) {
1829         continue;
1830       }
1831 
1832       if (incoming_value == 0) {
1833         incoming_value = op_id;
1834       } else if (op_id != incoming_value) {
1835         // Found two possible value.  Can't simplify.
1836         return false;
1837       }
1838     }
1839 
1840     if (incoming_value == 0) {
1841       // Code looks invalid.  Don't do anything.
1842       return false;
1843     }
1844 
1845     // We have a single incoming value.  Simplify using that value.
1846     inst->SetOpcode(SpvOpCopyObject);
1847     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
1848     return true;
1849   };
1850 }
1851 
BitCastScalarOrVector()1852 FoldingRule BitCastScalarOrVector() {
1853   return [](IRContext* context, Instruction* inst,
1854             const std::vector<const analysis::Constant*>& constants) {
1855     assert(inst->opcode() == SpvOpBitcast && constants.size() == 1);
1856     if (constants[0] == nullptr) return false;
1857 
1858     const analysis::Type* type =
1859         context->get_type_mgr()->GetType(inst->type_id());
1860     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
1861       return false;
1862 
1863     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1864     std::vector<uint32_t> words =
1865         GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
1866     if (words.size() == 0) return false;
1867 
1868     const analysis::Constant* bitcasted_constant =
1869         ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
1870     if (!bitcasted_constant) return false;
1871 
1872     auto new_feeder_id =
1873         const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
1874             ->result_id();
1875     inst->SetOpcode(SpvOpCopyObject);
1876     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
1877     return true;
1878   };
1879 }
1880 
RedundantSelect()1881 FoldingRule RedundantSelect() {
1882   // An OpSelect instruction where both values are the same or the condition is
1883   // constant can be replaced by one of the values
1884   return [](IRContext*, Instruction* inst,
1885             const std::vector<const analysis::Constant*>& constants) {
1886     assert(inst->opcode() == SpvOpSelect &&
1887            "Wrong opcode.  Should be OpSelect.");
1888     assert(inst->NumInOperands() == 3);
1889     assert(constants.size() == 3);
1890 
1891     uint32_t true_id = inst->GetSingleWordInOperand(1);
1892     uint32_t false_id = inst->GetSingleWordInOperand(2);
1893 
1894     if (true_id == false_id) {
1895       // Both results are the same, condition doesn't matter
1896       inst->SetOpcode(SpvOpCopyObject);
1897       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1898       return true;
1899     } else if (constants[0]) {
1900       const analysis::Type* type = constants[0]->type();
1901       if (type->AsBool()) {
1902         // Scalar constant value, select the corresponding value.
1903         inst->SetOpcode(SpvOpCopyObject);
1904         if (constants[0]->AsNullConstant() ||
1905             !constants[0]->AsBoolConstant()->value()) {
1906           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1907         } else {
1908           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1909         }
1910         return true;
1911       } else {
1912         assert(type->AsVector());
1913         if (constants[0]->AsNullConstant()) {
1914           // All values come from false id.
1915           inst->SetOpcode(SpvOpCopyObject);
1916           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1917           return true;
1918         } else {
1919           // Convert to a vector shuffle.
1920           std::vector<Operand> ops;
1921           ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
1922           ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
1923           const analysis::VectorConstant* vector_const =
1924               constants[0]->AsVectorConstant();
1925           uint32_t size =
1926               static_cast<uint32_t>(vector_const->GetComponents().size());
1927           for (uint32_t i = 0; i != size; ++i) {
1928             const analysis::Constant* component =
1929                 vector_const->GetComponents()[i];
1930             if (component->AsNullConstant() ||
1931                 !component->AsBoolConstant()->value()) {
1932               // Selecting from the false vector which is the second input
1933               // vector to the shuffle. Offset the index by |size|.
1934               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
1935             } else {
1936               // Selecting from true vector which is the first input vector to
1937               // the shuffle.
1938               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
1939             }
1940           }
1941 
1942           inst->SetOpcode(SpvOpVectorShuffle);
1943           inst->SetInOperands(std::move(ops));
1944           return true;
1945         }
1946       }
1947     }
1948 
1949     return false;
1950   };
1951 }
1952 
1953 enum class FloatConstantKind { Unknown, Zero, One };
1954 
getFloatConstantKind(const analysis::Constant * constant)1955 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
1956   if (constant == nullptr) {
1957     return FloatConstantKind::Unknown;
1958   }
1959 
1960   assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
1961 
1962   if (constant->AsNullConstant()) {
1963     return FloatConstantKind::Zero;
1964   } else if (const analysis::VectorConstant* vc =
1965                  constant->AsVectorConstant()) {
1966     const std::vector<const analysis::Constant*>& components =
1967         vc->GetComponents();
1968     assert(!components.empty());
1969 
1970     FloatConstantKind kind = getFloatConstantKind(components[0]);
1971 
1972     for (size_t i = 1; i < components.size(); ++i) {
1973       if (getFloatConstantKind(components[i]) != kind) {
1974         return FloatConstantKind::Unknown;
1975       }
1976     }
1977 
1978     return kind;
1979   } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
1980     if (fc->IsZero()) return FloatConstantKind::Zero;
1981 
1982     uint32_t width = fc->type()->AsFloat()->width();
1983     if (width != 32 && width != 64) return FloatConstantKind::Unknown;
1984 
1985     double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
1986 
1987     if (value == 0.0) {
1988       return FloatConstantKind::Zero;
1989     } else if (value == 1.0) {
1990       return FloatConstantKind::One;
1991     } else {
1992       return FloatConstantKind::Unknown;
1993     }
1994   } else {
1995     return FloatConstantKind::Unknown;
1996   }
1997 }
1998 
RedundantFAdd()1999 FoldingRule RedundantFAdd() {
2000   return [](IRContext*, Instruction* inst,
2001             const std::vector<const analysis::Constant*>& constants) {
2002     assert(inst->opcode() == SpvOpFAdd && "Wrong opcode.  Should be OpFAdd.");
2003     assert(constants.size() == 2);
2004 
2005     if (!inst->IsFloatingPointFoldingAllowed()) {
2006       return false;
2007     }
2008 
2009     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2010     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2011 
2012     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2013       inst->SetOpcode(SpvOpCopyObject);
2014       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2015                             {inst->GetSingleWordInOperand(
2016                                 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
2017       return true;
2018     }
2019 
2020     return false;
2021   };
2022 }
2023 
RedundantFSub()2024 FoldingRule RedundantFSub() {
2025   return [](IRContext*, Instruction* inst,
2026             const std::vector<const analysis::Constant*>& constants) {
2027     assert(inst->opcode() == SpvOpFSub && "Wrong opcode.  Should be OpFSub.");
2028     assert(constants.size() == 2);
2029 
2030     if (!inst->IsFloatingPointFoldingAllowed()) {
2031       return false;
2032     }
2033 
2034     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2035     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2036 
2037     if (kind0 == FloatConstantKind::Zero) {
2038       inst->SetOpcode(SpvOpFNegate);
2039       inst->SetInOperands(
2040           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
2041       return true;
2042     }
2043 
2044     if (kind1 == FloatConstantKind::Zero) {
2045       inst->SetOpcode(SpvOpCopyObject);
2046       inst->SetInOperands(
2047           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2048       return true;
2049     }
2050 
2051     return false;
2052   };
2053 }
2054 
RedundantFMul()2055 FoldingRule RedundantFMul() {
2056   return [](IRContext*, Instruction* inst,
2057             const std::vector<const analysis::Constant*>& constants) {
2058     assert(inst->opcode() == SpvOpFMul && "Wrong opcode.  Should be OpFMul.");
2059     assert(constants.size() == 2);
2060 
2061     if (!inst->IsFloatingPointFoldingAllowed()) {
2062       return false;
2063     }
2064 
2065     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2066     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2067 
2068     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2069       inst->SetOpcode(SpvOpCopyObject);
2070       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2071                             {inst->GetSingleWordInOperand(
2072                                 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
2073       return true;
2074     }
2075 
2076     if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
2077       inst->SetOpcode(SpvOpCopyObject);
2078       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2079                             {inst->GetSingleWordInOperand(
2080                                 kind0 == FloatConstantKind::One ? 1 : 0)}}});
2081       return true;
2082     }
2083 
2084     return false;
2085   };
2086 }
2087 
RedundantFDiv()2088 FoldingRule RedundantFDiv() {
2089   return [](IRContext*, Instruction* inst,
2090             const std::vector<const analysis::Constant*>& constants) {
2091     assert(inst->opcode() == SpvOpFDiv && "Wrong opcode.  Should be OpFDiv.");
2092     assert(constants.size() == 2);
2093 
2094     if (!inst->IsFloatingPointFoldingAllowed()) {
2095       return false;
2096     }
2097 
2098     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2099     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2100 
2101     if (kind0 == FloatConstantKind::Zero) {
2102       inst->SetOpcode(SpvOpCopyObject);
2103       inst->SetInOperands(
2104           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2105       return true;
2106     }
2107 
2108     if (kind1 == FloatConstantKind::One) {
2109       inst->SetOpcode(SpvOpCopyObject);
2110       inst->SetInOperands(
2111           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2112       return true;
2113     }
2114 
2115     return false;
2116   };
2117 }
2118 
RedundantFMix()2119 FoldingRule RedundantFMix() {
2120   return [](IRContext* context, Instruction* inst,
2121             const std::vector<const analysis::Constant*>& constants) {
2122     assert(inst->opcode() == SpvOpExtInst &&
2123            "Wrong opcode.  Should be OpExtInst.");
2124 
2125     if (!inst->IsFloatingPointFoldingAllowed()) {
2126       return false;
2127     }
2128 
2129     uint32_t instSetId =
2130         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2131 
2132     if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2133         inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2134             GLSLstd450FMix) {
2135       assert(constants.size() == 5);
2136 
2137       FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2138 
2139       if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2140         inst->SetOpcode(SpvOpCopyObject);
2141         inst->SetInOperands(
2142             {{SPV_OPERAND_TYPE_ID,
2143               {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2144                                                 ? kFMixXIdInIdx
2145                                                 : kFMixYIdInIdx)}}});
2146         return true;
2147       }
2148     }
2149 
2150     return false;
2151   };
2152 }
2153 
2154 // This rule handles addition of zero for integers.
RedundantIAdd()2155 FoldingRule RedundantIAdd() {
2156   return [](IRContext* context, Instruction* inst,
2157             const std::vector<const analysis::Constant*>& constants) {
2158     assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
2159 
2160     uint32_t operand = std::numeric_limits<uint32_t>::max();
2161     const analysis::Type* operand_type = nullptr;
2162     if (constants[0] && constants[0]->IsZero()) {
2163       operand = inst->GetSingleWordInOperand(1);
2164       operand_type = constants[0]->type();
2165     } else if (constants[1] && constants[1]->IsZero()) {
2166       operand = inst->GetSingleWordInOperand(0);
2167       operand_type = constants[1]->type();
2168     }
2169 
2170     if (operand != std::numeric_limits<uint32_t>::max()) {
2171       const analysis::Type* inst_type =
2172           context->get_type_mgr()->GetType(inst->type_id());
2173       if (inst_type->IsSame(operand_type)) {
2174         inst->SetOpcode(SpvOpCopyObject);
2175       } else {
2176         inst->SetOpcode(SpvOpBitcast);
2177       }
2178       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2179       return true;
2180     }
2181     return false;
2182   };
2183 }
2184 
2185 // This rule look for a dot with a constant vector containing a single 1 and
2186 // the rest 0s.  This is the same as doing an extract.
DotProductDoingExtract()2187 FoldingRule DotProductDoingExtract() {
2188   return [](IRContext* context, Instruction* inst,
2189             const std::vector<const analysis::Constant*>& constants) {
2190     assert(inst->opcode() == SpvOpDot && "Wrong opcode.  Should be OpDot.");
2191 
2192     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2193 
2194     if (!inst->IsFloatingPointFoldingAllowed()) {
2195       return false;
2196     }
2197 
2198     for (int i = 0; i < 2; ++i) {
2199       if (!constants[i]) {
2200         continue;
2201       }
2202 
2203       const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2204       assert(vector_type && "Inputs to OpDot must be vectors.");
2205       const analysis::Float* element_type =
2206           vector_type->element_type()->AsFloat();
2207       assert(element_type && "Inputs to OpDot must be vectors of floats.");
2208       uint32_t element_width = element_type->width();
2209       if (element_width != 32 && element_width != 64) {
2210         return false;
2211       }
2212 
2213       std::vector<const analysis::Constant*> components;
2214       components = constants[i]->GetVectorComponents(const_mgr);
2215 
2216       const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2217 
2218       uint32_t component_with_one = kNotFound;
2219       bool all_others_zero = true;
2220       for (uint32_t j = 0; j < components.size(); ++j) {
2221         const analysis::Constant* element = components[j];
2222         double value =
2223             (element_width == 32 ? element->GetFloat() : element->GetDouble());
2224         if (value == 0.0) {
2225           continue;
2226         } else if (value == 1.0) {
2227           if (component_with_one == kNotFound) {
2228             component_with_one = j;
2229           } else {
2230             component_with_one = kNotFound;
2231             break;
2232           }
2233         } else {
2234           all_others_zero = false;
2235           break;
2236         }
2237       }
2238 
2239       if (!all_others_zero || component_with_one == kNotFound) {
2240         continue;
2241       }
2242 
2243       std::vector<Operand> operands;
2244       operands.push_back(
2245           {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2246       operands.push_back(
2247           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2248 
2249       inst->SetOpcode(SpvOpCompositeExtract);
2250       inst->SetInOperands(std::move(operands));
2251       return true;
2252     }
2253     return false;
2254   };
2255 }
2256 
2257 // If we are storing an undef, then we can remove the store.
2258 //
2259 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2260 // is complicated.  Waiting to see if it is needed.
StoringUndef()2261 FoldingRule StoringUndef() {
2262   return [](IRContext* context, Instruction* inst,
2263             const std::vector<const analysis::Constant*>&) {
2264     assert(inst->opcode() == SpvOpStore && "Wrong opcode.  Should be OpStore.");
2265 
2266     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2267 
2268     // If this is a volatile store, the store cannot be removed.
2269     if (inst->NumInOperands() == 3) {
2270       if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
2271         return false;
2272       }
2273     }
2274 
2275     uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2276     Instruction* object_inst = def_use_mgr->GetDef(object_id);
2277     if (object_inst->opcode() == SpvOpUndef) {
2278       inst->ToNop();
2279       return true;
2280     }
2281     return false;
2282   };
2283 }
2284 
VectorShuffleFeedingShuffle()2285 FoldingRule VectorShuffleFeedingShuffle() {
2286   return [](IRContext* context, Instruction* inst,
2287             const std::vector<const analysis::Constant*>&) {
2288     assert(inst->opcode() == SpvOpVectorShuffle &&
2289            "Wrong opcode.  Should be OpVectorShuffle.");
2290 
2291     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2292     analysis::TypeManager* type_mgr = context->get_type_mgr();
2293 
2294     Instruction* feeding_shuffle_inst =
2295         def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2296     analysis::Vector* op0_type =
2297         type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2298     uint32_t op0_length = op0_type->element_count();
2299 
2300     bool feeder_is_op0 = true;
2301     if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2302       feeding_shuffle_inst =
2303           def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2304       feeder_is_op0 = false;
2305     }
2306 
2307     if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2308       return false;
2309     }
2310 
2311     Instruction* feeder2 =
2312         def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2313     analysis::Vector* feeder_op0_type =
2314         type_mgr->GetType(feeder2->type_id())->AsVector();
2315     uint32_t feeder_op0_length = feeder_op0_type->element_count();
2316 
2317     uint32_t new_feeder_id = 0;
2318     std::vector<Operand> new_operands;
2319     new_operands.resize(
2320         2, {SPV_OPERAND_TYPE_ID, {0}});  // Place holders for vector operands.
2321     const uint32_t undef_literal = 0xffffffff;
2322     for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2323       uint32_t component_index = inst->GetSingleWordInOperand(op);
2324 
2325       // Do not interpret the undefined value literal as coming from operand 1.
2326       if (component_index != undef_literal &&
2327           feeder_is_op0 == (component_index < op0_length)) {
2328         // This component comes from the feeding_shuffle_inst.  Update
2329         // |component_index| to be the index into the operand of the feeder.
2330 
2331         // Adjust component_index to get the index into the operands of the
2332         // feeding_shuffle_inst.
2333         if (component_index >= op0_length) {
2334           component_index -= op0_length;
2335         }
2336         component_index =
2337             feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2338 
2339         // Check if we are using a component from the first or second operand of
2340         // the feeding instruction.
2341         if (component_index < feeder_op0_length) {
2342           if (new_feeder_id == 0) {
2343             // First time through, save the id of the operand the element comes
2344             // from.
2345             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2346           } else if (new_feeder_id !=
2347                      feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2348             // We need both elements of the feeding_shuffle_inst, so we cannot
2349             // fold.
2350             return false;
2351           }
2352         } else {
2353           if (new_feeder_id == 0) {
2354             // First time through, save the id of the operand the element comes
2355             // from.
2356             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2357           } else if (new_feeder_id !=
2358                      feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2359             // We need both elements of the feeding_shuffle_inst, so we cannot
2360             // fold.
2361             return false;
2362           }
2363           component_index -= feeder_op0_length;
2364         }
2365 
2366         if (!feeder_is_op0) {
2367           component_index += op0_length;
2368         }
2369       }
2370       new_operands.push_back(
2371           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2372     }
2373 
2374     if (new_feeder_id == 0) {
2375       analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2376       const analysis::Type* type =
2377           type_mgr->GetType(feeding_shuffle_inst->type_id());
2378       const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2379       new_feeder_id =
2380           const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2381     }
2382 
2383     if (feeder_is_op0) {
2384       // If the size of the first vector operand changed then the indices
2385       // referring to the second operand need to be adjusted.
2386       Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2387       analysis::Type* new_feeder_type =
2388           type_mgr->GetType(new_feeder_inst->type_id());
2389       uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2390       int32_t adjustment = op0_length - new_op0_size;
2391 
2392       if (adjustment != 0) {
2393         for (uint32_t i = 2; i < new_operands.size(); i++) {
2394           if (inst->GetSingleWordInOperand(i) >= op0_length) {
2395             new_operands[i].words[0] -= adjustment;
2396           }
2397         }
2398       }
2399 
2400       new_operands[0].words[0] = new_feeder_id;
2401       new_operands[1] = inst->GetInOperand(1);
2402     } else {
2403       new_operands[1].words[0] = new_feeder_id;
2404       new_operands[0] = inst->GetInOperand(0);
2405     }
2406 
2407     inst->SetInOperands(std::move(new_operands));
2408     return true;
2409   };
2410 }
2411 
2412 // Removes duplicate ids from the interface list of an OpEntryPoint
2413 // instruction.
RemoveRedundantOperands()2414 FoldingRule RemoveRedundantOperands() {
2415   return [](IRContext*, Instruction* inst,
2416             const std::vector<const analysis::Constant*>&) {
2417     assert(inst->opcode() == SpvOpEntryPoint &&
2418            "Wrong opcode.  Should be OpEntryPoint.");
2419     bool has_redundant_operand = false;
2420     std::unordered_set<uint32_t> seen_operands;
2421     std::vector<Operand> new_operands;
2422 
2423     new_operands.emplace_back(inst->GetOperand(0));
2424     new_operands.emplace_back(inst->GetOperand(1));
2425     new_operands.emplace_back(inst->GetOperand(2));
2426     for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2427       if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2428         new_operands.emplace_back(inst->GetOperand(i));
2429       } else {
2430         has_redundant_operand = true;
2431       }
2432     }
2433 
2434     if (!has_redundant_operand) {
2435       return false;
2436     }
2437 
2438     inst->SetInOperands(std::move(new_operands));
2439     return true;
2440   };
2441 }
2442 
2443 // If an image instruction's operand is a constant, updates the image operand
2444 // flag from Offset to ConstOffset.
UpdateImageOperands()2445 FoldingRule UpdateImageOperands() {
2446   return [](IRContext*, Instruction* inst,
2447             const std::vector<const analysis::Constant*>& constants) {
2448     const auto opcode = inst->opcode();
2449     (void)opcode;
2450     assert((opcode == SpvOpImageSampleImplicitLod ||
2451             opcode == SpvOpImageSampleExplicitLod ||
2452             opcode == SpvOpImageSampleDrefImplicitLod ||
2453             opcode == SpvOpImageSampleDrefExplicitLod ||
2454             opcode == SpvOpImageSampleProjImplicitLod ||
2455             opcode == SpvOpImageSampleProjExplicitLod ||
2456             opcode == SpvOpImageSampleProjDrefImplicitLod ||
2457             opcode == SpvOpImageSampleProjDrefExplicitLod ||
2458             opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
2459             opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
2460             opcode == SpvOpImageWrite ||
2461             opcode == SpvOpImageSparseSampleImplicitLod ||
2462             opcode == SpvOpImageSparseSampleExplicitLod ||
2463             opcode == SpvOpImageSparseSampleDrefImplicitLod ||
2464             opcode == SpvOpImageSparseSampleDrefExplicitLod ||
2465             opcode == SpvOpImageSparseSampleProjImplicitLod ||
2466             opcode == SpvOpImageSparseSampleProjExplicitLod ||
2467             opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
2468             opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
2469             opcode == SpvOpImageSparseFetch ||
2470             opcode == SpvOpImageSparseGather ||
2471             opcode == SpvOpImageSparseDrefGather ||
2472             opcode == SpvOpImageSparseRead) &&
2473            "Wrong opcode.  Should be an image instruction.");
2474 
2475     int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2476     if (operand_index >= 0) {
2477       auto image_operands = inst->GetSingleWordInOperand(operand_index);
2478       if (image_operands & SpvImageOperandsOffsetMask) {
2479         uint32_t offset_operand_index = operand_index + 1;
2480         if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
2481         if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
2482         if (image_operands & SpvImageOperandsGradMask)
2483           offset_operand_index += 2;
2484         assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
2485                "Offset and ConstOffset may not be used together");
2486         if (offset_operand_index < inst->NumOperands()) {
2487           if (constants[offset_operand_index]) {
2488             image_operands = image_operands | SpvImageOperandsConstOffsetMask;
2489             image_operands = image_operands & ~SpvImageOperandsOffsetMask;
2490             inst->SetInOperand(operand_index, {image_operands});
2491             return true;
2492           }
2493         }
2494       }
2495     }
2496 
2497     return false;
2498   };
2499 }
2500 
2501 }  // namespace
2502 
AddFoldingRules()2503 void FoldingRules::AddFoldingRules() {
2504   // Add all folding rules to the list for the opcodes to which they apply.
2505   // Note that the order in which rules are added to the list matters. If a rule
2506   // applies to the instruction, the rest of the rules will not be attempted.
2507   // Take that into consideration.
2508   rules_[SpvOpBitcast].push_back(BitCastScalarOrVector());
2509 
2510   rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
2511 
2512   rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
2513   rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
2514   rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2515   rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
2516 
2517   rules_[SpvOpDot].push_back(DotProductDoingExtract());
2518 
2519   rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
2520 
2521   rules_[SpvOpFAdd].push_back(RedundantFAdd());
2522   rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
2523   rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
2524   rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
2525   rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
2526   rules_[SpvOpFAdd].push_back(FactorAddMuls());
2527 
2528   rules_[SpvOpFDiv].push_back(RedundantFDiv());
2529   rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
2530   rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
2531   rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
2532   rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
2533 
2534   rules_[SpvOpFMul].push_back(RedundantFMul());
2535   rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
2536   rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
2537   rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
2538 
2539   rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
2540   rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
2541   rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
2542 
2543   rules_[SpvOpFSub].push_back(RedundantFSub());
2544   rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
2545   rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
2546   rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
2547 
2548   rules_[SpvOpIAdd].push_back(RedundantIAdd());
2549   rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
2550   rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
2551   rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
2552   rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
2553   rules_[SpvOpIAdd].push_back(FactorAddMuls());
2554 
2555   rules_[SpvOpIMul].push_back(IntMultipleBy1());
2556   rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
2557   rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
2558 
2559   rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
2560   rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
2561   rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
2562 
2563   rules_[SpvOpPhi].push_back(RedundantPhi());
2564 
2565   rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
2566 
2567   rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
2568   rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
2569   rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
2570 
2571   rules_[SpvOpSelect].push_back(RedundantSelect());
2572 
2573   rules_[SpvOpStore].push_back(StoringUndef());
2574 
2575   rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
2576 
2577   rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2578 
2579   rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
2580   rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
2581   rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
2582   rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
2583   rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
2584   rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
2585   rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
2586   rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
2587   rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
2588   rules_[SpvOpImageGather].push_back(UpdateImageOperands());
2589   rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
2590   rules_[SpvOpImageRead].push_back(UpdateImageOperands());
2591   rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
2592   rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
2593   rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
2594   rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
2595       UpdateImageOperands());
2596   rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
2597       UpdateImageOperands());
2598   rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
2599       UpdateImageOperands());
2600   rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
2601       UpdateImageOperands());
2602   rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
2603       UpdateImageOperands());
2604   rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
2605       UpdateImageOperands());
2606   rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
2607   rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
2608   rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
2609   rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
2610 
2611   FeatureManager* feature_manager = context_->get_feature_mgr();
2612   // Add rules for GLSLstd450
2613   uint32_t ext_inst_glslstd450_id =
2614       feature_manager->GetExtInstImportId_GLSLstd450();
2615   if (ext_inst_glslstd450_id != 0) {
2616     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
2617         RedundantFMix());
2618   }
2619 }
2620 }  // namespace opt
2621 }  // namespace spvtools
2622