• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/folding_rules.h"
16 
17 #include <limits>
18 #include <memory>
19 #include <utility>
20 
21 #include "ir_builder.h"
22 #include "source/latest_version_glsl_std_450_header.h"
23 #include "source/opt/ir_context.h"
24 
25 namespace spvtools {
26 namespace opt {
27 namespace {
28 
29 constexpr uint32_t kExtractCompositeIdInIdx = 0;
30 constexpr uint32_t kInsertObjectIdInIdx = 0;
31 constexpr uint32_t kInsertCompositeIdInIdx = 1;
32 constexpr uint32_t kExtInstSetIdInIdx = 0;
33 constexpr uint32_t kExtInstInstructionInIdx = 1;
34 constexpr uint32_t kFMixXIdInIdx = 2;
35 constexpr uint32_t kFMixYIdInIdx = 3;
36 constexpr uint32_t kFMixAIdInIdx = 4;
37 constexpr uint32_t kStoreObjectInIdx = 1;
38 
39 // Some image instructions may contain an "image operands" argument.
40 // Returns the operand index for the "image operands".
41 // Returns -1 if the instruction does not have image operands.
ImageOperandsMaskInOperandIndex(Instruction * inst)42 int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
43   const auto opcode = inst->opcode();
44   switch (opcode) {
45     case spv::Op::OpImageSampleImplicitLod:
46     case spv::Op::OpImageSampleExplicitLod:
47     case spv::Op::OpImageSampleProjImplicitLod:
48     case spv::Op::OpImageSampleProjExplicitLod:
49     case spv::Op::OpImageFetch:
50     case spv::Op::OpImageRead:
51     case spv::Op::OpImageSparseSampleImplicitLod:
52     case spv::Op::OpImageSparseSampleExplicitLod:
53     case spv::Op::OpImageSparseSampleProjImplicitLod:
54     case spv::Op::OpImageSparseSampleProjExplicitLod:
55     case spv::Op::OpImageSparseFetch:
56     case spv::Op::OpImageSparseRead:
57       return inst->NumOperands() > 4 ? 2 : -1;
58     case spv::Op::OpImageSampleDrefImplicitLod:
59     case spv::Op::OpImageSampleDrefExplicitLod:
60     case spv::Op::OpImageSampleProjDrefImplicitLod:
61     case spv::Op::OpImageSampleProjDrefExplicitLod:
62     case spv::Op::OpImageGather:
63     case spv::Op::OpImageDrefGather:
64     case spv::Op::OpImageSparseSampleDrefImplicitLod:
65     case spv::Op::OpImageSparseSampleDrefExplicitLod:
66     case spv::Op::OpImageSparseSampleProjDrefImplicitLod:
67     case spv::Op::OpImageSparseSampleProjDrefExplicitLod:
68     case spv::Op::OpImageSparseGather:
69     case spv::Op::OpImageSparseDrefGather:
70       return inst->NumOperands() > 5 ? 3 : -1;
71     case spv::Op::OpImageWrite:
72       return inst->NumOperands() > 3 ? 3 : -1;
73     default:
74       return -1;
75   }
76 }
77 
78 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)79 uint32_t ElementWidth(const analysis::Type* type) {
80   if (const analysis::Vector* vec_type = type->AsVector()) {
81     return ElementWidth(vec_type->element_type());
82   } else if (const analysis::Float* float_type = type->AsFloat()) {
83     return float_type->width();
84   } else {
85     assert(type->AsInteger());
86     return type->AsInteger()->width();
87   }
88 }
89 
90 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)91 bool HasFloatingPoint(const analysis::Type* type) {
92   if (type->AsFloat()) {
93     return true;
94   } else if (const analysis::Vector* vec_type = type->AsVector()) {
95     return vec_type->element_type()->AsFloat() != nullptr;
96   }
97 
98   return false;
99 }
100 
101 // Returns false if |val| is NaN, infinite or subnormal.
102 template <typename T>
IsValidResult(T val)103 bool IsValidResult(T val) {
104   int classified = std::fpclassify(val);
105   switch (classified) {
106     case FP_NAN:
107     case FP_INFINITE:
108     case FP_SUBNORMAL:
109       return false;
110     default:
111       return true;
112   }
113 }
114 
ConstInput(const std::vector<const analysis::Constant * > & constants)115 const analysis::Constant* ConstInput(
116     const std::vector<const analysis::Constant*>& constants) {
117   return constants[0] ? constants[0] : constants[1];
118 }
119 
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)120 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
121                            Instruction* inst) {
122   uint32_t in_op = c ? 1u : 0u;
123   return context->get_def_use_mgr()->GetDef(
124       inst->GetSingleWordInOperand(in_op));
125 }
126 
ExtractInts(uint64_t val)127 std::vector<uint32_t> ExtractInts(uint64_t val) {
128   std::vector<uint32_t> words;
129   words.push_back(static_cast<uint32_t>(val));
130   words.push_back(static_cast<uint32_t>(val >> 32));
131   return words;
132 }
133 
GetWordsFromScalarIntConstant(const analysis::IntConstant * c)134 std::vector<uint32_t> GetWordsFromScalarIntConstant(
135     const analysis::IntConstant* c) {
136   assert(c != nullptr);
137   uint32_t width = c->type()->AsInteger()->width();
138   assert(width == 8 || width == 16 || width == 32 || width == 64);
139   if (width == 64) {
140     uint64_t uval = static_cast<uint64_t>(c->GetU64());
141     return ExtractInts(uval);
142   }
143   // Section 2.2.1 of the SPIR-V spec guarantees that all integer types
144   // smaller than 32-bits are automatically zero or sign extended to 32-bits.
145   return {c->GetU32BitValue()};
146 }
147 
GetWordsFromScalarFloatConstant(const analysis::FloatConstant * c)148 std::vector<uint32_t> GetWordsFromScalarFloatConstant(
149     const analysis::FloatConstant* c) {
150   assert(c != nullptr);
151   uint32_t width = c->type()->AsFloat()->width();
152   assert(width == 16 || width == 32 || width == 64);
153   if (width == 64) {
154     utils::FloatProxy<double> result(c->GetDouble());
155     return result.GetWords();
156   }
157   // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types
158   // smaller than 32-bits are automatically zero extended to 32-bits.
159   return {c->GetU32BitValue()};
160 }
161 
GetWordsFromNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)162 std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
163     analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
164   if (const auto* float_constant = c->AsFloatConstant()) {
165     return GetWordsFromScalarFloatConstant(float_constant);
166   } else if (const auto* int_constant = c->AsIntConstant()) {
167     return GetWordsFromScalarIntConstant(int_constant);
168   } else if (const auto* vec_constant = c->AsVectorConstant()) {
169     std::vector<uint32_t> words;
170     for (const auto* comp : vec_constant->GetComponents()) {
171       auto comp_in_words =
172           GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
173       words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
174     }
175     return words;
176   }
177   return {};
178 }
179 
ConvertWordsToNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const std::vector<uint32_t> & words,const analysis::Type * type)180 const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
181     analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
182     const analysis::Type* type) {
183   if (type->AsInteger() || type->AsFloat())
184     return const_mgr->GetConstant(type, words);
185   if (const auto* vec_type = type->AsVector())
186     return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
187   return nullptr;
188 }
189 
190 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
191 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)192 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
193                                      const analysis::Constant* c) {
194   assert(c);
195   assert(c->type()->AsFloat());
196   uint32_t width = c->type()->AsFloat()->width();
197   assert(width == 32 || width == 64);
198   std::vector<uint32_t> words;
199   if (width == 64) {
200     utils::FloatProxy<double> result(c->GetDouble() * -1.0);
201     words = result.GetWords();
202   } else {
203     utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
204     words = result.GetWords();
205   }
206 
207   const analysis::Constant* negated_const =
208       const_mgr->GetConstant(c->type(), std::move(words));
209   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
210 }
211 
212 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)213 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
214                                const analysis::Constant* c) {
215   assert(c);
216   assert(c->type()->AsInteger());
217   uint32_t width = c->type()->AsInteger()->width();
218   assert(width == 32 || width == 64);
219   std::vector<uint32_t> words;
220   if (width == 64) {
221     uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
222     words = ExtractInts(uval);
223   } else {
224     words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
225   }
226 
227   const analysis::Constant* negated_const =
228       const_mgr->GetConstant(c->type(), std::move(words));
229   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
230 }
231 
232 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)233 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
234                               const analysis::Constant* c) {
235   assert(const_mgr && c);
236   assert(c->type()->AsVector());
237   if (c->AsNullConstant()) {
238     // 0.0 vs -0.0 shouldn't matter.
239     return const_mgr->GetDefiningInstruction(c)->result_id();
240   } else {
241     const analysis::Type* component_type =
242         c->AsVectorConstant()->component_type();
243     std::vector<uint32_t> words;
244     for (auto& comp : c->AsVectorConstant()->GetComponents()) {
245       if (component_type->AsFloat()) {
246         words.push_back(NegateFloatingPointConstant(const_mgr, comp));
247       } else {
248         assert(component_type->AsInteger());
249         words.push_back(NegateIntegerConstant(const_mgr, comp));
250       }
251     }
252 
253     const analysis::Constant* negated_const =
254         const_mgr->GetConstant(c->type(), std::move(words));
255     return const_mgr->GetDefiningInstruction(negated_const)->result_id();
256   }
257 }
258 
259 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)260 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
261                         const analysis::Constant* c) {
262   if (c->type()->AsVector()) {
263     return NegateVectorConstant(const_mgr, c);
264   } else if (c->type()->AsFloat()) {
265     return NegateFloatingPointConstant(const_mgr, c);
266   } else {
267     assert(c->type()->AsInteger());
268     return NegateIntegerConstant(const_mgr, c);
269   }
270 }
271 
272 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
273 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)274 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
275                     const analysis::Constant* c) {
276   assert(const_mgr && c);
277   assert(c->type()->AsFloat());
278 
279   uint32_t width = c->type()->AsFloat()->width();
280   assert(width == 32 || width == 64);
281   std::vector<uint32_t> words;
282 
283   if (c->IsZero()) {
284     return 0;
285   }
286 
287   if (width == 64) {
288     spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
289     if (!IsValidResult(result.getAsFloat())) return 0;
290     words = result.GetWords();
291   } else {
292     spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
293     if (!IsValidResult(result.getAsFloat())) return 0;
294     words = result.GetWords();
295   }
296 
297   const analysis::Constant* negated_const =
298       const_mgr->GetConstant(c->type(), std::move(words));
299   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
300 }
301 
302 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()303 FoldingRule ReciprocalFDiv() {
304   return [](IRContext* context, Instruction* inst,
305             const std::vector<const analysis::Constant*>& constants) {
306     assert(inst->opcode() == spv::Op::OpFDiv);
307     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
308     const analysis::Type* type =
309         context->get_type_mgr()->GetType(inst->type_id());
310     if (!inst->IsFloatingPointFoldingAllowed()) return false;
311 
312     uint32_t width = ElementWidth(type);
313     if (width != 32 && width != 64) return false;
314 
315     if (constants[1] != nullptr) {
316       uint32_t id = 0;
317       if (const analysis::VectorConstant* vector_const =
318               constants[1]->AsVectorConstant()) {
319         std::vector<uint32_t> neg_ids;
320         for (auto& comp : vector_const->GetComponents()) {
321           id = Reciprocal(const_mgr, comp);
322           if (id == 0) return false;
323           neg_ids.push_back(id);
324         }
325         const analysis::Constant* negated_const =
326             const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
327         id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
328       } else if (constants[1]->AsFloatConstant()) {
329         id = Reciprocal(const_mgr, constants[1]);
330         if (id == 0) return false;
331       } else {
332         // Don't fold a null constant.
333         return false;
334       }
335       inst->SetOpcode(spv::Op::OpFMul);
336       inst->SetInOperands(
337           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
338            {SPV_OPERAND_TYPE_ID, {id}}});
339       return true;
340     }
341 
342     return false;
343   };
344 }
345 
346 // Elides consecutive negate instructions.
MergeNegateArithmetic()347 FoldingRule MergeNegateArithmetic() {
348   return [](IRContext* context, Instruction* inst,
349             const std::vector<const analysis::Constant*>& constants) {
350     assert(inst->opcode() == spv::Op::OpFNegate ||
351            inst->opcode() == spv::Op::OpSNegate);
352     (void)constants;
353     const analysis::Type* type =
354         context->get_type_mgr()->GetType(inst->type_id());
355     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
356       return false;
357 
358     Instruction* op_inst =
359         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
360     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
361       return false;
362 
363     if (op_inst->opcode() == inst->opcode()) {
364       // Elide negates.
365       inst->SetOpcode(spv::Op::OpCopyObject);
366       inst->SetInOperands(
367           {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
368       return true;
369     }
370 
371     return false;
372   };
373 }
374 
375 // Merges negate into a mul or div operation if that operation contains a
376 // constant operand.
377 // Cases:
378 // -(x * 2) = x * -2
379 // -(2 * x) = x * -2
380 // -(x / 2) = x / -2
381 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()382 FoldingRule MergeNegateMulDivArithmetic() {
383   return [](IRContext* context, Instruction* inst,
384             const std::vector<const analysis::Constant*>& constants) {
385     assert(inst->opcode() == spv::Op::OpFNegate ||
386            inst->opcode() == spv::Op::OpSNegate);
387     (void)constants;
388     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
389     const analysis::Type* type =
390         context->get_type_mgr()->GetType(inst->type_id());
391     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
392       return false;
393 
394     Instruction* op_inst =
395         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
396     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
397       return false;
398 
399     uint32_t width = ElementWidth(type);
400     if (width != 32 && width != 64) return false;
401 
402     spv::Op opcode = op_inst->opcode();
403     if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv ||
404         opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv ||
405         opcode == spv::Op::OpUDiv) {
406       std::vector<const analysis::Constant*> op_constants =
407           const_mgr->GetOperandConstants(op_inst);
408       // Merge negate into mul or div if one operand is constant.
409       if (op_constants[0] || op_constants[1]) {
410         bool zero_is_variable = op_constants[0] == nullptr;
411         const analysis::Constant* c = ConstInput(op_constants);
412         uint32_t neg_id = NegateConstant(const_mgr, c);
413         uint32_t non_const_id = zero_is_variable
414                                     ? op_inst->GetSingleWordInOperand(0u)
415                                     : op_inst->GetSingleWordInOperand(1u);
416         // Change this instruction to a mul/div.
417         inst->SetOpcode(op_inst->opcode());
418         if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
419             opcode == spv::Op::OpSDiv) {
420           uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
421           uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
422           inst->SetInOperands(
423               {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
424         } else {
425           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
426                                {SPV_OPERAND_TYPE_ID, {neg_id}}});
427         }
428         return true;
429       }
430     }
431 
432     return false;
433   };
434 }
435 
436 // Merges negate into a add or sub operation if that operation contains a
437 // constant operand.
438 // Cases:
439 // -(x + 2) = -2 - x
440 // -(2 + x) = -2 - x
441 // -(x - 2) = 2 - x
442 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()443 FoldingRule MergeNegateAddSubArithmetic() {
444   return [](IRContext* context, Instruction* inst,
445             const std::vector<const analysis::Constant*>& constants) {
446     assert(inst->opcode() == spv::Op::OpFNegate ||
447            inst->opcode() == spv::Op::OpSNegate);
448     (void)constants;
449     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
450     const analysis::Type* type =
451         context->get_type_mgr()->GetType(inst->type_id());
452     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
453       return false;
454 
455     Instruction* op_inst =
456         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
457     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
458       return false;
459 
460     uint32_t width = ElementWidth(type);
461     if (width != 32 && width != 64) return false;
462 
463     if (op_inst->opcode() == spv::Op::OpFAdd ||
464         op_inst->opcode() == spv::Op::OpFSub ||
465         op_inst->opcode() == spv::Op::OpIAdd ||
466         op_inst->opcode() == spv::Op::OpISub) {
467       std::vector<const analysis::Constant*> op_constants =
468           const_mgr->GetOperandConstants(op_inst);
469       if (op_constants[0] || op_constants[1]) {
470         bool zero_is_variable = op_constants[0] == nullptr;
471         bool is_add = (op_inst->opcode() == spv::Op::OpFAdd) ||
472                       (op_inst->opcode() == spv::Op::OpIAdd);
473         bool swap_operands = !is_add || zero_is_variable;
474         bool negate_const = is_add;
475         const analysis::Constant* c = ConstInput(op_constants);
476         uint32_t const_id = 0;
477         if (negate_const) {
478           const_id = NegateConstant(const_mgr, c);
479         } else {
480           const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
481                                       : op_inst->GetSingleWordInOperand(0u);
482         }
483 
484         // Swap operands if necessary and make the instruction a subtraction.
485         uint32_t op0 =
486             zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
487         uint32_t op1 =
488             zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
489         if (swap_operands) std::swap(op0, op1);
490         inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
491                                                : spv::Op::OpISub);
492         inst->SetInOperands(
493             {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
494         return true;
495       }
496     }
497 
498     return false;
499   };
500 }
501 
502 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)503 bool HasZero(const analysis::Constant* c) {
504   if (c->AsNullConstant()) {
505     return true;
506   }
507   if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
508     for (auto& comp : vec_const->GetComponents())
509       if (HasZero(comp)) return true;
510   } else {
511     assert(c->AsScalarConstant());
512     return c->AsScalarConstant()->IsZero();
513   }
514 
515   return false;
516 }
517 
518 // Performs |input1| |opcode| |input2| and returns the merged constant result
519 // id. Returns 0 if the result is not a valid value. The input types must be
520 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)521 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
522                                        spv::Op opcode,
523                                        const analysis::Constant* input1,
524                                        const analysis::Constant* input2) {
525   const analysis::Type* type = input1->type();
526   assert(type->AsFloat());
527   uint32_t width = type->AsFloat()->width();
528   assert(width == 32 || width == 64);
529   std::vector<uint32_t> words;
530 #define FOLD_OP(op)                                                          \
531   if (width == 64) {                                                         \
532     utils::FloatProxy<double> val =                                          \
533         input1->GetDouble() op input2->GetDouble();                          \
534     double dval = val.getAsFloat();                                          \
535     if (!IsValidResult(dval)) return 0;                                      \
536     words = val.GetWords();                                                  \
537   } else {                                                                   \
538     utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
539     float fval = val.getAsFloat();                                           \
540     if (!IsValidResult(fval)) return 0;                                      \
541     words = val.GetWords();                                                  \
542   }                                                                          \
543   static_assert(true, "require extra semicolon")
544   switch (opcode) {
545     case spv::Op::OpFMul:
546       FOLD_OP(*);
547       break;
548     case spv::Op::OpFDiv:
549       if (HasZero(input2)) return 0;
550       FOLD_OP(/);
551       break;
552     case spv::Op::OpFAdd:
553       FOLD_OP(+);
554       break;
555     case spv::Op::OpFSub:
556       FOLD_OP(-);
557       break;
558     default:
559       assert(false && "Unexpected operation");
560       break;
561   }
562 #undef FOLD_OP
563   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
564   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
565 }
566 
567 // Performs |input1| |opcode| |input2| and returns the merged constant result
568 // id. Returns 0 if the result is not a valid value. The input types must be
569 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)570 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
571                                  spv::Op opcode,
572                                  const analysis::Constant* input1,
573                                  const analysis::Constant* input2) {
574   assert(input1->type()->AsInteger());
575   const analysis::Integer* type = input1->type()->AsInteger();
576   uint32_t width = type->AsInteger()->width();
577   assert(width == 32 || width == 64);
578   std::vector<uint32_t> words;
579   // Regardless of the sign of the constant, folding is performed on an unsigned
580   // interpretation of the constant data. This avoids signed integer overflow
581   // while folding, and works because sign is irrelevant for the IAdd, ISub and
582   // IMul instructions.
583 #define FOLD_OP(op)                                      \
584   if (width == 64) {                                     \
585     uint64_t val = input1->GetU64() op input2->GetU64(); \
586     words = ExtractInts(val);                            \
587   } else {                                               \
588     uint32_t val = input1->GetU32() op input2->GetU32(); \
589     words.push_back(val);                                \
590   }                                                      \
591   static_assert(true, "require extra semicolon")
592   switch (opcode) {
593     case spv::Op::OpIMul:
594       FOLD_OP(*);
595       break;
596     case spv::Op::OpSDiv:
597     case spv::Op::OpUDiv:
598       assert(false && "Should not merge integer division");
599       break;
600     case spv::Op::OpIAdd:
601       FOLD_OP(+);
602       break;
603     case spv::Op::OpISub:
604       FOLD_OP(-);
605       break;
606     default:
607       assert(false && "Unexpected operation");
608       break;
609   }
610 #undef FOLD_OP
611   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
612   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
613 }
614 
615 // Performs |input1| |opcode| |input2| and returns the merged constant result
616 // id. Returns 0 if the result is not a valid value. The input types must be
617 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)618 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, spv::Op opcode,
619                           const analysis::Constant* input1,
620                           const analysis::Constant* input2) {
621   assert(input1 && input2);
622   const analysis::Type* type = input1->type();
623   std::vector<uint32_t> words;
624   if (const analysis::Vector* vector_type = type->AsVector()) {
625     const analysis::Type* ele_type = vector_type->element_type();
626     for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
627       uint32_t id = 0;
628 
629       const analysis::Constant* input1_comp = nullptr;
630       if (const analysis::VectorConstant* input1_vector =
631               input1->AsVectorConstant()) {
632         input1_comp = input1_vector->GetComponents()[i];
633       } else {
634         assert(input1->AsNullConstant());
635         input1_comp = const_mgr->GetConstant(ele_type, {});
636       }
637 
638       const analysis::Constant* input2_comp = nullptr;
639       if (const analysis::VectorConstant* input2_vector =
640               input2->AsVectorConstant()) {
641         input2_comp = input2_vector->GetComponents()[i];
642       } else {
643         assert(input2->AsNullConstant());
644         input2_comp = const_mgr->GetConstant(ele_type, {});
645       }
646 
647       if (ele_type->AsFloat()) {
648         id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
649                                            input2_comp);
650       } else {
651         assert(ele_type->AsInteger());
652         id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
653                                      input2_comp);
654       }
655       if (id == 0) return 0;
656       words.push_back(id);
657     }
658     const analysis::Constant* merged_const =
659         const_mgr->GetConstant(type, words);
660     return const_mgr->GetDefiningInstruction(merged_const)->result_id();
661   } else if (type->AsFloat()) {
662     return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
663   } else {
664     assert(type->AsInteger());
665     return PerformIntegerOperation(const_mgr, opcode, input1, input2);
666   }
667 }
668 
669 // Merges consecutive multiplies where each contains one constant operand.
670 // Cases:
671 // 2 * (x * 2) = x * 4
672 // 2 * (2 * x) = x * 4
673 // (x * 2) * 2 = x * 4
674 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()675 FoldingRule MergeMulMulArithmetic() {
676   return [](IRContext* context, Instruction* inst,
677             const std::vector<const analysis::Constant*>& constants) {
678     assert(inst->opcode() == spv::Op::OpFMul ||
679            inst->opcode() == spv::Op::OpIMul);
680     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
681     const analysis::Type* type =
682         context->get_type_mgr()->GetType(inst->type_id());
683     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
684       return false;
685 
686     uint32_t width = ElementWidth(type);
687     if (width != 32 && width != 64) return false;
688 
689     // Determine the constant input and the variable input in |inst|.
690     const analysis::Constant* const_input1 = ConstInput(constants);
691     if (!const_input1) return false;
692     Instruction* other_inst = NonConstInput(context, constants[0], inst);
693     if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
694       return false;
695 
696     if (other_inst->opcode() == inst->opcode()) {
697       std::vector<const analysis::Constant*> other_constants =
698           const_mgr->GetOperandConstants(other_inst);
699       const analysis::Constant* const_input2 = ConstInput(other_constants);
700       if (!const_input2) return false;
701 
702       bool other_first_is_variable = other_constants[0] == nullptr;
703       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
704                                             const_input1, const_input2);
705       if (merged_id == 0) return false;
706 
707       uint32_t non_const_id = other_first_is_variable
708                                   ? other_inst->GetSingleWordInOperand(0u)
709                                   : other_inst->GetSingleWordInOperand(1u);
710       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
711                            {SPV_OPERAND_TYPE_ID, {merged_id}}});
712       return true;
713     }
714 
715     return false;
716   };
717 }
718 
719 // Merges divides into subsequent multiplies if each instruction contains one
720 // constant operand. Does not support integer operations.
721 // Cases:
722 // 2 * (x / 2) = x * 1
723 // 2 * (2 / x) = 4 / x
724 // (x / 2) * 2 = x * 1
725 // (2 / x) * 2 = 4 / x
726 // (y / x) * x = y
727 // x * (y / x) = y
MergeMulDivArithmetic()728 FoldingRule MergeMulDivArithmetic() {
729   return [](IRContext* context, Instruction* inst,
730             const std::vector<const analysis::Constant*>& constants) {
731     assert(inst->opcode() == spv::Op::OpFMul);
732     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
733     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
734 
735     const analysis::Type* type =
736         context->get_type_mgr()->GetType(inst->type_id());
737     if (!inst->IsFloatingPointFoldingAllowed()) return false;
738 
739     uint32_t width = ElementWidth(type);
740     if (width != 32 && width != 64) return false;
741 
742     for (uint32_t i = 0; i < 2; i++) {
743       uint32_t op_id = inst->GetSingleWordInOperand(i);
744       Instruction* op_inst = def_use_mgr->GetDef(op_id);
745       if (op_inst->opcode() == spv::Op::OpFDiv) {
746         if (op_inst->GetSingleWordInOperand(1) ==
747             inst->GetSingleWordInOperand(1 - i)) {
748           inst->SetOpcode(spv::Op::OpCopyObject);
749           inst->SetInOperands(
750               {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
751           return true;
752         }
753       }
754     }
755 
756     const analysis::Constant* const_input1 = ConstInput(constants);
757     if (!const_input1) return false;
758     Instruction* other_inst = NonConstInput(context, constants[0], inst);
759     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
760 
761     if (other_inst->opcode() == spv::Op::OpFDiv) {
762       std::vector<const analysis::Constant*> other_constants =
763           const_mgr->GetOperandConstants(other_inst);
764       const analysis::Constant* const_input2 = ConstInput(other_constants);
765       if (!const_input2 || HasZero(const_input2)) return false;
766 
767       bool other_first_is_variable = other_constants[0] == nullptr;
768       // If the variable value is the second operand of the divide, multiply
769       // the constants together. Otherwise divide the constants.
770       uint32_t merged_id = PerformOperation(
771           const_mgr,
772           other_first_is_variable ? other_inst->opcode() : inst->opcode(),
773           const_input1, const_input2);
774       if (merged_id == 0) return false;
775 
776       uint32_t non_const_id = other_first_is_variable
777                                   ? other_inst->GetSingleWordInOperand(0u)
778                                   : other_inst->GetSingleWordInOperand(1u);
779 
780       // If the variable value is on the second operand of the div, then this
781       // operation is a div. Otherwise it should be a multiply.
782       inst->SetOpcode(other_first_is_variable ? inst->opcode()
783                                               : other_inst->opcode());
784       if (other_first_is_variable) {
785         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
786                              {SPV_OPERAND_TYPE_ID, {merged_id}}});
787       } else {
788         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
789                              {SPV_OPERAND_TYPE_ID, {non_const_id}}});
790       }
791       return true;
792     }
793 
794     return false;
795   };
796 }
797 
798 // Merges multiply of constant and negation.
799 // Cases:
800 // (-x) * 2 = x * -2
801 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()802 FoldingRule MergeMulNegateArithmetic() {
803   return [](IRContext* context, Instruction* inst,
804             const std::vector<const analysis::Constant*>& constants) {
805     assert(inst->opcode() == spv::Op::OpFMul ||
806            inst->opcode() == spv::Op::OpIMul);
807     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
808     const analysis::Type* type =
809         context->get_type_mgr()->GetType(inst->type_id());
810     bool uses_float = HasFloatingPoint(type);
811     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
812 
813     uint32_t width = ElementWidth(type);
814     if (width != 32 && width != 64) return false;
815 
816     const analysis::Constant* const_input1 = ConstInput(constants);
817     if (!const_input1) return false;
818     Instruction* other_inst = NonConstInput(context, constants[0], inst);
819     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
820       return false;
821 
822     if (other_inst->opcode() == spv::Op::OpFNegate ||
823         other_inst->opcode() == spv::Op::OpSNegate) {
824       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
825 
826       inst->SetInOperands(
827           {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
828            {SPV_OPERAND_TYPE_ID, {neg_id}}});
829       return true;
830     }
831 
832     return false;
833   };
834 }
835 
836 // Merges consecutive divides if each instruction contains one constant operand.
837 // Does not support integer division.
838 // Cases:
839 // 2 / (x / 2) = 4 / x
840 // 4 / (2 / x) = 2 * x
841 // (4 / x) / 2 = 2 / x
842 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()843 FoldingRule MergeDivDivArithmetic() {
844   return [](IRContext* context, Instruction* inst,
845             const std::vector<const analysis::Constant*>& constants) {
846     assert(inst->opcode() == spv::Op::OpFDiv);
847     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
848     const analysis::Type* type =
849         context->get_type_mgr()->GetType(inst->type_id());
850     if (!inst->IsFloatingPointFoldingAllowed()) return false;
851 
852     uint32_t width = ElementWidth(type);
853     if (width != 32 && width != 64) return false;
854 
855     const analysis::Constant* const_input1 = ConstInput(constants);
856     if (!const_input1 || HasZero(const_input1)) return false;
857     Instruction* other_inst = NonConstInput(context, constants[0], inst);
858     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
859 
860     bool first_is_variable = constants[0] == nullptr;
861     if (other_inst->opcode() == inst->opcode()) {
862       std::vector<const analysis::Constant*> other_constants =
863           const_mgr->GetOperandConstants(other_inst);
864       const analysis::Constant* const_input2 = ConstInput(other_constants);
865       if (!const_input2 || HasZero(const_input2)) return false;
866 
867       bool other_first_is_variable = other_constants[0] == nullptr;
868 
869       spv::Op merge_op = inst->opcode();
870       if (other_first_is_variable) {
871         // Constants magnify.
872         merge_op = spv::Op::OpFMul;
873       }
874 
875       // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
876       // because it is commutative.
877       if (first_is_variable) std::swap(const_input1, const_input2);
878       uint32_t merged_id =
879           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
880       if (merged_id == 0) return false;
881 
882       uint32_t non_const_id = other_first_is_variable
883                                   ? other_inst->GetSingleWordInOperand(0u)
884                                   : other_inst->GetSingleWordInOperand(1u);
885 
886       spv::Op op = inst->opcode();
887       if (!first_is_variable && !other_first_is_variable) {
888         // Effectively div of 1/x, so change to multiply.
889         op = spv::Op::OpFMul;
890       }
891 
892       uint32_t op1 = merged_id;
893       uint32_t op2 = non_const_id;
894       if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
895       inst->SetOpcode(op);
896       inst->SetInOperands(
897           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
898       return true;
899     }
900 
901     return false;
902   };
903 }
904 
905 // Fold multiplies succeeded by divides where each instruction contains a
906 // constant operand. Does not support integer divide.
907 // Cases:
908 // 4 / (x * 2) = 2 / x
909 // 4 / (2 * x) = 2 / x
910 // (x * 4) / 2 = x * 2
911 // (4 * x) / 2 = x * 2
912 // (x * y) / x = y
913 // (y * x) / x = y
MergeDivMulArithmetic()914 FoldingRule MergeDivMulArithmetic() {
915   return [](IRContext* context, Instruction* inst,
916             const std::vector<const analysis::Constant*>& constants) {
917     assert(inst->opcode() == spv::Op::OpFDiv);
918     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
919     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
920 
921     const analysis::Type* type =
922         context->get_type_mgr()->GetType(inst->type_id());
923     if (!inst->IsFloatingPointFoldingAllowed()) return false;
924 
925     uint32_t width = ElementWidth(type);
926     if (width != 32 && width != 64) return false;
927 
928     uint32_t op_id = inst->GetSingleWordInOperand(0);
929     Instruction* op_inst = def_use_mgr->GetDef(op_id);
930 
931     if (op_inst->opcode() == spv::Op::OpFMul) {
932       for (uint32_t i = 0; i < 2; i++) {
933         if (op_inst->GetSingleWordInOperand(i) ==
934             inst->GetSingleWordInOperand(1)) {
935           inst->SetOpcode(spv::Op::OpCopyObject);
936           inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
937                                 {op_inst->GetSingleWordInOperand(1 - i)}}});
938           return true;
939         }
940       }
941     }
942 
943     const analysis::Constant* const_input1 = ConstInput(constants);
944     if (!const_input1 || HasZero(const_input1)) return false;
945     Instruction* other_inst = NonConstInput(context, constants[0], inst);
946     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
947 
948     bool first_is_variable = constants[0] == nullptr;
949     if (other_inst->opcode() == spv::Op::OpFMul) {
950       std::vector<const analysis::Constant*> other_constants =
951           const_mgr->GetOperandConstants(other_inst);
952       const analysis::Constant* const_input2 = ConstInput(other_constants);
953       if (!const_input2) return false;
954 
955       bool other_first_is_variable = other_constants[0] == nullptr;
956 
957       // This is an x / (*) case. Swap the inputs.
958       if (first_is_variable) std::swap(const_input1, const_input2);
959       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
960                                             const_input1, const_input2);
961       if (merged_id == 0) return false;
962 
963       uint32_t non_const_id = other_first_is_variable
964                                   ? other_inst->GetSingleWordInOperand(0u)
965                                   : other_inst->GetSingleWordInOperand(1u);
966 
967       uint32_t op1 = merged_id;
968       uint32_t op2 = non_const_id;
969       if (first_is_variable) std::swap(op1, op2);
970 
971       // Convert to multiply
972       if (first_is_variable) inst->SetOpcode(other_inst->opcode());
973       inst->SetInOperands(
974           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
975       return true;
976     }
977 
978     return false;
979   };
980 }
981 
982 // Fold divides of a constant and a negation.
983 // Cases:
984 // (-x) / 2 = x / -2
985 // 2 / (-x) = -2 / x
MergeDivNegateArithmetic()986 FoldingRule MergeDivNegateArithmetic() {
987   return [](IRContext* context, Instruction* inst,
988             const std::vector<const analysis::Constant*>& constants) {
989     assert(inst->opcode() == spv::Op::OpFDiv);
990     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
991     if (!inst->IsFloatingPointFoldingAllowed()) return false;
992 
993     const analysis::Constant* const_input1 = ConstInput(constants);
994     if (!const_input1) return false;
995     Instruction* other_inst = NonConstInput(context, constants[0], inst);
996     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
997 
998     bool first_is_variable = constants[0] == nullptr;
999     if (other_inst->opcode() == spv::Op::OpFNegate) {
1000       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
1001 
1002       if (first_is_variable) {
1003         inst->SetInOperands(
1004             {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
1005              {SPV_OPERAND_TYPE_ID, {neg_id}}});
1006       } else {
1007         inst->SetInOperands(
1008             {{SPV_OPERAND_TYPE_ID, {neg_id}},
1009              {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1010       }
1011       return true;
1012     }
1013 
1014     return false;
1015   };
1016 }
1017 
1018 // Folds addition of a constant and a negation.
1019 // Cases:
1020 // (-x) + 2 = 2 - x
1021 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()1022 FoldingRule MergeAddNegateArithmetic() {
1023   return [](IRContext* context, Instruction* inst,
1024             const std::vector<const analysis::Constant*>& constants) {
1025     assert(inst->opcode() == spv::Op::OpFAdd ||
1026            inst->opcode() == spv::Op::OpIAdd);
1027     const analysis::Type* type =
1028         context->get_type_mgr()->GetType(inst->type_id());
1029     bool uses_float = HasFloatingPoint(type);
1030     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1031 
1032     const analysis::Constant* const_input1 = ConstInput(constants);
1033     if (!const_input1) return false;
1034     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1035     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1036       return false;
1037 
1038     if (other_inst->opcode() == spv::Op::OpSNegate ||
1039         other_inst->opcode() == spv::Op::OpFNegate) {
1040       inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
1041                                              : spv::Op::OpISub);
1042       uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
1043                                        : inst->GetSingleWordInOperand(1u);
1044       inst->SetInOperands(
1045           {{SPV_OPERAND_TYPE_ID, {const_id}},
1046            {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1047       return true;
1048     }
1049     return false;
1050   };
1051 }
1052 
1053 // Folds subtraction of a constant and a negation.
1054 // Cases:
1055 // (-x) - 2 = -2 - x
1056 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()1057 FoldingRule MergeSubNegateArithmetic() {
1058   return [](IRContext* context, Instruction* inst,
1059             const std::vector<const analysis::Constant*>& constants) {
1060     assert(inst->opcode() == spv::Op::OpFSub ||
1061            inst->opcode() == spv::Op::OpISub);
1062     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1063     const analysis::Type* type =
1064         context->get_type_mgr()->GetType(inst->type_id());
1065     bool uses_float = HasFloatingPoint(type);
1066     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1067 
1068     uint32_t width = ElementWidth(type);
1069     if (width != 32 && width != 64) return false;
1070 
1071     const analysis::Constant* const_input1 = ConstInput(constants);
1072     if (!const_input1) return false;
1073     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1074     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1075       return false;
1076 
1077     if (other_inst->opcode() == spv::Op::OpSNegate ||
1078         other_inst->opcode() == spv::Op::OpFNegate) {
1079       uint32_t op1 = 0;
1080       uint32_t op2 = 0;
1081       spv::Op opcode = inst->opcode();
1082       if (constants[0] != nullptr) {
1083         op1 = other_inst->GetSingleWordInOperand(0u);
1084         op2 = inst->GetSingleWordInOperand(0u);
1085         opcode = HasFloatingPoint(type) ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1086       } else {
1087         op1 = NegateConstant(const_mgr, const_input1);
1088         op2 = other_inst->GetSingleWordInOperand(0u);
1089       }
1090 
1091       inst->SetOpcode(opcode);
1092       inst->SetInOperands(
1093           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1094       return true;
1095     }
1096     return false;
1097   };
1098 }
1099 
1100 // Folds addition of an addition where each operation has a constant operand.
1101 // Cases:
1102 // (x + 2) + 2 = x + 4
1103 // (2 + x) + 2 = x + 4
1104 // 2 + (x + 2) = x + 4
1105 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1106 FoldingRule MergeAddAddArithmetic() {
1107   return [](IRContext* context, Instruction* inst,
1108             const std::vector<const analysis::Constant*>& constants) {
1109     assert(inst->opcode() == spv::Op::OpFAdd ||
1110            inst->opcode() == spv::Op::OpIAdd);
1111     const analysis::Type* type =
1112         context->get_type_mgr()->GetType(inst->type_id());
1113     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1114     bool uses_float = HasFloatingPoint(type);
1115     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1116 
1117     uint32_t width = ElementWidth(type);
1118     if (width != 32 && width != 64) return false;
1119 
1120     const analysis::Constant* const_input1 = ConstInput(constants);
1121     if (!const_input1) return false;
1122     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1123     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1124       return false;
1125 
1126     if (other_inst->opcode() == spv::Op::OpFAdd ||
1127         other_inst->opcode() == spv::Op::OpIAdd) {
1128       std::vector<const analysis::Constant*> other_constants =
1129           const_mgr->GetOperandConstants(other_inst);
1130       const analysis::Constant* const_input2 = ConstInput(other_constants);
1131       if (!const_input2) return false;
1132 
1133       Instruction* non_const_input =
1134           NonConstInput(context, other_constants[0], other_inst);
1135       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1136                                             const_input1, const_input2);
1137       if (merged_id == 0) return false;
1138 
1139       inst->SetInOperands(
1140           {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1141            {SPV_OPERAND_TYPE_ID, {merged_id}}});
1142       return true;
1143     }
1144     return false;
1145   };
1146 }
1147 
1148 // Folds addition of a subtraction where each operation has a constant operand.
1149 // Cases:
1150 // (x - 2) + 2 = x + 0
1151 // (2 - x) + 2 = 4 - x
1152 // 2 + (x - 2) = x + 0
1153 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1154 FoldingRule MergeAddSubArithmetic() {
1155   return [](IRContext* context, Instruction* inst,
1156             const std::vector<const analysis::Constant*>& constants) {
1157     assert(inst->opcode() == spv::Op::OpFAdd ||
1158            inst->opcode() == spv::Op::OpIAdd);
1159     const analysis::Type* type =
1160         context->get_type_mgr()->GetType(inst->type_id());
1161     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1162     bool uses_float = HasFloatingPoint(type);
1163     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1164 
1165     uint32_t width = ElementWidth(type);
1166     if (width != 32 && width != 64) return false;
1167 
1168     const analysis::Constant* const_input1 = ConstInput(constants);
1169     if (!const_input1) return false;
1170     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1171     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1172       return false;
1173 
1174     if (other_inst->opcode() == spv::Op::OpFSub ||
1175         other_inst->opcode() == spv::Op::OpISub) {
1176       std::vector<const analysis::Constant*> other_constants =
1177           const_mgr->GetOperandConstants(other_inst);
1178       const analysis::Constant* const_input2 = ConstInput(other_constants);
1179       if (!const_input2) return false;
1180 
1181       bool first_is_variable = other_constants[0] == nullptr;
1182       spv::Op op = inst->opcode();
1183       uint32_t op1 = 0;
1184       uint32_t op2 = 0;
1185       if (first_is_variable) {
1186         // Subtract constants. Non-constant operand is first.
1187         op1 = other_inst->GetSingleWordInOperand(0u);
1188         op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1189                                const_input2);
1190       } else {
1191         // Add constants. Constant operand is first. Change the opcode.
1192         op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1193                                const_input2);
1194         op2 = other_inst->GetSingleWordInOperand(1u);
1195         op = other_inst->opcode();
1196       }
1197       if (op1 == 0 || op2 == 0) return false;
1198 
1199       inst->SetOpcode(op);
1200       inst->SetInOperands(
1201           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1202       return true;
1203     }
1204     return false;
1205   };
1206 }
1207 
1208 // Folds subtraction of an addition where each operand has a constant operand.
1209 // Cases:
1210 // (x + 2) - 2 = x + 0
1211 // (2 + x) - 2 = x + 0
1212 // 2 - (x + 2) = 0 - x
1213 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1214 FoldingRule MergeSubAddArithmetic() {
1215   return [](IRContext* context, Instruction* inst,
1216             const std::vector<const analysis::Constant*>& constants) {
1217     assert(inst->opcode() == spv::Op::OpFSub ||
1218            inst->opcode() == spv::Op::OpISub);
1219     const analysis::Type* type =
1220         context->get_type_mgr()->GetType(inst->type_id());
1221     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1222     bool uses_float = HasFloatingPoint(type);
1223     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1224 
1225     uint32_t width = ElementWidth(type);
1226     if (width != 32 && width != 64) return false;
1227 
1228     const analysis::Constant* const_input1 = ConstInput(constants);
1229     if (!const_input1) return false;
1230     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1231     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1232       return false;
1233 
1234     if (other_inst->opcode() == spv::Op::OpFAdd ||
1235         other_inst->opcode() == spv::Op::OpIAdd) {
1236       std::vector<const analysis::Constant*> other_constants =
1237           const_mgr->GetOperandConstants(other_inst);
1238       const analysis::Constant* const_input2 = ConstInput(other_constants);
1239       if (!const_input2) return false;
1240 
1241       Instruction* non_const_input =
1242           NonConstInput(context, other_constants[0], other_inst);
1243 
1244       // If the first operand of the sub is not a constant, swap the constants
1245       // so the subtraction has the correct operands.
1246       if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1247       // Subtract the constants.
1248       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1249                                             const_input1, const_input2);
1250       spv::Op op = inst->opcode();
1251       uint32_t op1 = 0;
1252       uint32_t op2 = 0;
1253       if (constants[0] == nullptr) {
1254         // Non-constant operand is first. Change the opcode.
1255         op1 = non_const_input->result_id();
1256         op2 = merged_id;
1257         op = other_inst->opcode();
1258       } else {
1259         // Constant operand is first.
1260         op1 = merged_id;
1261         op2 = non_const_input->result_id();
1262       }
1263       if (op1 == 0 || op2 == 0) return false;
1264 
1265       inst->SetOpcode(op);
1266       inst->SetInOperands(
1267           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1268       return true;
1269     }
1270     return false;
1271   };
1272 }
1273 
1274 // Folds subtraction of a subtraction where each operand has a constant operand.
1275 // Cases:
1276 // (x - 2) - 2 = x - 4
1277 // (2 - x) - 2 = 0 - x
1278 // 2 - (x - 2) = 4 - x
1279 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1280 FoldingRule MergeSubSubArithmetic() {
1281   return [](IRContext* context, Instruction* inst,
1282             const std::vector<const analysis::Constant*>& constants) {
1283     assert(inst->opcode() == spv::Op::OpFSub ||
1284            inst->opcode() == spv::Op::OpISub);
1285     const analysis::Type* type =
1286         context->get_type_mgr()->GetType(inst->type_id());
1287     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1288     bool uses_float = HasFloatingPoint(type);
1289     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1290 
1291     uint32_t width = ElementWidth(type);
1292     if (width != 32 && width != 64) return false;
1293 
1294     const analysis::Constant* const_input1 = ConstInput(constants);
1295     if (!const_input1) return false;
1296     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1297     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1298       return false;
1299 
1300     if (other_inst->opcode() == spv::Op::OpFSub ||
1301         other_inst->opcode() == spv::Op::OpISub) {
1302       std::vector<const analysis::Constant*> other_constants =
1303           const_mgr->GetOperandConstants(other_inst);
1304       const analysis::Constant* const_input2 = ConstInput(other_constants);
1305       if (!const_input2) return false;
1306 
1307       Instruction* non_const_input =
1308           NonConstInput(context, other_constants[0], other_inst);
1309 
1310       // Merge the constants.
1311       uint32_t merged_id = 0;
1312       spv::Op merge_op = inst->opcode();
1313       if (other_constants[0] == nullptr) {
1314         merge_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1315       } else if (constants[0] == nullptr) {
1316         std::swap(const_input1, const_input2);
1317       }
1318       merged_id =
1319           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1320       if (merged_id == 0) return false;
1321 
1322       spv::Op op = inst->opcode();
1323       if (constants[0] != nullptr && other_constants[0] != nullptr) {
1324         // Change the operation.
1325         op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1326       }
1327 
1328       uint32_t op1 = 0;
1329       uint32_t op2 = 0;
1330       if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1331         op1 = merged_id;
1332         op2 = non_const_input->result_id();
1333       } else {
1334         op1 = non_const_input->result_id();
1335         op2 = merged_id;
1336       }
1337 
1338       inst->SetOpcode(op);
1339       inst->SetInOperands(
1340           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1341       return true;
1342     }
1343     return false;
1344   };
1345 }
1346 
1347 // Helper function for MergeGenericAddSubArithmetic. If |addend| and
1348 // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
MergeGenericAddendSub(uint32_t addend,uint32_t sub,Instruction * inst)1349 bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1350   IRContext* context = inst->context();
1351   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1352   Instruction* sub_inst = def_use_mgr->GetDef(sub);
1353   if (sub_inst->opcode() != spv::Op::OpFSub &&
1354       sub_inst->opcode() != spv::Op::OpISub)
1355     return false;
1356   if (sub_inst->opcode() == spv::Op::OpFSub &&
1357       !sub_inst->IsFloatingPointFoldingAllowed())
1358     return false;
1359   if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1360   inst->SetOpcode(spv::Op::OpCopyObject);
1361   inst->SetInOperands(
1362       {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1363   context->UpdateDefUse(inst);
1364   return true;
1365 }
1366 
1367 // Folds addition of a subtraction where the subtrahend is equal to the
1368 // other addend. Return a copy of the minuend. Accepts generic (const and
1369 // non-const) operands.
1370 // Cases:
1371 // (a - b) + b = a
1372 // b + (a - b) = a
MergeGenericAddSubArithmetic()1373 FoldingRule MergeGenericAddSubArithmetic() {
1374   return [](IRContext* context, Instruction* inst,
1375             const std::vector<const analysis::Constant*>&) {
1376     assert(inst->opcode() == spv::Op::OpFAdd ||
1377            inst->opcode() == spv::Op::OpIAdd);
1378     const analysis::Type* type =
1379         context->get_type_mgr()->GetType(inst->type_id());
1380     bool uses_float = HasFloatingPoint(type);
1381     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1382 
1383     uint32_t width = ElementWidth(type);
1384     if (width != 32 && width != 64) return false;
1385 
1386     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1387     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1388     if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1389     return MergeGenericAddendSub(add_op1, add_op0, inst);
1390   };
1391 }
1392 
1393 // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1394 // generate |factor0_0| * (|factor0_1| + |factor1_1|).
FactorAddMulsOpnds(uint32_t factor0_0,uint32_t factor0_1,uint32_t factor1_0,uint32_t factor1_1,Instruction * inst)1395 bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1396                         uint32_t factor1_0, uint32_t factor1_1,
1397                         Instruction* inst) {
1398   IRContext* context = inst->context();
1399   if (factor0_0 != factor1_0) return false;
1400   InstructionBuilder ir_builder(
1401       context, inst,
1402       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1403   Instruction* new_add_inst = ir_builder.AddBinaryOp(
1404       inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1405   inst->SetOpcode(inst->opcode() == spv::Op::OpFAdd ? spv::Op::OpFMul
1406                                                     : spv::Op::OpIMul);
1407   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1408                        {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1409   context->UpdateDefUse(inst);
1410   return true;
1411 }
1412 
1413 // Perform the following factoring identity, handling all operand order
1414 // combinations: (a * b) + (a * c) = a * (b + c)
FactorAddMuls()1415 FoldingRule FactorAddMuls() {
1416   return [](IRContext* context, Instruction* inst,
1417             const std::vector<const analysis::Constant*>&) {
1418     assert(inst->opcode() == spv::Op::OpFAdd ||
1419            inst->opcode() == spv::Op::OpIAdd);
1420     const analysis::Type* type =
1421         context->get_type_mgr()->GetType(inst->type_id());
1422     bool uses_float = HasFloatingPoint(type);
1423     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1424 
1425     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1426     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1427     Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1428     if (add_op0_inst->opcode() != spv::Op::OpFMul &&
1429         add_op0_inst->opcode() != spv::Op::OpIMul)
1430       return false;
1431     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1432     Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1433     if (add_op1_inst->opcode() != spv::Op::OpFMul &&
1434         add_op1_inst->opcode() != spv::Op::OpIMul)
1435       return false;
1436 
1437     // Only perform this optimization if both of the muls only have one use.
1438     // Otherwise this is a deoptimization in size and performance.
1439     if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1440     if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1441 
1442     if (add_op0_inst->opcode() == spv::Op::OpFMul &&
1443         (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1444          !add_op1_inst->IsFloatingPointFoldingAllowed()))
1445       return false;
1446 
1447     for (int i = 0; i < 2; i++) {
1448       for (int j = 0; j < 2; j++) {
1449         // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1450         if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1451                                add_op0_inst->GetSingleWordInOperand(1 - i),
1452                                add_op1_inst->GetSingleWordInOperand(j),
1453                                add_op1_inst->GetSingleWordInOperand(1 - j),
1454                                inst))
1455           return true;
1456       }
1457     }
1458     return false;
1459   };
1460 }
1461 
IntMultipleBy1()1462 FoldingRule IntMultipleBy1() {
1463   return [](IRContext*, Instruction* inst,
1464             const std::vector<const analysis::Constant*>& constants) {
1465     assert(inst->opcode() == spv::Op::OpIMul &&
1466            "Wrong opcode.  Should be OpIMul.");
1467     for (uint32_t i = 0; i < 2; i++) {
1468       if (constants[i] == nullptr) {
1469         continue;
1470       }
1471       const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1472       if (int_constant) {
1473         uint32_t width = ElementWidth(int_constant->type());
1474         if (width != 32 && width != 64) return false;
1475         bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1476                                     : int_constant->GetU64BitValue() == 1ull;
1477         if (is_one) {
1478           inst->SetOpcode(spv::Op::OpCopyObject);
1479           inst->SetInOperands(
1480               {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1481           return true;
1482         }
1483       }
1484     }
1485     return false;
1486   };
1487 }
1488 
1489 // Returns the number of elements that the |index|th in operand in |inst|
1490 // contributes to the result of |inst|.  |inst| must be an
1491 // OpCompositeConstructInstruction.
GetNumOfElementsContributedByOperand(IRContext * context,const Instruction * inst,uint32_t index)1492 uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
1493                                               const Instruction* inst,
1494                                               uint32_t index) {
1495   assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1496   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1497   analysis::TypeManager* type_mgr = context->get_type_mgr();
1498 
1499   analysis::Vector* result_type =
1500       type_mgr->GetType(inst->type_id())->AsVector();
1501   if (result_type == nullptr) {
1502     // If the result of the OpCompositeConstruct is not a vector then every
1503     // operands corresponds to a single element in the result.
1504     return 1;
1505   }
1506 
1507   // If the result type is a vector then the operands are either scalars or
1508   // vectors. If it is a scalar, then it corresponds to a single element.  If it
1509   // is a vector, then each element in the vector will be an element in the
1510   // result.
1511   uint32_t id = inst->GetSingleWordInOperand(index);
1512   Instruction* def = def_use_mgr->GetDef(id);
1513   analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
1514   if (type == nullptr) {
1515     return 1;
1516   }
1517   return type->element_count();
1518 }
1519 
1520 // Returns the in-operands for an OpCompositeExtract instruction that are needed
1521 // to extract the |result_index|th element in the result of |inst| without using
1522 // the result of |inst|. Returns the empty vector if |result_index| is
1523 // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
GetExtractOperandsForElementOfCompositeConstruct(IRContext * context,const Instruction * inst,uint32_t result_index)1524 std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
1525     IRContext* context, const Instruction* inst, uint32_t result_index) {
1526   assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1527   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1528   analysis::TypeManager* type_mgr = context->get_type_mgr();
1529 
1530   analysis::Type* result_type = type_mgr->GetType(inst->type_id());
1531   if (result_type->AsVector() == nullptr) {
1532     if (result_index < inst->NumInOperands()) {
1533       uint32_t id = inst->GetSingleWordInOperand(result_index);
1534       return {Operand(SPV_OPERAND_TYPE_ID, {id})};
1535     }
1536     return {};
1537   }
1538 
1539   // If the result type is a vector, then vector operands are concatenated.
1540   uint32_t total_element_count = 0;
1541   for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
1542     uint32_t element_count =
1543         GetNumOfElementsContributedByOperand(context, inst, idx);
1544     total_element_count += element_count;
1545     if (result_index < total_element_count) {
1546       std::vector<Operand> operands;
1547       uint32_t id = inst->GetSingleWordInOperand(idx);
1548       Instruction* operand_def = def_use_mgr->GetDef(id);
1549       analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
1550 
1551       operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1552       if (operand_type->AsVector()) {
1553         uint32_t start_index_of_id = total_element_count - element_count;
1554         uint32_t index_into_id = result_index - start_index_of_id;
1555         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
1556       }
1557       return operands;
1558     }
1559   }
1560   return {};
1561 }
1562 
CompositeConstructFeedingExtract(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1563 bool CompositeConstructFeedingExtract(
1564     IRContext* context, Instruction* inst,
1565     const std::vector<const analysis::Constant*>&) {
1566   // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1567   // then we can simply use the appropriate element in the construction.
1568   assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1569          "Wrong opcode.  Should be OpCompositeExtract.");
1570   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1571 
1572   // If there are no index operands, then this rule cannot do anything.
1573   if (inst->NumInOperands() <= 1) {
1574     return false;
1575   }
1576 
1577   uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1578   Instruction* cinst = def_use_mgr->GetDef(cid);
1579 
1580   if (cinst->opcode() != spv::Op::OpCompositeConstruct) {
1581     return false;
1582   }
1583 
1584   uint32_t index_into_result = inst->GetSingleWordInOperand(1);
1585   std::vector<Operand> operands =
1586       GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
1587                                                        index_into_result);
1588 
1589   if (operands.empty()) {
1590     return false;
1591   }
1592 
1593   // Add the remaining indices for extraction.
1594   for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1595     operands.push_back(
1596         {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
1597   }
1598 
1599   if (operands.size() == 1) {
1600     // If there were no extra indices, then we have the final object.  No need
1601     // to extract any more.
1602     inst->SetOpcode(spv::Op::OpCopyObject);
1603   }
1604 
1605   inst->SetInOperands(std::move(operands));
1606   return true;
1607 }
1608 
1609 // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
1610 // OpCompositeExtract instruction, and returns the type id of the final element
1611 // being accessed. Returns 0 if a valid type could not be found.
GetElementType(uint32_t type_id,Instruction::iterator start,Instruction::iterator end,const analysis::DefUseManager * def_use_manager)1612 uint32_t GetElementType(uint32_t type_id, Instruction::iterator start,
1613                         Instruction::iterator end,
1614                         const analysis::DefUseManager* def_use_manager) {
1615   for (auto index : make_range(std::move(start), std::move(end))) {
1616     const Instruction* type_inst = def_use_manager->GetDef(type_id);
1617     assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
1618            index.words.size() == 1);
1619     if (type_inst->opcode() == spv::Op::OpTypeArray) {
1620       type_id = type_inst->GetSingleWordInOperand(0);
1621     } else if (type_inst->opcode() == spv::Op::OpTypeMatrix) {
1622       type_id = type_inst->GetSingleWordInOperand(0);
1623     } else if (type_inst->opcode() == spv::Op::OpTypeStruct) {
1624       type_id = type_inst->GetSingleWordInOperand(index.words[0]);
1625     } else {
1626       return 0;
1627     }
1628   }
1629   return type_id;
1630 }
1631 
1632 // Returns true of |inst_1| and |inst_2| have the same indexes that will be used
1633 // to index into a composite object, excluding the last index.  The two
1634 // instructions must have the same opcode, and be either OpCompositeExtract or
1635 // OpCompositeInsert instructions.
HaveSameIndexesExceptForLast(Instruction * inst_1,Instruction * inst_2)1636 bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
1637   assert(inst_1->opcode() == inst_2->opcode() &&
1638          "Expecting the opcodes to be the same.");
1639   assert((inst_1->opcode() == spv::Op::OpCompositeInsert ||
1640           inst_1->opcode() == spv::Op::OpCompositeExtract) &&
1641          "Instructions must be OpCompositeInsert or OpCompositeExtract.");
1642 
1643   if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
1644     return false;
1645   }
1646 
1647   uint32_t first_index_position =
1648       (inst_1->opcode() == spv::Op::OpCompositeInsert ? 2 : 1);
1649   for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
1650        i++) {
1651     if (inst_1->GetSingleWordInOperand(i) !=
1652         inst_2->GetSingleWordInOperand(i)) {
1653       return false;
1654     }
1655   }
1656   return true;
1657 }
1658 
1659 // If the OpCompositeConstruct is simply putting back together elements that
1660 // where extracted from the same source, we can simply reuse the source.
1661 //
1662 // This is a common code pattern because of the way that scalar replacement
1663 // works.
CompositeExtractFeedingConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1664 bool CompositeExtractFeedingConstruct(
1665     IRContext* context, Instruction* inst,
1666     const std::vector<const analysis::Constant*>&) {
1667   assert(inst->opcode() == spv::Op::OpCompositeConstruct &&
1668          "Wrong opcode.  Should be OpCompositeConstruct.");
1669   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1670   uint32_t original_id = 0;
1671 
1672   if (inst->NumInOperands() == 0) {
1673     // The struct being constructed has no members.
1674     return false;
1675   }
1676 
1677   // Check each element to make sure they are:
1678   // - extractions
1679   // - extracting the same position they are inserting
1680   // - all extract from the same id.
1681   Instruction* first_element_inst = nullptr;
1682   for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1683     const uint32_t element_id = inst->GetSingleWordInOperand(i);
1684     Instruction* element_inst = def_use_mgr->GetDef(element_id);
1685     if (first_element_inst == nullptr) {
1686       first_element_inst = element_inst;
1687     }
1688 
1689     if (element_inst->opcode() != spv::Op::OpCompositeExtract) {
1690       return false;
1691     }
1692 
1693     if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
1694       return false;
1695     }
1696 
1697     if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
1698                                              1) != i) {
1699       return false;
1700     }
1701 
1702     if (i == 0) {
1703       original_id =
1704           element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1705     } else if (original_id !=
1706                element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1707       return false;
1708     }
1709   }
1710 
1711   // The last check it to see that the object being extracted from is the
1712   // correct type.
1713   Instruction* original_inst = def_use_mgr->GetDef(original_id);
1714   uint32_t original_type_id =
1715       GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
1716                      first_element_inst->end() - 1, def_use_mgr);
1717 
1718   if (inst->type_id() != original_type_id) {
1719     return false;
1720   }
1721 
1722   if (first_element_inst->NumInOperands() == 2) {
1723     // Simplify by using the original object.
1724     inst->SetOpcode(spv::Op::OpCopyObject);
1725     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1726     return true;
1727   }
1728 
1729   // Copies the original id and all indexes except for the last to the new
1730   // extract instruction.
1731   inst->SetOpcode(spv::Op::OpCompositeExtract);
1732   inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
1733                                            first_element_inst->end() - 1));
1734   return true;
1735 }
1736 
InsertFeedingExtract()1737 FoldingRule InsertFeedingExtract() {
1738   return [](IRContext* context, Instruction* inst,
1739             const std::vector<const analysis::Constant*>&) {
1740     assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1741            "Wrong opcode.  Should be OpCompositeExtract.");
1742     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1743     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1744     Instruction* cinst = def_use_mgr->GetDef(cid);
1745 
1746     if (cinst->opcode() != spv::Op::OpCompositeInsert) {
1747       return false;
1748     }
1749 
1750     // Find the first position where the list of insert and extract indicies
1751     // differ, if at all.
1752     uint32_t i;
1753     for (i = 1; i < inst->NumInOperands(); ++i) {
1754       if (i + 1 >= cinst->NumInOperands()) {
1755         break;
1756       }
1757 
1758       if (inst->GetSingleWordInOperand(i) !=
1759           cinst->GetSingleWordInOperand(i + 1)) {
1760         break;
1761       }
1762     }
1763 
1764     // We are extracting the element that was inserted.
1765     if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1766       inst->SetOpcode(spv::Op::OpCopyObject);
1767       inst->SetInOperands(
1768           {{SPV_OPERAND_TYPE_ID,
1769             {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1770       return true;
1771     }
1772 
1773     // Extracting the value that was inserted along with values for the base
1774     // composite.  Cannot do anything.
1775     if (i == inst->NumInOperands()) {
1776       return false;
1777     }
1778 
1779     // Extracting an element of the value that was inserted.  Extract from
1780     // that value directly.
1781     if (i + 1 == cinst->NumInOperands()) {
1782       std::vector<Operand> operands;
1783       operands.push_back(
1784           {SPV_OPERAND_TYPE_ID,
1785            {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1786       for (; i < inst->NumInOperands(); ++i) {
1787         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1788                             {inst->GetSingleWordInOperand(i)}});
1789       }
1790       inst->SetInOperands(std::move(operands));
1791       return true;
1792     }
1793 
1794     // Extracting a value that is disjoint from the element being inserted.
1795     // Rewrite the extract to use the composite input to the insert.
1796     std::vector<Operand> operands;
1797     operands.push_back(
1798         {SPV_OPERAND_TYPE_ID,
1799          {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1800     for (i = 1; i < inst->NumInOperands(); ++i) {
1801       operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1802                           {inst->GetSingleWordInOperand(i)}});
1803     }
1804     inst->SetInOperands(std::move(operands));
1805     return true;
1806   };
1807 }
1808 
1809 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1810 // operands of the VectorShuffle.  We just need to adjust the index in the
1811 // extract instruction.
VectorShuffleFeedingExtract()1812 FoldingRule VectorShuffleFeedingExtract() {
1813   return [](IRContext* context, Instruction* inst,
1814             const std::vector<const analysis::Constant*>&) {
1815     assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1816            "Wrong opcode.  Should be OpCompositeExtract.");
1817     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1818     analysis::TypeManager* type_mgr = context->get_type_mgr();
1819     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1820     Instruction* cinst = def_use_mgr->GetDef(cid);
1821 
1822     if (cinst->opcode() != spv::Op::OpVectorShuffle) {
1823       return false;
1824     }
1825 
1826     // Find the size of the first vector operand of the VectorShuffle
1827     Instruction* first_input =
1828         def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1829     analysis::Type* first_input_type =
1830         type_mgr->GetType(first_input->type_id());
1831     assert(first_input_type->AsVector() &&
1832            "Input to vector shuffle should be vectors.");
1833     uint32_t first_input_size = first_input_type->AsVector()->element_count();
1834 
1835     // Get index of the element the vector shuffle is placing in the position
1836     // being extracted.
1837     uint32_t new_index =
1838         cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1839 
1840     // Extracting an undefined value so fold this extract into an undef.
1841     const uint32_t undef_literal_value = 0xffffffff;
1842     if (new_index == undef_literal_value) {
1843       inst->SetOpcode(spv::Op::OpUndef);
1844       inst->SetInOperands({});
1845       return true;
1846     }
1847 
1848     // Get the id of the of the vector the elemtent comes from, and update the
1849     // index if needed.
1850     uint32_t new_vector = 0;
1851     if (new_index < first_input_size) {
1852       new_vector = cinst->GetSingleWordInOperand(0);
1853     } else {
1854       new_vector = cinst->GetSingleWordInOperand(1);
1855       new_index -= first_input_size;
1856     }
1857 
1858     // Update the extract instruction.
1859     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1860     inst->SetInOperand(1, {new_index});
1861     return true;
1862   };
1863 }
1864 
1865 // When an FMix with is feeding an Extract that extracts an element whose
1866 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1867 // operands of the FMix.
FMixFeedingExtract()1868 FoldingRule FMixFeedingExtract() {
1869   return [](IRContext* context, Instruction* inst,
1870             const std::vector<const analysis::Constant*>&) {
1871     assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1872            "Wrong opcode.  Should be OpCompositeExtract.");
1873     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1874     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1875 
1876     uint32_t composite_id =
1877         inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1878     Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1879 
1880     if (composite_inst->opcode() != spv::Op::OpExtInst) {
1881       return false;
1882     }
1883 
1884     uint32_t inst_set_id =
1885         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1886 
1887     if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1888             inst_set_id ||
1889         composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1890             GLSLstd450FMix) {
1891       return false;
1892     }
1893 
1894     // Get the |a| for the FMix instruction.
1895     uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1896     std::unique_ptr<Instruction> a(inst->Clone(context));
1897     a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1898     context->get_instruction_folder().FoldInstruction(a.get());
1899 
1900     if (a->opcode() != spv::Op::OpCopyObject) {
1901       return false;
1902     }
1903 
1904     const analysis::Constant* a_const =
1905         const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1906 
1907     if (!a_const) {
1908       return false;
1909     }
1910 
1911     bool use_x = false;
1912 
1913     assert(a_const->type()->AsFloat());
1914     double element_value = a_const->GetValueAsDouble();
1915     if (element_value == 0.0) {
1916       use_x = true;
1917     } else if (element_value == 1.0) {
1918       use_x = false;
1919     } else {
1920       return false;
1921     }
1922 
1923     // Get the id of the of the vector the element comes from.
1924     uint32_t new_vector = 0;
1925     if (use_x) {
1926       new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1927     } else {
1928       new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1929     }
1930 
1931     // Update the extract instruction.
1932     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1933     return true;
1934   };
1935 }
1936 
1937 // Returns the number of elements in the composite type |type|.  Returns 0 if
1938 // |type| is a scalar value. Return UINT32_MAX when the size is unknown at
1939 // compile time.
GetNumberOfElements(const analysis::Type * type)1940 uint32_t GetNumberOfElements(const analysis::Type* type) {
1941   if (auto* vector_type = type->AsVector()) {
1942     return vector_type->element_count();
1943   }
1944   if (auto* matrix_type = type->AsMatrix()) {
1945     return matrix_type->element_count();
1946   }
1947   if (auto* struct_type = type->AsStruct()) {
1948     return static_cast<uint32_t>(struct_type->element_types().size());
1949   }
1950   if (auto* array_type = type->AsArray()) {
1951     if (array_type->length_info().words[0] ==
1952             analysis::Array::LengthInfo::kConstant &&
1953         array_type->length_info().words.size() == 2) {
1954       return array_type->length_info().words[1];
1955     }
1956     return UINT32_MAX;
1957   }
1958   return 0;
1959 }
1960 
1961 // Returns a map with the set of values that were inserted into an object by
1962 // the chain of OpCompositeInsertInstruction starting with |inst|.
1963 // The map will map the index to the value inserted at that index. An empty map
1964 // will be returned if the map could not be properly generated.
GetInsertedValues(Instruction * inst)1965 std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
1966   analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
1967   std::map<uint32_t, uint32_t> values_inserted;
1968   Instruction* current_inst = inst;
1969   while (current_inst->opcode() == spv::Op::OpCompositeInsert) {
1970     if (current_inst->NumInOperands() > inst->NumInOperands()) {
1971       // This is to catch the case
1972       //   %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
1973       //   %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
1974       //   %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
1975       // In this case we cannot do a single construct to get the matrix.
1976       uint32_t partially_inserted_element_index =
1977           current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
1978       if (values_inserted.count(partially_inserted_element_index) == 0)
1979         return {};
1980     }
1981     if (HaveSameIndexesExceptForLast(inst, current_inst)) {
1982       values_inserted.insert(
1983           {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
1984                                                 1),
1985            current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
1986     }
1987     current_inst = def_use_mgr->GetDef(
1988         current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
1989   }
1990   return values_inserted;
1991 }
1992 
1993 // Returns true of there is an entry in |values_inserted| for every element of
1994 // |Type|.
DoInsertedValuesCoverEntireObject(const analysis::Type * type,std::map<uint32_t,uint32_t> & values_inserted)1995 bool DoInsertedValuesCoverEntireObject(
1996     const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
1997   uint32_t container_size = GetNumberOfElements(type);
1998   if (container_size != values_inserted.size()) {
1999     return false;
2000   }
2001 
2002   if (values_inserted.rbegin()->first >= container_size) {
2003     return false;
2004   }
2005   return true;
2006 }
2007 
2008 // Returns id of the type of the element that immediately contains the element
2009 // being inserted by the OpCompositeInsert instruction |inst|. Returns 0 if it
2010 // could not be found.
GetContainerTypeId(Instruction * inst)2011 uint32_t GetContainerTypeId(Instruction* inst) {
2012   assert(inst->opcode() == spv::Op::OpCompositeInsert);
2013   analysis::DefUseManager* def_use_manager = inst->context()->get_def_use_mgr();
2014   uint32_t container_type_id = GetElementType(
2015       inst->type_id(), inst->begin() + 4, inst->end() - 1, def_use_manager);
2016   return container_type_id;
2017 }
2018 
2019 // Returns an OpCompositeConstruct instruction that build an object with
2020 // |type_id| out of the values in |values_inserted|.  Each value will be
2021 // placed at the index corresponding to the value.  The new instruction will
2022 // be placed before |insert_before|.
BuildCompositeConstruct(uint32_t type_id,const std::map<uint32_t,uint32_t> & values_inserted,Instruction * insert_before)2023 Instruction* BuildCompositeConstruct(
2024     uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
2025     Instruction* insert_before) {
2026   InstructionBuilder ir_builder(
2027       insert_before->context(), insert_before,
2028       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2029 
2030   std::vector<uint32_t> ids_in_order;
2031   for (auto it : values_inserted) {
2032     ids_in_order.push_back(it.second);
2033   }
2034   Instruction* construct =
2035       ir_builder.AddCompositeConstruct(type_id, ids_in_order);
2036   return construct;
2037 }
2038 
2039 // Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
2040 // object as |inst| with final index removed.  If the resulting
2041 // OpCompositeInsert instruction would have no remaining indexes, the
2042 // instruction is replaced with an OpCopyObject instead.
InsertConstructedObject(Instruction * inst,const Instruction * construct)2043 void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
2044   if (inst->NumInOperands() == 3) {
2045     inst->SetOpcode(spv::Op::OpCopyObject);
2046     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
2047   } else {
2048     inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
2049     inst->RemoveOperand(inst->NumOperands() - 1);
2050   }
2051 }
2052 
2053 // Replaces a series of |OpCompositeInsert| instruction that cover the entire
2054 // object with an |OpCompositeConstruct|.
CompositeInsertToCompositeConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)2055 bool CompositeInsertToCompositeConstruct(
2056     IRContext* context, Instruction* inst,
2057     const std::vector<const analysis::Constant*>&) {
2058   assert(inst->opcode() == spv::Op::OpCompositeInsert &&
2059          "Wrong opcode.  Should be OpCompositeInsert.");
2060   if (inst->NumInOperands() < 3) return false;
2061 
2062   std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
2063   uint32_t container_type_id = GetContainerTypeId(inst);
2064   if (container_type_id == 0) {
2065     return false;
2066   }
2067 
2068   analysis::TypeManager* type_mgr = context->get_type_mgr();
2069   const analysis::Type* container_type = type_mgr->GetType(container_type_id);
2070   assert(container_type && "GetContainerTypeId returned a bad id.");
2071   if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
2072     return false;
2073   }
2074 
2075   Instruction* construct =
2076       BuildCompositeConstruct(container_type_id, values_inserted, inst);
2077   InsertConstructedObject(inst, construct);
2078   return true;
2079 }
2080 
RedundantPhi()2081 FoldingRule RedundantPhi() {
2082   // An OpPhi instruction where all values are the same or the result of the phi
2083   // itself, can be replaced by the value itself.
2084   return [](IRContext*, Instruction* inst,
2085             const std::vector<const analysis::Constant*>&) {
2086     assert(inst->opcode() == spv::Op::OpPhi &&
2087            "Wrong opcode.  Should be OpPhi.");
2088 
2089     uint32_t incoming_value = 0;
2090 
2091     for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
2092       uint32_t op_id = inst->GetSingleWordInOperand(i);
2093       if (op_id == inst->result_id()) {
2094         continue;
2095       }
2096 
2097       if (incoming_value == 0) {
2098         incoming_value = op_id;
2099       } else if (op_id != incoming_value) {
2100         // Found two possible value.  Can't simplify.
2101         return false;
2102       }
2103     }
2104 
2105     if (incoming_value == 0) {
2106       // Code looks invalid.  Don't do anything.
2107       return false;
2108     }
2109 
2110     // We have a single incoming value.  Simplify using that value.
2111     inst->SetOpcode(spv::Op::OpCopyObject);
2112     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
2113     return true;
2114   };
2115 }
2116 
BitCastScalarOrVector()2117 FoldingRule BitCastScalarOrVector() {
2118   return [](IRContext* context, Instruction* inst,
2119             const std::vector<const analysis::Constant*>& constants) {
2120     assert(inst->opcode() == spv::Op::OpBitcast && constants.size() == 1);
2121     if (constants[0] == nullptr) return false;
2122 
2123     const analysis::Type* type =
2124         context->get_type_mgr()->GetType(inst->type_id());
2125     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
2126       return false;
2127 
2128     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2129     std::vector<uint32_t> words =
2130         GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
2131     if (words.size() == 0) return false;
2132 
2133     const analysis::Constant* bitcasted_constant =
2134         ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
2135     if (!bitcasted_constant) return false;
2136 
2137     auto new_feeder_id =
2138         const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
2139             ->result_id();
2140     inst->SetOpcode(spv::Op::OpCopyObject);
2141     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2142     return true;
2143   };
2144 }
2145 
RedundantSelect()2146 FoldingRule RedundantSelect() {
2147   // An OpSelect instruction where both values are the same or the condition is
2148   // constant can be replaced by one of the values
2149   return [](IRContext*, Instruction* inst,
2150             const std::vector<const analysis::Constant*>& constants) {
2151     assert(inst->opcode() == spv::Op::OpSelect &&
2152            "Wrong opcode.  Should be OpSelect.");
2153     assert(inst->NumInOperands() == 3);
2154     assert(constants.size() == 3);
2155 
2156     uint32_t true_id = inst->GetSingleWordInOperand(1);
2157     uint32_t false_id = inst->GetSingleWordInOperand(2);
2158 
2159     if (true_id == false_id) {
2160       // Both results are the same, condition doesn't matter
2161       inst->SetOpcode(spv::Op::OpCopyObject);
2162       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2163       return true;
2164     } else if (constants[0]) {
2165       const analysis::Type* type = constants[0]->type();
2166       if (type->AsBool()) {
2167         // Scalar constant value, select the corresponding value.
2168         inst->SetOpcode(spv::Op::OpCopyObject);
2169         if (constants[0]->AsNullConstant() ||
2170             !constants[0]->AsBoolConstant()->value()) {
2171           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2172         } else {
2173           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2174         }
2175         return true;
2176       } else {
2177         assert(type->AsVector());
2178         if (constants[0]->AsNullConstant()) {
2179           // All values come from false id.
2180           inst->SetOpcode(spv::Op::OpCopyObject);
2181           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2182           return true;
2183         } else {
2184           // Convert to a vector shuffle.
2185           std::vector<Operand> ops;
2186           ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
2187           ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
2188           const analysis::VectorConstant* vector_const =
2189               constants[0]->AsVectorConstant();
2190           uint32_t size =
2191               static_cast<uint32_t>(vector_const->GetComponents().size());
2192           for (uint32_t i = 0; i != size; ++i) {
2193             const analysis::Constant* component =
2194                 vector_const->GetComponents()[i];
2195             if (component->AsNullConstant() ||
2196                 !component->AsBoolConstant()->value()) {
2197               // Selecting from the false vector which is the second input
2198               // vector to the shuffle. Offset the index by |size|.
2199               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
2200             } else {
2201               // Selecting from true vector which is the first input vector to
2202               // the shuffle.
2203               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
2204             }
2205           }
2206 
2207           inst->SetOpcode(spv::Op::OpVectorShuffle);
2208           inst->SetInOperands(std::move(ops));
2209           return true;
2210         }
2211       }
2212     }
2213 
2214     return false;
2215   };
2216 }
2217 
2218 enum class FloatConstantKind { Unknown, Zero, One };
2219 
getFloatConstantKind(const analysis::Constant * constant)2220 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
2221   if (constant == nullptr) {
2222     return FloatConstantKind::Unknown;
2223   }
2224 
2225   assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
2226 
2227   if (constant->AsNullConstant()) {
2228     return FloatConstantKind::Zero;
2229   } else if (const analysis::VectorConstant* vc =
2230                  constant->AsVectorConstant()) {
2231     const std::vector<const analysis::Constant*>& components =
2232         vc->GetComponents();
2233     assert(!components.empty());
2234 
2235     FloatConstantKind kind = getFloatConstantKind(components[0]);
2236 
2237     for (size_t i = 1; i < components.size(); ++i) {
2238       if (getFloatConstantKind(components[i]) != kind) {
2239         return FloatConstantKind::Unknown;
2240       }
2241     }
2242 
2243     return kind;
2244   } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
2245     if (fc->IsZero()) return FloatConstantKind::Zero;
2246 
2247     uint32_t width = fc->type()->AsFloat()->width();
2248     if (width != 32 && width != 64) return FloatConstantKind::Unknown;
2249 
2250     double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
2251 
2252     if (value == 0.0) {
2253       return FloatConstantKind::Zero;
2254     } else if (value == 1.0) {
2255       return FloatConstantKind::One;
2256     } else {
2257       return FloatConstantKind::Unknown;
2258     }
2259   } else {
2260     return FloatConstantKind::Unknown;
2261   }
2262 }
2263 
RedundantFAdd()2264 FoldingRule RedundantFAdd() {
2265   return [](IRContext*, Instruction* inst,
2266             const std::vector<const analysis::Constant*>& constants) {
2267     assert(inst->opcode() == spv::Op::OpFAdd &&
2268            "Wrong opcode.  Should be OpFAdd.");
2269     assert(constants.size() == 2);
2270 
2271     if (!inst->IsFloatingPointFoldingAllowed()) {
2272       return false;
2273     }
2274 
2275     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2276     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2277 
2278     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2279       inst->SetOpcode(spv::Op::OpCopyObject);
2280       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2281                             {inst->GetSingleWordInOperand(
2282                                 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
2283       return true;
2284     }
2285 
2286     return false;
2287   };
2288 }
2289 
RedundantFSub()2290 FoldingRule RedundantFSub() {
2291   return [](IRContext*, Instruction* inst,
2292             const std::vector<const analysis::Constant*>& constants) {
2293     assert(inst->opcode() == spv::Op::OpFSub &&
2294            "Wrong opcode.  Should be OpFSub.");
2295     assert(constants.size() == 2);
2296 
2297     if (!inst->IsFloatingPointFoldingAllowed()) {
2298       return false;
2299     }
2300 
2301     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2302     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2303 
2304     if (kind0 == FloatConstantKind::Zero) {
2305       inst->SetOpcode(spv::Op::OpFNegate);
2306       inst->SetInOperands(
2307           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
2308       return true;
2309     }
2310 
2311     if (kind1 == FloatConstantKind::Zero) {
2312       inst->SetOpcode(spv::Op::OpCopyObject);
2313       inst->SetInOperands(
2314           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2315       return true;
2316     }
2317 
2318     return false;
2319   };
2320 }
2321 
RedundantFMul()2322 FoldingRule RedundantFMul() {
2323   return [](IRContext*, Instruction* inst,
2324             const std::vector<const analysis::Constant*>& constants) {
2325     assert(inst->opcode() == spv::Op::OpFMul &&
2326            "Wrong opcode.  Should be OpFMul.");
2327     assert(constants.size() == 2);
2328 
2329     if (!inst->IsFloatingPointFoldingAllowed()) {
2330       return false;
2331     }
2332 
2333     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2334     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2335 
2336     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2337       inst->SetOpcode(spv::Op::OpCopyObject);
2338       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2339                             {inst->GetSingleWordInOperand(
2340                                 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
2341       return true;
2342     }
2343 
2344     if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
2345       inst->SetOpcode(spv::Op::OpCopyObject);
2346       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2347                             {inst->GetSingleWordInOperand(
2348                                 kind0 == FloatConstantKind::One ? 1 : 0)}}});
2349       return true;
2350     }
2351 
2352     return false;
2353   };
2354 }
2355 
RedundantFDiv()2356 FoldingRule RedundantFDiv() {
2357   return [](IRContext*, Instruction* inst,
2358             const std::vector<const analysis::Constant*>& constants) {
2359     assert(inst->opcode() == spv::Op::OpFDiv &&
2360            "Wrong opcode.  Should be OpFDiv.");
2361     assert(constants.size() == 2);
2362 
2363     if (!inst->IsFloatingPointFoldingAllowed()) {
2364       return false;
2365     }
2366 
2367     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2368     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2369 
2370     if (kind0 == FloatConstantKind::Zero) {
2371       inst->SetOpcode(spv::Op::OpCopyObject);
2372       inst->SetInOperands(
2373           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2374       return true;
2375     }
2376 
2377     if (kind1 == FloatConstantKind::One) {
2378       inst->SetOpcode(spv::Op::OpCopyObject);
2379       inst->SetInOperands(
2380           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2381       return true;
2382     }
2383 
2384     return false;
2385   };
2386 }
2387 
RedundantFMix()2388 FoldingRule RedundantFMix() {
2389   return [](IRContext* context, Instruction* inst,
2390             const std::vector<const analysis::Constant*>& constants) {
2391     assert(inst->opcode() == spv::Op::OpExtInst &&
2392            "Wrong opcode.  Should be OpExtInst.");
2393 
2394     if (!inst->IsFloatingPointFoldingAllowed()) {
2395       return false;
2396     }
2397 
2398     uint32_t instSetId =
2399         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2400 
2401     if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2402         inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2403             GLSLstd450FMix) {
2404       assert(constants.size() == 5);
2405 
2406       FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2407 
2408       if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2409         inst->SetOpcode(spv::Op::OpCopyObject);
2410         inst->SetInOperands(
2411             {{SPV_OPERAND_TYPE_ID,
2412               {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2413                                                 ? kFMixXIdInIdx
2414                                                 : kFMixYIdInIdx)}}});
2415         return true;
2416       }
2417     }
2418 
2419     return false;
2420   };
2421 }
2422 
2423 // This rule handles addition of zero for integers.
RedundantIAdd()2424 FoldingRule RedundantIAdd() {
2425   return [](IRContext* context, Instruction* inst,
2426             const std::vector<const analysis::Constant*>& constants) {
2427     assert(inst->opcode() == spv::Op::OpIAdd &&
2428            "Wrong opcode. Should be OpIAdd.");
2429 
2430     uint32_t operand = std::numeric_limits<uint32_t>::max();
2431     const analysis::Type* operand_type = nullptr;
2432     if (constants[0] && constants[0]->IsZero()) {
2433       operand = inst->GetSingleWordInOperand(1);
2434       operand_type = constants[0]->type();
2435     } else if (constants[1] && constants[1]->IsZero()) {
2436       operand = inst->GetSingleWordInOperand(0);
2437       operand_type = constants[1]->type();
2438     }
2439 
2440     if (operand != std::numeric_limits<uint32_t>::max()) {
2441       const analysis::Type* inst_type =
2442           context->get_type_mgr()->GetType(inst->type_id());
2443       if (inst_type->IsSame(operand_type)) {
2444         inst->SetOpcode(spv::Op::OpCopyObject);
2445       } else {
2446         inst->SetOpcode(spv::Op::OpBitcast);
2447       }
2448       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2449       return true;
2450     }
2451     return false;
2452   };
2453 }
2454 
2455 // This rule look for a dot with a constant vector containing a single 1 and
2456 // the rest 0s.  This is the same as doing an extract.
DotProductDoingExtract()2457 FoldingRule DotProductDoingExtract() {
2458   return [](IRContext* context, Instruction* inst,
2459             const std::vector<const analysis::Constant*>& constants) {
2460     assert(inst->opcode() == spv::Op::OpDot &&
2461            "Wrong opcode.  Should be OpDot.");
2462 
2463     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2464 
2465     if (!inst->IsFloatingPointFoldingAllowed()) {
2466       return false;
2467     }
2468 
2469     for (int i = 0; i < 2; ++i) {
2470       if (!constants[i]) {
2471         continue;
2472       }
2473 
2474       const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2475       assert(vector_type && "Inputs to OpDot must be vectors.");
2476       const analysis::Float* element_type =
2477           vector_type->element_type()->AsFloat();
2478       assert(element_type && "Inputs to OpDot must be vectors of floats.");
2479       uint32_t element_width = element_type->width();
2480       if (element_width != 32 && element_width != 64) {
2481         return false;
2482       }
2483 
2484       std::vector<const analysis::Constant*> components;
2485       components = constants[i]->GetVectorComponents(const_mgr);
2486 
2487       constexpr uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2488 
2489       uint32_t component_with_one = kNotFound;
2490       bool all_others_zero = true;
2491       for (uint32_t j = 0; j < components.size(); ++j) {
2492         const analysis::Constant* element = components[j];
2493         double value =
2494             (element_width == 32 ? element->GetFloat() : element->GetDouble());
2495         if (value == 0.0) {
2496           continue;
2497         } else if (value == 1.0) {
2498           if (component_with_one == kNotFound) {
2499             component_with_one = j;
2500           } else {
2501             component_with_one = kNotFound;
2502             break;
2503           }
2504         } else {
2505           all_others_zero = false;
2506           break;
2507         }
2508       }
2509 
2510       if (!all_others_zero || component_with_one == kNotFound) {
2511         continue;
2512       }
2513 
2514       std::vector<Operand> operands;
2515       operands.push_back(
2516           {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2517       operands.push_back(
2518           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2519 
2520       inst->SetOpcode(spv::Op::OpCompositeExtract);
2521       inst->SetInOperands(std::move(operands));
2522       return true;
2523     }
2524     return false;
2525   };
2526 }
2527 
2528 // If we are storing an undef, then we can remove the store.
2529 //
2530 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2531 // is complicated.  Waiting to see if it is needed.
StoringUndef()2532 FoldingRule StoringUndef() {
2533   return [](IRContext* context, Instruction* inst,
2534             const std::vector<const analysis::Constant*>&) {
2535     assert(inst->opcode() == spv::Op::OpStore &&
2536            "Wrong opcode.  Should be OpStore.");
2537 
2538     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2539 
2540     // If this is a volatile store, the store cannot be removed.
2541     if (inst->NumInOperands() == 3) {
2542       if (inst->GetSingleWordInOperand(2) &
2543           uint32_t(spv::MemoryAccessMask::Volatile)) {
2544         return false;
2545       }
2546     }
2547 
2548     uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2549     Instruction* object_inst = def_use_mgr->GetDef(object_id);
2550     if (object_inst->opcode() == spv::Op::OpUndef) {
2551       inst->ToNop();
2552       return true;
2553     }
2554     return false;
2555   };
2556 }
2557 
VectorShuffleFeedingShuffle()2558 FoldingRule VectorShuffleFeedingShuffle() {
2559   return [](IRContext* context, Instruction* inst,
2560             const std::vector<const analysis::Constant*>&) {
2561     assert(inst->opcode() == spv::Op::OpVectorShuffle &&
2562            "Wrong opcode.  Should be OpVectorShuffle.");
2563 
2564     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2565     analysis::TypeManager* type_mgr = context->get_type_mgr();
2566 
2567     Instruction* feeding_shuffle_inst =
2568         def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2569     analysis::Vector* op0_type =
2570         type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2571     uint32_t op0_length = op0_type->element_count();
2572 
2573     bool feeder_is_op0 = true;
2574     if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2575       feeding_shuffle_inst =
2576           def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2577       feeder_is_op0 = false;
2578     }
2579 
2580     if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2581       return false;
2582     }
2583 
2584     Instruction* feeder2 =
2585         def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2586     analysis::Vector* feeder_op0_type =
2587         type_mgr->GetType(feeder2->type_id())->AsVector();
2588     uint32_t feeder_op0_length = feeder_op0_type->element_count();
2589 
2590     uint32_t new_feeder_id = 0;
2591     std::vector<Operand> new_operands;
2592     new_operands.resize(
2593         2, {SPV_OPERAND_TYPE_ID, {0}});  // Place holders for vector operands.
2594     const uint32_t undef_literal = 0xffffffff;
2595     for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2596       uint32_t component_index = inst->GetSingleWordInOperand(op);
2597 
2598       // Do not interpret the undefined value literal as coming from operand 1.
2599       if (component_index != undef_literal &&
2600           feeder_is_op0 == (component_index < op0_length)) {
2601         // This component comes from the feeding_shuffle_inst.  Update
2602         // |component_index| to be the index into the operand of the feeder.
2603 
2604         // Adjust component_index to get the index into the operands of the
2605         // feeding_shuffle_inst.
2606         if (component_index >= op0_length) {
2607           component_index -= op0_length;
2608         }
2609         component_index =
2610             feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2611 
2612         // Check if we are using a component from the first or second operand of
2613         // the feeding instruction.
2614         if (component_index < feeder_op0_length) {
2615           if (new_feeder_id == 0) {
2616             // First time through, save the id of the operand the element comes
2617             // from.
2618             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2619           } else if (new_feeder_id !=
2620                      feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2621             // We need both elements of the feeding_shuffle_inst, so we cannot
2622             // fold.
2623             return false;
2624           }
2625         } else if (component_index != undef_literal) {
2626           if (new_feeder_id == 0) {
2627             // First time through, save the id of the operand the element comes
2628             // from.
2629             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2630           } else if (new_feeder_id !=
2631                      feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2632             // We need both elements of the feeding_shuffle_inst, so we cannot
2633             // fold.
2634             return false;
2635           }
2636           component_index -= feeder_op0_length;
2637         }
2638 
2639         if (!feeder_is_op0 && component_index != undef_literal) {
2640           component_index += op0_length;
2641         }
2642       }
2643       new_operands.push_back(
2644           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2645     }
2646 
2647     if (new_feeder_id == 0) {
2648       analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2649       const analysis::Type* type =
2650           type_mgr->GetType(feeding_shuffle_inst->type_id());
2651       const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2652       new_feeder_id =
2653           const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2654     }
2655 
2656     if (feeder_is_op0) {
2657       // If the size of the first vector operand changed then the indices
2658       // referring to the second operand need to be adjusted.
2659       Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2660       analysis::Type* new_feeder_type =
2661           type_mgr->GetType(new_feeder_inst->type_id());
2662       uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2663       int32_t adjustment = op0_length - new_op0_size;
2664 
2665       if (adjustment != 0) {
2666         for (uint32_t i = 2; i < new_operands.size(); i++) {
2667           uint32_t operand = inst->GetSingleWordInOperand(i);
2668           if (operand >= op0_length && operand != undef_literal) {
2669             new_operands[i].words[0] -= adjustment;
2670           }
2671         }
2672       }
2673 
2674       new_operands[0].words[0] = new_feeder_id;
2675       new_operands[1] = inst->GetInOperand(1);
2676     } else {
2677       new_operands[1].words[0] = new_feeder_id;
2678       new_operands[0] = inst->GetInOperand(0);
2679     }
2680 
2681     inst->SetInOperands(std::move(new_operands));
2682     return true;
2683   };
2684 }
2685 
2686 // Removes duplicate ids from the interface list of an OpEntryPoint
2687 // instruction.
RemoveRedundantOperands()2688 FoldingRule RemoveRedundantOperands() {
2689   return [](IRContext*, Instruction* inst,
2690             const std::vector<const analysis::Constant*>&) {
2691     assert(inst->opcode() == spv::Op::OpEntryPoint &&
2692            "Wrong opcode.  Should be OpEntryPoint.");
2693     bool has_redundant_operand = false;
2694     std::unordered_set<uint32_t> seen_operands;
2695     std::vector<Operand> new_operands;
2696 
2697     new_operands.emplace_back(inst->GetOperand(0));
2698     new_operands.emplace_back(inst->GetOperand(1));
2699     new_operands.emplace_back(inst->GetOperand(2));
2700     for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2701       if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2702         new_operands.emplace_back(inst->GetOperand(i));
2703       } else {
2704         has_redundant_operand = true;
2705       }
2706     }
2707 
2708     if (!has_redundant_operand) {
2709       return false;
2710     }
2711 
2712     inst->SetInOperands(std::move(new_operands));
2713     return true;
2714   };
2715 }
2716 
2717 // If an image instruction's operand is a constant, updates the image operand
2718 // flag from Offset to ConstOffset.
UpdateImageOperands()2719 FoldingRule UpdateImageOperands() {
2720   return [](IRContext*, Instruction* inst,
2721             const std::vector<const analysis::Constant*>& constants) {
2722     const auto opcode = inst->opcode();
2723     (void)opcode;
2724     assert((opcode == spv::Op::OpImageSampleImplicitLod ||
2725             opcode == spv::Op::OpImageSampleExplicitLod ||
2726             opcode == spv::Op::OpImageSampleDrefImplicitLod ||
2727             opcode == spv::Op::OpImageSampleDrefExplicitLod ||
2728             opcode == spv::Op::OpImageSampleProjImplicitLod ||
2729             opcode == spv::Op::OpImageSampleProjExplicitLod ||
2730             opcode == spv::Op::OpImageSampleProjDrefImplicitLod ||
2731             opcode == spv::Op::OpImageSampleProjDrefExplicitLod ||
2732             opcode == spv::Op::OpImageFetch ||
2733             opcode == spv::Op::OpImageGather ||
2734             opcode == spv::Op::OpImageDrefGather ||
2735             opcode == spv::Op::OpImageRead || opcode == spv::Op::OpImageWrite ||
2736             opcode == spv::Op::OpImageSparseSampleImplicitLod ||
2737             opcode == spv::Op::OpImageSparseSampleExplicitLod ||
2738             opcode == spv::Op::OpImageSparseSampleDrefImplicitLod ||
2739             opcode == spv::Op::OpImageSparseSampleDrefExplicitLod ||
2740             opcode == spv::Op::OpImageSparseSampleProjImplicitLod ||
2741             opcode == spv::Op::OpImageSparseSampleProjExplicitLod ||
2742             opcode == spv::Op::OpImageSparseSampleProjDrefImplicitLod ||
2743             opcode == spv::Op::OpImageSparseSampleProjDrefExplicitLod ||
2744             opcode == spv::Op::OpImageSparseFetch ||
2745             opcode == spv::Op::OpImageSparseGather ||
2746             opcode == spv::Op::OpImageSparseDrefGather ||
2747             opcode == spv::Op::OpImageSparseRead) &&
2748            "Wrong opcode.  Should be an image instruction.");
2749 
2750     int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2751     if (operand_index >= 0) {
2752       auto image_operands = inst->GetSingleWordInOperand(operand_index);
2753       if (image_operands & uint32_t(spv::ImageOperandsMask::Offset)) {
2754         uint32_t offset_operand_index = operand_index + 1;
2755         if (image_operands & uint32_t(spv::ImageOperandsMask::Bias))
2756           offset_operand_index++;
2757         if (image_operands & uint32_t(spv::ImageOperandsMask::Lod))
2758           offset_operand_index++;
2759         if (image_operands & uint32_t(spv::ImageOperandsMask::Grad))
2760           offset_operand_index += 2;
2761         assert(((image_operands &
2762                  uint32_t(spv::ImageOperandsMask::ConstOffset)) == 0) &&
2763                "Offset and ConstOffset may not be used together");
2764         if (offset_operand_index < inst->NumOperands()) {
2765           if (constants[offset_operand_index]) {
2766             if (constants[offset_operand_index]->IsZero()) {
2767               inst->RemoveInOperand(offset_operand_index);
2768             } else {
2769               image_operands = image_operands |
2770                                uint32_t(spv::ImageOperandsMask::ConstOffset);
2771             }
2772             image_operands =
2773                 image_operands & ~uint32_t(spv::ImageOperandsMask::Offset);
2774             inst->SetInOperand(operand_index, {image_operands});
2775             return true;
2776           }
2777         }
2778       }
2779     }
2780 
2781     return false;
2782   };
2783 }
2784 
2785 }  // namespace
2786 
AddFoldingRules()2787 void FoldingRules::AddFoldingRules() {
2788   // Add all folding rules to the list for the opcodes to which they apply.
2789   // Note that the order in which rules are added to the list matters. If a rule
2790   // applies to the instruction, the rest of the rules will not be attempted.
2791   // Take that into consideration.
2792   rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());
2793 
2794   rules_[spv::Op::OpCompositeConstruct].push_back(
2795       CompositeExtractFeedingConstruct);
2796 
2797   rules_[spv::Op::OpCompositeExtract].push_back(InsertFeedingExtract());
2798   rules_[spv::Op::OpCompositeExtract].push_back(
2799       CompositeConstructFeedingExtract);
2800   rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2801   rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract());
2802 
2803   rules_[spv::Op::OpCompositeInsert].push_back(
2804       CompositeInsertToCompositeConstruct);
2805 
2806   rules_[spv::Op::OpDot].push_back(DotProductDoingExtract());
2807 
2808   rules_[spv::Op::OpEntryPoint].push_back(RemoveRedundantOperands());
2809 
2810   rules_[spv::Op::OpFAdd].push_back(RedundantFAdd());
2811   rules_[spv::Op::OpFAdd].push_back(MergeAddNegateArithmetic());
2812   rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic());
2813   rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic());
2814   rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic());
2815   rules_[spv::Op::OpFAdd].push_back(FactorAddMuls());
2816 
2817   rules_[spv::Op::OpFDiv].push_back(RedundantFDiv());
2818   rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv());
2819   rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic());
2820   rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
2821   rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());
2822 
2823   rules_[spv::Op::OpFMul].push_back(RedundantFMul());
2824   rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
2825   rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());
2826   rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic());
2827 
2828   rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic());
2829   rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic());
2830   rules_[spv::Op::OpFNegate].push_back(MergeNegateMulDivArithmetic());
2831 
2832   rules_[spv::Op::OpFSub].push_back(RedundantFSub());
2833   rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic());
2834   rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic());
2835   rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic());
2836 
2837   rules_[spv::Op::OpIAdd].push_back(RedundantIAdd());
2838   rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic());
2839   rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic());
2840   rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic());
2841   rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic());
2842   rules_[spv::Op::OpIAdd].push_back(FactorAddMuls());
2843 
2844   rules_[spv::Op::OpIMul].push_back(IntMultipleBy1());
2845   rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic());
2846   rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic());
2847 
2848   rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic());
2849   rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic());
2850   rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic());
2851 
2852   rules_[spv::Op::OpPhi].push_back(RedundantPhi());
2853 
2854   rules_[spv::Op::OpSNegate].push_back(MergeNegateArithmetic());
2855   rules_[spv::Op::OpSNegate].push_back(MergeNegateMulDivArithmetic());
2856   rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic());
2857 
2858   rules_[spv::Op::OpSelect].push_back(RedundantSelect());
2859 
2860   rules_[spv::Op::OpStore].push_back(StoringUndef());
2861 
2862   rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2863 
2864   rules_[spv::Op::OpImageSampleImplicitLod].push_back(UpdateImageOperands());
2865   rules_[spv::Op::OpImageSampleExplicitLod].push_back(UpdateImageOperands());
2866   rules_[spv::Op::OpImageSampleDrefImplicitLod].push_back(
2867       UpdateImageOperands());
2868   rules_[spv::Op::OpImageSampleDrefExplicitLod].push_back(
2869       UpdateImageOperands());
2870   rules_[spv::Op::OpImageSampleProjImplicitLod].push_back(
2871       UpdateImageOperands());
2872   rules_[spv::Op::OpImageSampleProjExplicitLod].push_back(
2873       UpdateImageOperands());
2874   rules_[spv::Op::OpImageSampleProjDrefImplicitLod].push_back(
2875       UpdateImageOperands());
2876   rules_[spv::Op::OpImageSampleProjDrefExplicitLod].push_back(
2877       UpdateImageOperands());
2878   rules_[spv::Op::OpImageFetch].push_back(UpdateImageOperands());
2879   rules_[spv::Op::OpImageGather].push_back(UpdateImageOperands());
2880   rules_[spv::Op::OpImageDrefGather].push_back(UpdateImageOperands());
2881   rules_[spv::Op::OpImageRead].push_back(UpdateImageOperands());
2882   rules_[spv::Op::OpImageWrite].push_back(UpdateImageOperands());
2883   rules_[spv::Op::OpImageSparseSampleImplicitLod].push_back(
2884       UpdateImageOperands());
2885   rules_[spv::Op::OpImageSparseSampleExplicitLod].push_back(
2886       UpdateImageOperands());
2887   rules_[spv::Op::OpImageSparseSampleDrefImplicitLod].push_back(
2888       UpdateImageOperands());
2889   rules_[spv::Op::OpImageSparseSampleDrefExplicitLod].push_back(
2890       UpdateImageOperands());
2891   rules_[spv::Op::OpImageSparseSampleProjImplicitLod].push_back(
2892       UpdateImageOperands());
2893   rules_[spv::Op::OpImageSparseSampleProjExplicitLod].push_back(
2894       UpdateImageOperands());
2895   rules_[spv::Op::OpImageSparseSampleProjDrefImplicitLod].push_back(
2896       UpdateImageOperands());
2897   rules_[spv::Op::OpImageSparseSampleProjDrefExplicitLod].push_back(
2898       UpdateImageOperands());
2899   rules_[spv::Op::OpImageSparseFetch].push_back(UpdateImageOperands());
2900   rules_[spv::Op::OpImageSparseGather].push_back(UpdateImageOperands());
2901   rules_[spv::Op::OpImageSparseDrefGather].push_back(UpdateImageOperands());
2902   rules_[spv::Op::OpImageSparseRead].push_back(UpdateImageOperands());
2903 
2904   FeatureManager* feature_manager = context_->get_feature_mgr();
2905   // Add rules for GLSLstd450
2906   uint32_t ext_inst_glslstd450_id =
2907       feature_manager->GetExtInstImportId_GLSLstd450();
2908   if (ext_inst_glslstd450_id != 0) {
2909     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
2910         RedundantFMix());
2911   }
2912 }
2913 }  // namespace opt
2914 }  // namespace spvtools
2915