1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/const_folding_rules.h"
16
17 #include "source/opt/ir_context.h"
18
19 namespace spvtools {
20 namespace opt {
21 namespace {
22 constexpr uint32_t kExtractCompositeIdInIdx = 0;
23
24 // Returns a constants with the value NaN of the given type. Only works for
25 // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
GetNan(const analysis::Type * type,analysis::ConstantManager * const_mgr)26 const analysis::Constant* GetNan(const analysis::Type* type,
27 analysis::ConstantManager* const_mgr) {
28 const analysis::Float* float_type = type->AsFloat();
29 if (float_type == nullptr) {
30 return nullptr;
31 }
32
33 switch (float_type->width()) {
34 case 32:
35 return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
36 case 64:
37 return const_mgr->GetDoubleConst(
38 std::numeric_limits<double>::quiet_NaN());
39 default:
40 return nullptr;
41 }
42 }
43
44 // Returns a constants with the value INF of the given type. Only works for
45 // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
GetInf(const analysis::Type * type,analysis::ConstantManager * const_mgr)46 const analysis::Constant* GetInf(const analysis::Type* type,
47 analysis::ConstantManager* const_mgr) {
48 const analysis::Float* float_type = type->AsFloat();
49 if (float_type == nullptr) {
50 return nullptr;
51 }
52
53 switch (float_type->width()) {
54 case 32:
55 return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
56 case 64:
57 return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
58 default:
59 return nullptr;
60 }
61 }
62
63 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)64 bool HasFloatingPoint(const analysis::Type* type) {
65 if (type->AsFloat()) {
66 return true;
67 } else if (const analysis::Vector* vec_type = type->AsVector()) {
68 return vec_type->element_type()->AsFloat() != nullptr;
69 }
70
71 return false;
72 }
73
74 // Returns a constants with the value |-val| of the given type. Only works for
75 // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
NegateFPConst(const analysis::Type * result_type,const analysis::Constant * val,analysis::ConstantManager * const_mgr)76 const analysis::Constant* NegateFPConst(const analysis::Type* result_type,
77 const analysis::Constant* val,
78 analysis::ConstantManager* const_mgr) {
79 const analysis::Float* float_type = result_type->AsFloat();
80 assert(float_type != nullptr);
81 if (float_type->width() == 32) {
82 float fa = val->GetFloat();
83 return const_mgr->GetFloatConst(-fa);
84 } else if (float_type->width() == 64) {
85 double da = val->GetDouble();
86 return const_mgr->GetDoubleConst(-da);
87 }
88 return nullptr;
89 }
90
91 // Returns a constants with the value |-val| of the given type.
NegateIntConst(const analysis::Type * result_type,const analysis::Constant * val,analysis::ConstantManager * const_mgr)92 const analysis::Constant* NegateIntConst(const analysis::Type* result_type,
93 const analysis::Constant* val,
94 analysis::ConstantManager* const_mgr) {
95 const analysis::Integer* int_type = result_type->AsInteger();
96 assert(int_type != nullptr);
97
98 if (val->AsNullConstant()) {
99 return val;
100 }
101
102 uint64_t new_value = static_cast<uint64_t>(-val->GetSignExtendedValue());
103 return const_mgr->GetIntConst(new_value, int_type->width(),
104 int_type->IsSigned());
105 }
106
107 // Folds an OpcompositeExtract where input is a composite constant.
FoldExtractWithConstants()108 ConstantFoldingRule FoldExtractWithConstants() {
109 return [](IRContext* context, Instruction* inst,
110 const std::vector<const analysis::Constant*>& constants)
111 -> const analysis::Constant* {
112 const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
113 if (c == nullptr) {
114 return nullptr;
115 }
116
117 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
118 uint32_t element_index = inst->GetSingleWordInOperand(i);
119 if (c->AsNullConstant()) {
120 // Return Null for the return type.
121 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
122 analysis::TypeManager* type_mgr = context->get_type_mgr();
123 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
124 }
125
126 auto cc = c->AsCompositeConstant();
127 assert(cc != nullptr);
128 auto components = cc->GetComponents();
129 // Protect against invalid IR. Refuse to fold if the index is out
130 // of bounds.
131 if (element_index >= components.size()) return nullptr;
132 c = components[element_index];
133 }
134 return c;
135 };
136 }
137
138 // Folds an OpcompositeInsert where input is a composite constant.
FoldInsertWithConstants()139 ConstantFoldingRule FoldInsertWithConstants() {
140 return [](IRContext* context, Instruction* inst,
141 const std::vector<const analysis::Constant*>& constants)
142 -> const analysis::Constant* {
143 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
144 const analysis::Constant* object = constants[0];
145 const analysis::Constant* composite = constants[1];
146 if (object == nullptr || composite == nullptr) {
147 return nullptr;
148 }
149
150 // If there is more than 1 index, then each additional constant used by the
151 // index will need to be recreated to use the inserted object.
152 std::vector<const analysis::Constant*> chain;
153 std::vector<const analysis::Constant*> components;
154 const analysis::Type* type = nullptr;
155 const uint32_t final_index = (inst->NumInOperands() - 1);
156
157 // Work down hierarchy of all indexes
158 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
159 type = composite->type();
160
161 if (composite->AsNullConstant()) {
162 // Make new composite so it can be inserted in the index with the
163 // non-null value
164 if (const auto new_composite =
165 const_mgr->GetNullCompositeConstant(type)) {
166 // Keep track of any indexes along the way to last index
167 if (i != final_index) {
168 chain.push_back(new_composite);
169 }
170 components = new_composite->AsCompositeConstant()->GetComponents();
171 } else {
172 // Unsupported input type (such as structs)
173 return nullptr;
174 }
175 } else {
176 // Keep track of any indexes along the way to last index
177 if (i != final_index) {
178 chain.push_back(composite);
179 }
180 components = composite->AsCompositeConstant()->GetComponents();
181 }
182 const uint32_t index = inst->GetSingleWordInOperand(i);
183 composite = components[index];
184 }
185
186 // Final index in hierarchy is inserted with new object.
187 const uint32_t final_operand = inst->GetSingleWordInOperand(final_index);
188 std::vector<uint32_t> ids;
189 for (size_t i = 0; i < components.size(); i++) {
190 const analysis::Constant* constant =
191 (i == final_operand) ? object : components[i];
192 Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
193 ids.push_back(member_inst->result_id());
194 }
195 const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
196
197 // Work backwards up the chain and replace each index with new constant.
198 for (size_t i = chain.size(); i > 0; i--) {
199 // Need to insert any previous instruction into the module first.
200 // Can't just insert in types_values_begin() because it will move above
201 // where the types are declared.
202 // Can't compare with location of inst because not all new added
203 // instructions are added to types_values_
204 auto iter = context->types_values_end();
205 Module::inst_iterator* pos = &iter;
206 const_mgr->BuildInstructionAndAddToModule(new_constant, pos);
207
208 composite = chain[i - 1];
209 components = composite->AsCompositeConstant()->GetComponents();
210 type = composite->type();
211 ids.clear();
212 for (size_t k = 0; k < components.size(); k++) {
213 const uint32_t index =
214 inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
215 const analysis::Constant* constant =
216 (k == index) ? new_constant : components[k];
217 const uint32_t constant_id =
218 const_mgr->FindDeclaredConstant(constant, 0);
219 ids.push_back(constant_id);
220 }
221 new_constant = const_mgr->GetConstant(type, ids);
222 }
223
224 // If multiple constants were created, only need to return the top index.
225 return new_constant;
226 };
227 }
228
FoldVectorShuffleWithConstants()229 ConstantFoldingRule FoldVectorShuffleWithConstants() {
230 return [](IRContext* context, Instruction* inst,
231 const std::vector<const analysis::Constant*>& constants)
232 -> const analysis::Constant* {
233 assert(inst->opcode() == spv::Op::OpVectorShuffle);
234 const analysis::Constant* c1 = constants[0];
235 const analysis::Constant* c2 = constants[1];
236 if (c1 == nullptr || c2 == nullptr) {
237 return nullptr;
238 }
239
240 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
241 const analysis::Type* element_type = c1->type()->AsVector()->element_type();
242
243 std::vector<const analysis::Constant*> c1_components;
244 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
245 c1_components = vec_const->GetComponents();
246 } else {
247 assert(c1->AsNullConstant());
248 const analysis::Constant* element =
249 const_mgr->GetConstant(element_type, {});
250 c1_components.resize(c1->type()->AsVector()->element_count(), element);
251 }
252 std::vector<const analysis::Constant*> c2_components;
253 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
254 c2_components = vec_const->GetComponents();
255 } else {
256 assert(c2->AsNullConstant());
257 const analysis::Constant* element =
258 const_mgr->GetConstant(element_type, {});
259 c2_components.resize(c2->type()->AsVector()->element_count(), element);
260 }
261
262 std::vector<uint32_t> ids;
263 const uint32_t undef_literal_value = 0xffffffff;
264 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
265 uint32_t index = inst->GetSingleWordInOperand(i);
266 if (index == undef_literal_value) {
267 // Don't fold shuffle with undef literal value.
268 return nullptr;
269 } else if (index < c1_components.size()) {
270 Instruction* member_inst =
271 const_mgr->GetDefiningInstruction(c1_components[index]);
272 ids.push_back(member_inst->result_id());
273 } else {
274 Instruction* member_inst = const_mgr->GetDefiningInstruction(
275 c2_components[index - c1_components.size()]);
276 ids.push_back(member_inst->result_id());
277 }
278 }
279
280 analysis::TypeManager* type_mgr = context->get_type_mgr();
281 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
282 };
283 }
284
FoldVectorTimesScalar()285 ConstantFoldingRule FoldVectorTimesScalar() {
286 return [](IRContext* context, Instruction* inst,
287 const std::vector<const analysis::Constant*>& constants)
288 -> const analysis::Constant* {
289 assert(inst->opcode() == spv::Op::OpVectorTimesScalar);
290 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
291 analysis::TypeManager* type_mgr = context->get_type_mgr();
292
293 if (!inst->IsFloatingPointFoldingAllowed()) {
294 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
295 return nullptr;
296 }
297 }
298
299 const analysis::Constant* c1 = constants[0];
300 const analysis::Constant* c2 = constants[1];
301
302 if (c1 && c1->IsZero()) {
303 return c1;
304 }
305
306 if (c2 && c2->IsZero()) {
307 // Get or create the NullConstant for this type.
308 std::vector<uint32_t> ids;
309 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
310 }
311
312 if (c1 == nullptr || c2 == nullptr) {
313 return nullptr;
314 }
315
316 // Check result type.
317 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
318 const analysis::Vector* vector_type = result_type->AsVector();
319 assert(vector_type != nullptr);
320 const analysis::Type* element_type = vector_type->element_type();
321 assert(element_type != nullptr);
322 const analysis::Float* float_type = element_type->AsFloat();
323 assert(float_type != nullptr);
324
325 // Check types of c1 and c2.
326 assert(c1->type()->AsVector() == vector_type);
327 assert(c1->type()->AsVector()->element_type() == element_type &&
328 c2->type() == element_type);
329
330 // Get a float vector that is the result of vector-times-scalar.
331 std::vector<const analysis::Constant*> c1_components =
332 c1->GetVectorComponents(const_mgr);
333 std::vector<uint32_t> ids;
334 if (float_type->width() == 32) {
335 float scalar = c2->GetFloat();
336 for (uint32_t i = 0; i < c1_components.size(); ++i) {
337 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
338 std::vector<uint32_t> words = result.GetWords();
339 const analysis::Constant* new_elem =
340 const_mgr->GetConstant(float_type, words);
341 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
342 }
343 return const_mgr->GetConstant(vector_type, ids);
344 } else if (float_type->width() == 64) {
345 double scalar = c2->GetDouble();
346 for (uint32_t i = 0; i < c1_components.size(); ++i) {
347 utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
348 scalar);
349 std::vector<uint32_t> words = result.GetWords();
350 const analysis::Constant* new_elem =
351 const_mgr->GetConstant(float_type, words);
352 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
353 }
354 return const_mgr->GetConstant(vector_type, ids);
355 }
356 return nullptr;
357 };
358 }
359
360 // Returns to the constant that results from tranposing |matrix|. The result
361 // will have type |result_type|, and |matrix| must exist in |context|. The
362 // result constant will also exist in |context|.
TransposeMatrix(const analysis::Constant * matrix,analysis::Matrix * result_type,IRContext * context)363 const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix,
364 analysis::Matrix* result_type,
365 IRContext* context) {
366 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
367 if (matrix->AsNullConstant() != nullptr) {
368 return const_mgr->GetNullCompositeConstant(result_type);
369 }
370
371 const auto& columns = matrix->AsMatrixConstant()->GetComponents();
372 uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count();
373
374 // Collect the ids of the elements in their new positions.
375 std::vector<std::vector<uint32_t>> result_elements(number_of_rows);
376 for (const analysis::Constant* column : columns) {
377 if (column->AsNullConstant()) {
378 column = const_mgr->GetNullCompositeConstant(column->type());
379 }
380 const auto& column_components = column->AsVectorConstant()->GetComponents();
381
382 for (uint32_t row = 0; row < number_of_rows; ++row) {
383 result_elements[row].push_back(
384 const_mgr->GetDefiningInstruction(column_components[row])
385 ->result_id());
386 }
387 }
388
389 // Create the constant for each row in the result, and collect the ids.
390 std::vector<uint32_t> result_columns(number_of_rows);
391 for (uint32_t col = 0; col < number_of_rows; ++col) {
392 auto* element = const_mgr->GetConstant(result_type->element_type(),
393 result_elements[col]);
394 result_columns[col] =
395 const_mgr->GetDefiningInstruction(element)->result_id();
396 }
397
398 // Create the matrix constant from the row ids, and return it.
399 return const_mgr->GetConstant(result_type, result_columns);
400 }
401
FoldTranspose(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)402 const analysis::Constant* FoldTranspose(
403 IRContext* context, Instruction* inst,
404 const std::vector<const analysis::Constant*>& constants) {
405 assert(inst->opcode() == spv::Op::OpTranspose);
406
407 analysis::TypeManager* type_mgr = context->get_type_mgr();
408 if (!inst->IsFloatingPointFoldingAllowed()) {
409 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
410 return nullptr;
411 }
412 }
413
414 const analysis::Constant* matrix = constants[0];
415 if (matrix == nullptr) {
416 return nullptr;
417 }
418
419 auto* result_type = type_mgr->GetType(inst->type_id());
420 return TransposeMatrix(matrix, result_type->AsMatrix(), context);
421 }
422
FoldVectorTimesMatrix()423 ConstantFoldingRule FoldVectorTimesMatrix() {
424 return [](IRContext* context, Instruction* inst,
425 const std::vector<const analysis::Constant*>& constants)
426 -> const analysis::Constant* {
427 assert(inst->opcode() == spv::Op::OpVectorTimesMatrix);
428 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
429 analysis::TypeManager* type_mgr = context->get_type_mgr();
430
431 if (!inst->IsFloatingPointFoldingAllowed()) {
432 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
433 return nullptr;
434 }
435 }
436
437 const analysis::Constant* c1 = constants[0];
438 const analysis::Constant* c2 = constants[1];
439
440 if (c1 == nullptr || c2 == nullptr) {
441 return nullptr;
442 }
443
444 // Check result type.
445 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
446 const analysis::Vector* vector_type = result_type->AsVector();
447 assert(vector_type != nullptr);
448 const analysis::Type* element_type = vector_type->element_type();
449 assert(element_type != nullptr);
450 const analysis::Float* float_type = element_type->AsFloat();
451 assert(float_type != nullptr);
452
453 // Check types of c1 and c2.
454 assert(c1->type()->AsVector() == vector_type);
455 assert(c1->type()->AsVector()->element_type() == element_type &&
456 c2->type()->AsMatrix()->element_type() == vector_type);
457
458 uint32_t resultVectorSize = result_type->AsVector()->element_count();
459 std::vector<uint32_t> ids;
460
461 if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
462 std::vector<uint32_t> words(float_type->width() / 32, 0);
463 for (uint32_t i = 0; i < resultVectorSize; ++i) {
464 const analysis::Constant* new_elem =
465 const_mgr->GetConstant(float_type, words);
466 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
467 }
468 return const_mgr->GetConstant(vector_type, ids);
469 }
470
471 // Get a float vector that is the result of vector-times-matrix.
472 std::vector<const analysis::Constant*> c1_components =
473 c1->GetVectorComponents(const_mgr);
474 std::vector<const analysis::Constant*> c2_components =
475 c2->AsMatrixConstant()->GetComponents();
476
477 if (float_type->width() == 32) {
478 for (uint32_t i = 0; i < resultVectorSize; ++i) {
479 float result_scalar = 0.0f;
480 if (!c2_components[i]->AsNullConstant()) {
481 const analysis::VectorConstant* c2_vec =
482 c2_components[i]->AsVectorConstant();
483 for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
484 float c1_scalar = c1_components[j]->GetFloat();
485 float c2_scalar = c2_vec->GetComponents()[j]->GetFloat();
486 result_scalar += c1_scalar * c2_scalar;
487 }
488 }
489 utils::FloatProxy<float> result(result_scalar);
490 std::vector<uint32_t> words = result.GetWords();
491 const analysis::Constant* new_elem =
492 const_mgr->GetConstant(float_type, words);
493 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
494 }
495 return const_mgr->GetConstant(vector_type, ids);
496 } else if (float_type->width() == 64) {
497 for (uint32_t i = 0; i < c2_components.size(); ++i) {
498 double result_scalar = 0.0;
499 if (!c2_components[i]->AsNullConstant()) {
500 const analysis::VectorConstant* c2_vec =
501 c2_components[i]->AsVectorConstant();
502 for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
503 double c1_scalar = c1_components[j]->GetDouble();
504 double c2_scalar = c2_vec->GetComponents()[j]->GetDouble();
505 result_scalar += c1_scalar * c2_scalar;
506 }
507 }
508 utils::FloatProxy<double> result(result_scalar);
509 std::vector<uint32_t> words = result.GetWords();
510 const analysis::Constant* new_elem =
511 const_mgr->GetConstant(float_type, words);
512 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
513 }
514 return const_mgr->GetConstant(vector_type, ids);
515 }
516 return nullptr;
517 };
518 }
519
FoldMatrixTimesVector()520 ConstantFoldingRule FoldMatrixTimesVector() {
521 return [](IRContext* context, Instruction* inst,
522 const std::vector<const analysis::Constant*>& constants)
523 -> const analysis::Constant* {
524 assert(inst->opcode() == spv::Op::OpMatrixTimesVector);
525 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
526 analysis::TypeManager* type_mgr = context->get_type_mgr();
527
528 if (!inst->IsFloatingPointFoldingAllowed()) {
529 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
530 return nullptr;
531 }
532 }
533
534 const analysis::Constant* c1 = constants[0];
535 const analysis::Constant* c2 = constants[1];
536
537 if (c1 == nullptr || c2 == nullptr) {
538 return nullptr;
539 }
540
541 // Check result type.
542 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
543 const analysis::Vector* vector_type = result_type->AsVector();
544 assert(vector_type != nullptr);
545 const analysis::Type* element_type = vector_type->element_type();
546 assert(element_type != nullptr);
547 const analysis::Float* float_type = element_type->AsFloat();
548 assert(float_type != nullptr);
549
550 // Check types of c1 and c2.
551 assert(c1->type()->AsMatrix()->element_type() == vector_type);
552 assert(c2->type()->AsVector()->element_type() == element_type);
553
554 uint32_t resultVectorSize = result_type->AsVector()->element_count();
555 std::vector<uint32_t> ids;
556
557 if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
558 std::vector<uint32_t> words(float_type->width() / 32, 0);
559 for (uint32_t i = 0; i < resultVectorSize; ++i) {
560 const analysis::Constant* new_elem =
561 const_mgr->GetConstant(float_type, words);
562 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
563 }
564 return const_mgr->GetConstant(vector_type, ids);
565 }
566
567 // Get a float vector that is the result of matrix-times-vector.
568 std::vector<const analysis::Constant*> c1_components =
569 c1->AsMatrixConstant()->GetComponents();
570 std::vector<const analysis::Constant*> c2_components =
571 c2->GetVectorComponents(const_mgr);
572
573 if (float_type->width() == 32) {
574 for (uint32_t i = 0; i < resultVectorSize; ++i) {
575 float result_scalar = 0.0f;
576 for (uint32_t j = 0; j < c1_components.size(); ++j) {
577 if (!c1_components[j]->AsNullConstant()) {
578 float c1_scalar = c1_components[j]
579 ->AsVectorConstant()
580 ->GetComponents()[i]
581 ->GetFloat();
582 float c2_scalar = c2_components[j]->GetFloat();
583 result_scalar += c1_scalar * c2_scalar;
584 }
585 }
586 utils::FloatProxy<float> result(result_scalar);
587 std::vector<uint32_t> words = result.GetWords();
588 const analysis::Constant* new_elem =
589 const_mgr->GetConstant(float_type, words);
590 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
591 }
592 return const_mgr->GetConstant(vector_type, ids);
593 } else if (float_type->width() == 64) {
594 for (uint32_t i = 0; i < resultVectorSize; ++i) {
595 double result_scalar = 0.0;
596 for (uint32_t j = 0; j < c1_components.size(); ++j) {
597 if (!c1_components[j]->AsNullConstant()) {
598 double c1_scalar = c1_components[j]
599 ->AsVectorConstant()
600 ->GetComponents()[i]
601 ->GetDouble();
602 double c2_scalar = c2_components[j]->GetDouble();
603 result_scalar += c1_scalar * c2_scalar;
604 }
605 }
606 utils::FloatProxy<double> result(result_scalar);
607 std::vector<uint32_t> words = result.GetWords();
608 const analysis::Constant* new_elem =
609 const_mgr->GetConstant(float_type, words);
610 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
611 }
612 return const_mgr->GetConstant(vector_type, ids);
613 }
614 return nullptr;
615 };
616 }
617
FoldCompositeWithConstants()618 ConstantFoldingRule FoldCompositeWithConstants() {
619 // Folds an OpCompositeConstruct where all of the inputs are constants to a
620 // constant. A new constant is created if necessary.
621 return [](IRContext* context, Instruction* inst,
622 const std::vector<const analysis::Constant*>& constants)
623 -> const analysis::Constant* {
624 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
625 analysis::TypeManager* type_mgr = context->get_type_mgr();
626 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
627 Instruction* type_inst =
628 context->get_def_use_mgr()->GetDef(inst->type_id());
629
630 std::vector<uint32_t> ids;
631 for (uint32_t i = 0; i < constants.size(); ++i) {
632 const analysis::Constant* element_const = constants[i];
633 if (element_const == nullptr) {
634 return nullptr;
635 }
636
637 uint32_t component_type_id = 0;
638 if (type_inst->opcode() == spv::Op::OpTypeStruct) {
639 component_type_id = type_inst->GetSingleWordInOperand(i);
640 } else if (type_inst->opcode() == spv::Op::OpTypeArray) {
641 component_type_id = type_inst->GetSingleWordInOperand(0);
642 }
643
644 uint32_t element_id =
645 const_mgr->FindDeclaredConstant(element_const, component_type_id);
646 if (element_id == 0) {
647 return nullptr;
648 }
649 ids.push_back(element_id);
650 }
651 return const_mgr->GetConstant(new_type, ids);
652 };
653 }
654
655 // The interface for a function that returns the result of applying a scalar
656 // floating-point binary operation on |a| and |b|. The type of the return value
657 // will be |type|. The input constants must also be of type |type|.
658 using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
659 const analysis::Type* result_type, const analysis::Constant* a,
660 analysis::ConstantManager*)>;
661
662 // The interface for a function that returns the result of applying a scalar
663 // floating-point binary operation on |a| and |b|. The type of the return value
664 // will be |type|. The input constants must also be of type |type|.
665 using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
666 const analysis::Type* result_type, const analysis::Constant* a,
667 const analysis::Constant* b, analysis::ConstantManager*)>;
668
669 // Returns a |ConstantFoldingRule| that folds unary scalar ops
670 // using |scalar_rule| and unary vectors ops by applying
671 // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
672 // that is returned assumes that |constants| contains 1 entry. If they are
673 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
674 // whose element type is |Float| or |Integer|.
FoldUnaryOp(UnaryScalarFoldingRule scalar_rule)675 ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
676 return [scalar_rule](IRContext* context, Instruction* inst,
677 const std::vector<const analysis::Constant*>& constants)
678 -> const analysis::Constant* {
679
680 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
681 analysis::TypeManager* type_mgr = context->get_type_mgr();
682 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
683 const analysis::Vector* vector_type = result_type->AsVector();
684
685 const analysis::Constant* arg =
686 (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
687
688 if (arg == nullptr) {
689 return nullptr;
690 }
691
692 if (vector_type != nullptr) {
693 std::vector<const analysis::Constant*> a_components;
694 std::vector<const analysis::Constant*> results_components;
695
696 a_components = arg->GetVectorComponents(const_mgr);
697
698 // Fold each component of the vector.
699 for (uint32_t i = 0; i < a_components.size(); ++i) {
700 results_components.push_back(scalar_rule(vector_type->element_type(),
701 a_components[i], const_mgr));
702 if (results_components[i] == nullptr) {
703 return nullptr;
704 }
705 }
706
707 // Build the constant object and return it.
708 std::vector<uint32_t> ids;
709 for (const analysis::Constant* member : results_components) {
710 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
711 }
712 return const_mgr->GetConstant(vector_type, ids);
713 } else {
714 return scalar_rule(result_type, arg, const_mgr);
715 }
716 };
717 }
718
719 // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
720 // using |scalar_rule| and unary float point vectors ops by applying
721 // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
722 // that is returned assumes that |constants| contains 1 entry. If they are
723 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
724 // whose element type is |Float| or |Integer|.
FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule)725 ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
726 auto folding_rule = FoldUnaryOp(scalar_rule);
727 return [folding_rule](IRContext* context, Instruction* inst,
728 const std::vector<const analysis::Constant*>& constants)
729 -> const analysis::Constant* {
730 if (!inst->IsFloatingPointFoldingAllowed()) {
731 return nullptr;
732 }
733
734 return folding_rule(context, inst, constants);
735 };
736 }
737
738 // Returns the result of folding the constants in |constants| according the
739 // |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
740 // per component.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule,uint32_t result_type_id,const std::vector<const analysis::Constant * > & constants,IRContext * context)741 const analysis::Constant* FoldFPBinaryOp(
742 BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
743 const std::vector<const analysis::Constant*>& constants,
744 IRContext* context) {
745 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
746 analysis::TypeManager* type_mgr = context->get_type_mgr();
747 const analysis::Type* result_type = type_mgr->GetType(result_type_id);
748 const analysis::Vector* vector_type = result_type->AsVector();
749
750 if (constants[0] == nullptr || constants[1] == nullptr) {
751 return nullptr;
752 }
753
754 if (vector_type != nullptr) {
755 std::vector<const analysis::Constant*> a_components;
756 std::vector<const analysis::Constant*> b_components;
757 std::vector<const analysis::Constant*> results_components;
758
759 a_components = constants[0]->GetVectorComponents(const_mgr);
760 b_components = constants[1]->GetVectorComponents(const_mgr);
761
762 // Fold each component of the vector.
763 for (uint32_t i = 0; i < a_components.size(); ++i) {
764 results_components.push_back(scalar_rule(vector_type->element_type(),
765 a_components[i], b_components[i],
766 const_mgr));
767 if (results_components[i] == nullptr) {
768 return nullptr;
769 }
770 }
771
772 // Build the constant object and return it.
773 std::vector<uint32_t> ids;
774 for (const analysis::Constant* member : results_components) {
775 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
776 }
777 return const_mgr->GetConstant(vector_type, ids);
778 } else {
779 return scalar_rule(result_type, constants[0], constants[1], const_mgr);
780 }
781 }
782
783 // Returns a |ConstantFoldingRule| that folds floating point scalars using
784 // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
785 // elements of the vector. The |ConstantFoldingRule| that is returned assumes
786 // that |constants| contains 2 entries. If they are not |nullptr|, then their
787 // type is either |Float| or a |Vector| whose element type is |Float|.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule)788 ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
789 return [scalar_rule](IRContext* context, Instruction* inst,
790 const std::vector<const analysis::Constant*>& constants)
791 -> const analysis::Constant* {
792 if (!inst->IsFloatingPointFoldingAllowed()) {
793 return nullptr;
794 }
795 if (inst->opcode() == spv::Op::OpExtInst) {
796 return FoldFPBinaryOp(scalar_rule, inst->type_id(),
797 {constants[1], constants[2]}, context);
798 }
799 return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
800 };
801 }
802
803 // This macro defines a |UnaryScalarFoldingRule| that performs float to
804 // integer conversion.
805 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldFToIOp()806 UnaryScalarFoldingRule FoldFToIOp() {
807 return [](const analysis::Type* result_type, const analysis::Constant* a,
808 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
809 assert(result_type != nullptr && a != nullptr);
810 const analysis::Integer* integer_type = result_type->AsInteger();
811 const analysis::Float* float_type = a->type()->AsFloat();
812 assert(float_type != nullptr);
813 assert(integer_type != nullptr);
814 if (integer_type->width() != 32) return nullptr;
815 if (float_type->width() == 32) {
816 float fa = a->GetFloat();
817 uint32_t result = integer_type->IsSigned()
818 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
819 : static_cast<uint32_t>(fa);
820 std::vector<uint32_t> words = {result};
821 return const_mgr->GetConstant(result_type, words);
822 } else if (float_type->width() == 64) {
823 double fa = a->GetDouble();
824 uint32_t result = integer_type->IsSigned()
825 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
826 : static_cast<uint32_t>(fa);
827 std::vector<uint32_t> words = {result};
828 return const_mgr->GetConstant(result_type, words);
829 }
830 return nullptr;
831 };
832 }
833
834 // This function defines a |UnaryScalarFoldingRule| that performs integer to
835 // float conversion.
836 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldIToFOp()837 UnaryScalarFoldingRule FoldIToFOp() {
838 return [](const analysis::Type* result_type, const analysis::Constant* a,
839 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
840 assert(result_type != nullptr && a != nullptr);
841 const analysis::Integer* integer_type = a->type()->AsInteger();
842 const analysis::Float* float_type = result_type->AsFloat();
843 assert(float_type != nullptr);
844 assert(integer_type != nullptr);
845 if (integer_type->width() != 32) return nullptr;
846 uint32_t ua = a->GetU32();
847 if (float_type->width() == 32) {
848 float result_val = integer_type->IsSigned()
849 ? static_cast<float>(static_cast<int32_t>(ua))
850 : static_cast<float>(ua);
851 utils::FloatProxy<float> result(result_val);
852 std::vector<uint32_t> words = {result.data()};
853 return const_mgr->GetConstant(result_type, words);
854 } else if (float_type->width() == 64) {
855 double result_val = integer_type->IsSigned()
856 ? static_cast<double>(static_cast<int32_t>(ua))
857 : static_cast<double>(ua);
858 utils::FloatProxy<double> result(result_val);
859 std::vector<uint32_t> words = result.GetWords();
860 return const_mgr->GetConstant(result_type, words);
861 }
862 return nullptr;
863 };
864 }
865
866 // This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
FoldQuantizeToF16Scalar()867 UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
868 return [](const analysis::Type* result_type, const analysis::Constant* a,
869 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
870 assert(result_type != nullptr && a != nullptr);
871 const analysis::Float* float_type = a->type()->AsFloat();
872 assert(float_type != nullptr);
873 if (float_type->width() != 32) {
874 return nullptr;
875 }
876
877 float fa = a->GetFloat();
878 utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
879 utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
880 utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
881 orignal.castTo(quantized, utils::round_direction::kToZero);
882 quantized.castTo(result, utils::round_direction::kToZero);
883 std::vector<uint32_t> words = {result.getBits()};
884 return const_mgr->GetConstant(result_type, words);
885 };
886 }
887
888 // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
889 // operator |op| must work for both float and double, and use syntax "f1 op f2".
890 #define FOLD_FPARITH_OP(op) \
891 [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
892 const analysis::Constant* b, \
893 analysis::ConstantManager* const_mgr_in_macro) \
894 -> const analysis::Constant* { \
895 assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
896 assert(result_type_in_macro == a->type() && \
897 result_type_in_macro == b->type()); \
898 const analysis::Float* float_type_in_macro = \
899 result_type_in_macro->AsFloat(); \
900 assert(float_type_in_macro != nullptr); \
901 if (float_type_in_macro->width() == 32) { \
902 float fa = a->GetFloat(); \
903 float fb = b->GetFloat(); \
904 utils::FloatProxy<float> result_in_macro(fa op fb); \
905 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
906 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
907 words_in_macro); \
908 } else if (float_type_in_macro->width() == 64) { \
909 double fa = a->GetDouble(); \
910 double fb = b->GetDouble(); \
911 utils::FloatProxy<double> result_in_macro(fa op fb); \
912 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
913 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
914 words_in_macro); \
915 } \
916 return nullptr; \
917 }
918
919 // Define the folding rule for conversion between floating point and integer
FoldFToI()920 ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
FoldIToF()921 ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
FoldQuantizeToF16()922 ConstantFoldingRule FoldQuantizeToF16() {
923 return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
924 }
925
926 // Define the folding rules for subtraction, addition, multiplication, and
927 // division for floating point values.
FoldFSub()928 ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
FoldFAdd()929 ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
FoldFMul()930 ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
931
932 // Returns the constant that results from evaluating |numerator| / 0.0. Returns
933 // |nullptr| if the result could not be evaluated.
FoldFPScalarDivideByZero(const analysis::Type * result_type,const analysis::Constant * numerator,analysis::ConstantManager * const_mgr)934 const analysis::Constant* FoldFPScalarDivideByZero(
935 const analysis::Type* result_type, const analysis::Constant* numerator,
936 analysis::ConstantManager* const_mgr) {
937 if (numerator == nullptr) {
938 return nullptr;
939 }
940
941 if (numerator->IsZero()) {
942 return GetNan(result_type, const_mgr);
943 }
944
945 const analysis::Constant* result = GetInf(result_type, const_mgr);
946 if (result == nullptr) {
947 return nullptr;
948 }
949
950 if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
951 result = NegateFPConst(result_type, result, const_mgr);
952 }
953 return result;
954 }
955
956 // Returns the result of folding |numerator| / |denominator|. Returns |nullptr|
957 // if it cannot be folded.
FoldScalarFPDivide(const analysis::Type * result_type,const analysis::Constant * numerator,const analysis::Constant * denominator,analysis::ConstantManager * const_mgr)958 const analysis::Constant* FoldScalarFPDivide(
959 const analysis::Type* result_type, const analysis::Constant* numerator,
960 const analysis::Constant* denominator,
961 analysis::ConstantManager* const_mgr) {
962 if (denominator == nullptr) {
963 return nullptr;
964 }
965
966 if (denominator->IsZero()) {
967 return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
968 }
969
970 uint32_t width = denominator->type()->AsFloat()->width();
971 if (width != 32 && width != 64) {
972 return nullptr;
973 }
974
975 const analysis::FloatConstant* denominator_float =
976 denominator->AsFloatConstant();
977 if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
978 const analysis::Constant* result =
979 FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
980 if (result != nullptr)
981 result = NegateFPConst(result_type, result, const_mgr);
982 return result;
983 } else {
984 return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
985 }
986 }
987
988 // Returns the constant folding rule to fold |OpFDiv| with two constants.
FoldFDiv()989 ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
990
CompareFloatingPoint(bool op_result,bool op_unordered,bool need_ordered)991 bool CompareFloatingPoint(bool op_result, bool op_unordered,
992 bool need_ordered) {
993 if (need_ordered) {
994 // operands are ordered and Operand 1 is |op| Operand 2
995 return !op_unordered && op_result;
996 } else {
997 // operands are unordered or Operand 1 is |op| Operand 2
998 return op_unordered || op_result;
999 }
1000 }
1001
1002 // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
1003 // operator |op| must work for both float and double, and use syntax "f1 op f2".
1004 #define FOLD_FPCMP_OP(op, ord) \
1005 [](const analysis::Type* result_type, const analysis::Constant* a, \
1006 const analysis::Constant* b, \
1007 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
1008 assert(result_type != nullptr && a != nullptr && b != nullptr); \
1009 assert(result_type->AsBool()); \
1010 assert(a->type() == b->type()); \
1011 const analysis::Float* float_type = a->type()->AsFloat(); \
1012 assert(float_type != nullptr); \
1013 if (float_type->width() == 32) { \
1014 float fa = a->GetFloat(); \
1015 float fb = b->GetFloat(); \
1016 bool result = CompareFloatingPoint( \
1017 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
1018 std::vector<uint32_t> words = {uint32_t(result)}; \
1019 return const_mgr->GetConstant(result_type, words); \
1020 } else if (float_type->width() == 64) { \
1021 double fa = a->GetDouble(); \
1022 double fb = b->GetDouble(); \
1023 bool result = CompareFloatingPoint( \
1024 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
1025 std::vector<uint32_t> words = {uint32_t(result)}; \
1026 return const_mgr->GetConstant(result_type, words); \
1027 } \
1028 return nullptr; \
1029 }
1030
1031 // Define the folding rules for ordered and unordered comparison for floating
1032 // point values.
FoldFOrdEqual()1033 ConstantFoldingRule FoldFOrdEqual() {
1034 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
1035 }
FoldFUnordEqual()1036 ConstantFoldingRule FoldFUnordEqual() {
1037 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
1038 }
FoldFOrdNotEqual()1039 ConstantFoldingRule FoldFOrdNotEqual() {
1040 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
1041 }
FoldFUnordNotEqual()1042 ConstantFoldingRule FoldFUnordNotEqual() {
1043 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
1044 }
FoldFOrdLessThan()1045 ConstantFoldingRule FoldFOrdLessThan() {
1046 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
1047 }
FoldFUnordLessThan()1048 ConstantFoldingRule FoldFUnordLessThan() {
1049 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
1050 }
FoldFOrdGreaterThan()1051 ConstantFoldingRule FoldFOrdGreaterThan() {
1052 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
1053 }
FoldFUnordGreaterThan()1054 ConstantFoldingRule FoldFUnordGreaterThan() {
1055 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
1056 }
FoldFOrdLessThanEqual()1057 ConstantFoldingRule FoldFOrdLessThanEqual() {
1058 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
1059 }
FoldFUnordLessThanEqual()1060 ConstantFoldingRule FoldFUnordLessThanEqual() {
1061 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
1062 }
FoldFOrdGreaterThanEqual()1063 ConstantFoldingRule FoldFOrdGreaterThanEqual() {
1064 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
1065 }
FoldFUnordGreaterThanEqual()1066 ConstantFoldingRule FoldFUnordGreaterThanEqual() {
1067 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
1068 }
1069
1070 // Folds an OpDot where all of the inputs are constants to a
1071 // constant. A new constant is created if necessary.
FoldOpDotWithConstants()1072 ConstantFoldingRule FoldOpDotWithConstants() {
1073 return [](IRContext* context, Instruction* inst,
1074 const std::vector<const analysis::Constant*>& constants)
1075 -> const analysis::Constant* {
1076 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1077 analysis::TypeManager* type_mgr = context->get_type_mgr();
1078 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
1079 assert(new_type->AsFloat() && "OpDot should have a float return type.");
1080 const analysis::Float* float_type = new_type->AsFloat();
1081
1082 if (!inst->IsFloatingPointFoldingAllowed()) {
1083 return nullptr;
1084 }
1085
1086 // If one of the operands is 0, then the result is 0.
1087 bool has_zero_operand = false;
1088
1089 for (int i = 0; i < 2; ++i) {
1090 if (constants[i]) {
1091 if (constants[i]->AsNullConstant() ||
1092 constants[i]->AsVectorConstant()->IsZero()) {
1093 has_zero_operand = true;
1094 break;
1095 }
1096 }
1097 }
1098
1099 if (has_zero_operand) {
1100 if (float_type->width() == 32) {
1101 utils::FloatProxy<float> result(0.0f);
1102 std::vector<uint32_t> words = result.GetWords();
1103 return const_mgr->GetConstant(float_type, words);
1104 }
1105 if (float_type->width() == 64) {
1106 utils::FloatProxy<double> result(0.0);
1107 std::vector<uint32_t> words = result.GetWords();
1108 return const_mgr->GetConstant(float_type, words);
1109 }
1110 return nullptr;
1111 }
1112
1113 if (constants[0] == nullptr || constants[1] == nullptr) {
1114 return nullptr;
1115 }
1116
1117 std::vector<const analysis::Constant*> a_components;
1118 std::vector<const analysis::Constant*> b_components;
1119
1120 a_components = constants[0]->GetVectorComponents(const_mgr);
1121 b_components = constants[1]->GetVectorComponents(const_mgr);
1122
1123 utils::FloatProxy<double> result(0.0);
1124 std::vector<uint32_t> words = result.GetWords();
1125 const analysis::Constant* result_const =
1126 const_mgr->GetConstant(float_type, words);
1127 for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
1128 ++i) {
1129 if (a_components[i] == nullptr || b_components[i] == nullptr) {
1130 return nullptr;
1131 }
1132
1133 const analysis::Constant* component = FOLD_FPARITH_OP(*)(
1134 new_type, a_components[i], b_components[i], const_mgr);
1135 if (component == nullptr) {
1136 return nullptr;
1137 }
1138 result_const =
1139 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
1140 }
1141 return result_const;
1142 };
1143 }
1144
FoldFNegate()1145 ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(NegateFPConst); }
FoldSNegate()1146 ConstantFoldingRule FoldSNegate() { return FoldUnaryOp(NegateIntConst); }
1147
FoldFClampFeedingCompare(spv::Op cmp_opcode)1148 ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
1149 return [cmp_opcode](IRContext* context, Instruction* inst,
1150 const std::vector<const analysis::Constant*>& constants)
1151 -> const analysis::Constant* {
1152 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1153 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1154
1155 if (!inst->IsFloatingPointFoldingAllowed()) {
1156 return nullptr;
1157 }
1158
1159 uint32_t non_const_idx = (constants[0] ? 1 : 0);
1160 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
1161 Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
1162
1163 analysis::TypeManager* type_mgr = context->get_type_mgr();
1164 const analysis::Type* operand_type =
1165 type_mgr->GetType(operand_inst->type_id());
1166
1167 if (!operand_type->AsFloat()) {
1168 return nullptr;
1169 }
1170
1171 if (operand_type->AsFloat()->width() != 32 &&
1172 operand_type->AsFloat()->width() != 64) {
1173 return nullptr;
1174 }
1175
1176 if (operand_inst->opcode() != spv::Op::OpExtInst) {
1177 return nullptr;
1178 }
1179
1180 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
1181 return nullptr;
1182 }
1183
1184 if (constants[1] == nullptr && constants[0] == nullptr) {
1185 return nullptr;
1186 }
1187
1188 uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
1189 const analysis::Constant* max_const =
1190 const_mgr->FindDeclaredConstant(max_id);
1191
1192 uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
1193 const analysis::Constant* min_const =
1194 const_mgr->FindDeclaredConstant(min_id);
1195
1196 bool found_result = false;
1197 bool result = false;
1198
1199 switch (cmp_opcode) {
1200 case spv::Op::OpFOrdLessThan:
1201 case spv::Op::OpFUnordLessThan:
1202 case spv::Op::OpFOrdGreaterThanEqual:
1203 case spv::Op::OpFUnordGreaterThanEqual:
1204 if (constants[0]) {
1205 if (min_const) {
1206 if (constants[0]->GetValueAsDouble() <
1207 min_const->GetValueAsDouble()) {
1208 found_result = true;
1209 result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1210 cmp_opcode == spv::Op::OpFUnordLessThan);
1211 }
1212 }
1213 if (max_const) {
1214 if (constants[0]->GetValueAsDouble() >=
1215 max_const->GetValueAsDouble()) {
1216 found_result = true;
1217 result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1218 cmp_opcode == spv::Op::OpFUnordLessThan);
1219 }
1220 }
1221 }
1222
1223 if (constants[1]) {
1224 if (max_const) {
1225 if (max_const->GetValueAsDouble() <
1226 constants[1]->GetValueAsDouble()) {
1227 found_result = true;
1228 result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1229 cmp_opcode == spv::Op::OpFUnordLessThan);
1230 }
1231 }
1232
1233 if (min_const) {
1234 if (min_const->GetValueAsDouble() >=
1235 constants[1]->GetValueAsDouble()) {
1236 found_result = true;
1237 result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1238 cmp_opcode == spv::Op::OpFUnordLessThan);
1239 }
1240 }
1241 }
1242 break;
1243 case spv::Op::OpFOrdGreaterThan:
1244 case spv::Op::OpFUnordGreaterThan:
1245 case spv::Op::OpFOrdLessThanEqual:
1246 case spv::Op::OpFUnordLessThanEqual:
1247 if (constants[0]) {
1248 if (min_const) {
1249 if (constants[0]->GetValueAsDouble() <=
1250 min_const->GetValueAsDouble()) {
1251 found_result = true;
1252 result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1253 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1254 }
1255 }
1256 if (max_const) {
1257 if (constants[0]->GetValueAsDouble() >
1258 max_const->GetValueAsDouble()) {
1259 found_result = true;
1260 result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1261 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1262 }
1263 }
1264 }
1265
1266 if (constants[1]) {
1267 if (max_const) {
1268 if (max_const->GetValueAsDouble() <=
1269 constants[1]->GetValueAsDouble()) {
1270 found_result = true;
1271 result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1272 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1273 }
1274 }
1275
1276 if (min_const) {
1277 if (min_const->GetValueAsDouble() >
1278 constants[1]->GetValueAsDouble()) {
1279 found_result = true;
1280 result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1281 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1282 }
1283 }
1284 }
1285 break;
1286 default:
1287 return nullptr;
1288 }
1289
1290 if (!found_result) {
1291 return nullptr;
1292 }
1293
1294 const analysis::Type* bool_type =
1295 context->get_type_mgr()->GetType(inst->type_id());
1296 const analysis::Constant* result_const =
1297 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
1298 assert(result_const);
1299 return result_const;
1300 };
1301 }
1302
FoldFMix()1303 ConstantFoldingRule FoldFMix() {
1304 return [](IRContext* context, Instruction* inst,
1305 const std::vector<const analysis::Constant*>& constants)
1306 -> const analysis::Constant* {
1307 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1308 assert(inst->opcode() == spv::Op::OpExtInst &&
1309 "Expecting an extended instruction.");
1310 assert(inst->GetSingleWordInOperand(0) ==
1311 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1312 "Expecting a GLSLstd450 extended instruction.");
1313 assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
1314 "Expecting and FMix instruction.");
1315
1316 if (!inst->IsFloatingPointFoldingAllowed()) {
1317 return nullptr;
1318 }
1319
1320 // Make sure all FMix operands are constants.
1321 for (uint32_t i = 1; i < 4; i++) {
1322 if (constants[i] == nullptr) {
1323 return nullptr;
1324 }
1325 }
1326
1327 const analysis::Constant* one;
1328 bool is_vector = false;
1329 const analysis::Type* result_type = constants[1]->type();
1330 const analysis::Type* base_type = result_type;
1331 if (base_type->AsVector()) {
1332 is_vector = true;
1333 base_type = base_type->AsVector()->element_type();
1334 }
1335 assert(base_type->AsFloat() != nullptr &&
1336 "FMix is suppose to act on floats or vectors of floats.");
1337
1338 if (base_type->AsFloat()->width() == 32) {
1339 one = const_mgr->GetConstant(base_type,
1340 utils::FloatProxy<float>(1.0f).GetWords());
1341 } else {
1342 one = const_mgr->GetConstant(base_type,
1343 utils::FloatProxy<double>(1.0).GetWords());
1344 }
1345
1346 if (is_vector) {
1347 uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
1348 one =
1349 const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
1350 }
1351
1352 const analysis::Constant* temp1 = FoldFPBinaryOp(
1353 FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
1354 if (temp1 == nullptr) {
1355 return nullptr;
1356 }
1357
1358 const analysis::Constant* temp2 = FoldFPBinaryOp(
1359 FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
1360 if (temp2 == nullptr) {
1361 return nullptr;
1362 }
1363 const analysis::Constant* temp3 =
1364 FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
1365 {constants[2], constants[3]}, context);
1366 if (temp3 == nullptr) {
1367 return nullptr;
1368 }
1369 return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
1370 context);
1371 };
1372 }
1373
FoldMin(const analysis::Type * result_type,const analysis::Constant * a,const analysis::Constant * b,analysis::ConstantManager *)1374 const analysis::Constant* FoldMin(const analysis::Type* result_type,
1375 const analysis::Constant* a,
1376 const analysis::Constant* b,
1377 analysis::ConstantManager*) {
1378 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1379 if (int_type->width() == 32) {
1380 if (int_type->IsSigned()) {
1381 int32_t va = a->GetS32();
1382 int32_t vb = b->GetS32();
1383 return (va < vb ? a : b);
1384 } else {
1385 uint32_t va = a->GetU32();
1386 uint32_t vb = b->GetU32();
1387 return (va < vb ? a : b);
1388 }
1389 } else if (int_type->width() == 64) {
1390 if (int_type->IsSigned()) {
1391 int64_t va = a->GetS64();
1392 int64_t vb = b->GetS64();
1393 return (va < vb ? a : b);
1394 } else {
1395 uint64_t va = a->GetU64();
1396 uint64_t vb = b->GetU64();
1397 return (va < vb ? a : b);
1398 }
1399 }
1400 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1401 if (float_type->width() == 32) {
1402 float va = a->GetFloat();
1403 float vb = b->GetFloat();
1404 return (va < vb ? a : b);
1405 } else if (float_type->width() == 64) {
1406 double va = a->GetDouble();
1407 double vb = b->GetDouble();
1408 return (va < vb ? a : b);
1409 }
1410 }
1411 return nullptr;
1412 }
1413
FoldMax(const analysis::Type * result_type,const analysis::Constant * a,const analysis::Constant * b,analysis::ConstantManager *)1414 const analysis::Constant* FoldMax(const analysis::Type* result_type,
1415 const analysis::Constant* a,
1416 const analysis::Constant* b,
1417 analysis::ConstantManager*) {
1418 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1419 if (int_type->width() == 32) {
1420 if (int_type->IsSigned()) {
1421 int32_t va = a->GetS32();
1422 int32_t vb = b->GetS32();
1423 return (va > vb ? a : b);
1424 } else {
1425 uint32_t va = a->GetU32();
1426 uint32_t vb = b->GetU32();
1427 return (va > vb ? a : b);
1428 }
1429 } else if (int_type->width() == 64) {
1430 if (int_type->IsSigned()) {
1431 int64_t va = a->GetS64();
1432 int64_t vb = b->GetS64();
1433 return (va > vb ? a : b);
1434 } else {
1435 uint64_t va = a->GetU64();
1436 uint64_t vb = b->GetU64();
1437 return (va > vb ? a : b);
1438 }
1439 }
1440 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1441 if (float_type->width() == 32) {
1442 float va = a->GetFloat();
1443 float vb = b->GetFloat();
1444 return (va > vb ? a : b);
1445 } else if (float_type->width() == 64) {
1446 double va = a->GetDouble();
1447 double vb = b->GetDouble();
1448 return (va > vb ? a : b);
1449 }
1450 }
1451 return nullptr;
1452 }
1453
1454 // Fold an clamp instruction when all three operands are constant.
FoldClamp1(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1455 const analysis::Constant* FoldClamp1(
1456 IRContext* context, Instruction* inst,
1457 const std::vector<const analysis::Constant*>& constants) {
1458 assert(inst->opcode() == spv::Op::OpExtInst &&
1459 "Expecting an extended instruction.");
1460 assert(inst->GetSingleWordInOperand(0) ==
1461 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1462 "Expecting a GLSLstd450 extended instruction.");
1463
1464 // Make sure all Clamp operands are constants.
1465 for (uint32_t i = 1; i < 4; i++) {
1466 if (constants[i] == nullptr) {
1467 return nullptr;
1468 }
1469 }
1470
1471 const analysis::Constant* temp = FoldFPBinaryOp(
1472 FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1473 if (temp == nullptr) {
1474 return nullptr;
1475 }
1476 return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1477 context);
1478 }
1479
1480 // Fold a clamp instruction when |x <= min_val|.
FoldClamp2(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1481 const analysis::Constant* FoldClamp2(
1482 IRContext* context, Instruction* inst,
1483 const std::vector<const analysis::Constant*>& constants) {
1484 assert(inst->opcode() == spv::Op::OpExtInst &&
1485 "Expecting an extended instruction.");
1486 assert(inst->GetSingleWordInOperand(0) ==
1487 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1488 "Expecting a GLSLstd450 extended instruction.");
1489
1490 const analysis::Constant* x = constants[1];
1491 const analysis::Constant* min_val = constants[2];
1492
1493 if (x == nullptr || min_val == nullptr) {
1494 return nullptr;
1495 }
1496
1497 const analysis::Constant* temp =
1498 FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1499 if (temp == min_val) {
1500 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1501 // result of the max operation is |min_val|, we know the result of the min
1502 // operation, even if |max_val| is not a constant.
1503 return min_val;
1504 }
1505 return nullptr;
1506 }
1507
1508 // Fold a clamp instruction when |x >= max_val|.
FoldClamp3(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1509 const analysis::Constant* FoldClamp3(
1510 IRContext* context, Instruction* inst,
1511 const std::vector<const analysis::Constant*>& constants) {
1512 assert(inst->opcode() == spv::Op::OpExtInst &&
1513 "Expecting an extended instruction.");
1514 assert(inst->GetSingleWordInOperand(0) ==
1515 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1516 "Expecting a GLSLstd450 extended instruction.");
1517
1518 const analysis::Constant* x = constants[1];
1519 const analysis::Constant* max_val = constants[3];
1520
1521 if (x == nullptr || max_val == nullptr) {
1522 return nullptr;
1523 }
1524
1525 const analysis::Constant* temp =
1526 FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1527 if (temp == max_val) {
1528 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1529 // result of the max operation is |min_val|, we know the result of the min
1530 // operation, even if |max_val| is not a constant.
1531 return max_val;
1532 }
1533 return nullptr;
1534 }
1535
FoldFTranscendentalUnary(double (* fp)(double))1536 UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
1537 return
1538 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1539 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1540 assert(result_type != nullptr && a != nullptr);
1541 const analysis::Float* float_type = a->type()->AsFloat();
1542 assert(float_type != nullptr);
1543 assert(float_type == result_type->AsFloat());
1544 if (float_type->width() == 32) {
1545 float fa = a->GetFloat();
1546 float res = static_cast<float>(fp(fa));
1547 utils::FloatProxy<float> result(res);
1548 std::vector<uint32_t> words = result.GetWords();
1549 return const_mgr->GetConstant(result_type, words);
1550 } else if (float_type->width() == 64) {
1551 double fa = a->GetDouble();
1552 double res = fp(fa);
1553 utils::FloatProxy<double> result(res);
1554 std::vector<uint32_t> words = result.GetWords();
1555 return const_mgr->GetConstant(result_type, words);
1556 }
1557 return nullptr;
1558 };
1559 }
1560
FoldFTranscendentalBinary(double (* fp)(double,double))1561 BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1562 double)) {
1563 return
1564 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1565 const analysis::Constant* b,
1566 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1567 assert(result_type != nullptr && a != nullptr);
1568 const analysis::Float* float_type = a->type()->AsFloat();
1569 assert(float_type != nullptr);
1570 assert(float_type == result_type->AsFloat());
1571 assert(float_type == b->type()->AsFloat());
1572 if (float_type->width() == 32) {
1573 float fa = a->GetFloat();
1574 float fb = b->GetFloat();
1575 float res = static_cast<float>(fp(fa, fb));
1576 utils::FloatProxy<float> result(res);
1577 std::vector<uint32_t> words = result.GetWords();
1578 return const_mgr->GetConstant(result_type, words);
1579 } else if (float_type->width() == 64) {
1580 double fa = a->GetDouble();
1581 double fb = b->GetDouble();
1582 double res = fp(fa, fb);
1583 utils::FloatProxy<double> result(res);
1584 std::vector<uint32_t> words = result.GetWords();
1585 return const_mgr->GetConstant(result_type, words);
1586 }
1587 return nullptr;
1588 };
1589 }
1590 } // namespace
1591
AddFoldingRules()1592 void ConstantFoldingRules::AddFoldingRules() {
1593 // Add all folding rules to the list for the opcodes to which they apply.
1594 // Note that the order in which rules are added to the list matters. If a rule
1595 // applies to the instruction, the rest of the rules will not be attempted.
1596 // Take that into consideration.
1597
1598 rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
1599
1600 rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
1601 rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
1602
1603 rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
1604 rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
1605 rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
1606 rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
1607
1608 rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
1609 rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
1610 rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
1611 rules_[spv::Op::OpFMul].push_back(FoldFMul());
1612 rules_[spv::Op::OpFSub].push_back(FoldFSub());
1613
1614 rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual());
1615
1616 rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual());
1617
1618 rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual());
1619
1620 rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual());
1621
1622 rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan());
1623 rules_[spv::Op::OpFOrdLessThan].push_back(
1624 FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan));
1625
1626 rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan());
1627 rules_[spv::Op::OpFUnordLessThan].push_back(
1628 FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan));
1629
1630 rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
1631 rules_[spv::Op::OpFOrdGreaterThan].push_back(
1632 FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan));
1633
1634 rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
1635 rules_[spv::Op::OpFUnordGreaterThan].push_back(
1636 FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan));
1637
1638 rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
1639 rules_[spv::Op::OpFOrdLessThanEqual].push_back(
1640 FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual));
1641
1642 rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
1643 rules_[spv::Op::OpFUnordLessThanEqual].push_back(
1644 FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual));
1645
1646 rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
1647 rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(
1648 FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual));
1649
1650 rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1651 FoldFUnordGreaterThanEqual());
1652 rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1653 FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual));
1654
1655 rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
1656 rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
1657 rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
1658 rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
1659 rules_[spv::Op::OpTranspose].push_back(FoldTranspose);
1660
1661 rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
1662 rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
1663 rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
1664
1665 // Add rules for GLSLstd450
1666 FeatureManager* feature_manager = context_->get_feature_mgr();
1667 uint32_t ext_inst_glslstd450_id =
1668 feature_manager->GetExtInstImportId_GLSLstd450();
1669 if (ext_inst_glslstd450_id != 0) {
1670 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
1671 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
1672 FoldFPBinaryOp(FoldMin));
1673 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
1674 FoldFPBinaryOp(FoldMin));
1675 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
1676 FoldFPBinaryOp(FoldMin));
1677 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
1678 FoldFPBinaryOp(FoldMax));
1679 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
1680 FoldFPBinaryOp(FoldMax));
1681 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
1682 FoldFPBinaryOp(FoldMax));
1683 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1684 FoldClamp1);
1685 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1686 FoldClamp2);
1687 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1688 FoldClamp3);
1689 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1690 FoldClamp1);
1691 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1692 FoldClamp2);
1693 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1694 FoldClamp3);
1695 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1696 FoldClamp1);
1697 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1698 FoldClamp2);
1699 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1700 FoldClamp3);
1701 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
1702 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
1703 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
1704 FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
1705 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
1706 FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
1707 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
1708 FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
1709 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
1710 FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
1711 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
1712 FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
1713 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
1714 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
1715 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
1716 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
1717
1718 #ifdef __ANDROID__
1719 // Android NDK r15c targeting ABI 15 doesn't have full support for C++11
1720 // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
1721 // available up until ABI 18 so we use a shim
1722 auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
1723 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1724 FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
1725 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1726 FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
1727 #else
1728 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1729 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
1730 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1731 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
1732 #endif
1733
1734 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
1735 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
1736 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
1737 FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
1738 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
1739 FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
1740 }
1741 }
1742 } // namespace opt
1743 } // namespace spvtools
1744