• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/folding_rules.h"
16 
17 #include <limits>
18 #include <memory>
19 #include <utility>
20 
21 #include "ir_builder.h"
22 #include "source/latest_version_glsl_std_450_header.h"
23 #include "source/opt/ir_context.h"
24 
25 namespace spvtools {
26 namespace opt {
27 namespace {
28 
29 constexpr uint32_t kExtractCompositeIdInIdx = 0;
30 constexpr uint32_t kInsertObjectIdInIdx = 0;
31 constexpr uint32_t kInsertCompositeIdInIdx = 1;
32 constexpr uint32_t kExtInstSetIdInIdx = 0;
33 constexpr uint32_t kExtInstInstructionInIdx = 1;
34 constexpr uint32_t kFMixXIdInIdx = 2;
35 constexpr uint32_t kFMixYIdInIdx = 3;
36 constexpr uint32_t kFMixAIdInIdx = 4;
37 constexpr uint32_t kStoreObjectInIdx = 1;
38 
39 // Some image instructions may contain an "image operands" argument.
40 // Returns the operand index for the "image operands".
41 // Returns -1 if the instruction does not have image operands.
ImageOperandsMaskInOperandIndex(Instruction * inst)42 int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
43   const auto opcode = inst->opcode();
44   switch (opcode) {
45     case spv::Op::OpImageSampleImplicitLod:
46     case spv::Op::OpImageSampleExplicitLod:
47     case spv::Op::OpImageSampleProjImplicitLod:
48     case spv::Op::OpImageSampleProjExplicitLod:
49     case spv::Op::OpImageFetch:
50     case spv::Op::OpImageRead:
51     case spv::Op::OpImageSparseSampleImplicitLod:
52     case spv::Op::OpImageSparseSampleExplicitLod:
53     case spv::Op::OpImageSparseSampleProjImplicitLod:
54     case spv::Op::OpImageSparseSampleProjExplicitLod:
55     case spv::Op::OpImageSparseFetch:
56     case spv::Op::OpImageSparseRead:
57       return inst->NumOperands() > 4 ? 2 : -1;
58     case spv::Op::OpImageSampleDrefImplicitLod:
59     case spv::Op::OpImageSampleDrefExplicitLod:
60     case spv::Op::OpImageSampleProjDrefImplicitLod:
61     case spv::Op::OpImageSampleProjDrefExplicitLod:
62     case spv::Op::OpImageGather:
63     case spv::Op::OpImageDrefGather:
64     case spv::Op::OpImageSparseSampleDrefImplicitLod:
65     case spv::Op::OpImageSparseSampleDrefExplicitLod:
66     case spv::Op::OpImageSparseSampleProjDrefImplicitLod:
67     case spv::Op::OpImageSparseSampleProjDrefExplicitLod:
68     case spv::Op::OpImageSparseGather:
69     case spv::Op::OpImageSparseDrefGather:
70       return inst->NumOperands() > 5 ? 3 : -1;
71     case spv::Op::OpImageWrite:
72       return inst->NumOperands() > 3 ? 3 : -1;
73     default:
74       return -1;
75   }
76 }
77 
78 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)79 uint32_t ElementWidth(const analysis::Type* type) {
80   if (const analysis::Vector* vec_type = type->AsVector()) {
81     return ElementWidth(vec_type->element_type());
82   } else if (const analysis::Float* float_type = type->AsFloat()) {
83     return float_type->width();
84   } else {
85     assert(type->AsInteger());
86     return type->AsInteger()->width();
87   }
88 }
89 
90 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)91 bool HasFloatingPoint(const analysis::Type* type) {
92   if (type->AsFloat()) {
93     return true;
94   } else if (const analysis::Vector* vec_type = type->AsVector()) {
95     return vec_type->element_type()->AsFloat() != nullptr;
96   }
97 
98   return false;
99 }
100 
101 // Returns false if |val| is NaN, infinite or subnormal.
102 template <typename T>
IsValidResult(T val)103 bool IsValidResult(T val) {
104   int classified = std::fpclassify(val);
105   switch (classified) {
106     case FP_NAN:
107     case FP_INFINITE:
108     case FP_SUBNORMAL:
109       return false;
110     default:
111       return true;
112   }
113 }
114 
ConstInput(const std::vector<const analysis::Constant * > & constants)115 const analysis::Constant* ConstInput(
116     const std::vector<const analysis::Constant*>& constants) {
117   return constants[0] ? constants[0] : constants[1];
118 }
119 
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)120 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
121                            Instruction* inst) {
122   uint32_t in_op = c ? 1u : 0u;
123   return context->get_def_use_mgr()->GetDef(
124       inst->GetSingleWordInOperand(in_op));
125 }
126 
ExtractInts(uint64_t val)127 std::vector<uint32_t> ExtractInts(uint64_t val) {
128   std::vector<uint32_t> words;
129   words.push_back(static_cast<uint32_t>(val));
130   words.push_back(static_cast<uint32_t>(val >> 32));
131   return words;
132 }
133 
GetWordsFromScalarIntConstant(const analysis::IntConstant * c)134 std::vector<uint32_t> GetWordsFromScalarIntConstant(
135     const analysis::IntConstant* c) {
136   assert(c != nullptr);
137   uint32_t width = c->type()->AsInteger()->width();
138   assert(width == 8 || width == 16 || width == 32 || width == 64);
139   if (width == 64) {
140     uint64_t uval = static_cast<uint64_t>(c->GetU64());
141     return ExtractInts(uval);
142   }
143   // Section 2.2.1 of the SPIR-V spec guarantees that all integer types
144   // smaller than 32-bits are automatically zero or sign extended to 32-bits.
145   return {c->GetU32BitValue()};
146 }
147 
GetWordsFromScalarFloatConstant(const analysis::FloatConstant * c)148 std::vector<uint32_t> GetWordsFromScalarFloatConstant(
149     const analysis::FloatConstant* c) {
150   assert(c != nullptr);
151   uint32_t width = c->type()->AsFloat()->width();
152   assert(width == 16 || width == 32 || width == 64);
153   if (width == 64) {
154     utils::FloatProxy<double> result(c->GetDouble());
155     return result.GetWords();
156   }
157   // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types
158   // smaller than 32-bits are automatically zero extended to 32-bits.
159   return {c->GetU32BitValue()};
160 }
161 
GetWordsFromNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)162 std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
163     analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
164   if (const auto* float_constant = c->AsFloatConstant()) {
165     return GetWordsFromScalarFloatConstant(float_constant);
166   } else if (const auto* int_constant = c->AsIntConstant()) {
167     return GetWordsFromScalarIntConstant(int_constant);
168   } else if (const auto* vec_constant = c->AsVectorConstant()) {
169     std::vector<uint32_t> words;
170     for (const auto* comp : vec_constant->GetComponents()) {
171       auto comp_in_words =
172           GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
173       words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
174     }
175     return words;
176   }
177   return {};
178 }
179 
ConvertWordsToNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const std::vector<uint32_t> & words,const analysis::Type * type)180 const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
181     analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
182     const analysis::Type* type) {
183   if (type->AsInteger() || type->AsFloat())
184     return const_mgr->GetConstant(type, words);
185   if (const auto* vec_type = type->AsVector())
186     return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
187   return nullptr;
188 }
189 
190 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
191 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)192 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
193                                      const analysis::Constant* c) {
194   assert(c);
195   assert(c->type()->AsFloat());
196   uint32_t width = c->type()->AsFloat()->width();
197   assert(width == 32 || width == 64);
198   std::vector<uint32_t> words;
199   if (width == 64) {
200     utils::FloatProxy<double> result(c->GetDouble() * -1.0);
201     words = result.GetWords();
202   } else {
203     utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
204     words = result.GetWords();
205   }
206 
207   const analysis::Constant* negated_const =
208       const_mgr->GetConstant(c->type(), std::move(words));
209   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
210 }
211 
212 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)213 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
214                                const analysis::Constant* c) {
215   assert(c);
216   assert(c->type()->AsInteger());
217   uint32_t width = c->type()->AsInteger()->width();
218   assert(width == 32 || width == 64);
219   std::vector<uint32_t> words;
220   if (width == 64) {
221     uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
222     words = ExtractInts(uval);
223   } else {
224     words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
225   }
226 
227   const analysis::Constant* negated_const =
228       const_mgr->GetConstant(c->type(), std::move(words));
229   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
230 }
231 
232 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)233 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
234                               const analysis::Constant* c) {
235   assert(const_mgr && c);
236   assert(c->type()->AsVector());
237   if (c->AsNullConstant()) {
238     // 0.0 vs -0.0 shouldn't matter.
239     return const_mgr->GetDefiningInstruction(c)->result_id();
240   } else {
241     const analysis::Type* component_type =
242         c->AsVectorConstant()->component_type();
243     std::vector<uint32_t> words;
244     for (auto& comp : c->AsVectorConstant()->GetComponents()) {
245       if (component_type->AsFloat()) {
246         words.push_back(NegateFloatingPointConstant(const_mgr, comp));
247       } else {
248         assert(component_type->AsInteger());
249         words.push_back(NegateIntegerConstant(const_mgr, comp));
250       }
251     }
252 
253     const analysis::Constant* negated_const =
254         const_mgr->GetConstant(c->type(), std::move(words));
255     return const_mgr->GetDefiningInstruction(negated_const)->result_id();
256   }
257 }
258 
259 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)260 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
261                         const analysis::Constant* c) {
262   if (c->type()->AsVector()) {
263     return NegateVectorConstant(const_mgr, c);
264   } else if (c->type()->AsFloat()) {
265     return NegateFloatingPointConstant(const_mgr, c);
266   } else {
267     assert(c->type()->AsInteger());
268     return NegateIntegerConstant(const_mgr, c);
269   }
270 }
271 
272 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
273 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)274 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
275                     const analysis::Constant* c) {
276   assert(const_mgr && c);
277   assert(c->type()->AsFloat());
278 
279   uint32_t width = c->type()->AsFloat()->width();
280   assert(width == 32 || width == 64);
281   std::vector<uint32_t> words;
282 
283   if (c->IsZero()) {
284     return 0;
285   }
286 
287   if (width == 64) {
288     spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
289     if (!IsValidResult(result.getAsFloat())) return 0;
290     words = result.GetWords();
291   } else {
292     spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
293     if (!IsValidResult(result.getAsFloat())) return 0;
294     words = result.GetWords();
295   }
296 
297   const analysis::Constant* negated_const =
298       const_mgr->GetConstant(c->type(), std::move(words));
299   return const_mgr->GetDefiningInstruction(negated_const)->result_id();
300 }
301 
302 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()303 FoldingRule ReciprocalFDiv() {
304   return [](IRContext* context, Instruction* inst,
305             const std::vector<const analysis::Constant*>& constants) {
306     assert(inst->opcode() == spv::Op::OpFDiv);
307     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
308     const analysis::Type* type =
309         context->get_type_mgr()->GetType(inst->type_id());
310     if (!inst->IsFloatingPointFoldingAllowed()) return false;
311 
312     uint32_t width = ElementWidth(type);
313     if (width != 32 && width != 64) return false;
314 
315     if (constants[1] != nullptr) {
316       uint32_t id = 0;
317       if (const analysis::VectorConstant* vector_const =
318               constants[1]->AsVectorConstant()) {
319         std::vector<uint32_t> neg_ids;
320         for (auto& comp : vector_const->GetComponents()) {
321           id = Reciprocal(const_mgr, comp);
322           if (id == 0) return false;
323           neg_ids.push_back(id);
324         }
325         const analysis::Constant* negated_const =
326             const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
327         id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
328       } else if (constants[1]->AsFloatConstant()) {
329         id = Reciprocal(const_mgr, constants[1]);
330         if (id == 0) return false;
331       } else {
332         // Don't fold a null constant.
333         return false;
334       }
335       inst->SetOpcode(spv::Op::OpFMul);
336       inst->SetInOperands(
337           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
338            {SPV_OPERAND_TYPE_ID, {id}}});
339       return true;
340     }
341 
342     return false;
343   };
344 }
345 
346 // Elides consecutive negate instructions.
MergeNegateArithmetic()347 FoldingRule MergeNegateArithmetic() {
348   return [](IRContext* context, Instruction* inst,
349             const std::vector<const analysis::Constant*>& constants) {
350     assert(inst->opcode() == spv::Op::OpFNegate ||
351            inst->opcode() == spv::Op::OpSNegate);
352     (void)constants;
353     const analysis::Type* type =
354         context->get_type_mgr()->GetType(inst->type_id());
355     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
356       return false;
357 
358     Instruction* op_inst =
359         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
360     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
361       return false;
362 
363     if (op_inst->opcode() == inst->opcode()) {
364       // Elide negates.
365       inst->SetOpcode(spv::Op::OpCopyObject);
366       inst->SetInOperands(
367           {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
368       return true;
369     }
370 
371     return false;
372   };
373 }
374 
375 // Merges negate into a mul or div operation if that operation contains a
376 // constant operand.
377 // Cases:
378 // -(x * 2) = x * -2
379 // -(2 * x) = x * -2
380 // -(x / 2) = x / -2
381 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()382 FoldingRule MergeNegateMulDivArithmetic() {
383   return [](IRContext* context, Instruction* inst,
384             const std::vector<const analysis::Constant*>& constants) {
385     assert(inst->opcode() == spv::Op::OpFNegate ||
386            inst->opcode() == spv::Op::OpSNegate);
387     (void)constants;
388     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
389     const analysis::Type* type =
390         context->get_type_mgr()->GetType(inst->type_id());
391     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
392       return false;
393 
394     Instruction* op_inst =
395         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
396     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
397       return false;
398 
399     uint32_t width = ElementWidth(type);
400     if (width != 32 && width != 64) return false;
401 
402     spv::Op opcode = op_inst->opcode();
403     if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv ||
404         opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv ||
405         opcode == spv::Op::OpUDiv) {
406       std::vector<const analysis::Constant*> op_constants =
407           const_mgr->GetOperandConstants(op_inst);
408       // Merge negate into mul or div if one operand is constant.
409       if (op_constants[0] || op_constants[1]) {
410         bool zero_is_variable = op_constants[0] == nullptr;
411         const analysis::Constant* c = ConstInput(op_constants);
412         uint32_t neg_id = NegateConstant(const_mgr, c);
413         uint32_t non_const_id = zero_is_variable
414                                     ? op_inst->GetSingleWordInOperand(0u)
415                                     : op_inst->GetSingleWordInOperand(1u);
416         // Change this instruction to a mul/div.
417         inst->SetOpcode(op_inst->opcode());
418         if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
419             opcode == spv::Op::OpSDiv) {
420           uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
421           uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
422           inst->SetInOperands(
423               {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
424         } else {
425           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
426                                {SPV_OPERAND_TYPE_ID, {neg_id}}});
427         }
428         return true;
429       }
430     }
431 
432     return false;
433   };
434 }
435 
436 // Merges negate into a add or sub operation if that operation contains a
437 // constant operand.
438 // Cases:
439 // -(x + 2) = -2 - x
440 // -(2 + x) = -2 - x
441 // -(x - 2) = 2 - x
442 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()443 FoldingRule MergeNegateAddSubArithmetic() {
444   return [](IRContext* context, Instruction* inst,
445             const std::vector<const analysis::Constant*>& constants) {
446     assert(inst->opcode() == spv::Op::OpFNegate ||
447            inst->opcode() == spv::Op::OpSNegate);
448     (void)constants;
449     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
450     const analysis::Type* type =
451         context->get_type_mgr()->GetType(inst->type_id());
452     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
453       return false;
454 
455     Instruction* op_inst =
456         context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
457     if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
458       return false;
459 
460     uint32_t width = ElementWidth(type);
461     if (width != 32 && width != 64) return false;
462 
463     if (op_inst->opcode() == spv::Op::OpFAdd ||
464         op_inst->opcode() == spv::Op::OpFSub ||
465         op_inst->opcode() == spv::Op::OpIAdd ||
466         op_inst->opcode() == spv::Op::OpISub) {
467       std::vector<const analysis::Constant*> op_constants =
468           const_mgr->GetOperandConstants(op_inst);
469       if (op_constants[0] || op_constants[1]) {
470         bool zero_is_variable = op_constants[0] == nullptr;
471         bool is_add = (op_inst->opcode() == spv::Op::OpFAdd) ||
472                       (op_inst->opcode() == spv::Op::OpIAdd);
473         bool swap_operands = !is_add || zero_is_variable;
474         bool negate_const = is_add;
475         const analysis::Constant* c = ConstInput(op_constants);
476         uint32_t const_id = 0;
477         if (negate_const) {
478           const_id = NegateConstant(const_mgr, c);
479         } else {
480           const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
481                                       : op_inst->GetSingleWordInOperand(0u);
482         }
483 
484         // Swap operands if necessary and make the instruction a subtraction.
485         uint32_t op0 =
486             zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
487         uint32_t op1 =
488             zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
489         if (swap_operands) std::swap(op0, op1);
490         inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
491                                                : spv::Op::OpISub);
492         inst->SetInOperands(
493             {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
494         return true;
495       }
496     }
497 
498     return false;
499   };
500 }
501 
502 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)503 bool HasZero(const analysis::Constant* c) {
504   if (c->AsNullConstant()) {
505     return true;
506   }
507   if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
508     for (auto& comp : vec_const->GetComponents())
509       if (HasZero(comp)) return true;
510   } else {
511     assert(c->AsScalarConstant());
512     return c->AsScalarConstant()->IsZero();
513   }
514 
515   return false;
516 }
517 
518 // Performs |input1| |opcode| |input2| and returns the merged constant result
519 // id. Returns 0 if the result is not a valid value. The input types must be
520 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)521 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
522                                        spv::Op opcode,
523                                        const analysis::Constant* input1,
524                                        const analysis::Constant* input2) {
525   const analysis::Type* type = input1->type();
526   assert(type->AsFloat());
527   uint32_t width = type->AsFloat()->width();
528   assert(width == 32 || width == 64);
529   std::vector<uint32_t> words;
530 #define FOLD_OP(op)                                                          \
531   if (width == 64) {                                                         \
532     utils::FloatProxy<double> val =                                          \
533         input1->GetDouble() op input2->GetDouble();                          \
534     double dval = val.getAsFloat();                                          \
535     if (!IsValidResult(dval)) return 0;                                      \
536     words = val.GetWords();                                                  \
537   } else {                                                                   \
538     utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
539     float fval = val.getAsFloat();                                           \
540     if (!IsValidResult(fval)) return 0;                                      \
541     words = val.GetWords();                                                  \
542   }                                                                          \
543   static_assert(true, "require extra semicolon")
544   switch (opcode) {
545     case spv::Op::OpFMul:
546       FOLD_OP(*);
547       break;
548     case spv::Op::OpFDiv:
549       if (HasZero(input2)) return 0;
550       FOLD_OP(/);
551       break;
552     case spv::Op::OpFAdd:
553       FOLD_OP(+);
554       break;
555     case spv::Op::OpFSub:
556       FOLD_OP(-);
557       break;
558     default:
559       assert(false && "Unexpected operation");
560       break;
561   }
562 #undef FOLD_OP
563   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
564   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
565 }
566 
567 // Performs |input1| |opcode| |input2| and returns the merged constant result
568 // id. Returns 0 if the result is not a valid value. The input types must be
569 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)570 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
571                                  spv::Op opcode,
572                                  const analysis::Constant* input1,
573                                  const analysis::Constant* input2) {
574   assert(input1->type()->AsInteger());
575   const analysis::Integer* type = input1->type()->AsInteger();
576   uint32_t width = type->AsInteger()->width();
577   assert(width == 32 || width == 64);
578   std::vector<uint32_t> words;
579   // Regardless of the sign of the constant, folding is performed on an unsigned
580   // interpretation of the constant data. This avoids signed integer overflow
581   // while folding, and works because sign is irrelevant for the IAdd, ISub and
582   // IMul instructions.
583 #define FOLD_OP(op)                                      \
584   if (width == 64) {                                     \
585     uint64_t val = input1->GetU64() op input2->GetU64(); \
586     words = ExtractInts(val);                            \
587   } else {                                               \
588     uint32_t val = input1->GetU32() op input2->GetU32(); \
589     words.push_back(val);                                \
590   }                                                      \
591   static_assert(true, "require extra semicolon")
592   switch (opcode) {
593     case spv::Op::OpIMul:
594       FOLD_OP(*);
595       break;
596     case spv::Op::OpSDiv:
597     case spv::Op::OpUDiv:
598       assert(false && "Should not merge integer division");
599       break;
600     case spv::Op::OpIAdd:
601       FOLD_OP(+);
602       break;
603     case spv::Op::OpISub:
604       FOLD_OP(-);
605       break;
606     default:
607       assert(false && "Unexpected operation");
608       break;
609   }
610 #undef FOLD_OP
611   const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
612   return const_mgr->GetDefiningInstruction(merged_const)->result_id();
613 }
614 
615 // Performs |input1| |opcode| |input2| and returns the merged constant result
616 // id. Returns 0 if the result is not a valid value. The input types must be
617 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)618 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, spv::Op opcode,
619                           const analysis::Constant* input1,
620                           const analysis::Constant* input2) {
621   assert(input1 && input2);
622   const analysis::Type* type = input1->type();
623   std::vector<uint32_t> words;
624   if (const analysis::Vector* vector_type = type->AsVector()) {
625     const analysis::Type* ele_type = vector_type->element_type();
626     for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
627       uint32_t id = 0;
628 
629       const analysis::Constant* input1_comp = nullptr;
630       if (const analysis::VectorConstant* input1_vector =
631               input1->AsVectorConstant()) {
632         input1_comp = input1_vector->GetComponents()[i];
633       } else {
634         assert(input1->AsNullConstant());
635         input1_comp = const_mgr->GetConstant(ele_type, {});
636       }
637 
638       const analysis::Constant* input2_comp = nullptr;
639       if (const analysis::VectorConstant* input2_vector =
640               input2->AsVectorConstant()) {
641         input2_comp = input2_vector->GetComponents()[i];
642       } else {
643         assert(input2->AsNullConstant());
644         input2_comp = const_mgr->GetConstant(ele_type, {});
645       }
646 
647       if (ele_type->AsFloat()) {
648         id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
649                                            input2_comp);
650       } else {
651         assert(ele_type->AsInteger());
652         id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
653                                      input2_comp);
654       }
655       if (id == 0) return 0;
656       words.push_back(id);
657     }
658     const analysis::Constant* merged_const =
659         const_mgr->GetConstant(type, words);
660     return const_mgr->GetDefiningInstruction(merged_const)->result_id();
661   } else if (type->AsFloat()) {
662     return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
663   } else {
664     assert(type->AsInteger());
665     return PerformIntegerOperation(const_mgr, opcode, input1, input2);
666   }
667 }
668 
669 // Merges consecutive multiplies where each contains one constant operand.
670 // Cases:
671 // 2 * (x * 2) = x * 4
672 // 2 * (2 * x) = x * 4
673 // (x * 2) * 2 = x * 4
674 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()675 FoldingRule MergeMulMulArithmetic() {
676   return [](IRContext* context, Instruction* inst,
677             const std::vector<const analysis::Constant*>& constants) {
678     assert(inst->opcode() == spv::Op::OpFMul ||
679            inst->opcode() == spv::Op::OpIMul);
680     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
681     const analysis::Type* type =
682         context->get_type_mgr()->GetType(inst->type_id());
683     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
684       return false;
685 
686     uint32_t width = ElementWidth(type);
687     if (width != 32 && width != 64) return false;
688 
689     // Determine the constant input and the variable input in |inst|.
690     const analysis::Constant* const_input1 = ConstInput(constants);
691     if (!const_input1) return false;
692     Instruction* other_inst = NonConstInput(context, constants[0], inst);
693     if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
694       return false;
695 
696     if (other_inst->opcode() == inst->opcode()) {
697       std::vector<const analysis::Constant*> other_constants =
698           const_mgr->GetOperandConstants(other_inst);
699       const analysis::Constant* const_input2 = ConstInput(other_constants);
700       if (!const_input2) return false;
701 
702       bool other_first_is_variable = other_constants[0] == nullptr;
703       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
704                                             const_input1, const_input2);
705       if (merged_id == 0) return false;
706 
707       uint32_t non_const_id = other_first_is_variable
708                                   ? other_inst->GetSingleWordInOperand(0u)
709                                   : other_inst->GetSingleWordInOperand(1u);
710       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
711                            {SPV_OPERAND_TYPE_ID, {merged_id}}});
712       return true;
713     }
714 
715     return false;
716   };
717 }
718 
719 // Merges divides into subsequent multiplies if each instruction contains one
720 // constant operand. Does not support integer operations.
721 // Cases:
722 // 2 * (x / 2) = x * 1
723 // 2 * (2 / x) = 4 / x
724 // (x / 2) * 2 = x * 1
725 // (2 / x) * 2 = 4 / x
726 // (y / x) * x = y
727 // x * (y / x) = y
MergeMulDivArithmetic()728 FoldingRule MergeMulDivArithmetic() {
729   return [](IRContext* context, Instruction* inst,
730             const std::vector<const analysis::Constant*>& constants) {
731     assert(inst->opcode() == spv::Op::OpFMul);
732     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
733     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
734 
735     const analysis::Type* type =
736         context->get_type_mgr()->GetType(inst->type_id());
737     if (!inst->IsFloatingPointFoldingAllowed()) return false;
738 
739     uint32_t width = ElementWidth(type);
740     if (width != 32 && width != 64) return false;
741 
742     for (uint32_t i = 0; i < 2; i++) {
743       uint32_t op_id = inst->GetSingleWordInOperand(i);
744       Instruction* op_inst = def_use_mgr->GetDef(op_id);
745       if (op_inst->opcode() == spv::Op::OpFDiv) {
746         if (op_inst->GetSingleWordInOperand(1) ==
747             inst->GetSingleWordInOperand(1 - i)) {
748           inst->SetOpcode(spv::Op::OpCopyObject);
749           inst->SetInOperands(
750               {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
751           return true;
752         }
753       }
754     }
755 
756     const analysis::Constant* const_input1 = ConstInput(constants);
757     if (!const_input1) return false;
758     Instruction* other_inst = NonConstInput(context, constants[0], inst);
759     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
760 
761     if (other_inst->opcode() == spv::Op::OpFDiv) {
762       std::vector<const analysis::Constant*> other_constants =
763           const_mgr->GetOperandConstants(other_inst);
764       const analysis::Constant* const_input2 = ConstInput(other_constants);
765       if (!const_input2 || HasZero(const_input2)) return false;
766 
767       bool other_first_is_variable = other_constants[0] == nullptr;
768       // If the variable value is the second operand of the divide, multiply
769       // the constants together. Otherwise divide the constants.
770       uint32_t merged_id = PerformOperation(
771           const_mgr,
772           other_first_is_variable ? other_inst->opcode() : inst->opcode(),
773           const_input1, const_input2);
774       if (merged_id == 0) return false;
775 
776       uint32_t non_const_id = other_first_is_variable
777                                   ? other_inst->GetSingleWordInOperand(0u)
778                                   : other_inst->GetSingleWordInOperand(1u);
779 
780       // If the variable value is on the second operand of the div, then this
781       // operation is a div. Otherwise it should be a multiply.
782       inst->SetOpcode(other_first_is_variable ? inst->opcode()
783                                               : other_inst->opcode());
784       if (other_first_is_variable) {
785         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
786                              {SPV_OPERAND_TYPE_ID, {merged_id}}});
787       } else {
788         inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
789                              {SPV_OPERAND_TYPE_ID, {non_const_id}}});
790       }
791       return true;
792     }
793 
794     return false;
795   };
796 }
797 
798 // Merges multiply of constant and negation.
799 // Cases:
800 // (-x) * 2 = x * -2
801 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()802 FoldingRule MergeMulNegateArithmetic() {
803   return [](IRContext* context, Instruction* inst,
804             const std::vector<const analysis::Constant*>& constants) {
805     assert(inst->opcode() == spv::Op::OpFMul ||
806            inst->opcode() == spv::Op::OpIMul);
807     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
808     const analysis::Type* type =
809         context->get_type_mgr()->GetType(inst->type_id());
810     bool uses_float = HasFloatingPoint(type);
811     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
812 
813     uint32_t width = ElementWidth(type);
814     if (width != 32 && width != 64) return false;
815 
816     const analysis::Constant* const_input1 = ConstInput(constants);
817     if (!const_input1) return false;
818     Instruction* other_inst = NonConstInput(context, constants[0], inst);
819     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
820       return false;
821 
822     if (other_inst->opcode() == spv::Op::OpFNegate ||
823         other_inst->opcode() == spv::Op::OpSNegate) {
824       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
825 
826       inst->SetInOperands(
827           {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
828            {SPV_OPERAND_TYPE_ID, {neg_id}}});
829       return true;
830     }
831 
832     return false;
833   };
834 }
835 
836 // Merges consecutive divides if each instruction contains one constant operand.
837 // Does not support integer division.
838 // Cases:
839 // 2 / (x / 2) = 4 / x
840 // 4 / (2 / x) = 2 * x
841 // (4 / x) / 2 = 2 / x
842 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()843 FoldingRule MergeDivDivArithmetic() {
844   return [](IRContext* context, Instruction* inst,
845             const std::vector<const analysis::Constant*>& constants) {
846     assert(inst->opcode() == spv::Op::OpFDiv);
847     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
848     const analysis::Type* type =
849         context->get_type_mgr()->GetType(inst->type_id());
850     if (!inst->IsFloatingPointFoldingAllowed()) return false;
851 
852     uint32_t width = ElementWidth(type);
853     if (width != 32 && width != 64) return false;
854 
855     const analysis::Constant* const_input1 = ConstInput(constants);
856     if (!const_input1 || HasZero(const_input1)) return false;
857     Instruction* other_inst = NonConstInput(context, constants[0], inst);
858     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
859 
860     bool first_is_variable = constants[0] == nullptr;
861     if (other_inst->opcode() == inst->opcode()) {
862       std::vector<const analysis::Constant*> other_constants =
863           const_mgr->GetOperandConstants(other_inst);
864       const analysis::Constant* const_input2 = ConstInput(other_constants);
865       if (!const_input2 || HasZero(const_input2)) return false;
866 
867       bool other_first_is_variable = other_constants[0] == nullptr;
868 
869       spv::Op merge_op = inst->opcode();
870       if (other_first_is_variable) {
871         // Constants magnify.
872         merge_op = spv::Op::OpFMul;
873       }
874 
875       // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
876       // because it is commutative.
877       if (first_is_variable) std::swap(const_input1, const_input2);
878       uint32_t merged_id =
879           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
880       if (merged_id == 0) return false;
881 
882       uint32_t non_const_id = other_first_is_variable
883                                   ? other_inst->GetSingleWordInOperand(0u)
884                                   : other_inst->GetSingleWordInOperand(1u);
885 
886       spv::Op op = inst->opcode();
887       if (!first_is_variable && !other_first_is_variable) {
888         // Effectively div of 1/x, so change to multiply.
889         op = spv::Op::OpFMul;
890       }
891 
892       uint32_t op1 = merged_id;
893       uint32_t op2 = non_const_id;
894       if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
895       inst->SetOpcode(op);
896       inst->SetInOperands(
897           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
898       return true;
899     }
900 
901     return false;
902   };
903 }
904 
905 // Fold multiplies succeeded by divides where each instruction contains a
906 // constant operand. Does not support integer divide.
907 // Cases:
908 // 4 / (x * 2) = 2 / x
909 // 4 / (2 * x) = 2 / x
910 // (x * 4) / 2 = x * 2
911 // (4 * x) / 2 = x * 2
912 // (x * y) / x = y
913 // (y * x) / x = y
MergeDivMulArithmetic()914 FoldingRule MergeDivMulArithmetic() {
915   return [](IRContext* context, Instruction* inst,
916             const std::vector<const analysis::Constant*>& constants) {
917     assert(inst->opcode() == spv::Op::OpFDiv);
918     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
919     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
920 
921     const analysis::Type* type =
922         context->get_type_mgr()->GetType(inst->type_id());
923     if (!inst->IsFloatingPointFoldingAllowed()) return false;
924 
925     uint32_t width = ElementWidth(type);
926     if (width != 32 && width != 64) return false;
927 
928     uint32_t op_id = inst->GetSingleWordInOperand(0);
929     Instruction* op_inst = def_use_mgr->GetDef(op_id);
930 
931     if (op_inst->opcode() == spv::Op::OpFMul) {
932       for (uint32_t i = 0; i < 2; i++) {
933         if (op_inst->GetSingleWordInOperand(i) ==
934             inst->GetSingleWordInOperand(1)) {
935           inst->SetOpcode(spv::Op::OpCopyObject);
936           inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
937                                 {op_inst->GetSingleWordInOperand(1 - i)}}});
938           return true;
939         }
940       }
941     }
942 
943     const analysis::Constant* const_input1 = ConstInput(constants);
944     if (!const_input1 || HasZero(const_input1)) return false;
945     Instruction* other_inst = NonConstInput(context, constants[0], inst);
946     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
947 
948     bool first_is_variable = constants[0] == nullptr;
949     if (other_inst->opcode() == spv::Op::OpFMul) {
950       std::vector<const analysis::Constant*> other_constants =
951           const_mgr->GetOperandConstants(other_inst);
952       const analysis::Constant* const_input2 = ConstInput(other_constants);
953       if (!const_input2) return false;
954 
955       bool other_first_is_variable = other_constants[0] == nullptr;
956 
957       // This is an x / (*) case. Swap the inputs.
958       if (first_is_variable) std::swap(const_input1, const_input2);
959       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
960                                             const_input1, const_input2);
961       if (merged_id == 0) return false;
962 
963       uint32_t non_const_id = other_first_is_variable
964                                   ? other_inst->GetSingleWordInOperand(0u)
965                                   : other_inst->GetSingleWordInOperand(1u);
966 
967       uint32_t op1 = merged_id;
968       uint32_t op2 = non_const_id;
969       if (first_is_variable) std::swap(op1, op2);
970 
971       // Convert to multiply
972       if (first_is_variable) inst->SetOpcode(other_inst->opcode());
973       inst->SetInOperands(
974           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
975       return true;
976     }
977 
978     return false;
979   };
980 }
981 
982 // Fold divides of a constant and a negation.
983 // Cases:
984 // (-x) / 2 = x / -2
985 // 2 / (-x) = -2 / x
MergeDivNegateArithmetic()986 FoldingRule MergeDivNegateArithmetic() {
987   return [](IRContext* context, Instruction* inst,
988             const std::vector<const analysis::Constant*>& constants) {
989     assert(inst->opcode() == spv::Op::OpFDiv);
990     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
991     if (!inst->IsFloatingPointFoldingAllowed()) return false;
992 
993     const analysis::Constant* const_input1 = ConstInput(constants);
994     if (!const_input1) return false;
995     Instruction* other_inst = NonConstInput(context, constants[0], inst);
996     if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
997 
998     bool first_is_variable = constants[0] == nullptr;
999     if (other_inst->opcode() == spv::Op::OpFNegate) {
1000       uint32_t neg_id = NegateConstant(const_mgr, const_input1);
1001 
1002       if (first_is_variable) {
1003         inst->SetInOperands(
1004             {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
1005              {SPV_OPERAND_TYPE_ID, {neg_id}}});
1006       } else {
1007         inst->SetInOperands(
1008             {{SPV_OPERAND_TYPE_ID, {neg_id}},
1009              {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1010       }
1011       return true;
1012     }
1013 
1014     return false;
1015   };
1016 }
1017 
1018 // Folds addition of a constant and a negation.
1019 // Cases:
1020 // (-x) + 2 = 2 - x
1021 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()1022 FoldingRule MergeAddNegateArithmetic() {
1023   return [](IRContext* context, Instruction* inst,
1024             const std::vector<const analysis::Constant*>& constants) {
1025     assert(inst->opcode() == spv::Op::OpFAdd ||
1026            inst->opcode() == spv::Op::OpIAdd);
1027     const analysis::Type* type =
1028         context->get_type_mgr()->GetType(inst->type_id());
1029     bool uses_float = HasFloatingPoint(type);
1030     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1031 
1032     const analysis::Constant* const_input1 = ConstInput(constants);
1033     if (!const_input1) return false;
1034     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1035     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1036       return false;
1037 
1038     if (other_inst->opcode() == spv::Op::OpSNegate ||
1039         other_inst->opcode() == spv::Op::OpFNegate) {
1040       inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
1041                                              : spv::Op::OpISub);
1042       uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
1043                                        : inst->GetSingleWordInOperand(1u);
1044       inst->SetInOperands(
1045           {{SPV_OPERAND_TYPE_ID, {const_id}},
1046            {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1047       return true;
1048     }
1049     return false;
1050   };
1051 }
1052 
1053 // Folds subtraction of a constant and a negation.
1054 // Cases:
1055 // (-x) - 2 = -2 - x
1056 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()1057 FoldingRule MergeSubNegateArithmetic() {
1058   return [](IRContext* context, Instruction* inst,
1059             const std::vector<const analysis::Constant*>& constants) {
1060     assert(inst->opcode() == spv::Op::OpFSub ||
1061            inst->opcode() == spv::Op::OpISub);
1062     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1063     const analysis::Type* type =
1064         context->get_type_mgr()->GetType(inst->type_id());
1065     bool uses_float = HasFloatingPoint(type);
1066     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1067 
1068     uint32_t width = ElementWidth(type);
1069     if (width != 32 && width != 64) return false;
1070 
1071     const analysis::Constant* const_input1 = ConstInput(constants);
1072     if (!const_input1) return false;
1073     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1074     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1075       return false;
1076 
1077     if (other_inst->opcode() == spv::Op::OpSNegate ||
1078         other_inst->opcode() == spv::Op::OpFNegate) {
1079       uint32_t op1 = 0;
1080       uint32_t op2 = 0;
1081       spv::Op opcode = inst->opcode();
1082       if (constants[0] != nullptr) {
1083         op1 = other_inst->GetSingleWordInOperand(0u);
1084         op2 = inst->GetSingleWordInOperand(0u);
1085         opcode = HasFloatingPoint(type) ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1086       } else {
1087         op1 = NegateConstant(const_mgr, const_input1);
1088         op2 = other_inst->GetSingleWordInOperand(0u);
1089       }
1090 
1091       inst->SetOpcode(opcode);
1092       inst->SetInOperands(
1093           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1094       return true;
1095     }
1096     return false;
1097   };
1098 }
1099 
1100 // Folds addition of an addition where each operation has a constant operand.
1101 // Cases:
1102 // (x + 2) + 2 = x + 4
1103 // (2 + x) + 2 = x + 4
1104 // 2 + (x + 2) = x + 4
1105 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1106 FoldingRule MergeAddAddArithmetic() {
1107   return [](IRContext* context, Instruction* inst,
1108             const std::vector<const analysis::Constant*>& constants) {
1109     assert(inst->opcode() == spv::Op::OpFAdd ||
1110            inst->opcode() == spv::Op::OpIAdd);
1111     const analysis::Type* type =
1112         context->get_type_mgr()->GetType(inst->type_id());
1113     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1114     bool uses_float = HasFloatingPoint(type);
1115     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1116 
1117     uint32_t width = ElementWidth(type);
1118     if (width != 32 && width != 64) return false;
1119 
1120     const analysis::Constant* const_input1 = ConstInput(constants);
1121     if (!const_input1) return false;
1122     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1123     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1124       return false;
1125 
1126     if (other_inst->opcode() == spv::Op::OpFAdd ||
1127         other_inst->opcode() == spv::Op::OpIAdd) {
1128       std::vector<const analysis::Constant*> other_constants =
1129           const_mgr->GetOperandConstants(other_inst);
1130       const analysis::Constant* const_input2 = ConstInput(other_constants);
1131       if (!const_input2) return false;
1132 
1133       Instruction* non_const_input =
1134           NonConstInput(context, other_constants[0], other_inst);
1135       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1136                                             const_input1, const_input2);
1137       if (merged_id == 0) return false;
1138 
1139       inst->SetInOperands(
1140           {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1141            {SPV_OPERAND_TYPE_ID, {merged_id}}});
1142       return true;
1143     }
1144     return false;
1145   };
1146 }
1147 
1148 // Folds addition of a subtraction where each operation has a constant operand.
1149 // Cases:
1150 // (x - 2) + 2 = x + 0
1151 // (2 - x) + 2 = 4 - x
1152 // 2 + (x - 2) = x + 0
1153 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1154 FoldingRule MergeAddSubArithmetic() {
1155   return [](IRContext* context, Instruction* inst,
1156             const std::vector<const analysis::Constant*>& constants) {
1157     assert(inst->opcode() == spv::Op::OpFAdd ||
1158            inst->opcode() == spv::Op::OpIAdd);
1159     const analysis::Type* type =
1160         context->get_type_mgr()->GetType(inst->type_id());
1161     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1162     bool uses_float = HasFloatingPoint(type);
1163     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1164 
1165     uint32_t width = ElementWidth(type);
1166     if (width != 32 && width != 64) return false;
1167 
1168     const analysis::Constant* const_input1 = ConstInput(constants);
1169     if (!const_input1) return false;
1170     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1171     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1172       return false;
1173 
1174     if (other_inst->opcode() == spv::Op::OpFSub ||
1175         other_inst->opcode() == spv::Op::OpISub) {
1176       std::vector<const analysis::Constant*> other_constants =
1177           const_mgr->GetOperandConstants(other_inst);
1178       const analysis::Constant* const_input2 = ConstInput(other_constants);
1179       if (!const_input2) return false;
1180 
1181       bool first_is_variable = other_constants[0] == nullptr;
1182       spv::Op op = inst->opcode();
1183       uint32_t op1 = 0;
1184       uint32_t op2 = 0;
1185       if (first_is_variable) {
1186         // Subtract constants. Non-constant operand is first.
1187         op1 = other_inst->GetSingleWordInOperand(0u);
1188         op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1189                                const_input2);
1190       } else {
1191         // Add constants. Constant operand is first. Change the opcode.
1192         op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1193                                const_input2);
1194         op2 = other_inst->GetSingleWordInOperand(1u);
1195         op = other_inst->opcode();
1196       }
1197       if (op1 == 0 || op2 == 0) return false;
1198 
1199       inst->SetOpcode(op);
1200       inst->SetInOperands(
1201           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1202       return true;
1203     }
1204     return false;
1205   };
1206 }
1207 
1208 // Folds subtraction of an addition where each operand has a constant operand.
1209 // Cases:
1210 // (x + 2) - 2 = x + 0
1211 // (2 + x) - 2 = x + 0
1212 // 2 - (x + 2) = 0 - x
1213 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1214 FoldingRule MergeSubAddArithmetic() {
1215   return [](IRContext* context, Instruction* inst,
1216             const std::vector<const analysis::Constant*>& constants) {
1217     assert(inst->opcode() == spv::Op::OpFSub ||
1218            inst->opcode() == spv::Op::OpISub);
1219     const analysis::Type* type =
1220         context->get_type_mgr()->GetType(inst->type_id());
1221     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1222     bool uses_float = HasFloatingPoint(type);
1223     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1224 
1225     uint32_t width = ElementWidth(type);
1226     if (width != 32 && width != 64) return false;
1227 
1228     const analysis::Constant* const_input1 = ConstInput(constants);
1229     if (!const_input1) return false;
1230     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1231     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1232       return false;
1233 
1234     if (other_inst->opcode() == spv::Op::OpFAdd ||
1235         other_inst->opcode() == spv::Op::OpIAdd) {
1236       std::vector<const analysis::Constant*> other_constants =
1237           const_mgr->GetOperandConstants(other_inst);
1238       const analysis::Constant* const_input2 = ConstInput(other_constants);
1239       if (!const_input2) return false;
1240 
1241       Instruction* non_const_input =
1242           NonConstInput(context, other_constants[0], other_inst);
1243 
1244       // If the first operand of the sub is not a constant, swap the constants
1245       // so the subtraction has the correct operands.
1246       if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1247       // Subtract the constants.
1248       uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1249                                             const_input1, const_input2);
1250       spv::Op op = inst->opcode();
1251       uint32_t op1 = 0;
1252       uint32_t op2 = 0;
1253       if (constants[0] == nullptr) {
1254         // Non-constant operand is first. Change the opcode.
1255         op1 = non_const_input->result_id();
1256         op2 = merged_id;
1257         op = other_inst->opcode();
1258       } else {
1259         // Constant operand is first.
1260         op1 = merged_id;
1261         op2 = non_const_input->result_id();
1262       }
1263       if (op1 == 0 || op2 == 0) return false;
1264 
1265       inst->SetOpcode(op);
1266       inst->SetInOperands(
1267           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1268       return true;
1269     }
1270     return false;
1271   };
1272 }
1273 
1274 // Folds subtraction of a subtraction where each operand has a constant operand.
1275 // Cases:
1276 // (x - 2) - 2 = x - 4
1277 // (2 - x) - 2 = 0 - x
1278 // 2 - (x - 2) = 4 - x
1279 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1280 FoldingRule MergeSubSubArithmetic() {
1281   return [](IRContext* context, Instruction* inst,
1282             const std::vector<const analysis::Constant*>& constants) {
1283     assert(inst->opcode() == spv::Op::OpFSub ||
1284            inst->opcode() == spv::Op::OpISub);
1285     const analysis::Type* type =
1286         context->get_type_mgr()->GetType(inst->type_id());
1287     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1288     bool uses_float = HasFloatingPoint(type);
1289     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1290 
1291     uint32_t width = ElementWidth(type);
1292     if (width != 32 && width != 64) return false;
1293 
1294     const analysis::Constant* const_input1 = ConstInput(constants);
1295     if (!const_input1) return false;
1296     Instruction* other_inst = NonConstInput(context, constants[0], inst);
1297     if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1298       return false;
1299 
1300     if (other_inst->opcode() == spv::Op::OpFSub ||
1301         other_inst->opcode() == spv::Op::OpISub) {
1302       std::vector<const analysis::Constant*> other_constants =
1303           const_mgr->GetOperandConstants(other_inst);
1304       const analysis::Constant* const_input2 = ConstInput(other_constants);
1305       if (!const_input2) return false;
1306 
1307       Instruction* non_const_input =
1308           NonConstInput(context, other_constants[0], other_inst);
1309 
1310       // Merge the constants.
1311       uint32_t merged_id = 0;
1312       spv::Op merge_op = inst->opcode();
1313       if (other_constants[0] == nullptr) {
1314         merge_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1315       } else if (constants[0] == nullptr) {
1316         std::swap(const_input1, const_input2);
1317       }
1318       merged_id =
1319           PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1320       if (merged_id == 0) return false;
1321 
1322       spv::Op op = inst->opcode();
1323       if (constants[0] != nullptr && other_constants[0] != nullptr) {
1324         // Change the operation.
1325         op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1326       }
1327 
1328       uint32_t op1 = 0;
1329       uint32_t op2 = 0;
1330       if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1331         op1 = merged_id;
1332         op2 = non_const_input->result_id();
1333       } else {
1334         op1 = non_const_input->result_id();
1335         op2 = merged_id;
1336       }
1337 
1338       inst->SetOpcode(op);
1339       inst->SetInOperands(
1340           {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1341       return true;
1342     }
1343     return false;
1344   };
1345 }
1346 
1347 // Helper function for MergeGenericAddSubArithmetic. If |addend| and
1348 // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
MergeGenericAddendSub(uint32_t addend,uint32_t sub,Instruction * inst)1349 bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1350   IRContext* context = inst->context();
1351   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1352   Instruction* sub_inst = def_use_mgr->GetDef(sub);
1353   if (sub_inst->opcode() != spv::Op::OpFSub &&
1354       sub_inst->opcode() != spv::Op::OpISub)
1355     return false;
1356   if (sub_inst->opcode() == spv::Op::OpFSub &&
1357       !sub_inst->IsFloatingPointFoldingAllowed())
1358     return false;
1359   if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1360   inst->SetOpcode(spv::Op::OpCopyObject);
1361   inst->SetInOperands(
1362       {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1363   context->UpdateDefUse(inst);
1364   return true;
1365 }
1366 
1367 // Folds addition of a subtraction where the subtrahend is equal to the
1368 // other addend. Return a copy of the minuend. Accepts generic (const and
1369 // non-const) operands.
1370 // Cases:
1371 // (a - b) + b = a
1372 // b + (a - b) = a
MergeGenericAddSubArithmetic()1373 FoldingRule MergeGenericAddSubArithmetic() {
1374   return [](IRContext* context, Instruction* inst,
1375             const std::vector<const analysis::Constant*>&) {
1376     assert(inst->opcode() == spv::Op::OpFAdd ||
1377            inst->opcode() == spv::Op::OpIAdd);
1378     const analysis::Type* type =
1379         context->get_type_mgr()->GetType(inst->type_id());
1380     bool uses_float = HasFloatingPoint(type);
1381     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1382 
1383     uint32_t width = ElementWidth(type);
1384     if (width != 32 && width != 64) return false;
1385 
1386     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1387     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1388     if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1389     return MergeGenericAddendSub(add_op1, add_op0, inst);
1390   };
1391 }
1392 
1393 // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1394 // generate |factor0_0| * (|factor0_1| + |factor1_1|).
FactorAddMulsOpnds(uint32_t factor0_0,uint32_t factor0_1,uint32_t factor1_0,uint32_t factor1_1,Instruction * inst)1395 bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1396                         uint32_t factor1_0, uint32_t factor1_1,
1397                         Instruction* inst) {
1398   IRContext* context = inst->context();
1399   if (factor0_0 != factor1_0) return false;
1400   InstructionBuilder ir_builder(
1401       context, inst,
1402       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1403   Instruction* new_add_inst = ir_builder.AddBinaryOp(
1404       inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1405   inst->SetOpcode(inst->opcode() == spv::Op::OpFAdd ? spv::Op::OpFMul
1406                                                     : spv::Op::OpIMul);
1407   inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1408                        {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1409   context->UpdateDefUse(inst);
1410   return true;
1411 }
1412 
1413 // Perform the following factoring identity, handling all operand order
1414 // combinations: (a * b) + (a * c) = a * (b + c)
FactorAddMuls()1415 FoldingRule FactorAddMuls() {
1416   return [](IRContext* context, Instruction* inst,
1417             const std::vector<const analysis::Constant*>&) {
1418     assert(inst->opcode() == spv::Op::OpFAdd ||
1419            inst->opcode() == spv::Op::OpIAdd);
1420     const analysis::Type* type =
1421         context->get_type_mgr()->GetType(inst->type_id());
1422     bool uses_float = HasFloatingPoint(type);
1423     if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1424 
1425     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1426     uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1427     Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1428     if (add_op0_inst->opcode() != spv::Op::OpFMul &&
1429         add_op0_inst->opcode() != spv::Op::OpIMul)
1430       return false;
1431     uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1432     Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1433     if (add_op1_inst->opcode() != spv::Op::OpFMul &&
1434         add_op1_inst->opcode() != spv::Op::OpIMul)
1435       return false;
1436 
1437     // Only perform this optimization if both of the muls only have one use.
1438     // Otherwise this is a deoptimization in size and performance.
1439     if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1440     if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1441 
1442     if (add_op0_inst->opcode() == spv::Op::OpFMul &&
1443         (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1444          !add_op1_inst->IsFloatingPointFoldingAllowed()))
1445       return false;
1446 
1447     for (int i = 0; i < 2; i++) {
1448       for (int j = 0; j < 2; j++) {
1449         // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1450         if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1451                                add_op0_inst->GetSingleWordInOperand(1 - i),
1452                                add_op1_inst->GetSingleWordInOperand(j),
1453                                add_op1_inst->GetSingleWordInOperand(1 - j),
1454                                inst))
1455           return true;
1456       }
1457     }
1458     return false;
1459   };
1460 }
1461 
1462 // Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
ReplaceWithFma(Instruction * inst,uint32_t x,uint32_t y,uint32_t a)1463 void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
1464   uint32_t ext =
1465       inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1466 
1467   if (ext == 0) {
1468     inst->context()->AddExtInstImport("GLSL.std.450");
1469     ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1470     assert(ext != 0 &&
1471            "Could not add the GLSL.std.450 extended instruction set");
1472   }
1473 
1474   std::vector<Operand> operands;
1475   operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
1476   operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
1477   operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
1478   operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
1479   operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
1480 
1481   inst->SetOpcode(spv::Op::OpExtInst);
1482   inst->SetInOperands(std::move(operands));
1483 }
1484 
1485 // Folds a multiple and add into an Fma.
1486 //
1487 // Cases:
1488 // (x * y) + a = Fma x y a
1489 // a + (x * y) = Fma x y a
MergeMulAddArithmetic(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1490 bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
1491                            const std::vector<const analysis::Constant*>&) {
1492   assert(inst->opcode() == spv::Op::OpFAdd);
1493 
1494   if (!inst->IsFloatingPointFoldingAllowed()) {
1495     return false;
1496   }
1497 
1498   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1499   for (int i = 0; i < 2; i++) {
1500     uint32_t op_id = inst->GetSingleWordInOperand(i);
1501     Instruction* op_inst = def_use_mgr->GetDef(op_id);
1502 
1503     if (op_inst->opcode() != spv::Op::OpFMul) {
1504       continue;
1505     }
1506 
1507     if (!op_inst->IsFloatingPointFoldingAllowed()) {
1508       continue;
1509     }
1510 
1511     uint32_t x = op_inst->GetSingleWordInOperand(0);
1512     uint32_t y = op_inst->GetSingleWordInOperand(1);
1513     uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
1514     ReplaceWithFma(inst, x, y, a);
1515     return true;
1516   }
1517   return false;
1518 }
1519 
1520 // Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
1521 // negated if |negate_addition| is true, otherwise |x| gets negated.
ReplaceWithFmaAndNegate(Instruction * sub,uint32_t x,uint32_t y,uint32_t a,bool negate_addition)1522 void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
1523                              uint32_t a, bool negate_addition) {
1524   uint32_t ext =
1525       sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1526 
1527   if (ext == 0) {
1528     sub->context()->AddExtInstImport("GLSL.std.450");
1529     ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1530     assert(ext != 0 &&
1531            "Could not add the GLSL.std.450 extended instruction set");
1532   }
1533 
1534   InstructionBuilder ir_builder(
1535       sub->context(), sub,
1536       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1537 
1538   Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), spv::Op::OpFNegate,
1539                                            negate_addition ? a : x);
1540   uint32_t neg_op = neg->result_id();  // -a : -x
1541 
1542   std::vector<Operand> operands;
1543   operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
1544   operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
1545   operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
1546   operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
1547   operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
1548 
1549   sub->SetOpcode(spv::Op::OpExtInst);
1550   sub->SetInOperands(std::move(operands));
1551 }
1552 
1553 // Folds a multiply and subtract into an Fma and negation.
1554 //
1555 // Cases:
1556 // (x * y) - a = Fma x y -a
1557 // a - (x * y) = Fma -x y a
MergeMulSubArithmetic(IRContext * context,Instruction * sub,const std::vector<const analysis::Constant * > &)1558 bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
1559                            const std::vector<const analysis::Constant*>&) {
1560   assert(sub->opcode() == spv::Op::OpFSub);
1561 
1562   if (!sub->IsFloatingPointFoldingAllowed()) {
1563     return false;
1564   }
1565 
1566   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1567   for (int i = 0; i < 2; i++) {
1568     uint32_t op_id = sub->GetSingleWordInOperand(i);
1569     Instruction* mul = def_use_mgr->GetDef(op_id);
1570 
1571     if (mul->opcode() != spv::Op::OpFMul) {
1572       continue;
1573     }
1574 
1575     if (!mul->IsFloatingPointFoldingAllowed()) {
1576       continue;
1577     }
1578 
1579     uint32_t x = mul->GetSingleWordInOperand(0);
1580     uint32_t y = mul->GetSingleWordInOperand(1);
1581     uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
1582     ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
1583     return true;
1584   }
1585   return false;
1586 }
1587 
IntMultipleBy1()1588 FoldingRule IntMultipleBy1() {
1589   return [](IRContext*, Instruction* inst,
1590             const std::vector<const analysis::Constant*>& constants) {
1591     assert(inst->opcode() == spv::Op::OpIMul &&
1592            "Wrong opcode.  Should be OpIMul.");
1593     for (uint32_t i = 0; i < 2; i++) {
1594       if (constants[i] == nullptr) {
1595         continue;
1596       }
1597       const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1598       if (int_constant) {
1599         uint32_t width = ElementWidth(int_constant->type());
1600         if (width != 32 && width != 64) return false;
1601         bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1602                                     : int_constant->GetU64BitValue() == 1ull;
1603         if (is_one) {
1604           inst->SetOpcode(spv::Op::OpCopyObject);
1605           inst->SetInOperands(
1606               {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1607           return true;
1608         }
1609       }
1610     }
1611     return false;
1612   };
1613 }
1614 
1615 // Returns the number of elements that the |index|th in operand in |inst|
1616 // contributes to the result of |inst|.  |inst| must be an
1617 // OpCompositeConstructInstruction.
GetNumOfElementsContributedByOperand(IRContext * context,const Instruction * inst,uint32_t index)1618 uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
1619                                               const Instruction* inst,
1620                                               uint32_t index) {
1621   assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1622   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1623   analysis::TypeManager* type_mgr = context->get_type_mgr();
1624 
1625   analysis::Vector* result_type =
1626       type_mgr->GetType(inst->type_id())->AsVector();
1627   if (result_type == nullptr) {
1628     // If the result of the OpCompositeConstruct is not a vector then every
1629     // operands corresponds to a single element in the result.
1630     return 1;
1631   }
1632 
1633   // If the result type is a vector then the operands are either scalars or
1634   // vectors. If it is a scalar, then it corresponds to a single element.  If it
1635   // is a vector, then each element in the vector will be an element in the
1636   // result.
1637   uint32_t id = inst->GetSingleWordInOperand(index);
1638   Instruction* def = def_use_mgr->GetDef(id);
1639   analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
1640   if (type == nullptr) {
1641     return 1;
1642   }
1643   return type->element_count();
1644 }
1645 
1646 // Returns the in-operands for an OpCompositeExtract instruction that are needed
1647 // to extract the |result_index|th element in the result of |inst| without using
1648 // the result of |inst|. Returns the empty vector if |result_index| is
1649 // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
GetExtractOperandsForElementOfCompositeConstruct(IRContext * context,const Instruction * inst,uint32_t result_index)1650 std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
1651     IRContext* context, const Instruction* inst, uint32_t result_index) {
1652   assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1653   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1654   analysis::TypeManager* type_mgr = context->get_type_mgr();
1655 
1656   analysis::Type* result_type = type_mgr->GetType(inst->type_id());
1657   if (result_type->AsVector() == nullptr) {
1658     if (result_index < inst->NumInOperands()) {
1659       uint32_t id = inst->GetSingleWordInOperand(result_index);
1660       return {Operand(SPV_OPERAND_TYPE_ID, {id})};
1661     }
1662     return {};
1663   }
1664 
1665   // If the result type is a vector, then vector operands are concatenated.
1666   uint32_t total_element_count = 0;
1667   for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
1668     uint32_t element_count =
1669         GetNumOfElementsContributedByOperand(context, inst, idx);
1670     total_element_count += element_count;
1671     if (result_index < total_element_count) {
1672       std::vector<Operand> operands;
1673       uint32_t id = inst->GetSingleWordInOperand(idx);
1674       Instruction* operand_def = def_use_mgr->GetDef(id);
1675       analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
1676 
1677       operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1678       if (operand_type->AsVector()) {
1679         uint32_t start_index_of_id = total_element_count - element_count;
1680         uint32_t index_into_id = result_index - start_index_of_id;
1681         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
1682       }
1683       return operands;
1684     }
1685   }
1686   return {};
1687 }
1688 
CompositeConstructFeedingExtract(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1689 bool CompositeConstructFeedingExtract(
1690     IRContext* context, Instruction* inst,
1691     const std::vector<const analysis::Constant*>&) {
1692   // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1693   // then we can simply use the appropriate element in the construction.
1694   assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1695          "Wrong opcode.  Should be OpCompositeExtract.");
1696   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1697 
1698   // If there are no index operands, then this rule cannot do anything.
1699   if (inst->NumInOperands() <= 1) {
1700     return false;
1701   }
1702 
1703   uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1704   Instruction* cinst = def_use_mgr->GetDef(cid);
1705 
1706   if (cinst->opcode() != spv::Op::OpCompositeConstruct) {
1707     return false;
1708   }
1709 
1710   uint32_t index_into_result = inst->GetSingleWordInOperand(1);
1711   std::vector<Operand> operands =
1712       GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
1713                                                        index_into_result);
1714 
1715   if (operands.empty()) {
1716     return false;
1717   }
1718 
1719   // Add the remaining indices for extraction.
1720   for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1721     operands.push_back(
1722         {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
1723   }
1724 
1725   if (operands.size() == 1) {
1726     // If there were no extra indices, then we have the final object.  No need
1727     // to extract any more.
1728     inst->SetOpcode(spv::Op::OpCopyObject);
1729   }
1730 
1731   inst->SetInOperands(std::move(operands));
1732   return true;
1733 }
1734 
1735 // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
1736 // OpCompositeExtract instruction, and returns the type of the final element
1737 // being accessed.
GetElementType(uint32_t type_id,Instruction::iterator start,Instruction::iterator end,const analysis::TypeManager * type_mgr)1738 const analysis::Type* GetElementType(uint32_t type_id,
1739                                      Instruction::iterator start,
1740                                      Instruction::iterator end,
1741                                      const analysis::TypeManager* type_mgr) {
1742   const analysis::Type* type = type_mgr->GetType(type_id);
1743   for (auto index : make_range(std::move(start), std::move(end))) {
1744     assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
1745            index.words.size() == 1);
1746     if (auto* array_type = type->AsArray()) {
1747       type = array_type->element_type();
1748     } else if (auto* matrix_type = type->AsMatrix()) {
1749       type = matrix_type->element_type();
1750     } else if (auto* struct_type = type->AsStruct()) {
1751       type = struct_type->element_types()[index.words[0]];
1752     } else {
1753       type = nullptr;
1754     }
1755   }
1756   return type;
1757 }
1758 
1759 // Returns true of |inst_1| and |inst_2| have the same indexes that will be used
1760 // to index into a composite object, excluding the last index.  The two
1761 // instructions must have the same opcode, and be either OpCompositeExtract or
1762 // OpCompositeInsert instructions.
HaveSameIndexesExceptForLast(Instruction * inst_1,Instruction * inst_2)1763 bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
1764   assert(inst_1->opcode() == inst_2->opcode() &&
1765          "Expecting the opcodes to be the same.");
1766   assert((inst_1->opcode() == spv::Op::OpCompositeInsert ||
1767           inst_1->opcode() == spv::Op::OpCompositeExtract) &&
1768          "Instructions must be OpCompositeInsert or OpCompositeExtract.");
1769 
1770   if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
1771     return false;
1772   }
1773 
1774   uint32_t first_index_position =
1775       (inst_1->opcode() == spv::Op::OpCompositeInsert ? 2 : 1);
1776   for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
1777        i++) {
1778     if (inst_1->GetSingleWordInOperand(i) !=
1779         inst_2->GetSingleWordInOperand(i)) {
1780       return false;
1781     }
1782   }
1783   return true;
1784 }
1785 
1786 // If the OpCompositeConstruct is simply putting back together elements that
1787 // where extracted from the same source, we can simply reuse the source.
1788 //
1789 // This is a common code pattern because of the way that scalar replacement
1790 // works.
CompositeExtractFeedingConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1791 bool CompositeExtractFeedingConstruct(
1792     IRContext* context, Instruction* inst,
1793     const std::vector<const analysis::Constant*>&) {
1794   assert(inst->opcode() == spv::Op::OpCompositeConstruct &&
1795          "Wrong opcode.  Should be OpCompositeConstruct.");
1796   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1797   uint32_t original_id = 0;
1798 
1799   if (inst->NumInOperands() == 0) {
1800     // The struct being constructed has no members.
1801     return false;
1802   }
1803 
1804   // Check each element to make sure they are:
1805   // - extractions
1806   // - extracting the same position they are inserting
1807   // - all extract from the same id.
1808   Instruction* first_element_inst = nullptr;
1809   for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1810     const uint32_t element_id = inst->GetSingleWordInOperand(i);
1811     Instruction* element_inst = def_use_mgr->GetDef(element_id);
1812     if (first_element_inst == nullptr) {
1813       first_element_inst = element_inst;
1814     }
1815 
1816     if (element_inst->opcode() != spv::Op::OpCompositeExtract) {
1817       return false;
1818     }
1819 
1820     if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
1821       return false;
1822     }
1823 
1824     if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
1825                                              1) != i) {
1826       return false;
1827     }
1828 
1829     if (i == 0) {
1830       original_id =
1831           element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1832     } else if (original_id !=
1833                element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1834       return false;
1835     }
1836   }
1837 
1838   // The last check it to see that the object being extracted from is the
1839   // correct type.
1840   Instruction* original_inst = def_use_mgr->GetDef(original_id);
1841   analysis::TypeManager* type_mgr = context->get_type_mgr();
1842   const analysis::Type* original_type =
1843       GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
1844                      first_element_inst->end() - 1, type_mgr);
1845 
1846   if (original_type == nullptr) {
1847     return false;
1848   }
1849 
1850   if (inst->type_id() != type_mgr->GetId(original_type)) {
1851     return false;
1852   }
1853 
1854   if (first_element_inst->NumInOperands() == 2) {
1855     // Simplify by using the original object.
1856     inst->SetOpcode(spv::Op::OpCopyObject);
1857     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1858     return true;
1859   }
1860 
1861   // Copies the original id and all indexes except for the last to the new
1862   // extract instruction.
1863   inst->SetOpcode(spv::Op::OpCompositeExtract);
1864   inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
1865                                            first_element_inst->end() - 1));
1866   return true;
1867 }
1868 
InsertFeedingExtract()1869 FoldingRule InsertFeedingExtract() {
1870   return [](IRContext* context, Instruction* inst,
1871             const std::vector<const analysis::Constant*>&) {
1872     assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1873            "Wrong opcode.  Should be OpCompositeExtract.");
1874     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1875     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1876     Instruction* cinst = def_use_mgr->GetDef(cid);
1877 
1878     if (cinst->opcode() != spv::Op::OpCompositeInsert) {
1879       return false;
1880     }
1881 
1882     // Find the first position where the list of insert and extract indicies
1883     // differ, if at all.
1884     uint32_t i;
1885     for (i = 1; i < inst->NumInOperands(); ++i) {
1886       if (i + 1 >= cinst->NumInOperands()) {
1887         break;
1888       }
1889 
1890       if (inst->GetSingleWordInOperand(i) !=
1891           cinst->GetSingleWordInOperand(i + 1)) {
1892         break;
1893       }
1894     }
1895 
1896     // We are extracting the element that was inserted.
1897     if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1898       inst->SetOpcode(spv::Op::OpCopyObject);
1899       inst->SetInOperands(
1900           {{SPV_OPERAND_TYPE_ID,
1901             {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1902       return true;
1903     }
1904 
1905     // Extracting the value that was inserted along with values for the base
1906     // composite.  Cannot do anything.
1907     if (i == inst->NumInOperands()) {
1908       return false;
1909     }
1910 
1911     // Extracting an element of the value that was inserted.  Extract from
1912     // that value directly.
1913     if (i + 1 == cinst->NumInOperands()) {
1914       std::vector<Operand> operands;
1915       operands.push_back(
1916           {SPV_OPERAND_TYPE_ID,
1917            {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1918       for (; i < inst->NumInOperands(); ++i) {
1919         operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1920                             {inst->GetSingleWordInOperand(i)}});
1921       }
1922       inst->SetInOperands(std::move(operands));
1923       return true;
1924     }
1925 
1926     // Extracting a value that is disjoint from the element being inserted.
1927     // Rewrite the extract to use the composite input to the insert.
1928     std::vector<Operand> operands;
1929     operands.push_back(
1930         {SPV_OPERAND_TYPE_ID,
1931          {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1932     for (i = 1; i < inst->NumInOperands(); ++i) {
1933       operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1934                           {inst->GetSingleWordInOperand(i)}});
1935     }
1936     inst->SetInOperands(std::move(operands));
1937     return true;
1938   };
1939 }
1940 
1941 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1942 // operands of the VectorShuffle.  We just need to adjust the index in the
1943 // extract instruction.
VectorShuffleFeedingExtract()1944 FoldingRule VectorShuffleFeedingExtract() {
1945   return [](IRContext* context, Instruction* inst,
1946             const std::vector<const analysis::Constant*>&) {
1947     assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1948            "Wrong opcode.  Should be OpCompositeExtract.");
1949     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1950     analysis::TypeManager* type_mgr = context->get_type_mgr();
1951     uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1952     Instruction* cinst = def_use_mgr->GetDef(cid);
1953 
1954     if (cinst->opcode() != spv::Op::OpVectorShuffle) {
1955       return false;
1956     }
1957 
1958     // Find the size of the first vector operand of the VectorShuffle
1959     Instruction* first_input =
1960         def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1961     analysis::Type* first_input_type =
1962         type_mgr->GetType(first_input->type_id());
1963     assert(first_input_type->AsVector() &&
1964            "Input to vector shuffle should be vectors.");
1965     uint32_t first_input_size = first_input_type->AsVector()->element_count();
1966 
1967     // Get index of the element the vector shuffle is placing in the position
1968     // being extracted.
1969     uint32_t new_index =
1970         cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1971 
1972     // Extracting an undefined value so fold this extract into an undef.
1973     const uint32_t undef_literal_value = 0xffffffff;
1974     if (new_index == undef_literal_value) {
1975       inst->SetOpcode(spv::Op::OpUndef);
1976       inst->SetInOperands({});
1977       return true;
1978     }
1979 
1980     // Get the id of the of the vector the elemtent comes from, and update the
1981     // index if needed.
1982     uint32_t new_vector = 0;
1983     if (new_index < first_input_size) {
1984       new_vector = cinst->GetSingleWordInOperand(0);
1985     } else {
1986       new_vector = cinst->GetSingleWordInOperand(1);
1987       new_index -= first_input_size;
1988     }
1989 
1990     // Update the extract instruction.
1991     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1992     inst->SetInOperand(1, {new_index});
1993     return true;
1994   };
1995 }
1996 
1997 // When an FMix with is feeding an Extract that extracts an element whose
1998 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1999 // operands of the FMix.
FMixFeedingExtract()2000 FoldingRule FMixFeedingExtract() {
2001   return [](IRContext* context, Instruction* inst,
2002             const std::vector<const analysis::Constant*>&) {
2003     assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2004            "Wrong opcode.  Should be OpCompositeExtract.");
2005     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2006     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2007 
2008     uint32_t composite_id =
2009         inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2010     Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
2011 
2012     if (composite_inst->opcode() != spv::Op::OpExtInst) {
2013       return false;
2014     }
2015 
2016     uint32_t inst_set_id =
2017         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2018 
2019     if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
2020             inst_set_id ||
2021         composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
2022             GLSLstd450FMix) {
2023       return false;
2024     }
2025 
2026     // Get the |a| for the FMix instruction.
2027     uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
2028     std::unique_ptr<Instruction> a(inst->Clone(context));
2029     a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
2030     context->get_instruction_folder().FoldInstruction(a.get());
2031 
2032     if (a->opcode() != spv::Op::OpCopyObject) {
2033       return false;
2034     }
2035 
2036     const analysis::Constant* a_const =
2037         const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
2038 
2039     if (!a_const) {
2040       return false;
2041     }
2042 
2043     bool use_x = false;
2044 
2045     assert(a_const->type()->AsFloat());
2046     double element_value = a_const->GetValueAsDouble();
2047     if (element_value == 0.0) {
2048       use_x = true;
2049     } else if (element_value == 1.0) {
2050       use_x = false;
2051     } else {
2052       return false;
2053     }
2054 
2055     // Get the id of the of the vector the element comes from.
2056     uint32_t new_vector = 0;
2057     if (use_x) {
2058       new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
2059     } else {
2060       new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
2061     }
2062 
2063     // Update the extract instruction.
2064     inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
2065     return true;
2066   };
2067 }
2068 
2069 // Returns the number of elements in the composite type |type|.  Returns 0 if
2070 // |type| is a scalar value. Return UINT32_MAX when the size is unknown at
2071 // compile time.
GetNumberOfElements(const analysis::Type * type)2072 uint32_t GetNumberOfElements(const analysis::Type* type) {
2073   if (auto* vector_type = type->AsVector()) {
2074     return vector_type->element_count();
2075   }
2076   if (auto* matrix_type = type->AsMatrix()) {
2077     return matrix_type->element_count();
2078   }
2079   if (auto* struct_type = type->AsStruct()) {
2080     return static_cast<uint32_t>(struct_type->element_types().size());
2081   }
2082   if (auto* array_type = type->AsArray()) {
2083     if (array_type->length_info().words[0] ==
2084             analysis::Array::LengthInfo::kConstant &&
2085         array_type->length_info().words.size() == 2) {
2086       return array_type->length_info().words[1];
2087     }
2088     return UINT32_MAX;
2089   }
2090   return 0;
2091 }
2092 
2093 // Returns a map with the set of values that were inserted into an object by
2094 // the chain of OpCompositeInsertInstruction starting with |inst|.
2095 // The map will map the index to the value inserted at that index. An empty map
2096 // will be returned if the map could not be properly generated.
GetInsertedValues(Instruction * inst)2097 std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
2098   analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
2099   std::map<uint32_t, uint32_t> values_inserted;
2100   Instruction* current_inst = inst;
2101   while (current_inst->opcode() == spv::Op::OpCompositeInsert) {
2102     if (current_inst->NumInOperands() > inst->NumInOperands()) {
2103       // This is to catch the case
2104       //   %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
2105       //   %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
2106       //   %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
2107       // In this case we cannot do a single construct to get the matrix.
2108       uint32_t partially_inserted_element_index =
2109           current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
2110       if (values_inserted.count(partially_inserted_element_index) == 0)
2111         return {};
2112     }
2113     if (HaveSameIndexesExceptForLast(inst, current_inst)) {
2114       values_inserted.insert(
2115           {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
2116                                                 1),
2117            current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
2118     }
2119     current_inst = def_use_mgr->GetDef(
2120         current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
2121   }
2122   return values_inserted;
2123 }
2124 
2125 // Returns true of there is an entry in |values_inserted| for every element of
2126 // |Type|.
DoInsertedValuesCoverEntireObject(const analysis::Type * type,std::map<uint32_t,uint32_t> & values_inserted)2127 bool DoInsertedValuesCoverEntireObject(
2128     const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
2129   uint32_t container_size = GetNumberOfElements(type);
2130   if (container_size != values_inserted.size()) {
2131     return false;
2132   }
2133 
2134   if (values_inserted.rbegin()->first >= container_size) {
2135     return false;
2136   }
2137   return true;
2138 }
2139 
2140 // Returns the type of the element that immediately contains the element being
2141 // inserted by the OpCompositeInsert instruction |inst|.
GetContainerType(Instruction * inst)2142 const analysis::Type* GetContainerType(Instruction* inst) {
2143   assert(inst->opcode() == spv::Op::OpCompositeInsert);
2144   analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
2145   return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
2146                         type_mgr);
2147 }
2148 
2149 // Returns an OpCompositeConstruct instruction that build an object with
2150 // |type_id| out of the values in |values_inserted|.  Each value will be
2151 // placed at the index corresponding to the value.  The new instruction will
2152 // be placed before |insert_before|.
BuildCompositeConstruct(uint32_t type_id,const std::map<uint32_t,uint32_t> & values_inserted,Instruction * insert_before)2153 Instruction* BuildCompositeConstruct(
2154     uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
2155     Instruction* insert_before) {
2156   InstructionBuilder ir_builder(
2157       insert_before->context(), insert_before,
2158       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2159 
2160   std::vector<uint32_t> ids_in_order;
2161   for (auto it : values_inserted) {
2162     ids_in_order.push_back(it.second);
2163   }
2164   Instruction* construct =
2165       ir_builder.AddCompositeConstruct(type_id, ids_in_order);
2166   return construct;
2167 }
2168 
2169 // Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
2170 // object as |inst| with final index removed.  If the resulting
2171 // OpCompositeInsert instruction would have no remaining indexes, the
2172 // instruction is replaced with an OpCopyObject instead.
InsertConstructedObject(Instruction * inst,const Instruction * construct)2173 void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
2174   if (inst->NumInOperands() == 3) {
2175     inst->SetOpcode(spv::Op::OpCopyObject);
2176     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
2177   } else {
2178     inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
2179     inst->RemoveOperand(inst->NumOperands() - 1);
2180   }
2181 }
2182 
2183 // Replaces a series of |OpCompositeInsert| instruction that cover the entire
2184 // object with an |OpCompositeConstruct|.
CompositeInsertToCompositeConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)2185 bool CompositeInsertToCompositeConstruct(
2186     IRContext* context, Instruction* inst,
2187     const std::vector<const analysis::Constant*>&) {
2188   assert(inst->opcode() == spv::Op::OpCompositeInsert &&
2189          "Wrong opcode.  Should be OpCompositeInsert.");
2190   if (inst->NumInOperands() < 3) return false;
2191 
2192   std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
2193   const analysis::Type* container_type = GetContainerType(inst);
2194   if (container_type == nullptr) {
2195     return false;
2196   }
2197 
2198   if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
2199     return false;
2200   }
2201 
2202   analysis::TypeManager* type_mgr = context->get_type_mgr();
2203   Instruction* construct = BuildCompositeConstruct(
2204       type_mgr->GetId(container_type), values_inserted, inst);
2205   InsertConstructedObject(inst, construct);
2206   return true;
2207 }
2208 
RedundantPhi()2209 FoldingRule RedundantPhi() {
2210   // An OpPhi instruction where all values are the same or the result of the phi
2211   // itself, can be replaced by the value itself.
2212   return [](IRContext*, Instruction* inst,
2213             const std::vector<const analysis::Constant*>&) {
2214     assert(inst->opcode() == spv::Op::OpPhi &&
2215            "Wrong opcode.  Should be OpPhi.");
2216 
2217     uint32_t incoming_value = 0;
2218 
2219     for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
2220       uint32_t op_id = inst->GetSingleWordInOperand(i);
2221       if (op_id == inst->result_id()) {
2222         continue;
2223       }
2224 
2225       if (incoming_value == 0) {
2226         incoming_value = op_id;
2227       } else if (op_id != incoming_value) {
2228         // Found two possible value.  Can't simplify.
2229         return false;
2230       }
2231     }
2232 
2233     if (incoming_value == 0) {
2234       // Code looks invalid.  Don't do anything.
2235       return false;
2236     }
2237 
2238     // We have a single incoming value.  Simplify using that value.
2239     inst->SetOpcode(spv::Op::OpCopyObject);
2240     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
2241     return true;
2242   };
2243 }
2244 
BitCastScalarOrVector()2245 FoldingRule BitCastScalarOrVector() {
2246   return [](IRContext* context, Instruction* inst,
2247             const std::vector<const analysis::Constant*>& constants) {
2248     assert(inst->opcode() == spv::Op::OpBitcast && constants.size() == 1);
2249     if (constants[0] == nullptr) return false;
2250 
2251     const analysis::Type* type =
2252         context->get_type_mgr()->GetType(inst->type_id());
2253     if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
2254       return false;
2255 
2256     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2257     std::vector<uint32_t> words =
2258         GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
2259     if (words.size() == 0) return false;
2260 
2261     const analysis::Constant* bitcasted_constant =
2262         ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
2263     if (!bitcasted_constant) return false;
2264 
2265     auto new_feeder_id =
2266         const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
2267             ->result_id();
2268     inst->SetOpcode(spv::Op::OpCopyObject);
2269     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2270     return true;
2271   };
2272 }
2273 
RedundantSelect()2274 FoldingRule RedundantSelect() {
2275   // An OpSelect instruction where both values are the same or the condition is
2276   // constant can be replaced by one of the values
2277   return [](IRContext*, Instruction* inst,
2278             const std::vector<const analysis::Constant*>& constants) {
2279     assert(inst->opcode() == spv::Op::OpSelect &&
2280            "Wrong opcode.  Should be OpSelect.");
2281     assert(inst->NumInOperands() == 3);
2282     assert(constants.size() == 3);
2283 
2284     uint32_t true_id = inst->GetSingleWordInOperand(1);
2285     uint32_t false_id = inst->GetSingleWordInOperand(2);
2286 
2287     if (true_id == false_id) {
2288       // Both results are the same, condition doesn't matter
2289       inst->SetOpcode(spv::Op::OpCopyObject);
2290       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2291       return true;
2292     } else if (constants[0]) {
2293       const analysis::Type* type = constants[0]->type();
2294       if (type->AsBool()) {
2295         // Scalar constant value, select the corresponding value.
2296         inst->SetOpcode(spv::Op::OpCopyObject);
2297         if (constants[0]->AsNullConstant() ||
2298             !constants[0]->AsBoolConstant()->value()) {
2299           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2300         } else {
2301           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2302         }
2303         return true;
2304       } else {
2305         assert(type->AsVector());
2306         if (constants[0]->AsNullConstant()) {
2307           // All values come from false id.
2308           inst->SetOpcode(spv::Op::OpCopyObject);
2309           inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2310           return true;
2311         } else {
2312           // Convert to a vector shuffle.
2313           std::vector<Operand> ops;
2314           ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
2315           ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
2316           const analysis::VectorConstant* vector_const =
2317               constants[0]->AsVectorConstant();
2318           uint32_t size =
2319               static_cast<uint32_t>(vector_const->GetComponents().size());
2320           for (uint32_t i = 0; i != size; ++i) {
2321             const analysis::Constant* component =
2322                 vector_const->GetComponents()[i];
2323             if (component->AsNullConstant() ||
2324                 !component->AsBoolConstant()->value()) {
2325               // Selecting from the false vector which is the second input
2326               // vector to the shuffle. Offset the index by |size|.
2327               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
2328             } else {
2329               // Selecting from true vector which is the first input vector to
2330               // the shuffle.
2331               ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
2332             }
2333           }
2334 
2335           inst->SetOpcode(spv::Op::OpVectorShuffle);
2336           inst->SetInOperands(std::move(ops));
2337           return true;
2338         }
2339       }
2340     }
2341 
2342     return false;
2343   };
2344 }
2345 
2346 enum class FloatConstantKind { Unknown, Zero, One };
2347 
getFloatConstantKind(const analysis::Constant * constant)2348 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
2349   if (constant == nullptr) {
2350     return FloatConstantKind::Unknown;
2351   }
2352 
2353   assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
2354 
2355   if (constant->AsNullConstant()) {
2356     return FloatConstantKind::Zero;
2357   } else if (const analysis::VectorConstant* vc =
2358                  constant->AsVectorConstant()) {
2359     const std::vector<const analysis::Constant*>& components =
2360         vc->GetComponents();
2361     assert(!components.empty());
2362 
2363     FloatConstantKind kind = getFloatConstantKind(components[0]);
2364 
2365     for (size_t i = 1; i < components.size(); ++i) {
2366       if (getFloatConstantKind(components[i]) != kind) {
2367         return FloatConstantKind::Unknown;
2368       }
2369     }
2370 
2371     return kind;
2372   } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
2373     if (fc->IsZero()) return FloatConstantKind::Zero;
2374 
2375     uint32_t width = fc->type()->AsFloat()->width();
2376     if (width != 32 && width != 64) return FloatConstantKind::Unknown;
2377 
2378     double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
2379 
2380     if (value == 0.0) {
2381       return FloatConstantKind::Zero;
2382     } else if (value == 1.0) {
2383       return FloatConstantKind::One;
2384     } else {
2385       return FloatConstantKind::Unknown;
2386     }
2387   } else {
2388     return FloatConstantKind::Unknown;
2389   }
2390 }
2391 
RedundantFAdd()2392 FoldingRule RedundantFAdd() {
2393   return [](IRContext*, Instruction* inst,
2394             const std::vector<const analysis::Constant*>& constants) {
2395     assert(inst->opcode() == spv::Op::OpFAdd &&
2396            "Wrong opcode.  Should be OpFAdd.");
2397     assert(constants.size() == 2);
2398 
2399     if (!inst->IsFloatingPointFoldingAllowed()) {
2400       return false;
2401     }
2402 
2403     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2404     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2405 
2406     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2407       inst->SetOpcode(spv::Op::OpCopyObject);
2408       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2409                             {inst->GetSingleWordInOperand(
2410                                 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
2411       return true;
2412     }
2413 
2414     return false;
2415   };
2416 }
2417 
RedundantFSub()2418 FoldingRule RedundantFSub() {
2419   return [](IRContext*, Instruction* inst,
2420             const std::vector<const analysis::Constant*>& constants) {
2421     assert(inst->opcode() == spv::Op::OpFSub &&
2422            "Wrong opcode.  Should be OpFSub.");
2423     assert(constants.size() == 2);
2424 
2425     if (!inst->IsFloatingPointFoldingAllowed()) {
2426       return false;
2427     }
2428 
2429     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2430     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2431 
2432     if (kind0 == FloatConstantKind::Zero) {
2433       inst->SetOpcode(spv::Op::OpFNegate);
2434       inst->SetInOperands(
2435           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
2436       return true;
2437     }
2438 
2439     if (kind1 == FloatConstantKind::Zero) {
2440       inst->SetOpcode(spv::Op::OpCopyObject);
2441       inst->SetInOperands(
2442           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2443       return true;
2444     }
2445 
2446     return false;
2447   };
2448 }
2449 
RedundantFMul()2450 FoldingRule RedundantFMul() {
2451   return [](IRContext*, Instruction* inst,
2452             const std::vector<const analysis::Constant*>& constants) {
2453     assert(inst->opcode() == spv::Op::OpFMul &&
2454            "Wrong opcode.  Should be OpFMul.");
2455     assert(constants.size() == 2);
2456 
2457     if (!inst->IsFloatingPointFoldingAllowed()) {
2458       return false;
2459     }
2460 
2461     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2462     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2463 
2464     if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2465       inst->SetOpcode(spv::Op::OpCopyObject);
2466       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2467                             {inst->GetSingleWordInOperand(
2468                                 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
2469       return true;
2470     }
2471 
2472     if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
2473       inst->SetOpcode(spv::Op::OpCopyObject);
2474       inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2475                             {inst->GetSingleWordInOperand(
2476                                 kind0 == FloatConstantKind::One ? 1 : 0)}}});
2477       return true;
2478     }
2479 
2480     return false;
2481   };
2482 }
2483 
RedundantFDiv()2484 FoldingRule RedundantFDiv() {
2485   return [](IRContext*, Instruction* inst,
2486             const std::vector<const analysis::Constant*>& constants) {
2487     assert(inst->opcode() == spv::Op::OpFDiv &&
2488            "Wrong opcode.  Should be OpFDiv.");
2489     assert(constants.size() == 2);
2490 
2491     if (!inst->IsFloatingPointFoldingAllowed()) {
2492       return false;
2493     }
2494 
2495     FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2496     FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2497 
2498     if (kind0 == FloatConstantKind::Zero) {
2499       inst->SetOpcode(spv::Op::OpCopyObject);
2500       inst->SetInOperands(
2501           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2502       return true;
2503     }
2504 
2505     if (kind1 == FloatConstantKind::One) {
2506       inst->SetOpcode(spv::Op::OpCopyObject);
2507       inst->SetInOperands(
2508           {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2509       return true;
2510     }
2511 
2512     return false;
2513   };
2514 }
2515 
RedundantFMix()2516 FoldingRule RedundantFMix() {
2517   return [](IRContext* context, Instruction* inst,
2518             const std::vector<const analysis::Constant*>& constants) {
2519     assert(inst->opcode() == spv::Op::OpExtInst &&
2520            "Wrong opcode.  Should be OpExtInst.");
2521 
2522     if (!inst->IsFloatingPointFoldingAllowed()) {
2523       return false;
2524     }
2525 
2526     uint32_t instSetId =
2527         context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2528 
2529     if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2530         inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2531             GLSLstd450FMix) {
2532       assert(constants.size() == 5);
2533 
2534       FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2535 
2536       if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2537         inst->SetOpcode(spv::Op::OpCopyObject);
2538         inst->SetInOperands(
2539             {{SPV_OPERAND_TYPE_ID,
2540               {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2541                                                 ? kFMixXIdInIdx
2542                                                 : kFMixYIdInIdx)}}});
2543         return true;
2544       }
2545     }
2546 
2547     return false;
2548   };
2549 }
2550 
2551 // This rule handles addition of zero for integers.
RedundantIAdd()2552 FoldingRule RedundantIAdd() {
2553   return [](IRContext* context, Instruction* inst,
2554             const std::vector<const analysis::Constant*>& constants) {
2555     assert(inst->opcode() == spv::Op::OpIAdd &&
2556            "Wrong opcode. Should be OpIAdd.");
2557 
2558     uint32_t operand = std::numeric_limits<uint32_t>::max();
2559     const analysis::Type* operand_type = nullptr;
2560     if (constants[0] && constants[0]->IsZero()) {
2561       operand = inst->GetSingleWordInOperand(1);
2562       operand_type = constants[0]->type();
2563     } else if (constants[1] && constants[1]->IsZero()) {
2564       operand = inst->GetSingleWordInOperand(0);
2565       operand_type = constants[1]->type();
2566     }
2567 
2568     if (operand != std::numeric_limits<uint32_t>::max()) {
2569       const analysis::Type* inst_type =
2570           context->get_type_mgr()->GetType(inst->type_id());
2571       if (inst_type->IsSame(operand_type)) {
2572         inst->SetOpcode(spv::Op::OpCopyObject);
2573       } else {
2574         inst->SetOpcode(spv::Op::OpBitcast);
2575       }
2576       inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2577       return true;
2578     }
2579     return false;
2580   };
2581 }
2582 
2583 // This rule look for a dot with a constant vector containing a single 1 and
2584 // the rest 0s.  This is the same as doing an extract.
DotProductDoingExtract()2585 FoldingRule DotProductDoingExtract() {
2586   return [](IRContext* context, Instruction* inst,
2587             const std::vector<const analysis::Constant*>& constants) {
2588     assert(inst->opcode() == spv::Op::OpDot &&
2589            "Wrong opcode.  Should be OpDot.");
2590 
2591     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2592 
2593     if (!inst->IsFloatingPointFoldingAllowed()) {
2594       return false;
2595     }
2596 
2597     for (int i = 0; i < 2; ++i) {
2598       if (!constants[i]) {
2599         continue;
2600       }
2601 
2602       const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2603       assert(vector_type && "Inputs to OpDot must be vectors.");
2604       const analysis::Float* element_type =
2605           vector_type->element_type()->AsFloat();
2606       assert(element_type && "Inputs to OpDot must be vectors of floats.");
2607       uint32_t element_width = element_type->width();
2608       if (element_width != 32 && element_width != 64) {
2609         return false;
2610       }
2611 
2612       std::vector<const analysis::Constant*> components;
2613       components = constants[i]->GetVectorComponents(const_mgr);
2614 
2615       constexpr uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2616 
2617       uint32_t component_with_one = kNotFound;
2618       bool all_others_zero = true;
2619       for (uint32_t j = 0; j < components.size(); ++j) {
2620         const analysis::Constant* element = components[j];
2621         double value =
2622             (element_width == 32 ? element->GetFloat() : element->GetDouble());
2623         if (value == 0.0) {
2624           continue;
2625         } else if (value == 1.0) {
2626           if (component_with_one == kNotFound) {
2627             component_with_one = j;
2628           } else {
2629             component_with_one = kNotFound;
2630             break;
2631           }
2632         } else {
2633           all_others_zero = false;
2634           break;
2635         }
2636       }
2637 
2638       if (!all_others_zero || component_with_one == kNotFound) {
2639         continue;
2640       }
2641 
2642       std::vector<Operand> operands;
2643       operands.push_back(
2644           {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2645       operands.push_back(
2646           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2647 
2648       inst->SetOpcode(spv::Op::OpCompositeExtract);
2649       inst->SetInOperands(std::move(operands));
2650       return true;
2651     }
2652     return false;
2653   };
2654 }
2655 
2656 // If we are storing an undef, then we can remove the store.
2657 //
2658 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2659 // is complicated.  Waiting to see if it is needed.
StoringUndef()2660 FoldingRule StoringUndef() {
2661   return [](IRContext* context, Instruction* inst,
2662             const std::vector<const analysis::Constant*>&) {
2663     assert(inst->opcode() == spv::Op::OpStore &&
2664            "Wrong opcode.  Should be OpStore.");
2665 
2666     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2667 
2668     // If this is a volatile store, the store cannot be removed.
2669     if (inst->NumInOperands() == 3) {
2670       if (inst->GetSingleWordInOperand(2) &
2671           uint32_t(spv::MemoryAccessMask::Volatile)) {
2672         return false;
2673       }
2674     }
2675 
2676     uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2677     Instruction* object_inst = def_use_mgr->GetDef(object_id);
2678     if (object_inst->opcode() == spv::Op::OpUndef) {
2679       inst->ToNop();
2680       return true;
2681     }
2682     return false;
2683   };
2684 }
2685 
VectorShuffleFeedingShuffle()2686 FoldingRule VectorShuffleFeedingShuffle() {
2687   return [](IRContext* context, Instruction* inst,
2688             const std::vector<const analysis::Constant*>&) {
2689     assert(inst->opcode() == spv::Op::OpVectorShuffle &&
2690            "Wrong opcode.  Should be OpVectorShuffle.");
2691 
2692     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2693     analysis::TypeManager* type_mgr = context->get_type_mgr();
2694 
2695     Instruction* feeding_shuffle_inst =
2696         def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2697     analysis::Vector* op0_type =
2698         type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2699     uint32_t op0_length = op0_type->element_count();
2700 
2701     bool feeder_is_op0 = true;
2702     if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2703       feeding_shuffle_inst =
2704           def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2705       feeder_is_op0 = false;
2706     }
2707 
2708     if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2709       return false;
2710     }
2711 
2712     Instruction* feeder2 =
2713         def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2714     analysis::Vector* feeder_op0_type =
2715         type_mgr->GetType(feeder2->type_id())->AsVector();
2716     uint32_t feeder_op0_length = feeder_op0_type->element_count();
2717 
2718     uint32_t new_feeder_id = 0;
2719     std::vector<Operand> new_operands;
2720     new_operands.resize(
2721         2, {SPV_OPERAND_TYPE_ID, {0}});  // Place holders for vector operands.
2722     const uint32_t undef_literal = 0xffffffff;
2723     for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2724       uint32_t component_index = inst->GetSingleWordInOperand(op);
2725 
2726       // Do not interpret the undefined value literal as coming from operand 1.
2727       if (component_index != undef_literal &&
2728           feeder_is_op0 == (component_index < op0_length)) {
2729         // This component comes from the feeding_shuffle_inst.  Update
2730         // |component_index| to be the index into the operand of the feeder.
2731 
2732         // Adjust component_index to get the index into the operands of the
2733         // feeding_shuffle_inst.
2734         if (component_index >= op0_length) {
2735           component_index -= op0_length;
2736         }
2737         component_index =
2738             feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2739 
2740         // Check if we are using a component from the first or second operand of
2741         // the feeding instruction.
2742         if (component_index < feeder_op0_length) {
2743           if (new_feeder_id == 0) {
2744             // First time through, save the id of the operand the element comes
2745             // from.
2746             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2747           } else if (new_feeder_id !=
2748                      feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2749             // We need both elements of the feeding_shuffle_inst, so we cannot
2750             // fold.
2751             return false;
2752           }
2753         } else if (component_index != undef_literal) {
2754           if (new_feeder_id == 0) {
2755             // First time through, save the id of the operand the element comes
2756             // from.
2757             new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2758           } else if (new_feeder_id !=
2759                      feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2760             // We need both elements of the feeding_shuffle_inst, so we cannot
2761             // fold.
2762             return false;
2763           }
2764           component_index -= feeder_op0_length;
2765         }
2766 
2767         if (!feeder_is_op0 && component_index != undef_literal) {
2768           component_index += op0_length;
2769         }
2770       }
2771       new_operands.push_back(
2772           {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2773     }
2774 
2775     if (new_feeder_id == 0) {
2776       analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2777       const analysis::Type* type =
2778           type_mgr->GetType(feeding_shuffle_inst->type_id());
2779       const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2780       new_feeder_id =
2781           const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2782     }
2783 
2784     if (feeder_is_op0) {
2785       // If the size of the first vector operand changed then the indices
2786       // referring to the second operand need to be adjusted.
2787       Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2788       analysis::Type* new_feeder_type =
2789           type_mgr->GetType(new_feeder_inst->type_id());
2790       uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2791       int32_t adjustment = op0_length - new_op0_size;
2792 
2793       if (adjustment != 0) {
2794         for (uint32_t i = 2; i < new_operands.size(); i++) {
2795           uint32_t operand = inst->GetSingleWordInOperand(i);
2796           if (operand >= op0_length && operand != undef_literal) {
2797             new_operands[i].words[0] -= adjustment;
2798           }
2799         }
2800       }
2801 
2802       new_operands[0].words[0] = new_feeder_id;
2803       new_operands[1] = inst->GetInOperand(1);
2804     } else {
2805       new_operands[1].words[0] = new_feeder_id;
2806       new_operands[0] = inst->GetInOperand(0);
2807     }
2808 
2809     inst->SetInOperands(std::move(new_operands));
2810     return true;
2811   };
2812 }
2813 
2814 // Removes duplicate ids from the interface list of an OpEntryPoint
2815 // instruction.
RemoveRedundantOperands()2816 FoldingRule RemoveRedundantOperands() {
2817   return [](IRContext*, Instruction* inst,
2818             const std::vector<const analysis::Constant*>&) {
2819     assert(inst->opcode() == spv::Op::OpEntryPoint &&
2820            "Wrong opcode.  Should be OpEntryPoint.");
2821     bool has_redundant_operand = false;
2822     std::unordered_set<uint32_t> seen_operands;
2823     std::vector<Operand> new_operands;
2824 
2825     new_operands.emplace_back(inst->GetOperand(0));
2826     new_operands.emplace_back(inst->GetOperand(1));
2827     new_operands.emplace_back(inst->GetOperand(2));
2828     for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2829       if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2830         new_operands.emplace_back(inst->GetOperand(i));
2831       } else {
2832         has_redundant_operand = true;
2833       }
2834     }
2835 
2836     if (!has_redundant_operand) {
2837       return false;
2838     }
2839 
2840     inst->SetInOperands(std::move(new_operands));
2841     return true;
2842   };
2843 }
2844 
2845 // If an image instruction's operand is a constant, updates the image operand
2846 // flag from Offset to ConstOffset.
UpdateImageOperands()2847 FoldingRule UpdateImageOperands() {
2848   return [](IRContext*, Instruction* inst,
2849             const std::vector<const analysis::Constant*>& constants) {
2850     const auto opcode = inst->opcode();
2851     (void)opcode;
2852     assert((opcode == spv::Op::OpImageSampleImplicitLod ||
2853             opcode == spv::Op::OpImageSampleExplicitLod ||
2854             opcode == spv::Op::OpImageSampleDrefImplicitLod ||
2855             opcode == spv::Op::OpImageSampleDrefExplicitLod ||
2856             opcode == spv::Op::OpImageSampleProjImplicitLod ||
2857             opcode == spv::Op::OpImageSampleProjExplicitLod ||
2858             opcode == spv::Op::OpImageSampleProjDrefImplicitLod ||
2859             opcode == spv::Op::OpImageSampleProjDrefExplicitLod ||
2860             opcode == spv::Op::OpImageFetch ||
2861             opcode == spv::Op::OpImageGather ||
2862             opcode == spv::Op::OpImageDrefGather ||
2863             opcode == spv::Op::OpImageRead || opcode == spv::Op::OpImageWrite ||
2864             opcode == spv::Op::OpImageSparseSampleImplicitLod ||
2865             opcode == spv::Op::OpImageSparseSampleExplicitLod ||
2866             opcode == spv::Op::OpImageSparseSampleDrefImplicitLod ||
2867             opcode == spv::Op::OpImageSparseSampleDrefExplicitLod ||
2868             opcode == spv::Op::OpImageSparseSampleProjImplicitLod ||
2869             opcode == spv::Op::OpImageSparseSampleProjExplicitLod ||
2870             opcode == spv::Op::OpImageSparseSampleProjDrefImplicitLod ||
2871             opcode == spv::Op::OpImageSparseSampleProjDrefExplicitLod ||
2872             opcode == spv::Op::OpImageSparseFetch ||
2873             opcode == spv::Op::OpImageSparseGather ||
2874             opcode == spv::Op::OpImageSparseDrefGather ||
2875             opcode == spv::Op::OpImageSparseRead) &&
2876            "Wrong opcode.  Should be an image instruction.");
2877 
2878     int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2879     if (operand_index >= 0) {
2880       auto image_operands = inst->GetSingleWordInOperand(operand_index);
2881       if (image_operands & uint32_t(spv::ImageOperandsMask::Offset)) {
2882         uint32_t offset_operand_index = operand_index + 1;
2883         if (image_operands & uint32_t(spv::ImageOperandsMask::Bias))
2884           offset_operand_index++;
2885         if (image_operands & uint32_t(spv::ImageOperandsMask::Lod))
2886           offset_operand_index++;
2887         if (image_operands & uint32_t(spv::ImageOperandsMask::Grad))
2888           offset_operand_index += 2;
2889         assert(((image_operands &
2890                  uint32_t(spv::ImageOperandsMask::ConstOffset)) == 0) &&
2891                "Offset and ConstOffset may not be used together");
2892         if (offset_operand_index < inst->NumOperands()) {
2893           if (constants[offset_operand_index]) {
2894             if (constants[offset_operand_index]->IsZero()) {
2895               inst->RemoveInOperand(offset_operand_index);
2896             } else {
2897               image_operands = image_operands |
2898                                uint32_t(spv::ImageOperandsMask::ConstOffset);
2899             }
2900             image_operands =
2901                 image_operands & ~uint32_t(spv::ImageOperandsMask::Offset);
2902             inst->SetInOperand(operand_index, {image_operands});
2903             return true;
2904           }
2905         }
2906       }
2907     }
2908 
2909     return false;
2910   };
2911 }
2912 
2913 }  // namespace
2914 
AddFoldingRules()2915 void FoldingRules::AddFoldingRules() {
2916   // Add all folding rules to the list for the opcodes to which they apply.
2917   // Note that the order in which rules are added to the list matters. If a rule
2918   // applies to the instruction, the rest of the rules will not be attempted.
2919   // Take that into consideration.
2920   rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());
2921 
2922   rules_[spv::Op::OpCompositeConstruct].push_back(
2923       CompositeExtractFeedingConstruct);
2924 
2925   rules_[spv::Op::OpCompositeExtract].push_back(InsertFeedingExtract());
2926   rules_[spv::Op::OpCompositeExtract].push_back(
2927       CompositeConstructFeedingExtract);
2928   rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2929   rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract());
2930 
2931   rules_[spv::Op::OpCompositeInsert].push_back(
2932       CompositeInsertToCompositeConstruct);
2933 
2934   rules_[spv::Op::OpDot].push_back(DotProductDoingExtract());
2935 
2936   rules_[spv::Op::OpEntryPoint].push_back(RemoveRedundantOperands());
2937 
2938   rules_[spv::Op::OpFAdd].push_back(RedundantFAdd());
2939   rules_[spv::Op::OpFAdd].push_back(MergeAddNegateArithmetic());
2940   rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic());
2941   rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic());
2942   rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic());
2943   rules_[spv::Op::OpFAdd].push_back(FactorAddMuls());
2944   rules_[spv::Op::OpFAdd].push_back(MergeMulAddArithmetic);
2945 
2946   rules_[spv::Op::OpFDiv].push_back(RedundantFDiv());
2947   rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv());
2948   rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic());
2949   rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
2950   rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());
2951 
2952   rules_[spv::Op::OpFMul].push_back(RedundantFMul());
2953   rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
2954   rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());
2955   rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic());
2956 
2957   rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic());
2958   rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic());
2959   rules_[spv::Op::OpFNegate].push_back(MergeNegateMulDivArithmetic());
2960 
2961   rules_[spv::Op::OpFSub].push_back(RedundantFSub());
2962   rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic());
2963   rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic());
2964   rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic());
2965   rules_[spv::Op::OpFSub].push_back(MergeMulSubArithmetic);
2966 
2967   rules_[spv::Op::OpIAdd].push_back(RedundantIAdd());
2968   rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic());
2969   rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic());
2970   rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic());
2971   rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic());
2972   rules_[spv::Op::OpIAdd].push_back(FactorAddMuls());
2973 
2974   rules_[spv::Op::OpIMul].push_back(IntMultipleBy1());
2975   rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic());
2976   rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic());
2977 
2978   rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic());
2979   rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic());
2980   rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic());
2981 
2982   rules_[spv::Op::OpPhi].push_back(RedundantPhi());
2983 
2984   rules_[spv::Op::OpSNegate].push_back(MergeNegateArithmetic());
2985   rules_[spv::Op::OpSNegate].push_back(MergeNegateMulDivArithmetic());
2986   rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic());
2987 
2988   rules_[spv::Op::OpSelect].push_back(RedundantSelect());
2989 
2990   rules_[spv::Op::OpStore].push_back(StoringUndef());
2991 
2992   rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2993 
2994   rules_[spv::Op::OpImageSampleImplicitLod].push_back(UpdateImageOperands());
2995   rules_[spv::Op::OpImageSampleExplicitLod].push_back(UpdateImageOperands());
2996   rules_[spv::Op::OpImageSampleDrefImplicitLod].push_back(
2997       UpdateImageOperands());
2998   rules_[spv::Op::OpImageSampleDrefExplicitLod].push_back(
2999       UpdateImageOperands());
3000   rules_[spv::Op::OpImageSampleProjImplicitLod].push_back(
3001       UpdateImageOperands());
3002   rules_[spv::Op::OpImageSampleProjExplicitLod].push_back(
3003       UpdateImageOperands());
3004   rules_[spv::Op::OpImageSampleProjDrefImplicitLod].push_back(
3005       UpdateImageOperands());
3006   rules_[spv::Op::OpImageSampleProjDrefExplicitLod].push_back(
3007       UpdateImageOperands());
3008   rules_[spv::Op::OpImageFetch].push_back(UpdateImageOperands());
3009   rules_[spv::Op::OpImageGather].push_back(UpdateImageOperands());
3010   rules_[spv::Op::OpImageDrefGather].push_back(UpdateImageOperands());
3011   rules_[spv::Op::OpImageRead].push_back(UpdateImageOperands());
3012   rules_[spv::Op::OpImageWrite].push_back(UpdateImageOperands());
3013   rules_[spv::Op::OpImageSparseSampleImplicitLod].push_back(
3014       UpdateImageOperands());
3015   rules_[spv::Op::OpImageSparseSampleExplicitLod].push_back(
3016       UpdateImageOperands());
3017   rules_[spv::Op::OpImageSparseSampleDrefImplicitLod].push_back(
3018       UpdateImageOperands());
3019   rules_[spv::Op::OpImageSparseSampleDrefExplicitLod].push_back(
3020       UpdateImageOperands());
3021   rules_[spv::Op::OpImageSparseSampleProjImplicitLod].push_back(
3022       UpdateImageOperands());
3023   rules_[spv::Op::OpImageSparseSampleProjExplicitLod].push_back(
3024       UpdateImageOperands());
3025   rules_[spv::Op::OpImageSparseSampleProjDrefImplicitLod].push_back(
3026       UpdateImageOperands());
3027   rules_[spv::Op::OpImageSparseSampleProjDrefExplicitLod].push_back(
3028       UpdateImageOperands());
3029   rules_[spv::Op::OpImageSparseFetch].push_back(UpdateImageOperands());
3030   rules_[spv::Op::OpImageSparseGather].push_back(UpdateImageOperands());
3031   rules_[spv::Op::OpImageSparseDrefGather].push_back(UpdateImageOperands());
3032   rules_[spv::Op::OpImageSparseRead].push_back(UpdateImageOperands());
3033 
3034   FeatureManager* feature_manager = context_->get_feature_mgr();
3035   // Add rules for GLSLstd450
3036   uint32_t ext_inst_glslstd450_id =
3037       feature_manager->GetExtInstImportId_GLSLstd450();
3038   if (ext_inst_glslstd450_id != 0) {
3039     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
3040         RedundantFMix());
3041   }
3042 }
3043 }  // namespace opt
3044 }  // namespace spvtools
3045