• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/const_folding_rules.h"
16 
17 #include "source/opt/ir_context.h"
18 
19 namespace spvtools {
20 namespace opt {
21 namespace {
22 constexpr uint32_t kExtractCompositeIdInIdx = 0;
23 
24 // Returns a constants with the value NaN of the given type.  Only works for
25 // 32-bit and 64-bit float point types.  Returns |nullptr| if an error occurs.
GetNan(const analysis::Type * type,analysis::ConstantManager * const_mgr)26 const analysis::Constant* GetNan(const analysis::Type* type,
27                                  analysis::ConstantManager* const_mgr) {
28   const analysis::Float* float_type = type->AsFloat();
29   if (float_type == nullptr) {
30     return nullptr;
31   }
32 
33   switch (float_type->width()) {
34     case 32:
35       return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
36     case 64:
37       return const_mgr->GetDoubleConst(
38           std::numeric_limits<double>::quiet_NaN());
39     default:
40       return nullptr;
41   }
42 }
43 
44 // Returns a constants with the value INF of the given type.  Only works for
45 // 32-bit and 64-bit float point types.  Returns |nullptr| if an error occurs.
GetInf(const analysis::Type * type,analysis::ConstantManager * const_mgr)46 const analysis::Constant* GetInf(const analysis::Type* type,
47                                  analysis::ConstantManager* const_mgr) {
48   const analysis::Float* float_type = type->AsFloat();
49   if (float_type == nullptr) {
50     return nullptr;
51   }
52 
53   switch (float_type->width()) {
54     case 32:
55       return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
56     case 64:
57       return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
58     default:
59       return nullptr;
60   }
61 }
62 
63 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)64 bool HasFloatingPoint(const analysis::Type* type) {
65   if (type->AsFloat()) {
66     return true;
67   } else if (const analysis::Vector* vec_type = type->AsVector()) {
68     return vec_type->element_type()->AsFloat() != nullptr;
69   }
70 
71   return false;
72 }
73 
74 // Returns a constants with the value |-val| of the given type.  Only works for
75 // 32-bit and 64-bit float point types.  Returns |nullptr| if an error occurs.
NegateFPConst(const analysis::Type * result_type,const analysis::Constant * val,analysis::ConstantManager * const_mgr)76 const analysis::Constant* NegateFPConst(const analysis::Type* result_type,
77                                         const analysis::Constant* val,
78                                         analysis::ConstantManager* const_mgr) {
79   const analysis::Float* float_type = result_type->AsFloat();
80   assert(float_type != nullptr);
81   if (float_type->width() == 32) {
82     float fa = val->GetFloat();
83     return const_mgr->GetFloatConst(-fa);
84   } else if (float_type->width() == 64) {
85     double da = val->GetDouble();
86     return const_mgr->GetDoubleConst(-da);
87   }
88   return nullptr;
89 }
90 
91 // Returns a constants with the value |-val| of the given type.
NegateIntConst(const analysis::Type * result_type,const analysis::Constant * val,analysis::ConstantManager * const_mgr)92 const analysis::Constant* NegateIntConst(const analysis::Type* result_type,
93                                          const analysis::Constant* val,
94                                          analysis::ConstantManager* const_mgr) {
95   const analysis::Integer* int_type = result_type->AsInteger();
96   assert(int_type != nullptr);
97 
98   if (val->AsNullConstant()) {
99     return val;
100   }
101 
102   uint64_t new_value = static_cast<uint64_t>(-val->GetSignExtendedValue());
103   return const_mgr->GetIntConst(new_value, int_type->width(),
104                                 int_type->IsSigned());
105 }
106 
107 // Folds an OpcompositeExtract where input is a composite constant.
FoldExtractWithConstants()108 ConstantFoldingRule FoldExtractWithConstants() {
109   return [](IRContext* context, Instruction* inst,
110             const std::vector<const analysis::Constant*>& constants)
111              -> const analysis::Constant* {
112     const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
113     if (c == nullptr) {
114       return nullptr;
115     }
116 
117     for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
118       uint32_t element_index = inst->GetSingleWordInOperand(i);
119       if (c->AsNullConstant()) {
120         // Return Null for the return type.
121         analysis::ConstantManager* const_mgr = context->get_constant_mgr();
122         analysis::TypeManager* type_mgr = context->get_type_mgr();
123         return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
124       }
125 
126       auto cc = c->AsCompositeConstant();
127       assert(cc != nullptr);
128       auto components = cc->GetComponents();
129       // Protect against invalid IR.  Refuse to fold if the index is out
130       // of bounds.
131       if (element_index >= components.size()) return nullptr;
132       c = components[element_index];
133     }
134     return c;
135   };
136 }
137 
138 // Folds an OpcompositeInsert where input is a composite constant.
FoldInsertWithConstants()139 ConstantFoldingRule FoldInsertWithConstants() {
140   return [](IRContext* context, Instruction* inst,
141             const std::vector<const analysis::Constant*>& constants)
142              -> const analysis::Constant* {
143     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
144     const analysis::Constant* object = constants[0];
145     const analysis::Constant* composite = constants[1];
146     if (object == nullptr || composite == nullptr) {
147       return nullptr;
148     }
149 
150     // If there is more than 1 index, then each additional constant used by the
151     // index will need to be recreated to use the inserted object.
152     std::vector<const analysis::Constant*> chain;
153     std::vector<const analysis::Constant*> components;
154     const analysis::Type* type = nullptr;
155     const uint32_t final_index = (inst->NumInOperands() - 1);
156 
157     // Work down hierarchy of all indexes
158     for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
159       type = composite->type();
160 
161       if (composite->AsNullConstant()) {
162         // Make new composite so it can be inserted in the index with the
163         // non-null value
164         if (const auto new_composite =
165                 const_mgr->GetNullCompositeConstant(type)) {
166           // Keep track of any indexes along the way to last index
167           if (i != final_index) {
168             chain.push_back(new_composite);
169           }
170           components = new_composite->AsCompositeConstant()->GetComponents();
171         } else {
172           // Unsupported input type (such as structs)
173           return nullptr;
174         }
175       } else {
176         // Keep track of any indexes along the way to last index
177         if (i != final_index) {
178           chain.push_back(composite);
179         }
180         components = composite->AsCompositeConstant()->GetComponents();
181       }
182       const uint32_t index = inst->GetSingleWordInOperand(i);
183       composite = components[index];
184     }
185 
186     // Final index in hierarchy is inserted with new object.
187     const uint32_t final_operand = inst->GetSingleWordInOperand(final_index);
188     std::vector<uint32_t> ids;
189     for (size_t i = 0; i < components.size(); i++) {
190       const analysis::Constant* constant =
191           (i == final_operand) ? object : components[i];
192       Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
193       ids.push_back(member_inst->result_id());
194     }
195     const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
196 
197     // Work backwards up the chain and replace each index with new constant.
198     for (size_t i = chain.size(); i > 0; i--) {
199       // Need to insert any previous instruction into the module first.
200       // Can't just insert in types_values_begin() because it will move above
201       // where the types are declared.
202       // Can't compare with location of inst because not all new added
203       // instructions are added to types_values_
204       auto iter = context->types_values_end();
205       Module::inst_iterator* pos = &iter;
206       const_mgr->BuildInstructionAndAddToModule(new_constant, pos);
207 
208       composite = chain[i - 1];
209       components = composite->AsCompositeConstant()->GetComponents();
210       type = composite->type();
211       ids.clear();
212       for (size_t k = 0; k < components.size(); k++) {
213         const uint32_t index =
214             inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
215         const analysis::Constant* constant =
216             (k == index) ? new_constant : components[k];
217         const uint32_t constant_id =
218             const_mgr->FindDeclaredConstant(constant, 0);
219         ids.push_back(constant_id);
220       }
221       new_constant = const_mgr->GetConstant(type, ids);
222     }
223 
224     // If multiple constants were created, only need to return the top index.
225     return new_constant;
226   };
227 }
228 
FoldVectorShuffleWithConstants()229 ConstantFoldingRule FoldVectorShuffleWithConstants() {
230   return [](IRContext* context, Instruction* inst,
231             const std::vector<const analysis::Constant*>& constants)
232              -> const analysis::Constant* {
233     assert(inst->opcode() == spv::Op::OpVectorShuffle);
234     const analysis::Constant* c1 = constants[0];
235     const analysis::Constant* c2 = constants[1];
236     if (c1 == nullptr || c2 == nullptr) {
237       return nullptr;
238     }
239 
240     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
241     const analysis::Type* element_type = c1->type()->AsVector()->element_type();
242 
243     std::vector<const analysis::Constant*> c1_components;
244     if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
245       c1_components = vec_const->GetComponents();
246     } else {
247       assert(c1->AsNullConstant());
248       const analysis::Constant* element =
249           const_mgr->GetConstant(element_type, {});
250       c1_components.resize(c1->type()->AsVector()->element_count(), element);
251     }
252     std::vector<const analysis::Constant*> c2_components;
253     if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
254       c2_components = vec_const->GetComponents();
255     } else {
256       assert(c2->AsNullConstant());
257       const analysis::Constant* element =
258           const_mgr->GetConstant(element_type, {});
259       c2_components.resize(c2->type()->AsVector()->element_count(), element);
260     }
261 
262     std::vector<uint32_t> ids;
263     const uint32_t undef_literal_value = 0xffffffff;
264     for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
265       uint32_t index = inst->GetSingleWordInOperand(i);
266       if (index == undef_literal_value) {
267         // Don't fold shuffle with undef literal value.
268         return nullptr;
269       } else if (index < c1_components.size()) {
270         Instruction* member_inst =
271             const_mgr->GetDefiningInstruction(c1_components[index]);
272         ids.push_back(member_inst->result_id());
273       } else {
274         Instruction* member_inst = const_mgr->GetDefiningInstruction(
275             c2_components[index - c1_components.size()]);
276         ids.push_back(member_inst->result_id());
277       }
278     }
279 
280     analysis::TypeManager* type_mgr = context->get_type_mgr();
281     return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
282   };
283 }
284 
FoldVectorTimesScalar()285 ConstantFoldingRule FoldVectorTimesScalar() {
286   return [](IRContext* context, Instruction* inst,
287             const std::vector<const analysis::Constant*>& constants)
288              -> const analysis::Constant* {
289     assert(inst->opcode() == spv::Op::OpVectorTimesScalar);
290     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
291     analysis::TypeManager* type_mgr = context->get_type_mgr();
292 
293     if (!inst->IsFloatingPointFoldingAllowed()) {
294       if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
295         return nullptr;
296       }
297     }
298 
299     const analysis::Constant* c1 = constants[0];
300     const analysis::Constant* c2 = constants[1];
301 
302     if (c1 && c1->IsZero()) {
303       return c1;
304     }
305 
306     if (c2 && c2->IsZero()) {
307       // Get or create the NullConstant for this type.
308       std::vector<uint32_t> ids;
309       return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
310     }
311 
312     if (c1 == nullptr || c2 == nullptr) {
313       return nullptr;
314     }
315 
316     // Check result type.
317     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
318     const analysis::Vector* vector_type = result_type->AsVector();
319     assert(vector_type != nullptr);
320     const analysis::Type* element_type = vector_type->element_type();
321     assert(element_type != nullptr);
322     const analysis::Float* float_type = element_type->AsFloat();
323     assert(float_type != nullptr);
324 
325     // Check types of c1 and c2.
326     assert(c1->type()->AsVector() == vector_type);
327     assert(c1->type()->AsVector()->element_type() == element_type &&
328            c2->type() == element_type);
329 
330     // Get a float vector that is the result of vector-times-scalar.
331     std::vector<const analysis::Constant*> c1_components =
332         c1->GetVectorComponents(const_mgr);
333     std::vector<uint32_t> ids;
334     if (float_type->width() == 32) {
335       float scalar = c2->GetFloat();
336       for (uint32_t i = 0; i < c1_components.size(); ++i) {
337         utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
338         std::vector<uint32_t> words = result.GetWords();
339         const analysis::Constant* new_elem =
340             const_mgr->GetConstant(float_type, words);
341         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
342       }
343       return const_mgr->GetConstant(vector_type, ids);
344     } else if (float_type->width() == 64) {
345       double scalar = c2->GetDouble();
346       for (uint32_t i = 0; i < c1_components.size(); ++i) {
347         utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
348                                          scalar);
349         std::vector<uint32_t> words = result.GetWords();
350         const analysis::Constant* new_elem =
351             const_mgr->GetConstant(float_type, words);
352         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
353       }
354       return const_mgr->GetConstant(vector_type, ids);
355     }
356     return nullptr;
357   };
358 }
359 
360 // Returns to the constant that results from tranposing |matrix|. The result
361 // will have type |result_type|, and |matrix| must exist in |context|. The
362 // result constant will also exist in |context|.
TransposeMatrix(const analysis::Constant * matrix,analysis::Matrix * result_type,IRContext * context)363 const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix,
364                                           analysis::Matrix* result_type,
365                                           IRContext* context) {
366   analysis::ConstantManager* const_mgr = context->get_constant_mgr();
367   if (matrix->AsNullConstant() != nullptr) {
368     return const_mgr->GetNullCompositeConstant(result_type);
369   }
370 
371   const auto& columns = matrix->AsMatrixConstant()->GetComponents();
372   uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count();
373 
374   // Collect the ids of the elements in their new positions.
375   std::vector<std::vector<uint32_t>> result_elements(number_of_rows);
376   for (const analysis::Constant* column : columns) {
377     if (column->AsNullConstant()) {
378       column = const_mgr->GetNullCompositeConstant(column->type());
379     }
380     const auto& column_components = column->AsVectorConstant()->GetComponents();
381 
382     for (uint32_t row = 0; row < number_of_rows; ++row) {
383       result_elements[row].push_back(
384           const_mgr->GetDefiningInstruction(column_components[row])
385               ->result_id());
386     }
387   }
388 
389   // Create the constant for each row in the result, and collect the ids.
390   std::vector<uint32_t> result_columns(number_of_rows);
391   for (uint32_t col = 0; col < number_of_rows; ++col) {
392     auto* element = const_mgr->GetConstant(result_type->element_type(),
393                                            result_elements[col]);
394     result_columns[col] =
395         const_mgr->GetDefiningInstruction(element)->result_id();
396   }
397 
398   // Create the matrix constant from the row ids, and return it.
399   return const_mgr->GetConstant(result_type, result_columns);
400 }
401 
FoldTranspose(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)402 const analysis::Constant* FoldTranspose(
403     IRContext* context, Instruction* inst,
404     const std::vector<const analysis::Constant*>& constants) {
405   assert(inst->opcode() == spv::Op::OpTranspose);
406 
407   analysis::TypeManager* type_mgr = context->get_type_mgr();
408   if (!inst->IsFloatingPointFoldingAllowed()) {
409     if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
410       return nullptr;
411     }
412   }
413 
414   const analysis::Constant* matrix = constants[0];
415   if (matrix == nullptr) {
416     return nullptr;
417   }
418 
419   auto* result_type = type_mgr->GetType(inst->type_id());
420   return TransposeMatrix(matrix, result_type->AsMatrix(), context);
421 }
422 
FoldVectorTimesMatrix()423 ConstantFoldingRule FoldVectorTimesMatrix() {
424   return [](IRContext* context, Instruction* inst,
425             const std::vector<const analysis::Constant*>& constants)
426              -> const analysis::Constant* {
427     assert(inst->opcode() == spv::Op::OpVectorTimesMatrix);
428     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
429     analysis::TypeManager* type_mgr = context->get_type_mgr();
430 
431     if (!inst->IsFloatingPointFoldingAllowed()) {
432       if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
433         return nullptr;
434       }
435     }
436 
437     const analysis::Constant* c1 = constants[0];
438     const analysis::Constant* c2 = constants[1];
439 
440     if (c1 == nullptr || c2 == nullptr) {
441       return nullptr;
442     }
443 
444     // Check result type.
445     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
446     const analysis::Vector* vector_type = result_type->AsVector();
447     assert(vector_type != nullptr);
448     const analysis::Type* element_type = vector_type->element_type();
449     assert(element_type != nullptr);
450     const analysis::Float* float_type = element_type->AsFloat();
451     assert(float_type != nullptr);
452 
453     // Check types of c1 and c2.
454     assert(c1->type()->AsVector() == vector_type);
455     assert(c1->type()->AsVector()->element_type() == element_type &&
456            c2->type()->AsMatrix()->element_type() == vector_type);
457 
458     uint32_t resultVectorSize = result_type->AsVector()->element_count();
459     std::vector<uint32_t> ids;
460 
461     if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
462       std::vector<uint32_t> words(float_type->width() / 32, 0);
463       for (uint32_t i = 0; i < resultVectorSize; ++i) {
464         const analysis::Constant* new_elem =
465             const_mgr->GetConstant(float_type, words);
466         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
467       }
468       return const_mgr->GetConstant(vector_type, ids);
469     }
470 
471     // Get a float vector that is the result of vector-times-matrix.
472     std::vector<const analysis::Constant*> c1_components =
473         c1->GetVectorComponents(const_mgr);
474     std::vector<const analysis::Constant*> c2_components =
475         c2->AsMatrixConstant()->GetComponents();
476 
477     if (float_type->width() == 32) {
478       for (uint32_t i = 0; i < resultVectorSize; ++i) {
479         float result_scalar = 0.0f;
480         if (!c2_components[i]->AsNullConstant()) {
481           const analysis::VectorConstant* c2_vec =
482               c2_components[i]->AsVectorConstant();
483           for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
484             float c1_scalar = c1_components[j]->GetFloat();
485             float c2_scalar = c2_vec->GetComponents()[j]->GetFloat();
486             result_scalar += c1_scalar * c2_scalar;
487           }
488         }
489         utils::FloatProxy<float> result(result_scalar);
490         std::vector<uint32_t> words = result.GetWords();
491         const analysis::Constant* new_elem =
492             const_mgr->GetConstant(float_type, words);
493         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
494       }
495       return const_mgr->GetConstant(vector_type, ids);
496     } else if (float_type->width() == 64) {
497       for (uint32_t i = 0; i < c2_components.size(); ++i) {
498         double result_scalar = 0.0;
499         if (!c2_components[i]->AsNullConstant()) {
500           const analysis::VectorConstant* c2_vec =
501               c2_components[i]->AsVectorConstant();
502           for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
503             double c1_scalar = c1_components[j]->GetDouble();
504             double c2_scalar = c2_vec->GetComponents()[j]->GetDouble();
505             result_scalar += c1_scalar * c2_scalar;
506           }
507         }
508         utils::FloatProxy<double> result(result_scalar);
509         std::vector<uint32_t> words = result.GetWords();
510         const analysis::Constant* new_elem =
511             const_mgr->GetConstant(float_type, words);
512         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
513       }
514       return const_mgr->GetConstant(vector_type, ids);
515     }
516     return nullptr;
517   };
518 }
519 
FoldMatrixTimesVector()520 ConstantFoldingRule FoldMatrixTimesVector() {
521   return [](IRContext* context, Instruction* inst,
522             const std::vector<const analysis::Constant*>& constants)
523              -> const analysis::Constant* {
524     assert(inst->opcode() == spv::Op::OpMatrixTimesVector);
525     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
526     analysis::TypeManager* type_mgr = context->get_type_mgr();
527 
528     if (!inst->IsFloatingPointFoldingAllowed()) {
529       if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
530         return nullptr;
531       }
532     }
533 
534     const analysis::Constant* c1 = constants[0];
535     const analysis::Constant* c2 = constants[1];
536 
537     if (c1 == nullptr || c2 == nullptr) {
538       return nullptr;
539     }
540 
541     // Check result type.
542     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
543     const analysis::Vector* vector_type = result_type->AsVector();
544     assert(vector_type != nullptr);
545     const analysis::Type* element_type = vector_type->element_type();
546     assert(element_type != nullptr);
547     const analysis::Float* float_type = element_type->AsFloat();
548     assert(float_type != nullptr);
549 
550     // Check types of c1 and c2.
551     assert(c1->type()->AsMatrix()->element_type() == vector_type);
552     assert(c2->type()->AsVector()->element_type() == element_type);
553 
554     uint32_t resultVectorSize = result_type->AsVector()->element_count();
555     std::vector<uint32_t> ids;
556 
557     if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
558       std::vector<uint32_t> words(float_type->width() / 32, 0);
559       for (uint32_t i = 0; i < resultVectorSize; ++i) {
560         const analysis::Constant* new_elem =
561             const_mgr->GetConstant(float_type, words);
562         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
563       }
564       return const_mgr->GetConstant(vector_type, ids);
565     }
566 
567     // Get a float vector that is the result of matrix-times-vector.
568     std::vector<const analysis::Constant*> c1_components =
569         c1->AsMatrixConstant()->GetComponents();
570     std::vector<const analysis::Constant*> c2_components =
571         c2->GetVectorComponents(const_mgr);
572 
573     if (float_type->width() == 32) {
574       for (uint32_t i = 0; i < resultVectorSize; ++i) {
575         float result_scalar = 0.0f;
576         for (uint32_t j = 0; j < c1_components.size(); ++j) {
577           if (!c1_components[j]->AsNullConstant()) {
578             float c1_scalar = c1_components[j]
579                                   ->AsVectorConstant()
580                                   ->GetComponents()[i]
581                                   ->GetFloat();
582             float c2_scalar = c2_components[j]->GetFloat();
583             result_scalar += c1_scalar * c2_scalar;
584           }
585         }
586         utils::FloatProxy<float> result(result_scalar);
587         std::vector<uint32_t> words = result.GetWords();
588         const analysis::Constant* new_elem =
589             const_mgr->GetConstant(float_type, words);
590         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
591       }
592       return const_mgr->GetConstant(vector_type, ids);
593     } else if (float_type->width() == 64) {
594       for (uint32_t i = 0; i < resultVectorSize; ++i) {
595         double result_scalar = 0.0;
596         for (uint32_t j = 0; j < c1_components.size(); ++j) {
597           if (!c1_components[j]->AsNullConstant()) {
598             double c1_scalar = c1_components[j]
599                                    ->AsVectorConstant()
600                                    ->GetComponents()[i]
601                                    ->GetDouble();
602             double c2_scalar = c2_components[j]->GetDouble();
603             result_scalar += c1_scalar * c2_scalar;
604           }
605         }
606         utils::FloatProxy<double> result(result_scalar);
607         std::vector<uint32_t> words = result.GetWords();
608         const analysis::Constant* new_elem =
609             const_mgr->GetConstant(float_type, words);
610         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
611       }
612       return const_mgr->GetConstant(vector_type, ids);
613     }
614     return nullptr;
615   };
616 }
617 
FoldCompositeWithConstants()618 ConstantFoldingRule FoldCompositeWithConstants() {
619   // Folds an OpCompositeConstruct where all of the inputs are constants to a
620   // constant.  A new constant is created if necessary.
621   return [](IRContext* context, Instruction* inst,
622             const std::vector<const analysis::Constant*>& constants)
623              -> const analysis::Constant* {
624     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
625     analysis::TypeManager* type_mgr = context->get_type_mgr();
626     const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
627     Instruction* type_inst =
628         context->get_def_use_mgr()->GetDef(inst->type_id());
629 
630     std::vector<uint32_t> ids;
631     for (uint32_t i = 0; i < constants.size(); ++i) {
632       const analysis::Constant* element_const = constants[i];
633       if (element_const == nullptr) {
634         return nullptr;
635       }
636 
637       uint32_t component_type_id = 0;
638       if (type_inst->opcode() == spv::Op::OpTypeStruct) {
639         component_type_id = type_inst->GetSingleWordInOperand(i);
640       } else if (type_inst->opcode() == spv::Op::OpTypeArray) {
641         component_type_id = type_inst->GetSingleWordInOperand(0);
642       }
643 
644       uint32_t element_id =
645           const_mgr->FindDeclaredConstant(element_const, component_type_id);
646       if (element_id == 0) {
647         return nullptr;
648       }
649       ids.push_back(element_id);
650     }
651     return const_mgr->GetConstant(new_type, ids);
652   };
653 }
654 
655 // The interface for a function that returns the result of applying a scalar
656 // floating-point binary operation on |a| and |b|.  The type of the return value
657 // will be |type|.  The input constants must also be of type |type|.
658 using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
659     const analysis::Type* result_type, const analysis::Constant* a,
660     analysis::ConstantManager*)>;
661 
662 // The interface for a function that returns the result of applying a scalar
663 // floating-point binary operation on |a| and |b|.  The type of the return value
664 // will be |type|.  The input constants must also be of type |type|.
665 using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
666     const analysis::Type* result_type, const analysis::Constant* a,
667     const analysis::Constant* b, analysis::ConstantManager*)>;
668 
669 // Returns a |ConstantFoldingRule| that folds unary scalar ops
670 // using |scalar_rule| and unary vectors ops by applying
671 // |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
672 // that is returned assumes that |constants| contains 1 entry.  If they are
673 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
674 // whose element type is |Float| or |Integer|.
FoldUnaryOp(UnaryScalarFoldingRule scalar_rule)675 ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
676   return [scalar_rule](IRContext* context, Instruction* inst,
677                        const std::vector<const analysis::Constant*>& constants)
678              -> const analysis::Constant* {
679 
680     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
681     analysis::TypeManager* type_mgr = context->get_type_mgr();
682     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
683     const analysis::Vector* vector_type = result_type->AsVector();
684 
685     const analysis::Constant* arg =
686         (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
687 
688     if (arg == nullptr) {
689       return nullptr;
690     }
691 
692     if (vector_type != nullptr) {
693       std::vector<const analysis::Constant*> a_components;
694       std::vector<const analysis::Constant*> results_components;
695 
696       a_components = arg->GetVectorComponents(const_mgr);
697 
698       // Fold each component of the vector.
699       for (uint32_t i = 0; i < a_components.size(); ++i) {
700         results_components.push_back(scalar_rule(vector_type->element_type(),
701                                                  a_components[i], const_mgr));
702         if (results_components[i] == nullptr) {
703           return nullptr;
704         }
705       }
706 
707       // Build the constant object and return it.
708       std::vector<uint32_t> ids;
709       for (const analysis::Constant* member : results_components) {
710         ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
711       }
712       return const_mgr->GetConstant(vector_type, ids);
713     } else {
714       return scalar_rule(result_type, arg, const_mgr);
715     }
716   };
717 }
718 
719 // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
720 // using |scalar_rule| and unary float point vectors ops by applying
721 // |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
722 // that is returned assumes that |constants| contains 1 entry.  If they are
723 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
724 // whose element type is |Float| or |Integer|.
FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule)725 ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
726   auto folding_rule = FoldUnaryOp(scalar_rule);
727   return [folding_rule](IRContext* context, Instruction* inst,
728                         const std::vector<const analysis::Constant*>& constants)
729              -> const analysis::Constant* {
730     if (!inst->IsFloatingPointFoldingAllowed()) {
731       return nullptr;
732     }
733 
734     return folding_rule(context, inst, constants);
735   };
736 }
737 
738 // Returns the result of folding the constants in |constants| according the
739 // |scalar_rule|.  If |result_type| is a vector, then |scalar_rule| is applied
740 // per component.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule,uint32_t result_type_id,const std::vector<const analysis::Constant * > & constants,IRContext * context)741 const analysis::Constant* FoldFPBinaryOp(
742     BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
743     const std::vector<const analysis::Constant*>& constants,
744     IRContext* context) {
745   analysis::ConstantManager* const_mgr = context->get_constant_mgr();
746   analysis::TypeManager* type_mgr = context->get_type_mgr();
747   const analysis::Type* result_type = type_mgr->GetType(result_type_id);
748   const analysis::Vector* vector_type = result_type->AsVector();
749 
750   if (constants[0] == nullptr || constants[1] == nullptr) {
751     return nullptr;
752   }
753 
754   if (vector_type != nullptr) {
755     std::vector<const analysis::Constant*> a_components;
756     std::vector<const analysis::Constant*> b_components;
757     std::vector<const analysis::Constant*> results_components;
758 
759     a_components = constants[0]->GetVectorComponents(const_mgr);
760     b_components = constants[1]->GetVectorComponents(const_mgr);
761 
762     // Fold each component of the vector.
763     for (uint32_t i = 0; i < a_components.size(); ++i) {
764       results_components.push_back(scalar_rule(vector_type->element_type(),
765                                                a_components[i], b_components[i],
766                                                const_mgr));
767       if (results_components[i] == nullptr) {
768         return nullptr;
769       }
770     }
771 
772     // Build the constant object and return it.
773     std::vector<uint32_t> ids;
774     for (const analysis::Constant* member : results_components) {
775       ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
776     }
777     return const_mgr->GetConstant(vector_type, ids);
778   } else {
779     return scalar_rule(result_type, constants[0], constants[1], const_mgr);
780   }
781 }
782 
783 // Returns a |ConstantFoldingRule| that folds floating point scalars using
784 // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
785 // elements of the vector.  The |ConstantFoldingRule| that is returned assumes
786 // that |constants| contains 2 entries.  If they are not |nullptr|, then their
787 // type is either |Float| or a |Vector| whose element type is |Float|.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule)788 ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
789   return [scalar_rule](IRContext* context, Instruction* inst,
790                        const std::vector<const analysis::Constant*>& constants)
791              -> const analysis::Constant* {
792     if (!inst->IsFloatingPointFoldingAllowed()) {
793       return nullptr;
794     }
795     if (inst->opcode() == spv::Op::OpExtInst) {
796       return FoldFPBinaryOp(scalar_rule, inst->type_id(),
797                             {constants[1], constants[2]}, context);
798     }
799     return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
800   };
801 }
802 
803 // This macro defines a |UnaryScalarFoldingRule| that performs float to
804 // integer conversion.
805 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldFToIOp()806 UnaryScalarFoldingRule FoldFToIOp() {
807   return [](const analysis::Type* result_type, const analysis::Constant* a,
808             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
809     assert(result_type != nullptr && a != nullptr);
810     const analysis::Integer* integer_type = result_type->AsInteger();
811     const analysis::Float* float_type = a->type()->AsFloat();
812     assert(float_type != nullptr);
813     assert(integer_type != nullptr);
814     if (integer_type->width() != 32) return nullptr;
815     if (float_type->width() == 32) {
816       float fa = a->GetFloat();
817       uint32_t result = integer_type->IsSigned()
818                             ? static_cast<uint32_t>(static_cast<int32_t>(fa))
819                             : static_cast<uint32_t>(fa);
820       std::vector<uint32_t> words = {result};
821       return const_mgr->GetConstant(result_type, words);
822     } else if (float_type->width() == 64) {
823       double fa = a->GetDouble();
824       uint32_t result = integer_type->IsSigned()
825                             ? static_cast<uint32_t>(static_cast<int32_t>(fa))
826                             : static_cast<uint32_t>(fa);
827       std::vector<uint32_t> words = {result};
828       return const_mgr->GetConstant(result_type, words);
829     }
830     return nullptr;
831   };
832 }
833 
834 // This function defines a |UnaryScalarFoldingRule| that performs integer to
835 // float conversion.
836 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldIToFOp()837 UnaryScalarFoldingRule FoldIToFOp() {
838   return [](const analysis::Type* result_type, const analysis::Constant* a,
839             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
840     assert(result_type != nullptr && a != nullptr);
841     const analysis::Integer* integer_type = a->type()->AsInteger();
842     const analysis::Float* float_type = result_type->AsFloat();
843     assert(float_type != nullptr);
844     assert(integer_type != nullptr);
845     if (integer_type->width() != 32) return nullptr;
846     uint32_t ua = a->GetU32();
847     if (float_type->width() == 32) {
848       float result_val = integer_type->IsSigned()
849                              ? static_cast<float>(static_cast<int32_t>(ua))
850                              : static_cast<float>(ua);
851       utils::FloatProxy<float> result(result_val);
852       std::vector<uint32_t> words = {result.data()};
853       return const_mgr->GetConstant(result_type, words);
854     } else if (float_type->width() == 64) {
855       double result_val = integer_type->IsSigned()
856                               ? static_cast<double>(static_cast<int32_t>(ua))
857                               : static_cast<double>(ua);
858       utils::FloatProxy<double> result(result_val);
859       std::vector<uint32_t> words = result.GetWords();
860       return const_mgr->GetConstant(result_type, words);
861     }
862     return nullptr;
863   };
864 }
865 
866 // This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
FoldQuantizeToF16Scalar()867 UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
868   return [](const analysis::Type* result_type, const analysis::Constant* a,
869             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
870     assert(result_type != nullptr && a != nullptr);
871     const analysis::Float* float_type = a->type()->AsFloat();
872     assert(float_type != nullptr);
873     if (float_type->width() != 32) {
874       return nullptr;
875     }
876 
877     float fa = a->GetFloat();
878     utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
879     utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
880     utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
881     orignal.castTo(quantized, utils::round_direction::kToZero);
882     quantized.castTo(result, utils::round_direction::kToZero);
883     std::vector<uint32_t> words = {result.getBits()};
884     return const_mgr->GetConstant(result_type, words);
885   };
886 }
887 
888 // This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
889 // operator |op| must work for both float and double, and use syntax "f1 op f2".
890 #define FOLD_FPARITH_OP(op)                                                   \
891   [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
892      const analysis::Constant* b,                                             \
893      analysis::ConstantManager* const_mgr_in_macro)                           \
894       -> const analysis::Constant* {                                          \
895     assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr);  \
896     assert(result_type_in_macro == a->type() &&                               \
897            result_type_in_macro == b->type());                                \
898     const analysis::Float* float_type_in_macro =                              \
899         result_type_in_macro->AsFloat();                                      \
900     assert(float_type_in_macro != nullptr);                                   \
901     if (float_type_in_macro->width() == 32) {                                 \
902       float fa = a->GetFloat();                                               \
903       float fb = b->GetFloat();                                               \
904       utils::FloatProxy<float> result_in_macro(fa op fb);                     \
905       std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
906       return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
907                                              words_in_macro);                 \
908     } else if (float_type_in_macro->width() == 64) {                          \
909       double fa = a->GetDouble();                                             \
910       double fb = b->GetDouble();                                             \
911       utils::FloatProxy<double> result_in_macro(fa op fb);                    \
912       std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
913       return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
914                                              words_in_macro);                 \
915     }                                                                         \
916     return nullptr;                                                           \
917   }
918 
919 // Define the folding rule for conversion between floating point and integer
FoldFToI()920 ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
FoldIToF()921 ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
FoldQuantizeToF16()922 ConstantFoldingRule FoldQuantizeToF16() {
923   return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
924 }
925 
926 // Define the folding rules for subtraction, addition, multiplication, and
927 // division for floating point values.
FoldFSub()928 ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
FoldFAdd()929 ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
FoldFMul()930 ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
931 
932 // Returns the constant that results from evaluating |numerator| / 0.0.  Returns
933 // |nullptr| if the result could not be evaluated.
FoldFPScalarDivideByZero(const analysis::Type * result_type,const analysis::Constant * numerator,analysis::ConstantManager * const_mgr)934 const analysis::Constant* FoldFPScalarDivideByZero(
935     const analysis::Type* result_type, const analysis::Constant* numerator,
936     analysis::ConstantManager* const_mgr) {
937   if (numerator == nullptr) {
938     return nullptr;
939   }
940 
941   if (numerator->IsZero()) {
942     return GetNan(result_type, const_mgr);
943   }
944 
945   const analysis::Constant* result = GetInf(result_type, const_mgr);
946   if (result == nullptr) {
947     return nullptr;
948   }
949 
950   if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
951     result = NegateFPConst(result_type, result, const_mgr);
952   }
953   return result;
954 }
955 
956 // Returns the result of folding |numerator| / |denominator|.  Returns |nullptr|
957 // if it cannot be folded.
FoldScalarFPDivide(const analysis::Type * result_type,const analysis::Constant * numerator,const analysis::Constant * denominator,analysis::ConstantManager * const_mgr)958 const analysis::Constant* FoldScalarFPDivide(
959     const analysis::Type* result_type, const analysis::Constant* numerator,
960     const analysis::Constant* denominator,
961     analysis::ConstantManager* const_mgr) {
962   if (denominator == nullptr) {
963     return nullptr;
964   }
965 
966   if (denominator->IsZero()) {
967     return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
968   }
969 
970   uint32_t width = denominator->type()->AsFloat()->width();
971   if (width != 32 && width != 64) {
972     return nullptr;
973   }
974 
975   const analysis::FloatConstant* denominator_float =
976       denominator->AsFloatConstant();
977   if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
978     const analysis::Constant* result =
979         FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
980     if (result != nullptr)
981       result = NegateFPConst(result_type, result, const_mgr);
982     return result;
983   } else {
984     return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
985   }
986 }
987 
988 // Returns the constant folding rule to fold |OpFDiv| with two constants.
FoldFDiv()989 ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
990 
CompareFloatingPoint(bool op_result,bool op_unordered,bool need_ordered)991 bool CompareFloatingPoint(bool op_result, bool op_unordered,
992                           bool need_ordered) {
993   if (need_ordered) {
994     // operands are ordered and Operand 1 is |op| Operand 2
995     return !op_unordered && op_result;
996   } else {
997     // operands are unordered or Operand 1 is |op| Operand 2
998     return op_unordered || op_result;
999   }
1000 }
1001 
1002 // This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
1003 // operator |op| must work for both float and double, and use syntax "f1 op f2".
1004 #define FOLD_FPCMP_OP(op, ord)                                            \
1005   [](const analysis::Type* result_type, const analysis::Constant* a,      \
1006      const analysis::Constant* b,                                         \
1007      analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
1008     assert(result_type != nullptr && a != nullptr && b != nullptr);       \
1009     assert(result_type->AsBool());                                        \
1010     assert(a->type() == b->type());                                       \
1011     const analysis::Float* float_type = a->type()->AsFloat();             \
1012     assert(float_type != nullptr);                                        \
1013     if (float_type->width() == 32) {                                      \
1014       float fa = a->GetFloat();                                           \
1015       float fb = b->GetFloat();                                           \
1016       bool result = CompareFloatingPoint(                                 \
1017           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
1018       std::vector<uint32_t> words = {uint32_t(result)};                   \
1019       return const_mgr->GetConstant(result_type, words);                  \
1020     } else if (float_type->width() == 64) {                               \
1021       double fa = a->GetDouble();                                         \
1022       double fb = b->GetDouble();                                         \
1023       bool result = CompareFloatingPoint(                                 \
1024           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
1025       std::vector<uint32_t> words = {uint32_t(result)};                   \
1026       return const_mgr->GetConstant(result_type, words);                  \
1027     }                                                                     \
1028     return nullptr;                                                       \
1029   }
1030 
1031 // Define the folding rules for ordered and unordered comparison for floating
1032 // point values.
FoldFOrdEqual()1033 ConstantFoldingRule FoldFOrdEqual() {
1034   return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
1035 }
FoldFUnordEqual()1036 ConstantFoldingRule FoldFUnordEqual() {
1037   return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
1038 }
FoldFOrdNotEqual()1039 ConstantFoldingRule FoldFOrdNotEqual() {
1040   return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
1041 }
FoldFUnordNotEqual()1042 ConstantFoldingRule FoldFUnordNotEqual() {
1043   return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
1044 }
FoldFOrdLessThan()1045 ConstantFoldingRule FoldFOrdLessThan() {
1046   return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
1047 }
FoldFUnordLessThan()1048 ConstantFoldingRule FoldFUnordLessThan() {
1049   return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
1050 }
FoldFOrdGreaterThan()1051 ConstantFoldingRule FoldFOrdGreaterThan() {
1052   return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
1053 }
FoldFUnordGreaterThan()1054 ConstantFoldingRule FoldFUnordGreaterThan() {
1055   return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
1056 }
FoldFOrdLessThanEqual()1057 ConstantFoldingRule FoldFOrdLessThanEqual() {
1058   return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
1059 }
FoldFUnordLessThanEqual()1060 ConstantFoldingRule FoldFUnordLessThanEqual() {
1061   return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
1062 }
FoldFOrdGreaterThanEqual()1063 ConstantFoldingRule FoldFOrdGreaterThanEqual() {
1064   return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
1065 }
FoldFUnordGreaterThanEqual()1066 ConstantFoldingRule FoldFUnordGreaterThanEqual() {
1067   return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
1068 }
1069 
1070 // Folds an OpDot where all of the inputs are constants to a
1071 // constant.  A new constant is created if necessary.
FoldOpDotWithConstants()1072 ConstantFoldingRule FoldOpDotWithConstants() {
1073   return [](IRContext* context, Instruction* inst,
1074             const std::vector<const analysis::Constant*>& constants)
1075              -> const analysis::Constant* {
1076     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1077     analysis::TypeManager* type_mgr = context->get_type_mgr();
1078     const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
1079     assert(new_type->AsFloat() && "OpDot should have a float return type.");
1080     const analysis::Float* float_type = new_type->AsFloat();
1081 
1082     if (!inst->IsFloatingPointFoldingAllowed()) {
1083       return nullptr;
1084     }
1085 
1086     // If one of the operands is 0, then the result is 0.
1087     bool has_zero_operand = false;
1088 
1089     for (int i = 0; i < 2; ++i) {
1090       if (constants[i]) {
1091         if (constants[i]->AsNullConstant() ||
1092             constants[i]->AsVectorConstant()->IsZero()) {
1093           has_zero_operand = true;
1094           break;
1095         }
1096       }
1097     }
1098 
1099     if (has_zero_operand) {
1100       if (float_type->width() == 32) {
1101         utils::FloatProxy<float> result(0.0f);
1102         std::vector<uint32_t> words = result.GetWords();
1103         return const_mgr->GetConstant(float_type, words);
1104       }
1105       if (float_type->width() == 64) {
1106         utils::FloatProxy<double> result(0.0);
1107         std::vector<uint32_t> words = result.GetWords();
1108         return const_mgr->GetConstant(float_type, words);
1109       }
1110       return nullptr;
1111     }
1112 
1113     if (constants[0] == nullptr || constants[1] == nullptr) {
1114       return nullptr;
1115     }
1116 
1117     std::vector<const analysis::Constant*> a_components;
1118     std::vector<const analysis::Constant*> b_components;
1119 
1120     a_components = constants[0]->GetVectorComponents(const_mgr);
1121     b_components = constants[1]->GetVectorComponents(const_mgr);
1122 
1123     utils::FloatProxy<double> result(0.0);
1124     std::vector<uint32_t> words = result.GetWords();
1125     const analysis::Constant* result_const =
1126         const_mgr->GetConstant(float_type, words);
1127     for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
1128          ++i) {
1129       if (a_components[i] == nullptr || b_components[i] == nullptr) {
1130         return nullptr;
1131       }
1132 
1133       const analysis::Constant* component = FOLD_FPARITH_OP(*)(
1134           new_type, a_components[i], b_components[i], const_mgr);
1135       if (component == nullptr) {
1136         return nullptr;
1137       }
1138       result_const =
1139           FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
1140     }
1141     return result_const;
1142   };
1143 }
1144 
FoldFNegate()1145 ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(NegateFPConst); }
FoldSNegate()1146 ConstantFoldingRule FoldSNegate() { return FoldUnaryOp(NegateIntConst); }
1147 
FoldFClampFeedingCompare(spv::Op cmp_opcode)1148 ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
1149   return [cmp_opcode](IRContext* context, Instruction* inst,
1150                       const std::vector<const analysis::Constant*>& constants)
1151              -> const analysis::Constant* {
1152     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1153     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1154 
1155     if (!inst->IsFloatingPointFoldingAllowed()) {
1156       return nullptr;
1157     }
1158 
1159     uint32_t non_const_idx = (constants[0] ? 1 : 0);
1160     uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
1161     Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
1162 
1163     analysis::TypeManager* type_mgr = context->get_type_mgr();
1164     const analysis::Type* operand_type =
1165         type_mgr->GetType(operand_inst->type_id());
1166 
1167     if (!operand_type->AsFloat()) {
1168       return nullptr;
1169     }
1170 
1171     if (operand_type->AsFloat()->width() != 32 &&
1172         operand_type->AsFloat()->width() != 64) {
1173       return nullptr;
1174     }
1175 
1176     if (operand_inst->opcode() != spv::Op::OpExtInst) {
1177       return nullptr;
1178     }
1179 
1180     if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
1181       return nullptr;
1182     }
1183 
1184     if (constants[1] == nullptr && constants[0] == nullptr) {
1185       return nullptr;
1186     }
1187 
1188     uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
1189     const analysis::Constant* max_const =
1190         const_mgr->FindDeclaredConstant(max_id);
1191 
1192     uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
1193     const analysis::Constant* min_const =
1194         const_mgr->FindDeclaredConstant(min_id);
1195 
1196     bool found_result = false;
1197     bool result = false;
1198 
1199     switch (cmp_opcode) {
1200       case spv::Op::OpFOrdLessThan:
1201       case spv::Op::OpFUnordLessThan:
1202       case spv::Op::OpFOrdGreaterThanEqual:
1203       case spv::Op::OpFUnordGreaterThanEqual:
1204         if (constants[0]) {
1205           if (min_const) {
1206             if (constants[0]->GetValueAsDouble() <
1207                 min_const->GetValueAsDouble()) {
1208               found_result = true;
1209               result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1210                         cmp_opcode == spv::Op::OpFUnordLessThan);
1211             }
1212           }
1213           if (max_const) {
1214             if (constants[0]->GetValueAsDouble() >=
1215                 max_const->GetValueAsDouble()) {
1216               found_result = true;
1217               result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1218                          cmp_opcode == spv::Op::OpFUnordLessThan);
1219             }
1220           }
1221         }
1222 
1223         if (constants[1]) {
1224           if (max_const) {
1225             if (max_const->GetValueAsDouble() <
1226                 constants[1]->GetValueAsDouble()) {
1227               found_result = true;
1228               result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1229                         cmp_opcode == spv::Op::OpFUnordLessThan);
1230             }
1231           }
1232 
1233           if (min_const) {
1234             if (min_const->GetValueAsDouble() >=
1235                 constants[1]->GetValueAsDouble()) {
1236               found_result = true;
1237               result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1238                          cmp_opcode == spv::Op::OpFUnordLessThan);
1239             }
1240           }
1241         }
1242         break;
1243       case spv::Op::OpFOrdGreaterThan:
1244       case spv::Op::OpFUnordGreaterThan:
1245       case spv::Op::OpFOrdLessThanEqual:
1246       case spv::Op::OpFUnordLessThanEqual:
1247         if (constants[0]) {
1248           if (min_const) {
1249             if (constants[0]->GetValueAsDouble() <=
1250                 min_const->GetValueAsDouble()) {
1251               found_result = true;
1252               result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1253                         cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1254             }
1255           }
1256           if (max_const) {
1257             if (constants[0]->GetValueAsDouble() >
1258                 max_const->GetValueAsDouble()) {
1259               found_result = true;
1260               result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1261                          cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1262             }
1263           }
1264         }
1265 
1266         if (constants[1]) {
1267           if (max_const) {
1268             if (max_const->GetValueAsDouble() <=
1269                 constants[1]->GetValueAsDouble()) {
1270               found_result = true;
1271               result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1272                         cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1273             }
1274           }
1275 
1276           if (min_const) {
1277             if (min_const->GetValueAsDouble() >
1278                 constants[1]->GetValueAsDouble()) {
1279               found_result = true;
1280               result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1281                          cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1282             }
1283           }
1284         }
1285         break;
1286       default:
1287         return nullptr;
1288     }
1289 
1290     if (!found_result) {
1291       return nullptr;
1292     }
1293 
1294     const analysis::Type* bool_type =
1295         context->get_type_mgr()->GetType(inst->type_id());
1296     const analysis::Constant* result_const =
1297         const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
1298     assert(result_const);
1299     return result_const;
1300   };
1301 }
1302 
FoldFMix()1303 ConstantFoldingRule FoldFMix() {
1304   return [](IRContext* context, Instruction* inst,
1305             const std::vector<const analysis::Constant*>& constants)
1306              -> const analysis::Constant* {
1307     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1308     assert(inst->opcode() == spv::Op::OpExtInst &&
1309            "Expecting an extended instruction.");
1310     assert(inst->GetSingleWordInOperand(0) ==
1311                context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1312            "Expecting a GLSLstd450 extended instruction.");
1313     assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
1314            "Expecting and FMix instruction.");
1315 
1316     if (!inst->IsFloatingPointFoldingAllowed()) {
1317       return nullptr;
1318     }
1319 
1320     // Make sure all FMix operands are constants.
1321     for (uint32_t i = 1; i < 4; i++) {
1322       if (constants[i] == nullptr) {
1323         return nullptr;
1324       }
1325     }
1326 
1327     const analysis::Constant* one;
1328     bool is_vector = false;
1329     const analysis::Type* result_type = constants[1]->type();
1330     const analysis::Type* base_type = result_type;
1331     if (base_type->AsVector()) {
1332       is_vector = true;
1333       base_type = base_type->AsVector()->element_type();
1334     }
1335     assert(base_type->AsFloat() != nullptr &&
1336            "FMix is suppose to act on floats or vectors of floats.");
1337 
1338     if (base_type->AsFloat()->width() == 32) {
1339       one = const_mgr->GetConstant(base_type,
1340                                    utils::FloatProxy<float>(1.0f).GetWords());
1341     } else {
1342       one = const_mgr->GetConstant(base_type,
1343                                    utils::FloatProxy<double>(1.0).GetWords());
1344     }
1345 
1346     if (is_vector) {
1347       uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
1348       one =
1349           const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
1350     }
1351 
1352     const analysis::Constant* temp1 = FoldFPBinaryOp(
1353         FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
1354     if (temp1 == nullptr) {
1355       return nullptr;
1356     }
1357 
1358     const analysis::Constant* temp2 = FoldFPBinaryOp(
1359         FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
1360     if (temp2 == nullptr) {
1361       return nullptr;
1362     }
1363     const analysis::Constant* temp3 =
1364         FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
1365                        {constants[2], constants[3]}, context);
1366     if (temp3 == nullptr) {
1367       return nullptr;
1368     }
1369     return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
1370                           context);
1371   };
1372 }
1373 
FoldMin(const analysis::Type * result_type,const analysis::Constant * a,const analysis::Constant * b,analysis::ConstantManager *)1374 const analysis::Constant* FoldMin(const analysis::Type* result_type,
1375                                   const analysis::Constant* a,
1376                                   const analysis::Constant* b,
1377                                   analysis::ConstantManager*) {
1378   if (const analysis::Integer* int_type = result_type->AsInteger()) {
1379     if (int_type->width() == 32) {
1380       if (int_type->IsSigned()) {
1381         int32_t va = a->GetS32();
1382         int32_t vb = b->GetS32();
1383         return (va < vb ? a : b);
1384       } else {
1385         uint32_t va = a->GetU32();
1386         uint32_t vb = b->GetU32();
1387         return (va < vb ? a : b);
1388       }
1389     } else if (int_type->width() == 64) {
1390       if (int_type->IsSigned()) {
1391         int64_t va = a->GetS64();
1392         int64_t vb = b->GetS64();
1393         return (va < vb ? a : b);
1394       } else {
1395         uint64_t va = a->GetU64();
1396         uint64_t vb = b->GetU64();
1397         return (va < vb ? a : b);
1398       }
1399     }
1400   } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1401     if (float_type->width() == 32) {
1402       float va = a->GetFloat();
1403       float vb = b->GetFloat();
1404       return (va < vb ? a : b);
1405     } else if (float_type->width() == 64) {
1406       double va = a->GetDouble();
1407       double vb = b->GetDouble();
1408       return (va < vb ? a : b);
1409     }
1410   }
1411   return nullptr;
1412 }
1413 
FoldMax(const analysis::Type * result_type,const analysis::Constant * a,const analysis::Constant * b,analysis::ConstantManager *)1414 const analysis::Constant* FoldMax(const analysis::Type* result_type,
1415                                   const analysis::Constant* a,
1416                                   const analysis::Constant* b,
1417                                   analysis::ConstantManager*) {
1418   if (const analysis::Integer* int_type = result_type->AsInteger()) {
1419     if (int_type->width() == 32) {
1420       if (int_type->IsSigned()) {
1421         int32_t va = a->GetS32();
1422         int32_t vb = b->GetS32();
1423         return (va > vb ? a : b);
1424       } else {
1425         uint32_t va = a->GetU32();
1426         uint32_t vb = b->GetU32();
1427         return (va > vb ? a : b);
1428       }
1429     } else if (int_type->width() == 64) {
1430       if (int_type->IsSigned()) {
1431         int64_t va = a->GetS64();
1432         int64_t vb = b->GetS64();
1433         return (va > vb ? a : b);
1434       } else {
1435         uint64_t va = a->GetU64();
1436         uint64_t vb = b->GetU64();
1437         return (va > vb ? a : b);
1438       }
1439     }
1440   } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1441     if (float_type->width() == 32) {
1442       float va = a->GetFloat();
1443       float vb = b->GetFloat();
1444       return (va > vb ? a : b);
1445     } else if (float_type->width() == 64) {
1446       double va = a->GetDouble();
1447       double vb = b->GetDouble();
1448       return (va > vb ? a : b);
1449     }
1450   }
1451   return nullptr;
1452 }
1453 
1454 // Fold an clamp instruction when all three operands are constant.
FoldClamp1(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1455 const analysis::Constant* FoldClamp1(
1456     IRContext* context, Instruction* inst,
1457     const std::vector<const analysis::Constant*>& constants) {
1458   assert(inst->opcode() == spv::Op::OpExtInst &&
1459          "Expecting an extended instruction.");
1460   assert(inst->GetSingleWordInOperand(0) ==
1461              context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1462          "Expecting a GLSLstd450 extended instruction.");
1463 
1464   // Make sure all Clamp operands are constants.
1465   for (uint32_t i = 1; i < 4; i++) {
1466     if (constants[i] == nullptr) {
1467       return nullptr;
1468     }
1469   }
1470 
1471   const analysis::Constant* temp = FoldFPBinaryOp(
1472       FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1473   if (temp == nullptr) {
1474     return nullptr;
1475   }
1476   return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1477                         context);
1478 }
1479 
1480 // Fold a clamp instruction when |x <= min_val|.
FoldClamp2(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1481 const analysis::Constant* FoldClamp2(
1482     IRContext* context, Instruction* inst,
1483     const std::vector<const analysis::Constant*>& constants) {
1484   assert(inst->opcode() == spv::Op::OpExtInst &&
1485          "Expecting an extended instruction.");
1486   assert(inst->GetSingleWordInOperand(0) ==
1487              context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1488          "Expecting a GLSLstd450 extended instruction.");
1489 
1490   const analysis::Constant* x = constants[1];
1491   const analysis::Constant* min_val = constants[2];
1492 
1493   if (x == nullptr || min_val == nullptr) {
1494     return nullptr;
1495   }
1496 
1497   const analysis::Constant* temp =
1498       FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1499   if (temp == min_val) {
1500     // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1501     // result of the max operation is |min_val|, we know the result of the min
1502     // operation, even if |max_val| is not a constant.
1503     return min_val;
1504   }
1505   return nullptr;
1506 }
1507 
1508 // Fold a clamp instruction when |x >= max_val|.
FoldClamp3(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1509 const analysis::Constant* FoldClamp3(
1510     IRContext* context, Instruction* inst,
1511     const std::vector<const analysis::Constant*>& constants) {
1512   assert(inst->opcode() == spv::Op::OpExtInst &&
1513          "Expecting an extended instruction.");
1514   assert(inst->GetSingleWordInOperand(0) ==
1515              context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1516          "Expecting a GLSLstd450 extended instruction.");
1517 
1518   const analysis::Constant* x = constants[1];
1519   const analysis::Constant* max_val = constants[3];
1520 
1521   if (x == nullptr || max_val == nullptr) {
1522     return nullptr;
1523   }
1524 
1525   const analysis::Constant* temp =
1526       FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1527   if (temp == max_val) {
1528     // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1529     // result of the max operation is |min_val|, we know the result of the min
1530     // operation, even if |max_val| is not a constant.
1531     return max_val;
1532   }
1533   return nullptr;
1534 }
1535 
FoldFTranscendentalUnary(double (* fp)(double))1536 UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
1537   return
1538       [fp](const analysis::Type* result_type, const analysis::Constant* a,
1539            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1540         assert(result_type != nullptr && a != nullptr);
1541         const analysis::Float* float_type = a->type()->AsFloat();
1542         assert(float_type != nullptr);
1543         assert(float_type == result_type->AsFloat());
1544         if (float_type->width() == 32) {
1545           float fa = a->GetFloat();
1546           float res = static_cast<float>(fp(fa));
1547           utils::FloatProxy<float> result(res);
1548           std::vector<uint32_t> words = result.GetWords();
1549           return const_mgr->GetConstant(result_type, words);
1550         } else if (float_type->width() == 64) {
1551           double fa = a->GetDouble();
1552           double res = fp(fa);
1553           utils::FloatProxy<double> result(res);
1554           std::vector<uint32_t> words = result.GetWords();
1555           return const_mgr->GetConstant(result_type, words);
1556         }
1557         return nullptr;
1558       };
1559 }
1560 
FoldFTranscendentalBinary(double (* fp)(double,double))1561 BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1562                                                                double)) {
1563   return
1564       [fp](const analysis::Type* result_type, const analysis::Constant* a,
1565            const analysis::Constant* b,
1566            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1567         assert(result_type != nullptr && a != nullptr);
1568         const analysis::Float* float_type = a->type()->AsFloat();
1569         assert(float_type != nullptr);
1570         assert(float_type == result_type->AsFloat());
1571         assert(float_type == b->type()->AsFloat());
1572         if (float_type->width() == 32) {
1573           float fa = a->GetFloat();
1574           float fb = b->GetFloat();
1575           float res = static_cast<float>(fp(fa, fb));
1576           utils::FloatProxy<float> result(res);
1577           std::vector<uint32_t> words = result.GetWords();
1578           return const_mgr->GetConstant(result_type, words);
1579         } else if (float_type->width() == 64) {
1580           double fa = a->GetDouble();
1581           double fb = b->GetDouble();
1582           double res = fp(fa, fb);
1583           utils::FloatProxy<double> result(res);
1584           std::vector<uint32_t> words = result.GetWords();
1585           return const_mgr->GetConstant(result_type, words);
1586         }
1587         return nullptr;
1588       };
1589 }
1590 }  // namespace
1591 
AddFoldingRules()1592 void ConstantFoldingRules::AddFoldingRules() {
1593   // Add all folding rules to the list for the opcodes to which they apply.
1594   // Note that the order in which rules are added to the list matters. If a rule
1595   // applies to the instruction, the rest of the rules will not be attempted.
1596   // Take that into consideration.
1597 
1598   rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
1599 
1600   rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
1601   rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
1602 
1603   rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
1604   rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
1605   rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
1606   rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
1607 
1608   rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
1609   rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
1610   rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
1611   rules_[spv::Op::OpFMul].push_back(FoldFMul());
1612   rules_[spv::Op::OpFSub].push_back(FoldFSub());
1613 
1614   rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual());
1615 
1616   rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual());
1617 
1618   rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual());
1619 
1620   rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual());
1621 
1622   rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan());
1623   rules_[spv::Op::OpFOrdLessThan].push_back(
1624       FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan));
1625 
1626   rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan());
1627   rules_[spv::Op::OpFUnordLessThan].push_back(
1628       FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan));
1629 
1630   rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
1631   rules_[spv::Op::OpFOrdGreaterThan].push_back(
1632       FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan));
1633 
1634   rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
1635   rules_[spv::Op::OpFUnordGreaterThan].push_back(
1636       FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan));
1637 
1638   rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
1639   rules_[spv::Op::OpFOrdLessThanEqual].push_back(
1640       FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual));
1641 
1642   rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
1643   rules_[spv::Op::OpFUnordLessThanEqual].push_back(
1644       FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual));
1645 
1646   rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
1647   rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(
1648       FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual));
1649 
1650   rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1651       FoldFUnordGreaterThanEqual());
1652   rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1653       FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual));
1654 
1655   rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
1656   rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
1657   rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
1658   rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
1659   rules_[spv::Op::OpTranspose].push_back(FoldTranspose);
1660 
1661   rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
1662   rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
1663   rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
1664 
1665   // Add rules for GLSLstd450
1666   FeatureManager* feature_manager = context_->get_feature_mgr();
1667   uint32_t ext_inst_glslstd450_id =
1668       feature_manager->GetExtInstImportId_GLSLstd450();
1669   if (ext_inst_glslstd450_id != 0) {
1670     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
1671     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
1672         FoldFPBinaryOp(FoldMin));
1673     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
1674         FoldFPBinaryOp(FoldMin));
1675     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
1676         FoldFPBinaryOp(FoldMin));
1677     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
1678         FoldFPBinaryOp(FoldMax));
1679     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
1680         FoldFPBinaryOp(FoldMax));
1681     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
1682         FoldFPBinaryOp(FoldMax));
1683     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1684         FoldClamp1);
1685     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1686         FoldClamp2);
1687     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1688         FoldClamp3);
1689     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1690         FoldClamp1);
1691     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1692         FoldClamp2);
1693     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1694         FoldClamp3);
1695     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1696         FoldClamp1);
1697     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1698         FoldClamp2);
1699     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1700         FoldClamp3);
1701     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
1702         FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
1703     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
1704         FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
1705     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
1706         FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
1707     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
1708         FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
1709     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
1710         FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
1711     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
1712         FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
1713     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
1714         FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
1715     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
1716         FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
1717 
1718 #ifdef __ANDROID__
1719     // Android NDK r15c targeting ABI 15 doesn't have full support for C++11
1720     // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
1721     // available up until ABI 18 so we use a shim
1722     auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
1723     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1724         FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
1725     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1726         FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
1727 #else
1728     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1729         FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
1730     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1731         FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
1732 #endif
1733 
1734     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
1735         FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
1736     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
1737         FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
1738     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
1739         FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
1740   }
1741 }
1742 }  // namespace opt
1743 }  // namespace spvtools
1744