• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/folding_rules.h"
16 
17 #include <climits>
18 #include <limits>
19 #include <memory>
20 #include <utility>
21 
22 #include "ir_builder.h"
23 #include "source/latest_version_glsl_std_450_header.h"
24 #include "source/opt/ir_context.h"
25 
26 namespace spvtools {
27 namespace opt {
28 namespace {
29 
30 const uint32_t kExtractCompositeIdInIdx = 0;
31 const uint32_t kInsertObjectIdInIdx = 0;
32 const uint32_t kInsertCompositeIdInIdx = 1;
33 const uint32_t kExtInstSetIdInIdx = 0;
34 const uint32_t kExtInstInstructionInIdx = 1;
35 const uint32_t kFMixXIdInIdx = 2;
36 const uint32_t kFMixYIdInIdx = 3;
37 const uint32_t kFMixAIdInIdx = 4;
38 const uint32_t kStoreObjectInIdx = 1;
39 
40 // Some image instructions may contain an "image operands" argument.
41 // Returns the operand index for the "image operands".
42 // Returns -1 if the instruction does not have image operands.
ImageOperandsMaskInOperandIndex(Instruction * inst)43 int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
44   const auto opcode = inst->opcode();
45   switch (opcode) {
46     case SpvOpImageSampleImplicitLod:
47     case SpvOpImageSampleExplicitLod:
48     case SpvOpImageSampleProjImplicitLod:
49     case SpvOpImageSampleProjExplicitLod:
50     case SpvOpImageFetch:
51     case SpvOpImageRead:
52     case SpvOpImageSparseSampleImplicitLod:
53     case SpvOpImageSparseSampleExplicitLod:
54     case SpvOpImageSparseSampleProjImplicitLod:
55     case SpvOpImageSparseSampleProjExplicitLod:
56     case SpvOpImageSparseFetch:
57     case SpvOpImageSparseRead:
58       return inst->NumOperands() > 4 ? 2 : -1;
59     case SpvOpImageSampleDrefImplicitLod:
60     case SpvOpImageSampleDrefExplicitLod:
61     case SpvOpImageSampleProjDrefImplicitLod:
62     case SpvOpImageSampleProjDrefExplicitLod:
63     case SpvOpImageGather:
64     case SpvOpImageDrefGather:
65     case SpvOpImageSparseSampleDrefImplicitLod:
66     case SpvOpImageSparseSampleDrefExplicitLod:
67     case SpvOpImageSparseSampleProjDrefImplicitLod:
68     case SpvOpImageSparseSampleProjDrefExplicitLod:
69     case SpvOpImageSparseGather:
70     case SpvOpImageSparseDrefGather:
71       return inst->NumOperands() > 5 ? 3 : -1;
72     case SpvOpImageWrite:
73       return inst->NumOperands() > 3 ? 3 : -1;
74     default:
75       return -1;
76   }
77 }
78 
79 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)80 uint32_t ElementWidth(const analysis::Type* type) {
81   if (const analysis::Vector* vec_type = type->AsVector()) {
82     return ElementWidth(vec_type->element_type());
83   } else if (const analysis::Float* float_type = type->AsFloat()) {
84     return float_type->width();
85   } else {
86     assert(type->AsInteger());
87     return type->AsInteger()->width();
88   }
89 }
90 
91 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)92 bool HasFloatingPoint(const analysis::Type* type) {
93   if (type->AsFloat()) {
94     return true;
95   } else if (const analysis::Vector* vec_type = type->AsVector()) {
96     return vec_type->element_type()->AsFloat() != nullptr;
97   }
98 
99   return false;
100 }
101 
102 // Returns false if |val| is NaN, infinite or subnormal.
103 template <typename T>
IsValidResult(T val)104 bool IsValidResult(T val) {
105   int classified = std::fpclassify(val);
106   switch (classified) {
107     case FP_NAN:
108     case FP_INFINITE:
109     case FP_SUBNORMAL:
110       return false;
111     default:
112       return true;
113   }
114 }
115 
ConstInput(const std::vector<const analysis::Constant * > & constants)116 const analysis::Constant* ConstInput(
117     const std::vector<const analysis::Constant*>& constants) {
118   return constants[0] ? constants[0] : constants[1];
119 }
120 
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)121 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
122                            Instruction* inst) {
123   uint32_t in_op = c ? 1u : 0u;
124   return context->get_def_use_mgr()->GetDef(
125       inst->GetSingleWordInOperand(in_op));
126 }
127 
ExtractInts(uint64_t val)128 std::vector<uint32_t> ExtractInts(uint64_t val) {
129   std::vector<uint32_t> words;
130   words.push_back(static_cast<uint32_t>(val));
131   words.push_back(static_cast<uint32_t>(val >> 32));
132   return words;
133 }
134 
GetWordsFromScalarIntConstant(const analysis::IntConstant * c)135 std::vector<uint32_t> GetWordsFromScalarIntConstant(
136     const analysis::IntConstant* c) {
137   assert(c != nullptr);
138   uint32_t width = c->type()->AsInteger()->width();
139   assert(width == 32 || width == 64);
140   if (width == 64) {
141     uint64_t uval = static_cast<uint64_t>(c->GetU64());
142     return ExtractInts(uval);
143   }
144   return {c->GetU32()};
145 }
146 
GetWordsFromScalarFloatConstant(const analysis::FloatConstant * c)147 std::vector<uint32_t> GetWordsFromScalarFloatConstant(
148     const analysis::FloatConstant* c) {
149   assert(c != nullptr);
150   uint32_t width = c->type()->AsFloat()->width();
151   assert(width == 32 || width == 64);
152   if (width == 64) {
153     utils::FloatProxy<double> result(c->GetDouble());
154     return result.GetWords();
155   }
156   utils::FloatProxy<float> result(c->GetFloat());
157   return result.GetWords();
158 }
159 
GetWordsFromNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)160 std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
161     analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
162   if (const auto* float_constant = c->AsFloatConstant()) {
163     return GetWordsFromScalarFloatConstant(float_constant);
164   } else if (const auto* int_constant = c->AsIntConstant()) {
165     return GetWordsFromScalarIntConstant(int_constant);
166   } else if (const auto* vec_constant = c->AsVectorConstant()) {
167     std::vector<uint32_t> words;
168     for (const auto* comp : vec_constant->GetComponents()) {
169       auto comp_in_words =
170           GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
171       words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
172     }
173     return words;
174   }
175   return {};
176 }
177 
ConvertWordsToNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const std::vector<uint32_t> & words,const analysis::Type * type)178 const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
179     analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
180     const analysis::Type* type) {
181   if (type->AsInteger() || type->AsFloat())
182     return const_mgr->GetConstant(type, words);
183   if (const auto* vec_type = type->AsVector())
184     return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
185   return nullptr;
186 }
187 
188 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
189 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)190 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
191                                      const analysis::Constant* c) {
192   assert(c);
193   assert(c->type()->AsFloat());
194   uint32_t width = c->type()->AsFloat()->width();
195   assert(width == 32 || width == 64);
196   std::vector<uint32_t> words;
197   if (width == 64) {
198     utils::FloatProxy<double> result(c->GetDouble() * -1.0);
199     words = result.GetWords();
200   } else {
201     utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
202     words = result.GetWords();
203   }
204 
205   const analysis::Constant* negated_const =
206       const_mgr->GetConstant(c->type(), std::move(words));
207   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
208 }
209 
210 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)211 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
212                                const analysis::Constant* c) {
213   assert(c);
214   assert(c->type()->AsInteger());
215   uint32_t width = c->type()->AsInteger()->width();
216   assert(width == 32 || width == 64);
217   std::vector<uint32_t> words;
218   if (width == 64) {
219     uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
220     words = ExtractInts(uval);
221   } else {
222     words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
223   }
224 
225   const analysis::Constant* negated_const =
226       const_mgr->GetConstant(c->type(), std::move(words));
227   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
228 }
229 
230 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)231 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
232                               const analysis::Constant* c) {
233   assert(const_mgr && c);
234   assert(c->type()->AsVector());
235   if (c->AsNullConstant()) {
236     // 0.0 vs -0.0 shouldn't matter.
237     return const_mgr->GetDefiningInstruction(c)->result_id();
238   } else {
239     const analysis::Type* component_type =
240         c->AsVectorConstant()->component_type();
241     std::vector<uint32_t> words;
242     for (auto& comp : c->AsVectorConstant()->GetComponents()) {
243       if (component_type->AsFloat()) {
244         words.push_back(NegateFloatingPointConstant(const_mgr, comp));
245       } else {
246         assert(component_type->AsInteger());
247         words.push_back(NegateIntegerConstant(const_mgr, comp));
248       }
249     }
250 
251     const analysis::Constant* negated_const =
252         const_mgr->GetConstant(c->type(), std::move(words));
253     return const_mgr->GetDefiningInstruction(negated_const)->result_id();
254   }
255 }
256 
257 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)258 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
259                         const analysis::Constant* c) {
260   if (c->type()->AsVector()) {
261     return NegateVectorConstant(const_mgr, c);
262   } else if (c->type()->AsFloat()) {
263     return NegateFloatingPointConstant(const_mgr, c);
264   } else {
265     assert(c->type()->AsInteger());
266     return NegateIntegerConstant(const_mgr, c);
267   }
268 }
269 
270 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
271 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)272 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
273                     const analysis::Constant* c) {
274   assert(const_mgr && c);
275   assert(c->type()->AsFloat());
276 
277   uint32_t width = c->type()->AsFloat()->width();
278   assert(width == 32 || width == 64);
279   std::vector<uint32_t> words;
280   if (width == 64) {
281     spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
282     if (!IsValidResult(result.getAsFloat())) return 0;
283     words = result.GetWords();
284   } else {
285     spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
286     if (!IsValidResult(result.getAsFloat())) return 0;
287     words = result.GetWords();
288   }
289 
290   const analysis::Constant* negated_const =
291       const_mgr->GetConstant(c->type(), std::move(words));
292   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
293 }
294 
295 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()296 FoldingRule ReciprocalFDiv() {
297   return [](IRContext* context, Instruction* inst,
298             const std::vector<const analysis::Constant*>& constants) {
299     assert(inst->opcode() == SpvOpFDiv);
300     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
301     const analysis::Type* type =
302         context->get_type_mgr()->GetType(inst->type_id());
303     if (!inst->IsFloatingPointFoldingAllowed()) return false;
304 
305     uint32_t width = ElementWidth(type);
306     if (width != 32 && width != 64) return false;
307 
308     if (constants[1] != nullptr) {
309       uint32_t id = 0;
310       if (const analysis::VectorConstant* vector_const =
311               constants[1]->AsVectorConstant()) {
312         std::vector<uint32_t> neg_ids;
313         for (auto& comp : vector_const->GetComponents()) {
314           id = Reciprocal(const_mgr, comp);
315           if (id == 0) return false;
316           neg_ids.push_back(id);
317         }
318         const analysis::Constant* negated_const =
319             const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
320         id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
321       } else if (constants[1]->AsFloatConstant()) {
322         id = Reciprocal(const_mgr, constants[1]);
323         if (id == 0) return false;
324       } else {
325         // Don't fold a null constant.
326         return false;
327       }
328       inst->SetOpcode(SpvOpFMul);
329       inst->SetInOperands(
330           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
331            {SPV_OPERAND_TYPE_ID, {id}}});
332       return true;
333     }
334 
335     return false;
336   };
337 }
338 
339 // Elides consecutive negate instructions.
MergeNegateArithmetic()340 FoldingRule MergeNegateArithmetic() {
341   return [](IRContext* context, Instruction* inst,
342             const std::vector<const analysis::Constant*>& constants) {
343     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
344     (void)constants;
345     const analysis::Type* type =
346         context->get_type_mgr()->GetType(inst->type_id());
347     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
348       return false;
349 
350     Instruction* op_inst =
351         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
352     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
353       return false;
354 
355     if (op_inst->opcode() == inst->opcode()) {
356       // Elide negates.
357       inst->SetOpcode(SpvOpCopyObject);
358       inst->SetInOperands(
359           {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
360       return true;
361     }
362 
363     return false;
364   };
365 }
366 
367 // Merges negate into a mul or div operation if that operation contains a
368 // constant operand.
369 // Cases:
370 // -(x * 2) = x * -2
371 // -(2 * x) = x * -2
372 // -(x / 2) = x / -2
373 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()374 FoldingRule MergeNegateMulDivArithmetic() {
375   return [](IRContext* context, Instruction* inst,
376             const std::vector<const analysis::Constant*>& constants) {
377     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
378     (void)constants;
379     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
380     const analysis::Type* type =
381         context->get_type_mgr()->GetType(inst->type_id());
382     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
383       return false;
384 
385     Instruction* op_inst =
386         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
387     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
388       return false;
389 
390     uint32_t width = ElementWidth(type);
391     if (width != 32 && width != 64) return false;
392 
393     SpvOp opcode = op_inst->opcode();
394     if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
395         opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
396       std::vector<const analysis::Constant*> op_constants =
397           const_mgr->GetOperandConstants(op_inst);
398       // Merge negate into mul or div if one operand is constant.
399       if (op_constants[0] || op_constants[1]) {
400         bool zero_is_variable = op_constants[0] == nullptr;
401         const analysis::Constant* c = ConstInput(op_constants);
402         uint32_t neg_id = NegateConstant(const_mgr, c);
403         uint32_t non_const_id = zero_is_variable
404                                     ? op_inst->GetSingleWordInOperand(0u)
405                                     : op_inst->GetSingleWordInOperand(1u);
406         // Change this instruction to a mul/div.
407         inst->SetOpcode(op_inst->opcode());
408         if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
409           uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
410           uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
411           inst->SetInOperands(
412               {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
413         } else {
414           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
415                                {SPV_OPERAND_TYPE_ID, {neg_id}}});
416         }
417         return true;
418       }
419     }
420 
421     return false;
422   };
423 }
424 
425 // Merges negate into a add or sub operation if that operation contains a
426 // constant operand.
427 // Cases:
428 // -(x + 2) = -2 - x
429 // -(2 + x) = -2 - x
430 // -(x - 2) = 2 - x
431 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()432 FoldingRule MergeNegateAddSubArithmetic() {
433   return [](IRContext* context, Instruction* inst,
434             const std::vector<const analysis::Constant*>& constants) {
435     assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
436     (void)constants;
437     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
438     const analysis::Type* type =
439         context->get_type_mgr()->GetType(inst->type_id());
440     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
441       return false;
442 
443     Instruction* op_inst =
444         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
445     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
446       return false;
447 
448     uint32_t width = ElementWidth(type);
449     if (width != 32 && width != 64) return false;
450 
451     if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
452         op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
453       std::vector<const analysis::Constant*> op_constants =
454           const_mgr->GetOperandConstants(op_inst);
455       if (op_constants[0] || op_constants[1]) {
456         bool zero_is_variable = op_constants[0] == nullptr;
457         bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
458                       (op_inst->opcode() == SpvOpIAdd);
459         bool swap_operands = !is_add || zero_is_variable;
460         bool negate_const = is_add;
461         const analysis::Constant* c = ConstInput(op_constants);
462         uint32_t const_id = 0;
463         if (negate_const) {
464           const_id = NegateConstant(const_mgr, c);
465         } else {
466           const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
467                                       : op_inst->GetSingleWordInOperand(0u);
468         }
469 
470         // Swap operands if necessary and make the instruction a subtraction.
471         uint32_t op0 =
472             zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
473         uint32_t op1 =
474             zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
475         if (swap_operands) std::swap(op0, op1);
476         inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
477         inst->SetInOperands(
478             {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
479         return true;
480       }
481     }
482 
483     return false;
484   };
485 }
486 
487 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)488 bool HasZero(const analysis::Constant* c) {
489   if (c->AsNullConstant()) {
490     return true;
491   }
492   if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
493     for (auto& comp : vec_const->GetComponents())
494       if (HasZero(comp)) return true;
495   } else {
496     assert(c->AsScalarConstant());
497     return c->AsScalarConstant()->IsZero();
498   }
499 
500   return false;
501 }
502 
503 // Performs |input1| |opcode| |input2| and returns the merged constant result
504 // id. Returns 0 if the result is not a valid value. The input types must be
505 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)506 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
507                                        SpvOp opcode,
508                                        const analysis::Constant* input1,
509                                        const analysis::Constant* input2) {
510   const analysis::Type* type = input1->type();
511   assert(type->AsFloat());
512   uint32_t width = type->AsFloat()->width();
513   assert(width == 32 || width == 64);
514   std::vector<uint32_t> words;
515 #define FOLD_OP(op)                                                          \
516   if (width == 64) {                                                         \
517     utils::FloatProxy<double> val =                                          \
518         input1->GetDouble() op input2->GetDouble();                          \
519     double dval = val.getAsFloat();                                          \
520     if (!IsValidResult(dval)) return 0;                                      \
521     words = val.GetWords();                                                  \
522   } else {                                                                   \
523     utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
524     float fval = val.getAsFloat();                                           \
525     if (!IsValidResult(fval)) return 0;                                      \
526     words = val.GetWords();                                                  \
527   }                                                                          \
528   static_assert(true, "require extra semicolon")
529   switch (opcode) {
530     case SpvOpFMul:
531       FOLD_OP(*);
532       break;
533     case SpvOpFDiv:
534       if (HasZero(input2)) return 0;
535       FOLD_OP(/);
536       break;
537     case SpvOpFAdd:
538       FOLD_OP(+);
539       break;
540     case SpvOpFSub:
541       FOLD_OP(-);
542       break;
543     default:
544       assert(false && "Unexpected operation");
545       break;
546   }
547 #undef FOLD_OP
548   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
549   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
550 }
551 
552 // Performs |input1| |opcode| |input2| and returns the merged constant result
553 // id. Returns 0 if the result is not a valid value. The input types must be
554 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)555 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
556                                  SpvOp opcode, const analysis::Constant* input1,
557                                  const analysis::Constant* input2) {
558   assert(input1->type()->AsInteger());
559   const analysis::Integer* type = input1->type()->AsInteger();
560   uint32_t width = type->AsInteger()->width();
561   assert(width == 32 || width == 64);
562   std::vector<uint32_t> words;
563   // Regardless of the sign of the constant, folding is performed on an unsigned
564   // interpretation of the constant data. This avoids signed integer overflow
565   // while folding, and works because sign is irrelevant for the IAdd, ISub and
566   // IMul instructions.
567 #define FOLD_OP(op)                                      \
568   if (width == 64) {                                     \
569     uint64_t val = input1->GetU64() op input2->GetU64(); \
570     words = ExtractInts(val);                            \
571   } else {                                               \
572     uint32_t val = input1->GetU32() op input2->GetU32(); \
573     words.push_back(val);                                \
574   }                                                      \
575   static_assert(true, "require extra semicolon")
576   switch (opcode) {
577     case SpvOpIMul:
578       FOLD_OP(*);
579       break;
580     case SpvOpSDiv:
581     case SpvOpUDiv:
582       assert(false && "Should not merge integer division");
583       break;
584     case SpvOpIAdd:
585       FOLD_OP(+);
586       break;
587     case SpvOpISub:
588       FOLD_OP(-);
589       break;
590     default:
591       assert(false && "Unexpected operation");
592       break;
593   }
594 #undef FOLD_OP
595   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
596   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
597 }
598 
599 // Performs |input1| |opcode| |input2| and returns the merged constant result
600 // id. Returns 0 if the result is not a valid value. The input types must be
601 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)602 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
603                           const analysis::Constant* input1,
604                           const analysis::Constant* input2) {
605   assert(input1 && input2);
606   const analysis::Type* type = input1->type();
607   std::vector<uint32_t> words;
608   if (const analysis::Vector* vector_type = type->AsVector()) {
609     const analysis::Type* ele_type = vector_type->element_type();
610     for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
611       uint32_t id = 0;
612 
613       const analysis::Constant* input1_comp = nullptr;
614       if (const analysis::VectorConstant* input1_vector =
615               input1->AsVectorConstant()) {
616         input1_comp = input1_vector->GetComponents()[i];
617       } else {
618         assert(input1->AsNullConstant());
619         input1_comp = const_mgr->GetConstant(ele_type, {});
620       }
621 
622       const analysis::Constant* input2_comp = nullptr;
623       if (const analysis::VectorConstant* input2_vector =
624               input2->AsVectorConstant()) {
625         input2_comp = input2_vector->GetComponents()[i];
626       } else {
627         assert(input2->AsNullConstant());
628         input2_comp = const_mgr->GetConstant(ele_type, {});
629       }
630 
631       if (ele_type->AsFloat()) {
632         id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
633                                            input2_comp);
634       } else {
635         assert(ele_type->AsInteger());
636         id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
637                                      input2_comp);
638       }
639       if (id == 0) return 0;
640       words.push_back(id);
641     }
642     const analysis::Constant* merged_const =
643         const_mgr->GetConstant(type, words);
644     return const_mgr->GetDefiningInstruction(merged_const)->result_id();
645   } else if (type->AsFloat()) {
646     return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
647   } else {
648     assert(type->AsInteger());
649     return PerformIntegerOperation(const_mgr, opcode, input1, input2);
650   }
651 }
652 
653 // Merges consecutive multiplies where each contains one constant operand.
654 // Cases:
655 // 2 * (x * 2) = x * 4
656 // 2 * (2 * x) = x * 4
657 // (x * 2) * 2 = x * 4
658 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()659 FoldingRule MergeMulMulArithmetic() {
660   return [](IRContext* context, Instruction* inst,
661             const std::vector<const analysis::Constant*>& constants) {
662     assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
663     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
664     const analysis::Type* type =
665         context->get_type_mgr()->GetType(inst->type_id());
666     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
667       return false;
668 
669     uint32_t width = ElementWidth(type);
670     if (width != 32 && width != 64) return false;
671 
672     // Determine the constant input and the variable input in |inst|.
673     const analysis::Constant* const_input1 = ConstInput(constants);
674     if (!const_input1) return false;
675     Instruction* other_inst = NonConstInput(context, constants[0], inst);
676     if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
677       return false;
678 
679     if (other_inst->opcode() == inst->opcode()) {
680       std::vector<const analysis::Constant*> other_constants =
681           const_mgr->GetOperandConstants(other_inst);
682       const analysis::Constant* const_input2 = ConstInput(other_constants);
683       if (!const_input2) return false;
684 
685       bool other_first_is_variable = other_constants[0] == nullptr;
686       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
687                                             const_input1, const_input2);
688       if (merged_id == 0) return false;
689 
690       uint32_t non_const_id = other_first_is_variable
691                                   ? other_inst->GetSingleWordInOperand(0u)
692                                   : other_inst->GetSingleWordInOperand(1u);
693       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
694                            {SPV_OPERAND_TYPE_ID, {merged_id}}});
695       return true;
696     }
697 
698     return false;
699   };
700 }
701 
702 // Merges divides into subsequent multiplies if each instruction contains one
703 // constant operand. Does not support integer operations.
704 // Cases:
705 // 2 * (x / 2) = x * 1
706 // 2 * (2 / x) = 4 / x
707 // (x / 2) * 2 = x * 1
708 // (2 / x) * 2 = 4 / x
709 // (y / x) * x = y
710 // x * (y / x) = y
MergeMulDivArithmetic()711 FoldingRule MergeMulDivArithmetic() {
712   return [](IRContext* context, Instruction* inst,
713             const std::vector<const analysis::Constant*>& constants) {
714     assert(inst->opcode() == SpvOpFMul);
715     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
716     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
717 
718     const analysis::Type* type =
719         context->get_type_mgr()->GetType(inst->type_id());
720     if (!inst->IsFloatingPointFoldingAllowed()) return false;
721 
722     uint32_t width = ElementWidth(type);
723     if (width != 32 && width != 64) return false;
724 
725     for (uint32_t i = 0; i < 2; i++) {
726       uint32_t op_id = inst->GetSingleWordInOperand(i);
727       Instruction* op_inst = def_use_mgr->GetDef(op_id);
728       if (op_inst->opcode() == SpvOpFDiv) {
729         if (op_inst->GetSingleWordInOperand(1) ==
730             inst->GetSingleWordInOperand(1 - i)) {
731           inst->SetOpcode(SpvOpCopyObject);
732           inst->SetInOperands(
733               {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
734           return true;
735         }
736       }
737     }
738 
739     const analysis::Constant* const_input1 = ConstInput(constants);
740     if (!const_input1) return false;
741     Instruction* other_inst = NonConstInput(context, constants[0], inst);
742     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
743 
744     if (other_inst->opcode() == SpvOpFDiv) {
745       std::vector<const analysis::Constant*> other_constants =
746           const_mgr->GetOperandConstants(other_inst);
747       const analysis::Constant* const_input2 = ConstInput(other_constants);
748       if (!const_input2 || HasZero(const_input2)) return false;
749 
750       bool other_first_is_variable = other_constants[0] == nullptr;
751       // If the variable value is the second operand of the divide, multiply
752       // the constants together. Otherwise divide the constants.
753       uint32_t merged_id = PerformOperation(
754           const_mgr,
755           other_first_is_variable ? other_inst->opcode() : inst->opcode(),
756           const_input1, const_input2);
757       if (merged_id == 0) return false;
758 
759       uint32_t non_const_id = other_first_is_variable
760                                   ? other_inst->GetSingleWordInOperand(0u)
761                                   : other_inst->GetSingleWordInOperand(1u);
762 
763       // If the variable value is on the second operand of the div, then this
764       // operation is a div. Otherwise it should be a multiply.
765       inst->SetOpcode(other_first_is_variable ? inst->opcode()
766                                               : other_inst->opcode());
767       if (other_first_is_variable) {
768         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
769                              {SPV_OPERAND_TYPE_ID, {merged_id}}});
770       } else {
771         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
772                              {SPV_OPERAND_TYPE_ID, {non_const_id}}});
773       }
774       return true;
775     }
776 
777     return false;
778   };
779 }
780 
781 // Merges multiply of constant and negation.
782 // Cases:
783 // (-x) * 2 = x * -2
784 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()785 FoldingRule MergeMulNegateArithmetic() {
786   return [](IRContext* context, Instruction* inst,
787             const std::vector<const analysis::Constant*>& constants) {
788     assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
789     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
790     const analysis::Type* type =
791         context->get_type_mgr()->GetType(inst->type_id());
792     bool uses_float = HasFloatingPoint(type);
793     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
794 
795     uint32_t width = ElementWidth(type);
796     if (width != 32 && width != 64) return false;
797 
798     const analysis::Constant* const_input1 = ConstInput(constants);
799     if (!const_input1) return false;
800     Instruction* other_inst = NonConstInput(context, constants[0], inst);
801     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
802       return false;
803 
804     if (other_inst->opcode() == SpvOpFNegate ||
805         other_inst->opcode() == SpvOpSNegate) {
806       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
807 
808       inst->SetInOperands(
809           {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
810            {SPV_OPERAND_TYPE_ID, {neg_id}}});
811       return true;
812     }
813 
814     return false;
815   };
816 }
817 
818 // Merges consecutive divides if each instruction contains one constant operand.
819 // Does not support integer division.
820 // Cases:
821 // 2 / (x / 2) = 4 / x
822 // 4 / (2 / x) = 2 * x
823 // (4 / x) / 2 = 2 / x
824 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()825 FoldingRule MergeDivDivArithmetic() {
826   return [](IRContext* context, Instruction* inst,
827             const std::vector<const analysis::Constant*>& constants) {
828     assert(inst->opcode() == SpvOpFDiv);
829     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
830     const analysis::Type* type =
831         context->get_type_mgr()->GetType(inst->type_id());
832     if (!inst->IsFloatingPointFoldingAllowed()) return false;
833 
834     uint32_t width = ElementWidth(type);
835     if (width != 32 && width != 64) return false;
836 
837     const analysis::Constant* const_input1 = ConstInput(constants);
838     if (!const_input1 || HasZero(const_input1)) return false;
839     Instruction* other_inst = NonConstInput(context, constants[0], inst);
840     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
841 
842     bool first_is_variable = constants[0] == nullptr;
843     if (other_inst->opcode() == inst->opcode()) {
844       std::vector<const analysis::Constant*> other_constants =
845           const_mgr->GetOperandConstants(other_inst);
846       const analysis::Constant* const_input2 = ConstInput(other_constants);
847       if (!const_input2 || HasZero(const_input2)) return false;
848 
849       bool other_first_is_variable = other_constants[0] == nullptr;
850 
851       SpvOp merge_op = inst->opcode();
852       if (other_first_is_variable) {
853         // Constants magnify.
854         merge_op = SpvOpFMul;
855       }
856 
857       // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
858       // because it is commutative.
859       if (first_is_variable) std::swap(const_input1, const_input2);
860       uint32_t merged_id =
861           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
862       if (merged_id == 0) return false;
863 
864       uint32_t non_const_id = other_first_is_variable
865                                   ? other_inst->GetSingleWordInOperand(0u)
866                                   : other_inst->GetSingleWordInOperand(1u);
867 
868       SpvOp op = inst->opcode();
869       if (!first_is_variable && !other_first_is_variable) {
870         // Effectively div of 1/x, so change to multiply.
871         op = SpvOpFMul;
872       }
873 
874       uint32_t op1 = merged_id;
875       uint32_t op2 = non_const_id;
876       if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
877       inst->SetOpcode(op);
878       inst->SetInOperands(
879           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
880       return true;
881     }
882 
883     return false;
884   };
885 }
886 
887 // Fold multiplies succeeded by divides where each instruction contains a
888 // constant operand. Does not support integer divide.
889 // Cases:
890 // 4 / (x * 2) = 2 / x
891 // 4 / (2 * x) = 2 / x
892 // (x * 4) / 2 = x * 2
893 // (4 * x) / 2 = x * 2
894 // (x * y) / x = y
895 // (y * x) / x = y
MergeDivMulArithmetic()896 FoldingRule MergeDivMulArithmetic() {
897   return [](IRContext* context, Instruction* inst,
898             const std::vector<const analysis::Constant*>& constants) {
899     assert(inst->opcode() == SpvOpFDiv);
900     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
901     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
902 
903     const analysis::Type* type =
904         context->get_type_mgr()->GetType(inst->type_id());
905     if (!inst->IsFloatingPointFoldingAllowed()) return false;
906 
907     uint32_t width = ElementWidth(type);
908     if (width != 32 && width != 64) return false;
909 
910     uint32_t op_id = inst->GetSingleWordInOperand(0);
911     Instruction* op_inst = def_use_mgr->GetDef(op_id);
912 
913     if (op_inst->opcode() == SpvOpFMul) {
914       for (uint32_t i = 0; i < 2; i++) {
915         if (op_inst->GetSingleWordInOperand(i) ==
916             inst->GetSingleWordInOperand(1)) {
917           inst->SetOpcode(SpvOpCopyObject);
918           inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
919                                 {op_inst->GetSingleWordInOperand(1 - i)}}});
920           return true;
921         }
922       }
923     }
924 
925     const analysis::Constant* const_input1 = ConstInput(constants);
926     if (!const_input1 || HasZero(const_input1)) return false;
927     Instruction* other_inst = NonConstInput(context, constants[0], inst);
928     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
929 
930     bool first_is_variable = constants[0] == nullptr;
931     if (other_inst->opcode() == SpvOpFMul) {
932       std::vector<const analysis::Constant*> other_constants =
933           const_mgr->GetOperandConstants(other_inst);
934       const analysis::Constant* const_input2 = ConstInput(other_constants);
935       if (!const_input2) return false;
936 
937       bool other_first_is_variable = other_constants[0] == nullptr;
938 
939       // This is an x / (*) case. Swap the inputs.
940       if (first_is_variable) std::swap(const_input1, const_input2);
941       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
942                                             const_input1, const_input2);
943       if (merged_id == 0) return false;
944 
945       uint32_t non_const_id = other_first_is_variable
946                                   ? other_inst->GetSingleWordInOperand(0u)
947                                   : other_inst->GetSingleWordInOperand(1u);
948 
949       uint32_t op1 = merged_id;
950       uint32_t op2 = non_const_id;
951       if (first_is_variable) std::swap(op1, op2);
952 
953       // Convert to multiply
954       if (first_is_variable) inst->SetOpcode(other_inst->opcode());
955       inst->SetInOperands(
956           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
957       return true;
958     }
959 
960     return false;
961   };
962 }
963 
964 // Fold divides of a constant and a negation.
965 // Cases:
966 // (-x) / 2 = x / -2
967 // 2 / (-x) = -2 / x
MergeDivNegateArithmetic()968 FoldingRule MergeDivNegateArithmetic() {
969   return [](IRContext* context, Instruction* inst,
970             const std::vector<const analysis::Constant*>& constants) {
971     assert(inst->opcode() == SpvOpFDiv);
972     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
973     if (!inst->IsFloatingPointFoldingAllowed()) return false;
974 
975     const analysis::Constant* const_input1 = ConstInput(constants);
976     if (!const_input1) return false;
977     Instruction* other_inst = NonConstInput(context, constants[0], inst);
978     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
979 
980     bool first_is_variable = constants[0] == nullptr;
981     if (other_inst->opcode() == SpvOpFNegate) {
982       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
983 
984       if (first_is_variable) {
985         inst->SetInOperands(
986             {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
987              {SPV_OPERAND_TYPE_ID, {neg_id}}});
988       } else {
989         inst->SetInOperands(
990             {{SPV_OPERAND_TYPE_ID, {neg_id}},
991              {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
992       }
993       return true;
994     }
995 
996     return false;
997   };
998 }
999 
1000 // Folds addition of a constant and a negation.
1001 // Cases:
1002 // (-x) + 2 = 2 - x
1003 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()1004 FoldingRule MergeAddNegateArithmetic() {
1005   return [](IRContext* context, Instruction* inst,
1006             const std::vector<const analysis::Constant*>& constants) {
1007     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1008     const analysis::Type* type =
1009         context->get_type_mgr()->GetType(inst->type_id());
1010     bool uses_float = HasFloatingPoint(type);
1011     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1012 
1013     const analysis::Constant* const_input1 = ConstInput(constants);
1014     if (!const_input1) return false;
1015     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1016     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1017       return false;
1018 
1019     if (other_inst->opcode() == SpvOpSNegate ||
1020         other_inst->opcode() == SpvOpFNegate) {
1021       inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
1022       uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
1023                                        : inst->GetSingleWordInOperand(1u);
1024       inst->SetInOperands(
1025           {{SPV_OPERAND_TYPE_ID, {const_id}},
1026            {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1027       return true;
1028     }
1029     return false;
1030   };
1031 }
1032 
1033 // Folds subtraction of a constant and a negation.
1034 // Cases:
1035 // (-x) - 2 = -2 - x
1036 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()1037 FoldingRule MergeSubNegateArithmetic() {
1038   return [](IRContext* context, Instruction* inst,
1039             const std::vector<const analysis::Constant*>& constants) {
1040     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1041     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1042     const analysis::Type* type =
1043         context->get_type_mgr()->GetType(inst->type_id());
1044     bool uses_float = HasFloatingPoint(type);
1045     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1046 
1047     uint32_t width = ElementWidth(type);
1048     if (width != 32 && width != 64) return false;
1049 
1050     const analysis::Constant* const_input1 = ConstInput(constants);
1051     if (!const_input1) return false;
1052     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1053     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1054       return false;
1055 
1056     if (other_inst->opcode() == SpvOpSNegate ||
1057         other_inst->opcode() == SpvOpFNegate) {
1058       uint32_t op1 = 0;
1059       uint32_t op2 = 0;
1060       SpvOp opcode = inst->opcode();
1061       if (constants[0] != nullptr) {
1062         op1 = other_inst->GetSingleWordInOperand(0u);
1063         op2 = inst->GetSingleWordInOperand(0u);
1064         opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
1065       } else {
1066         op1 = NegateConstant(const_mgr, const_input1);
1067         op2 = other_inst->GetSingleWordInOperand(0u);
1068       }
1069 
1070       inst->SetOpcode(opcode);
1071       inst->SetInOperands(
1072           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1073       return true;
1074     }
1075     return false;
1076   };
1077 }
1078 
1079 // Folds addition of an addition where each operation has a constant operand.
1080 // Cases:
1081 // (x + 2) + 2 = x + 4
1082 // (2 + x) + 2 = x + 4
1083 // 2 + (x + 2) = x + 4
1084 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1085 FoldingRule MergeAddAddArithmetic() {
1086   return [](IRContext* context, Instruction* inst,
1087             const std::vector<const analysis::Constant*>& constants) {
1088     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1089     const analysis::Type* type =
1090         context->get_type_mgr()->GetType(inst->type_id());
1091     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1092     bool uses_float = HasFloatingPoint(type);
1093     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1094 
1095     uint32_t width = ElementWidth(type);
1096     if (width != 32 && width != 64) return false;
1097 
1098     const analysis::Constant* const_input1 = ConstInput(constants);
1099     if (!const_input1) return false;
1100     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1101     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1102       return false;
1103 
1104     if (other_inst->opcode() == SpvOpFAdd ||
1105         other_inst->opcode() == SpvOpIAdd) {
1106       std::vector<const analysis::Constant*> other_constants =
1107           const_mgr->GetOperandConstants(other_inst);
1108       const analysis::Constant* const_input2 = ConstInput(other_constants);
1109       if (!const_input2) return false;
1110 
1111       Instruction* non_const_input =
1112           NonConstInput(context, other_constants[0], other_inst);
1113       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1114                                             const_input1, const_input2);
1115       if (merged_id == 0) return false;
1116 
1117       inst->SetInOperands(
1118           {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1119            {SPV_OPERAND_TYPE_ID, {merged_id}}});
1120       return true;
1121     }
1122     return false;
1123   };
1124 }
1125 
1126 // Folds addition of a subtraction where each operation has a constant operand.
1127 // Cases:
1128 // (x - 2) + 2 = x + 0
1129 // (2 - x) + 2 = 4 - x
1130 // 2 + (x - 2) = x + 0
1131 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1132 FoldingRule MergeAddSubArithmetic() {
1133   return [](IRContext* context, Instruction* inst,
1134             const std::vector<const analysis::Constant*>& constants) {
1135     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1136     const analysis::Type* type =
1137         context->get_type_mgr()->GetType(inst->type_id());
1138     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1139     bool uses_float = HasFloatingPoint(type);
1140     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1141 
1142     uint32_t width = ElementWidth(type);
1143     if (width != 32 && width != 64) return false;
1144 
1145     const analysis::Constant* const_input1 = ConstInput(constants);
1146     if (!const_input1) return false;
1147     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1148     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1149       return false;
1150 
1151     if (other_inst->opcode() == SpvOpFSub ||
1152         other_inst->opcode() == SpvOpISub) {
1153       std::vector<const analysis::Constant*> other_constants =
1154           const_mgr->GetOperandConstants(other_inst);
1155       const analysis::Constant* const_input2 = ConstInput(other_constants);
1156       if (!const_input2) return false;
1157 
1158       bool first_is_variable = other_constants[0] == nullptr;
1159       SpvOp op = inst->opcode();
1160       uint32_t op1 = 0;
1161       uint32_t op2 = 0;
1162       if (first_is_variable) {
1163         // Subtract constants. Non-constant operand is first.
1164         op1 = other_inst->GetSingleWordInOperand(0u);
1165         op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1166                                const_input2);
1167       } else {
1168         // Add constants. Constant operand is first. Change the opcode.
1169         op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1170                                const_input2);
1171         op2 = other_inst->GetSingleWordInOperand(1u);
1172         op = other_inst->opcode();
1173       }
1174       if (op1 == 0 || op2 == 0) return false;
1175 
1176       inst->SetOpcode(op);
1177       inst->SetInOperands(
1178           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1179       return true;
1180     }
1181     return false;
1182   };
1183 }
1184 
1185 // Folds subtraction of an addition where each operand has a constant operand.
1186 // Cases:
1187 // (x + 2) - 2 = x + 0
1188 // (2 + x) - 2 = x + 0
1189 // 2 - (x + 2) = 0 - x
1190 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1191 FoldingRule MergeSubAddArithmetic() {
1192   return [](IRContext* context, Instruction* inst,
1193             const std::vector<const analysis::Constant*>& constants) {
1194     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1195     const analysis::Type* type =
1196         context->get_type_mgr()->GetType(inst->type_id());
1197     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1198     bool uses_float = HasFloatingPoint(type);
1199     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1200 
1201     uint32_t width = ElementWidth(type);
1202     if (width != 32 && width != 64) return false;
1203 
1204     const analysis::Constant* const_input1 = ConstInput(constants);
1205     if (!const_input1) return false;
1206     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1207     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1208       return false;
1209 
1210     if (other_inst->opcode() == SpvOpFAdd ||
1211         other_inst->opcode() == SpvOpIAdd) {
1212       std::vector<const analysis::Constant*> other_constants =
1213           const_mgr->GetOperandConstants(other_inst);
1214       const analysis::Constant* const_input2 = ConstInput(other_constants);
1215       if (!const_input2) return false;
1216 
1217       Instruction* non_const_input =
1218           NonConstInput(context, other_constants[0], other_inst);
1219 
1220       // If the first operand of the sub is not a constant, swap the constants
1221       // so the subtraction has the correct operands.
1222       if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1223       // Subtract the constants.
1224       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1225                                             const_input1, const_input2);
1226       SpvOp op = inst->opcode();
1227       uint32_t op1 = 0;
1228       uint32_t op2 = 0;
1229       if (constants[0] == nullptr) {
1230         // Non-constant operand is first. Change the opcode.
1231         op1 = non_const_input->result_id();
1232         op2 = merged_id;
1233         op = other_inst->opcode();
1234       } else {
1235         // Constant operand is first.
1236         op1 = merged_id;
1237         op2 = non_const_input->result_id();
1238       }
1239       if (op1 == 0 || op2 == 0) return false;
1240 
1241       inst->SetOpcode(op);
1242       inst->SetInOperands(
1243           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1244       return true;
1245     }
1246     return false;
1247   };
1248 }
1249 
1250 // Folds subtraction of a subtraction where each operand has a constant operand.
1251 // Cases:
1252 // (x - 2) - 2 = x - 4
1253 // (2 - x) - 2 = 0 - x
1254 // 2 - (x - 2) = 4 - x
1255 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1256 FoldingRule MergeSubSubArithmetic() {
1257   return [](IRContext* context, Instruction* inst,
1258             const std::vector<const analysis::Constant*>& constants) {
1259     assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1260     const analysis::Type* type =
1261         context->get_type_mgr()->GetType(inst->type_id());
1262     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1263     bool uses_float = HasFloatingPoint(type);
1264     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1265 
1266     uint32_t width = ElementWidth(type);
1267     if (width != 32 && width != 64) return false;
1268 
1269     const analysis::Constant* const_input1 = ConstInput(constants);
1270     if (!const_input1) return false;
1271     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1272     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1273       return false;
1274 
1275     if (other_inst->opcode() == SpvOpFSub ||
1276         other_inst->opcode() == SpvOpISub) {
1277       std::vector<const analysis::Constant*> other_constants =
1278           const_mgr->GetOperandConstants(other_inst);
1279       const analysis::Constant* const_input2 = ConstInput(other_constants);
1280       if (!const_input2) return false;
1281 
1282       Instruction* non_const_input =
1283           NonConstInput(context, other_constants[0], other_inst);
1284 
1285       // Merge the constants.
1286       uint32_t merged_id = 0;
1287       SpvOp merge_op = inst->opcode();
1288       if (other_constants[0] == nullptr) {
1289         merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1290       } else if (constants[0] == nullptr) {
1291         std::swap(const_input1, const_input2);
1292       }
1293       merged_id =
1294           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1295       if (merged_id == 0) return false;
1296 
1297       SpvOp op = inst->opcode();
1298       if (constants[0] != nullptr && other_constants[0] != nullptr) {
1299         // Change the operation.
1300         op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1301       }
1302 
1303       uint32_t op1 = 0;
1304       uint32_t op2 = 0;
1305       if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1306         op1 = merged_id;
1307         op2 = non_const_input->result_id();
1308       } else {
1309         op1 = non_const_input->result_id();
1310         op2 = merged_id;
1311       }
1312 
1313       inst->SetOpcode(op);
1314       inst->SetInOperands(
1315           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1316       return true;
1317     }
1318     return false;
1319   };
1320 }
1321 
1322 // Helper function for MergeGenericAddSubArithmetic. If |addend| and
1323 // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
MergeGenericAddendSub(uint32_t addend,uint32_t sub,Instruction * inst)1324 bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1325   IRContext* context = inst->context();
1326   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1327   Instruction* sub_inst = def_use_mgr->GetDef(sub);
1328   if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
1329     return false;
1330   if (sub_inst->opcode() == SpvOpFSub &&
1331       !sub_inst->IsFloatingPointFoldingAllowed())
1332     return false;
1333   if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1334   inst->SetOpcode(SpvOpCopyObject);
1335   inst->SetInOperands(
1336       {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1337   context->UpdateDefUse(inst);
1338   return true;
1339 }
1340 
1341 // Folds addition of a subtraction where the subtrahend is equal to the
1342 // other addend. Return a copy of the minuend. Accepts generic (const and
1343 // non-const) operands.
1344 // Cases:
1345 // (a - b) + b = a
1346 // b + (a - b) = a
MergeGenericAddSubArithmetic()1347 FoldingRule MergeGenericAddSubArithmetic() {
1348   return [](IRContext* context, Instruction* inst,
1349             const std::vector<const analysis::Constant*>&) {
1350     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1351     const analysis::Type* type =
1352         context->get_type_mgr()->GetType(inst->type_id());
1353     bool uses_float = HasFloatingPoint(type);
1354     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1355 
1356     uint32_t width = ElementWidth(type);
1357     if (width != 32 && width != 64) return false;
1358 
1359     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1360     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1361     if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1362     return MergeGenericAddendSub(add_op1, add_op0, inst);
1363   };
1364 }
1365 
1366 // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1367 // generate |factor0_0| * (|factor0_1| + |factor1_1|).
FactorAddMulsOpnds(uint32_t factor0_0,uint32_t factor0_1,uint32_t factor1_0,uint32_t factor1_1,Instruction * inst)1368 bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1369                         uint32_t factor1_0, uint32_t factor1_1,
1370                         Instruction* inst) {
1371   IRContext* context = inst->context();
1372   if (factor0_0 != factor1_0) return false;
1373   InstructionBuilder ir_builder(
1374       context, inst,
1375       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1376   Instruction* new_add_inst = ir_builder.AddBinaryOp(
1377       inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1378   inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
1379   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1380                        {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1381   context->UpdateDefUse(inst);
1382   return true;
1383 }
1384 
1385 // Perform the following factoring identity, handling all operand order
1386 // combinations: (a * b) + (a * c) = a * (b + c)
FactorAddMuls()1387 FoldingRule FactorAddMuls() {
1388   return [](IRContext* context, Instruction* inst,
1389             const std::vector<const analysis::Constant*>&) {
1390     assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1391     const analysis::Type* type =
1392         context->get_type_mgr()->GetType(inst->type_id());
1393     bool uses_float = HasFloatingPoint(type);
1394     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1395 
1396     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1397     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1398     Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1399     if (add_op0_inst->opcode() != SpvOpFMul &&
1400         add_op0_inst->opcode() != SpvOpIMul)
1401       return false;
1402     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1403     Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1404     if (add_op1_inst->opcode() != SpvOpFMul &&
1405         add_op1_inst->opcode() != SpvOpIMul)
1406       return false;
1407 
1408     // Only perform this optimization if both of the muls only have one use.
1409     // Otherwise this is a deoptimization in size and performance.
1410     if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1411     if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1412 
1413     if (add_op0_inst->opcode() == SpvOpFMul &&
1414         (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1415          !add_op1_inst->IsFloatingPointFoldingAllowed()))
1416       return false;
1417 
1418     for (int i = 0; i < 2; i++) {
1419       for (int j = 0; j < 2; j++) {
1420         // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1421         if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1422                                add_op0_inst->GetSingleWordInOperand(1 - i),
1423                                add_op1_inst->GetSingleWordInOperand(j),
1424                                add_op1_inst->GetSingleWordInOperand(1 - j),
1425                                inst))
1426           return true;
1427       }
1428     }
1429     return false;
1430   };
1431 }
1432 
IntMultipleBy1()1433 FoldingRule IntMultipleBy1() {
1434   return [](IRContext*, Instruction* inst,
1435             const std::vector<const analysis::Constant*>& constants) {
1436     assert(inst->opcode() == SpvOpIMul && "Wrong opcode.  Should be OpIMul.");
1437     for (uint32_t i = 0; i < 2; i++) {
1438       if (constants[i] == nullptr) {
1439         continue;
1440       }
1441       const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1442       if (int_constant) {
1443         uint32_t width = ElementWidth(int_constant->type());
1444         if (width != 32 && width != 64) return false;
1445         bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1446                                     : int_constant->GetU64BitValue() == 1ull;
1447         if (is_one) {
1448           inst->SetOpcode(SpvOpCopyObject);
1449           inst->SetInOperands(
1450               {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1451           return true;
1452         }
1453       }
1454     }
1455     return false;
1456   };
1457 }
1458 
1459 // Returns the number of elements that the |index|th in operand in |inst|
1460 // contributes to the result of |inst|.  |inst| must be an
1461 // OpCompositeConstructInstruction.
GetNumOfElementsContributedByOperand(IRContext * context,const Instruction * inst,uint32_t index)1462 uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
1463                                               const Instruction* inst,
1464                                               uint32_t index) {
1465   assert(inst->opcode() == SpvOpCompositeConstruct);
1466   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1467   analysis::TypeManager* type_mgr = context->get_type_mgr();
1468 
1469   analysis::Vector* result_type =
1470       type_mgr->GetType(inst->type_id())->AsVector();
1471   if (result_type == nullptr) {
1472     // If the result of the OpCompositeConstruct is not a vector then every
1473     // operands corresponds to a single element in the result.
1474     return 1;
1475   }
1476 
1477   // If the result type is a vector then the operands are either scalars or
1478   // vectors. If it is a scalar, then it corresponds to a single element.  If it
1479   // is a vector, then each element in the vector will be an element in the
1480   // result.
1481   uint32_t id = inst->GetSingleWordInOperand(index);
1482   Instruction* def = def_use_mgr->GetDef(id);
1483   analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
1484   if (type == nullptr) {
1485     return 1;
1486   }
1487   return type->element_count();
1488 }
1489 
1490 // Returns the in-operands for an OpCompositeExtract instruction that are needed
1491 // to extract the |result_index|th element in the result of |inst| without using
1492 // the result of |inst|. Returns the empty vector if |result_index| is
1493 // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
GetExtractOperandsForElementOfCompositeConstruct(IRContext * context,const Instruction * inst,uint32_t result_index)1494 std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
1495     IRContext* context, const Instruction* inst, uint32_t result_index) {
1496   assert(inst->opcode() == SpvOpCompositeConstruct);
1497   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1498   analysis::TypeManager* type_mgr = context->get_type_mgr();
1499 
1500   analysis::Type* result_type = type_mgr->GetType(inst->type_id());
1501   if (result_type->AsVector() == nullptr) {
1502     uint32_t id = inst->GetSingleWordInOperand(result_index);
1503     return {Operand(SPV_OPERAND_TYPE_ID, {id})};
1504   }
1505 
1506   // If the result type is a vector, then vector operands are concatenated.
1507   uint32_t total_element_count = 0;
1508   for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
1509     uint32_t element_count =
1510         GetNumOfElementsContributedByOperand(context, inst, idx);
1511     total_element_count += element_count;
1512     if (result_index < total_element_count) {
1513       std::vector<Operand> operands;
1514       uint32_t id = inst->GetSingleWordInOperand(idx);
1515       Instruction* operand_def = def_use_mgr->GetDef(id);
1516       analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
1517 
1518       operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1519       if (operand_type->AsVector()) {
1520         uint32_t start_index_of_id = total_element_count - element_count;
1521         uint32_t index_into_id = result_index - start_index_of_id;
1522         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
1523       }
1524       return operands;
1525     }
1526   }
1527   return {};
1528 }
1529 
CompositeConstructFeedingExtract(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1530 bool CompositeConstructFeedingExtract(
1531     IRContext* context, Instruction* inst,
1532     const std::vector<const analysis::Constant*>&) {
1533   // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1534   // then we can simply use the appropriate element in the construction.
1535   assert(inst->opcode() == SpvOpCompositeExtract &&
1536          "Wrong opcode.  Should be OpCompositeExtract.");
1537   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1538 
1539   // If there are no index operands, then this rule cannot do anything.
1540   if (inst->NumInOperands() <= 1) {
1541     return false;
1542   }
1543 
1544   uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1545   Instruction* cinst = def_use_mgr->GetDef(cid);
1546 
1547   if (cinst->opcode() != SpvOpCompositeConstruct) {
1548     return false;
1549   }
1550 
1551   uint32_t index_into_result = inst->GetSingleWordInOperand(1);
1552   std::vector<Operand> operands =
1553       GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
1554                                                        index_into_result);
1555 
1556   if (operands.empty()) {
1557     return false;
1558   }
1559 
1560   // Add the remaining indices for extraction.
1561   for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1562     operands.push_back(
1563         {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
1564   }
1565 
1566   if (operands.size() == 1) {
1567     // If there were no extra indices, then we have the final object.  No need
1568     // to extract any more.
1569     inst->SetOpcode(SpvOpCopyObject);
1570   }
1571 
1572   inst->SetInOperands(std::move(operands));
1573   return true;
1574 }
1575 
1576 // If the OpCompositeConstruct is simply putting back together elements that
1577 // where extracted from the same source, we can simply reuse the source.
1578 //
1579 // This is a common code pattern because of the way that scalar replacement
1580 // works.
CompositeExtractFeedingConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1581 bool CompositeExtractFeedingConstruct(
1582     IRContext* context, Instruction* inst,
1583     const std::vector<const analysis::Constant*>&) {
1584   assert(inst->opcode() == SpvOpCompositeConstruct &&
1585          "Wrong opcode.  Should be OpCompositeConstruct.");
1586   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1587   uint32_t original_id = 0;
1588 
1589   if (inst->NumInOperands() == 0) {
1590     // The struct being constructed has no members.
1591     return false;
1592   }
1593 
1594   // Check each element to make sure they are:
1595   // - extractions
1596   // - extracting the same position they are inserting
1597   // - all extract from the same id.
1598   for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1599     const uint32_t element_id = inst->GetSingleWordInOperand(i);
1600     Instruction* element_inst = def_use_mgr->GetDef(element_id);
1601 
1602     if (element_inst->opcode() != SpvOpCompositeExtract) {
1603       return false;
1604     }
1605 
1606     if (element_inst->NumInOperands() != 2) {
1607       return false;
1608     }
1609 
1610     if (element_inst->GetSingleWordInOperand(1) != i) {
1611       return false;
1612     }
1613 
1614     if (i == 0) {
1615       original_id =
1616           element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1617     } else if (original_id !=
1618                element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1619       return false;
1620     }
1621   }
1622 
1623   // The last check it to see that the object being extracted from is the
1624   // correct type.
1625   Instruction* original_inst = def_use_mgr->GetDef(original_id);
1626   if (original_inst->type_id() != inst->type_id()) {
1627     return false;
1628   }
1629 
1630   // Simplify by using the original object.
1631   inst->SetOpcode(SpvOpCopyObject);
1632   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1633   return true;
1634 }
1635 
InsertFeedingExtract()1636 FoldingRule InsertFeedingExtract() {
1637   return [](IRContext* context, Instruction* inst,
1638             const std::vector<const analysis::Constant*>&) {
1639     assert(inst->opcode() == SpvOpCompositeExtract &&
1640            "Wrong opcode.  Should be OpCompositeExtract.");
1641     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1642     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1643     Instruction* cinst = def_use_mgr->GetDef(cid);
1644 
1645     if (cinst->opcode() != SpvOpCompositeInsert) {
1646       return false;
1647     }
1648 
1649     // Find the first position where the list of insert and extract indicies
1650     // differ, if at all.
1651     uint32_t i;
1652     for (i = 1; i < inst->NumInOperands(); ++i) {
1653       if (i + 1 >= cinst->NumInOperands()) {
1654         break;
1655       }
1656 
1657       if (inst->GetSingleWordInOperand(i) !=
1658           cinst->GetSingleWordInOperand(i + 1)) {
1659         break;
1660       }
1661     }
1662 
1663     // We are extracting the element that was inserted.
1664     if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1665       inst->SetOpcode(SpvOpCopyObject);
1666       inst->SetInOperands(
1667           {{SPV_OPERAND_TYPE_ID,
1668             {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1669       return true;
1670     }
1671 
1672     // Extracting the value that was inserted along with values for the base
1673     // composite.  Cannot do anything.
1674     if (i == inst->NumInOperands()) {
1675       return false;
1676     }
1677 
1678     // Extracting an element of the value that was inserted.  Extract from
1679     // that value directly.
1680     if (i + 1 == cinst->NumInOperands()) {
1681       std::vector<Operand> operands;
1682       operands.push_back(
1683           {SPV_OPERAND_TYPE_ID,
1684            {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1685       for (; i < inst->NumInOperands(); ++i) {
1686         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1687                             {inst->GetSingleWordInOperand(i)}});
1688       }
1689       inst->SetInOperands(std::move(operands));
1690       return true;
1691     }
1692 
1693     // Extracting a value that is disjoint from the element being inserted.
1694     // Rewrite the extract to use the composite input to the insert.
1695     std::vector<Operand> operands;
1696     operands.push_back(
1697         {SPV_OPERAND_TYPE_ID,
1698          {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1699     for (i = 1; i < inst->NumInOperands(); ++i) {
1700       operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1701                           {inst->GetSingleWordInOperand(i)}});
1702     }
1703     inst->SetInOperands(std::move(operands));
1704     return true;
1705   };
1706 }
1707 
1708 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1709 // operands of the VectorShuffle.  We just need to adjust the index in the
1710 // extract instruction.
VectorShuffleFeedingExtract()1711 FoldingRule VectorShuffleFeedingExtract() {
1712   return [](IRContext* context, Instruction* inst,
1713             const std::vector<const analysis::Constant*>&) {
1714     assert(inst->opcode() == SpvOpCompositeExtract &&
1715            "Wrong opcode.  Should be OpCompositeExtract.");
1716     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1717     analysis::TypeManager* type_mgr = context->get_type_mgr();
1718     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1719     Instruction* cinst = def_use_mgr->GetDef(cid);
1720 
1721     if (cinst->opcode() != SpvOpVectorShuffle) {
1722       return false;
1723     }
1724 
1725     // Find the size of the first vector operand of the VectorShuffle
1726     Instruction* first_input =
1727         def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1728     analysis::Type* first_input_type =
1729         type_mgr->GetType(first_input->type_id());
1730     assert(first_input_type->AsVector() &&
1731            "Input to vector shuffle should be vectors.");
1732     uint32_t first_input_size = first_input_type->AsVector()->element_count();
1733 
1734     // Get index of the element the vector shuffle is placing in the position
1735     // being extracted.
1736     uint32_t new_index =
1737         cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1738 
1739     // Extracting an undefined value so fold this extract into an undef.
1740     const uint32_t undef_literal_value = 0xffffffff;
1741     if (new_index == undef_literal_value) {
1742       inst->SetOpcode(SpvOpUndef);
1743       inst->SetInOperands({});
1744       return true;
1745     }
1746 
1747     // Get the id of the of the vector the elemtent comes from, and update the
1748     // index if needed.
1749     uint32_t new_vector = 0;
1750     if (new_index < first_input_size) {
1751       new_vector = cinst->GetSingleWordInOperand(0);
1752     } else {
1753       new_vector = cinst->GetSingleWordInOperand(1);
1754       new_index -= first_input_size;
1755     }
1756 
1757     // Update the extract instruction.
1758     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1759     inst->SetInOperand(1, {new_index});
1760     return true;
1761   };
1762 }
1763 
1764 // When an FMix with is feeding an Extract that extracts an element whose
1765 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1766 // operands of the FMix.
FMixFeedingExtract()1767 FoldingRule FMixFeedingExtract() {
1768   return [](IRContext* context, Instruction* inst,
1769             const std::vector<const analysis::Constant*>&) {
1770     assert(inst->opcode() == SpvOpCompositeExtract &&
1771            "Wrong opcode.  Should be OpCompositeExtract.");
1772     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1773     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1774 
1775     uint32_t composite_id =
1776         inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1777     Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1778 
1779     if (composite_inst->opcode() != SpvOpExtInst) {
1780       return false;
1781     }
1782 
1783     uint32_t inst_set_id =
1784         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1785 
1786     if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1787             inst_set_id ||
1788         composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1789             GLSLstd450FMix) {
1790       return false;
1791     }
1792 
1793     // Get the |a| for the FMix instruction.
1794     uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1795     std::unique_ptr<Instruction> a(inst->Clone(context));
1796     a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1797     context->get_instruction_folder().FoldInstruction(a.get());
1798 
1799     if (a->opcode() != SpvOpCopyObject) {
1800       return false;
1801     }
1802 
1803     const analysis::Constant* a_const =
1804         const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1805 
1806     if (!a_const) {
1807       return false;
1808     }
1809 
1810     bool use_x = false;
1811 
1812     assert(a_const->type()->AsFloat());
1813     double element_value = a_const->GetValueAsDouble();
1814     if (element_value == 0.0) {
1815       use_x = true;
1816     } else if (element_value == 1.0) {
1817       use_x = false;
1818     } else {
1819       return false;
1820     }
1821 
1822     // Get the id of the of the vector the element comes from.
1823     uint32_t new_vector = 0;
1824     if (use_x) {
1825       new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1826     } else {
1827       new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1828     }
1829 
1830     // Update the extract instruction.
1831     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1832     return true;
1833   };
1834 }
1835 
RedundantPhi()1836 FoldingRule RedundantPhi() {
1837   // An OpPhi instruction where all values are the same or the result of the phi
1838   // itself, can be replaced by the value itself.
1839   return [](IRContext*, Instruction* inst,
1840             const std::vector<const analysis::Constant*>&) {
1841     assert(inst->opcode() == SpvOpPhi && "Wrong opcode.  Should be OpPhi.");
1842 
1843     uint32_t incoming_value = 0;
1844 
1845     for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
1846       uint32_t op_id = inst->GetSingleWordInOperand(i);
1847       if (op_id == inst->result_id()) {
1848         continue;
1849       }
1850 
1851       if (incoming_value == 0) {
1852         incoming_value = op_id;
1853       } else if (op_id != incoming_value) {
1854         // Found two possible value.  Can't simplify.
1855         return false;
1856       }
1857     }
1858 
1859     if (incoming_value == 0) {
1860       // Code looks invalid.  Don't do anything.
1861       return false;
1862     }
1863 
1864     // We have a single incoming value.  Simplify using that value.
1865     inst->SetOpcode(SpvOpCopyObject);
1866     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
1867     return true;
1868   };
1869 }
1870 
BitCastScalarOrVector()1871 FoldingRule BitCastScalarOrVector() {
1872   return [](IRContext* context, Instruction* inst,
1873             const std::vector<const analysis::Constant*>& constants) {
1874     assert(inst->opcode() == SpvOpBitcast && constants.size() == 1);
1875     if (constants[0] == nullptr) return false;
1876 
1877     const analysis::Type* type =
1878         context->get_type_mgr()->GetType(inst->type_id());
1879     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
1880       return false;
1881 
1882     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1883     std::vector<uint32_t> words =
1884         GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
1885     if (words.size() == 0) return false;
1886 
1887     const analysis::Constant* bitcasted_constant =
1888         ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
1889     if (!bitcasted_constant) return false;
1890 
1891     auto new_feeder_id =
1892         const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
1893             ->result_id();
1894     inst->SetOpcode(SpvOpCopyObject);
1895     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
1896     return true;
1897   };
1898 }
1899 
RedundantSelect()1900 FoldingRule RedundantSelect() {
1901   // An OpSelect instruction where both values are the same or the condition is
1902   // constant can be replaced by one of the values
1903   return [](IRContext*, Instruction* inst,
1904             const std::vector<const analysis::Constant*>& constants) {
1905     assert(inst->opcode() == SpvOpSelect &&
1906            "Wrong opcode.  Should be OpSelect.");
1907     assert(inst->NumInOperands() == 3);
1908     assert(constants.size() == 3);
1909 
1910     uint32_t true_id = inst->GetSingleWordInOperand(1);
1911     uint32_t false_id = inst->GetSingleWordInOperand(2);
1912 
1913     if (true_id == false_id) {
1914       // Both results are the same, condition doesn't matter
1915       inst->SetOpcode(SpvOpCopyObject);
1916       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1917       return true;
1918     } else if (constants[0]) {
1919       const analysis::Type* type = constants[0]->type();
1920       if (type->AsBool()) {
1921         // Scalar constant value, select the corresponding value.
1922         inst->SetOpcode(SpvOpCopyObject);
1923         if (constants[0]->AsNullConstant() ||
1924             !constants[0]->AsBoolConstant()->value()) {
1925           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1926         } else {
1927           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1928         }
1929         return true;
1930       } else {
1931         assert(type->AsVector());
1932         if (constants[0]->AsNullConstant()) {
1933           // All values come from false id.
1934           inst->SetOpcode(SpvOpCopyObject);
1935           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1936           return true;
1937         } else {
1938           // Convert to a vector shuffle.
1939           std::vector<Operand> ops;
1940           ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
1941           ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
1942           const analysis::VectorConstant* vector_const =
1943               constants[0]->AsVectorConstant();
1944           uint32_t size =
1945               static_cast<uint32_t>(vector_const->GetComponents().size());
1946           for (uint32_t i = 0; i != size; ++i) {
1947             const analysis::Constant* component =
1948                 vector_const->GetComponents()[i];
1949             if (component->AsNullConstant() ||
1950                 !component->AsBoolConstant()->value()) {
1951               // Selecting from the false vector which is the second input
1952               // vector to the shuffle. Offset the index by |size|.
1953               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
1954             } else {
1955               // Selecting from true vector which is the first input vector to
1956               // the shuffle.
1957               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
1958             }
1959           }
1960 
1961           inst->SetOpcode(SpvOpVectorShuffle);
1962           inst->SetInOperands(std::move(ops));
1963           return true;
1964         }
1965       }
1966     }
1967 
1968     return false;
1969   };
1970 }
1971 
1972 enum class FloatConstantKind { Unknown, Zero, One };
1973 
getFloatConstantKind(const analysis::Constant * constant)1974 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
1975   if (constant == nullptr) {
1976     return FloatConstantKind::Unknown;
1977   }
1978 
1979   assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
1980 
1981   if (constant->AsNullConstant()) {
1982     return FloatConstantKind::Zero;
1983   } else if (const analysis::VectorConstant* vc =
1984                  constant->AsVectorConstant()) {
1985     const std::vector<const analysis::Constant*>& components =
1986         vc->GetComponents();
1987     assert(!components.empty());
1988 
1989     FloatConstantKind kind = getFloatConstantKind(components[0]);
1990 
1991     for (size_t i = 1; i < components.size(); ++i) {
1992       if (getFloatConstantKind(components[i]) != kind) {
1993         return FloatConstantKind::Unknown;
1994       }
1995     }
1996 
1997     return kind;
1998   } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
1999     if (fc->IsZero()) return FloatConstantKind::Zero;
2000 
2001     uint32_t width = fc->type()->AsFloat()->width();
2002     if (width != 32 && width != 64) return FloatConstantKind::Unknown;
2003 
2004     double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
2005 
2006     if (value == 0.0) {
2007       return FloatConstantKind::Zero;
2008     } else if (value == 1.0) {
2009       return FloatConstantKind::One;
2010     } else {
2011       return FloatConstantKind::Unknown;
2012     }
2013   } else {
2014     return FloatConstantKind::Unknown;
2015   }
2016 }
2017 
RedundantFAdd()2018 FoldingRule RedundantFAdd() {
2019   return [](IRContext*, Instruction* inst,
2020             const std::vector<const analysis::Constant*>& constants) {
2021     assert(inst->opcode() == SpvOpFAdd && "Wrong opcode.  Should be OpFAdd.");
2022     assert(constants.size() == 2);
2023 
2024     if (!inst->IsFloatingPointFoldingAllowed()) {
2025       return false;
2026     }
2027 
2028     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2029     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2030 
2031     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2032       inst->SetOpcode(SpvOpCopyObject);
2033       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2034                             {inst->GetSingleWordInOperand(
2035                                 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
2036       return true;
2037     }
2038 
2039     return false;
2040   };
2041 }
2042 
RedundantFSub()2043 FoldingRule RedundantFSub() {
2044   return [](IRContext*, Instruction* inst,
2045             const std::vector<const analysis::Constant*>& constants) {
2046     assert(inst->opcode() == SpvOpFSub && "Wrong opcode.  Should be OpFSub.");
2047     assert(constants.size() == 2);
2048 
2049     if (!inst->IsFloatingPointFoldingAllowed()) {
2050       return false;
2051     }
2052 
2053     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2054     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2055 
2056     if (kind0 == FloatConstantKind::Zero) {
2057       inst->SetOpcode(SpvOpFNegate);
2058       inst->SetInOperands(
2059           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
2060       return true;
2061     }
2062 
2063     if (kind1 == FloatConstantKind::Zero) {
2064       inst->SetOpcode(SpvOpCopyObject);
2065       inst->SetInOperands(
2066           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2067       return true;
2068     }
2069 
2070     return false;
2071   };
2072 }
2073 
RedundantFMul()2074 FoldingRule RedundantFMul() {
2075   return [](IRContext*, Instruction* inst,
2076             const std::vector<const analysis::Constant*>& constants) {
2077     assert(inst->opcode() == SpvOpFMul && "Wrong opcode.  Should be OpFMul.");
2078     assert(constants.size() == 2);
2079 
2080     if (!inst->IsFloatingPointFoldingAllowed()) {
2081       return false;
2082     }
2083 
2084     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2085     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2086 
2087     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2088       inst->SetOpcode(SpvOpCopyObject);
2089       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2090                             {inst->GetSingleWordInOperand(
2091                                 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
2092       return true;
2093     }
2094 
2095     if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
2096       inst->SetOpcode(SpvOpCopyObject);
2097       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2098                             {inst->GetSingleWordInOperand(
2099                                 kind0 == FloatConstantKind::One ? 1 : 0)}}});
2100       return true;
2101     }
2102 
2103     return false;
2104   };
2105 }
2106 
RedundantFDiv()2107 FoldingRule RedundantFDiv() {
2108   return [](IRContext*, Instruction* inst,
2109             const std::vector<const analysis::Constant*>& constants) {
2110     assert(inst->opcode() == SpvOpFDiv && "Wrong opcode.  Should be OpFDiv.");
2111     assert(constants.size() == 2);
2112 
2113     if (!inst->IsFloatingPointFoldingAllowed()) {
2114       return false;
2115     }
2116 
2117     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2118     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2119 
2120     if (kind0 == FloatConstantKind::Zero) {
2121       inst->SetOpcode(SpvOpCopyObject);
2122       inst->SetInOperands(
2123           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2124       return true;
2125     }
2126 
2127     if (kind1 == FloatConstantKind::One) {
2128       inst->SetOpcode(SpvOpCopyObject);
2129       inst->SetInOperands(
2130           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2131       return true;
2132     }
2133 
2134     return false;
2135   };
2136 }
2137 
RedundantFMix()2138 FoldingRule RedundantFMix() {
2139   return [](IRContext* context, Instruction* inst,
2140             const std::vector<const analysis::Constant*>& constants) {
2141     assert(inst->opcode() == SpvOpExtInst &&
2142            "Wrong opcode.  Should be OpExtInst.");
2143 
2144     if (!inst->IsFloatingPointFoldingAllowed()) {
2145       return false;
2146     }
2147 
2148     uint32_t instSetId =
2149         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2150 
2151     if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2152         inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2153             GLSLstd450FMix) {
2154       assert(constants.size() == 5);
2155 
2156       FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2157 
2158       if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2159         inst->SetOpcode(SpvOpCopyObject);
2160         inst->SetInOperands(
2161             {{SPV_OPERAND_TYPE_ID,
2162               {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2163                                                 ? kFMixXIdInIdx
2164                                                 : kFMixYIdInIdx)}}});
2165         return true;
2166       }
2167     }
2168 
2169     return false;
2170   };
2171 }
2172 
2173 // This rule handles addition of zero for integers.
RedundantIAdd()2174 FoldingRule RedundantIAdd() {
2175   return [](IRContext* context, Instruction* inst,
2176             const std::vector<const analysis::Constant*>& constants) {
2177     assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
2178 
2179     uint32_t operand = std::numeric_limits<uint32_t>::max();
2180     const analysis::Type* operand_type = nullptr;
2181     if (constants[0] && constants[0]->IsZero()) {
2182       operand = inst->GetSingleWordInOperand(1);
2183       operand_type = constants[0]->type();
2184     } else if (constants[1] && constants[1]->IsZero()) {
2185       operand = inst->GetSingleWordInOperand(0);
2186       operand_type = constants[1]->type();
2187     }
2188 
2189     if (operand != std::numeric_limits<uint32_t>::max()) {
2190       const analysis::Type* inst_type =
2191           context->get_type_mgr()->GetType(inst->type_id());
2192       if (inst_type->IsSame(operand_type)) {
2193         inst->SetOpcode(SpvOpCopyObject);
2194       } else {
2195         inst->SetOpcode(SpvOpBitcast);
2196       }
2197       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2198       return true;
2199     }
2200     return false;
2201   };
2202 }
2203 
2204 // This rule look for a dot with a constant vector containing a single 1 and
2205 // the rest 0s.  This is the same as doing an extract.
DotProductDoingExtract()2206 FoldingRule DotProductDoingExtract() {
2207   return [](IRContext* context, Instruction* inst,
2208             const std::vector<const analysis::Constant*>& constants) {
2209     assert(inst->opcode() == SpvOpDot && "Wrong opcode.  Should be OpDot.");
2210 
2211     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2212 
2213     if (!inst->IsFloatingPointFoldingAllowed()) {
2214       return false;
2215     }
2216 
2217     for (int i = 0; i < 2; ++i) {
2218       if (!constants[i]) {
2219         continue;
2220       }
2221 
2222       const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2223       assert(vector_type && "Inputs to OpDot must be vectors.");
2224       const analysis::Float* element_type =
2225           vector_type->element_type()->AsFloat();
2226       assert(element_type && "Inputs to OpDot must be vectors of floats.");
2227       uint32_t element_width = element_type->width();
2228       if (element_width != 32 && element_width != 64) {
2229         return false;
2230       }
2231 
2232       std::vector<const analysis::Constant*> components;
2233       components = constants[i]->GetVectorComponents(const_mgr);
2234 
2235       const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2236 
2237       uint32_t component_with_one = kNotFound;
2238       bool all_others_zero = true;
2239       for (uint32_t j = 0; j < components.size(); ++j) {
2240         const analysis::Constant* element = components[j];
2241         double value =
2242             (element_width == 32 ? element->GetFloat() : element->GetDouble());
2243         if (value == 0.0) {
2244           continue;
2245         } else if (value == 1.0) {
2246           if (component_with_one == kNotFound) {
2247             component_with_one = j;
2248           } else {
2249             component_with_one = kNotFound;
2250             break;
2251           }
2252         } else {
2253           all_others_zero = false;
2254           break;
2255         }
2256       }
2257 
2258       if (!all_others_zero || component_with_one == kNotFound) {
2259         continue;
2260       }
2261 
2262       std::vector<Operand> operands;
2263       operands.push_back(
2264           {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2265       operands.push_back(
2266           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2267 
2268       inst->SetOpcode(SpvOpCompositeExtract);
2269       inst->SetInOperands(std::move(operands));
2270       return true;
2271     }
2272     return false;
2273   };
2274 }
2275 
2276 // If we are storing an undef, then we can remove the store.
2277 //
2278 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2279 // is complicated.  Waiting to see if it is needed.
StoringUndef()2280 FoldingRule StoringUndef() {
2281   return [](IRContext* context, Instruction* inst,
2282             const std::vector<const analysis::Constant*>&) {
2283     assert(inst->opcode() == SpvOpStore && "Wrong opcode.  Should be OpStore.");
2284 
2285     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2286 
2287     // If this is a volatile store, the store cannot be removed.
2288     if (inst->NumInOperands() == 3) {
2289       if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
2290         return false;
2291       }
2292     }
2293 
2294     uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2295     Instruction* object_inst = def_use_mgr->GetDef(object_id);
2296     if (object_inst->opcode() == SpvOpUndef) {
2297       inst->ToNop();
2298       return true;
2299     }
2300     return false;
2301   };
2302 }
2303 
VectorShuffleFeedingShuffle()2304 FoldingRule VectorShuffleFeedingShuffle() {
2305   return [](IRContext* context, Instruction* inst,
2306             const std::vector<const analysis::Constant*>&) {
2307     assert(inst->opcode() == SpvOpVectorShuffle &&
2308            "Wrong opcode.  Should be OpVectorShuffle.");
2309 
2310     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2311     analysis::TypeManager* type_mgr = context->get_type_mgr();
2312 
2313     Instruction* feeding_shuffle_inst =
2314         def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2315     analysis::Vector* op0_type =
2316         type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2317     uint32_t op0_length = op0_type->element_count();
2318 
2319     bool feeder_is_op0 = true;
2320     if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2321       feeding_shuffle_inst =
2322           def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2323       feeder_is_op0 = false;
2324     }
2325 
2326     if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2327       return false;
2328     }
2329 
2330     Instruction* feeder2 =
2331         def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2332     analysis::Vector* feeder_op0_type =
2333         type_mgr->GetType(feeder2->type_id())->AsVector();
2334     uint32_t feeder_op0_length = feeder_op0_type->element_count();
2335 
2336     uint32_t new_feeder_id = 0;
2337     std::vector<Operand> new_operands;
2338     new_operands.resize(
2339         2, {SPV_OPERAND_TYPE_ID, {0}});  // Place holders for vector operands.
2340     const uint32_t undef_literal = 0xffffffff;
2341     for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2342       uint32_t component_index = inst->GetSingleWordInOperand(op);
2343 
2344       // Do not interpret the undefined value literal as coming from operand 1.
2345       if (component_index != undef_literal &&
2346           feeder_is_op0 == (component_index < op0_length)) {
2347         // This component comes from the feeding_shuffle_inst.  Update
2348         // |component_index| to be the index into the operand of the feeder.
2349 
2350         // Adjust component_index to get the index into the operands of the
2351         // feeding_shuffle_inst.
2352         if (component_index >= op0_length) {
2353           component_index -= op0_length;
2354         }
2355         component_index =
2356             feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2357 
2358         // Check if we are using a component from the first or second operand of
2359         // the feeding instruction.
2360         if (component_index < feeder_op0_length) {
2361           if (new_feeder_id == 0) {
2362             // First time through, save the id of the operand the element comes
2363             // from.
2364             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2365           } else if (new_feeder_id !=
2366                      feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2367             // We need both elements of the feeding_shuffle_inst, so we cannot
2368             // fold.
2369             return false;
2370           }
2371         } else {
2372           if (new_feeder_id == 0) {
2373             // First time through, save the id of the operand the element comes
2374             // from.
2375             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2376           } else if (new_feeder_id !=
2377                      feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2378             // We need both elements of the feeding_shuffle_inst, so we cannot
2379             // fold.
2380             return false;
2381           }
2382           component_index -= feeder_op0_length;
2383         }
2384 
2385         if (!feeder_is_op0) {
2386           component_index += op0_length;
2387         }
2388       }
2389       new_operands.push_back(
2390           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2391     }
2392 
2393     if (new_feeder_id == 0) {
2394       analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2395       const analysis::Type* type =
2396           type_mgr->GetType(feeding_shuffle_inst->type_id());
2397       const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2398       new_feeder_id =
2399           const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2400     }
2401 
2402     if (feeder_is_op0) {
2403       // If the size of the first vector operand changed then the indices
2404       // referring to the second operand need to be adjusted.
2405       Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2406       analysis::Type* new_feeder_type =
2407           type_mgr->GetType(new_feeder_inst->type_id());
2408       uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2409       int32_t adjustment = op0_length - new_op0_size;
2410 
2411       if (adjustment != 0) {
2412         for (uint32_t i = 2; i < new_operands.size(); i++) {
2413           if (inst->GetSingleWordInOperand(i) >= op0_length) {
2414             new_operands[i].words[0] -= adjustment;
2415           }
2416         }
2417       }
2418 
2419       new_operands[0].words[0] = new_feeder_id;
2420       new_operands[1] = inst->GetInOperand(1);
2421     } else {
2422       new_operands[1].words[0] = new_feeder_id;
2423       new_operands[0] = inst->GetInOperand(0);
2424     }
2425 
2426     inst->SetInOperands(std::move(new_operands));
2427     return true;
2428   };
2429 }
2430 
2431 // Removes duplicate ids from the interface list of an OpEntryPoint
2432 // instruction.
RemoveRedundantOperands()2433 FoldingRule RemoveRedundantOperands() {
2434   return [](IRContext*, Instruction* inst,
2435             const std::vector<const analysis::Constant*>&) {
2436     assert(inst->opcode() == SpvOpEntryPoint &&
2437            "Wrong opcode.  Should be OpEntryPoint.");
2438     bool has_redundant_operand = false;
2439     std::unordered_set<uint32_t> seen_operands;
2440     std::vector<Operand> new_operands;
2441 
2442     new_operands.emplace_back(inst->GetOperand(0));
2443     new_operands.emplace_back(inst->GetOperand(1));
2444     new_operands.emplace_back(inst->GetOperand(2));
2445     for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2446       if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2447         new_operands.emplace_back(inst->GetOperand(i));
2448       } else {
2449         has_redundant_operand = true;
2450       }
2451     }
2452 
2453     if (!has_redundant_operand) {
2454       return false;
2455     }
2456 
2457     inst->SetInOperands(std::move(new_operands));
2458     return true;
2459   };
2460 }
2461 
2462 // If an image instruction's operand is a constant, updates the image operand
2463 // flag from Offset to ConstOffset.
UpdateImageOperands()2464 FoldingRule UpdateImageOperands() {
2465   return [](IRContext*, Instruction* inst,
2466             const std::vector<const analysis::Constant*>& constants) {
2467     const auto opcode = inst->opcode();
2468     (void)opcode;
2469     assert((opcode == SpvOpImageSampleImplicitLod ||
2470             opcode == SpvOpImageSampleExplicitLod ||
2471             opcode == SpvOpImageSampleDrefImplicitLod ||
2472             opcode == SpvOpImageSampleDrefExplicitLod ||
2473             opcode == SpvOpImageSampleProjImplicitLod ||
2474             opcode == SpvOpImageSampleProjExplicitLod ||
2475             opcode == SpvOpImageSampleProjDrefImplicitLod ||
2476             opcode == SpvOpImageSampleProjDrefExplicitLod ||
2477             opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
2478             opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
2479             opcode == SpvOpImageWrite ||
2480             opcode == SpvOpImageSparseSampleImplicitLod ||
2481             opcode == SpvOpImageSparseSampleExplicitLod ||
2482             opcode == SpvOpImageSparseSampleDrefImplicitLod ||
2483             opcode == SpvOpImageSparseSampleDrefExplicitLod ||
2484             opcode == SpvOpImageSparseSampleProjImplicitLod ||
2485             opcode == SpvOpImageSparseSampleProjExplicitLod ||
2486             opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
2487             opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
2488             opcode == SpvOpImageSparseFetch ||
2489             opcode == SpvOpImageSparseGather ||
2490             opcode == SpvOpImageSparseDrefGather ||
2491             opcode == SpvOpImageSparseRead) &&
2492            "Wrong opcode.  Should be an image instruction.");
2493 
2494     int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2495     if (operand_index >= 0) {
2496       auto image_operands = inst->GetSingleWordInOperand(operand_index);
2497       if (image_operands & SpvImageOperandsOffsetMask) {
2498         uint32_t offset_operand_index = operand_index + 1;
2499         if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
2500         if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
2501         if (image_operands & SpvImageOperandsGradMask)
2502           offset_operand_index += 2;
2503         assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
2504                "Offset and ConstOffset may not be used together");
2505         if (offset_operand_index < inst->NumOperands()) {
2506           if (constants[offset_operand_index]) {
2507             image_operands = image_operands | SpvImageOperandsConstOffsetMask;
2508             image_operands = image_operands & ~SpvImageOperandsOffsetMask;
2509             inst->SetInOperand(operand_index, {image_operands});
2510             return true;
2511           }
2512         }
2513       }
2514     }
2515 
2516     return false;
2517   };
2518 }
2519 
2520 }  // namespace
2521 
AddFoldingRules()2522 void FoldingRules::AddFoldingRules() {
2523   // Add all folding rules to the list for the opcodes to which they apply.
2524   // Note that the order in which rules are added to the list matters. If a rule
2525   // applies to the instruction, the rest of the rules will not be attempted.
2526   // Take that into consideration.
2527   rules_[SpvOpBitcast].push_back(BitCastScalarOrVector());
2528 
2529   rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
2530 
2531   rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
2532   rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract);
2533   rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2534   rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
2535 
2536   rules_[SpvOpDot].push_back(DotProductDoingExtract());
2537 
2538   rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
2539 
2540   rules_[SpvOpFAdd].push_back(RedundantFAdd());
2541   rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
2542   rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
2543   rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
2544   rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
2545   rules_[SpvOpFAdd].push_back(FactorAddMuls());
2546 
2547   rules_[SpvOpFDiv].push_back(RedundantFDiv());
2548   rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
2549   rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
2550   rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
2551   rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
2552 
2553   rules_[SpvOpFMul].push_back(RedundantFMul());
2554   rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
2555   rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
2556   rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
2557 
2558   rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
2559   rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
2560   rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
2561 
2562   rules_[SpvOpFSub].push_back(RedundantFSub());
2563   rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
2564   rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
2565   rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
2566 
2567   rules_[SpvOpIAdd].push_back(RedundantIAdd());
2568   rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
2569   rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
2570   rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
2571   rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
2572   rules_[SpvOpIAdd].push_back(FactorAddMuls());
2573 
2574   rules_[SpvOpIMul].push_back(IntMultipleBy1());
2575   rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
2576   rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
2577 
2578   rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
2579   rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
2580   rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
2581 
2582   rules_[SpvOpPhi].push_back(RedundantPhi());
2583 
2584   rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
2585   rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
2586   rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
2587 
2588   rules_[SpvOpSelect].push_back(RedundantSelect());
2589 
2590   rules_[SpvOpStore].push_back(StoringUndef());
2591 
2592   rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2593 
2594   rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
2595   rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
2596   rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
2597   rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
2598   rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
2599   rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
2600   rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
2601   rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
2602   rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
2603   rules_[SpvOpImageGather].push_back(UpdateImageOperands());
2604   rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
2605   rules_[SpvOpImageRead].push_back(UpdateImageOperands());
2606   rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
2607   rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
2608   rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
2609   rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
2610       UpdateImageOperands());
2611   rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
2612       UpdateImageOperands());
2613   rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
2614       UpdateImageOperands());
2615   rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
2616       UpdateImageOperands());
2617   rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
2618       UpdateImageOperands());
2619   rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
2620       UpdateImageOperands());
2621   rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
2622   rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
2623   rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
2624   rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
2625 
2626   FeatureManager* feature_manager = context_->get_feature_mgr();
2627   // Add rules for GLSLstd450
2628   uint32_t ext_inst_glslstd450_id =
2629       feature_manager->GetExtInstImportId_GLSLstd450();
2630   if (ext_inst_glslstd450_id != 0) {
2631     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
2632         RedundantFMix());
2633   }
2634 }
2635 }  // namespace opt
2636 }  // namespace spvtools
2637