1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/const_folding_rules.h"
16
17 #include "source/opt/ir_context.h"
18
19 namespace spvtools {
20 namespace opt {
21 namespace {
22 constexpr uint32_t kExtractCompositeIdInIdx = 0;
23
24 // Returns the value obtained by extracting the |number_of_bits| least
25 // significant bits from |value|, and sign-extending it to 64-bits.
SignExtendValue(uint64_t value,uint32_t number_of_bits)26 uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) {
27 if (number_of_bits == 64) return value;
28
29 uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1);
30 uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1) - 1ull;
31 if (value & mask_for_sign_bit) {
32 // Set upper bits to 1
33 value |= ~mask_for_significant_bits;
34 } else {
35 // Clear the upper bits
36 value &= mask_for_significant_bits;
37 }
38 return value;
39 }
40
41 // Returns the value obtained by extracting the |number_of_bits| least
42 // significant bits from |value|, and zero-extending it to 64-bits.
ZeroExtendValue(uint64_t value,uint32_t number_of_bits)43 uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) {
44 if (number_of_bits == 64) return value;
45
46 uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits);
47 uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1;
48 value &= mask_for_bits_to_keep;
49 return value;
50 }
51
52 // Returns a constant whose value is `value` and type is `type`. This constant
53 // will be generated by `const_mgr`. The type must be a scalar integer type.
GenerateIntegerConstant(const analysis::Integer * integer_type,uint64_t result,analysis::ConstantManager * const_mgr)54 const analysis::Constant* GenerateIntegerConstant(
55 const analysis::Integer* integer_type, uint64_t result,
56 analysis::ConstantManager* const_mgr) {
57 assert(integer_type != nullptr);
58
59 std::vector<uint32_t> words;
60 if (integer_type->width() == 64) {
61 // In the 64-bit case, two words are needed to represent the value.
62 words = {static_cast<uint32_t>(result),
63 static_cast<uint32_t>(result >> 32)};
64 } else {
65 // In all other cases, only a single word is needed.
66 assert(integer_type->width() <= 32);
67 if (integer_type->IsSigned()) {
68 result = SignExtendValue(result, integer_type->width());
69 } else {
70 result = ZeroExtendValue(result, integer_type->width());
71 }
72 words = {static_cast<uint32_t>(result)};
73 }
74 return const_mgr->GetConstant(integer_type, words);
75 }
76
77 // Returns a constants with the value NaN of the given type. Only works for
78 // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
GetNan(const analysis::Type * type,analysis::ConstantManager * const_mgr)79 const analysis::Constant* GetNan(const analysis::Type* type,
80 analysis::ConstantManager* const_mgr) {
81 const analysis::Float* float_type = type->AsFloat();
82 if (float_type == nullptr) {
83 return nullptr;
84 }
85
86 switch (float_type->width()) {
87 case 32:
88 return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
89 case 64:
90 return const_mgr->GetDoubleConst(
91 std::numeric_limits<double>::quiet_NaN());
92 default:
93 return nullptr;
94 }
95 }
96
97 // Returns a constants with the value INF of the given type. Only works for
98 // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
GetInf(const analysis::Type * type,analysis::ConstantManager * const_mgr)99 const analysis::Constant* GetInf(const analysis::Type* type,
100 analysis::ConstantManager* const_mgr) {
101 const analysis::Float* float_type = type->AsFloat();
102 if (float_type == nullptr) {
103 return nullptr;
104 }
105
106 switch (float_type->width()) {
107 case 32:
108 return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
109 case 64:
110 return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
111 default:
112 return nullptr;
113 }
114 }
115
116 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)117 bool HasFloatingPoint(const analysis::Type* type) {
118 if (type->AsFloat()) {
119 return true;
120 } else if (const analysis::Vector* vec_type = type->AsVector()) {
121 return vec_type->element_type()->AsFloat() != nullptr;
122 }
123
124 return false;
125 }
126
127 // Returns a constants with the value |-val| of the given type. Only works for
128 // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
NegateFPConst(const analysis::Type * result_type,const analysis::Constant * val,analysis::ConstantManager * const_mgr)129 const analysis::Constant* NegateFPConst(const analysis::Type* result_type,
130 const analysis::Constant* val,
131 analysis::ConstantManager* const_mgr) {
132 const analysis::Float* float_type = result_type->AsFloat();
133 assert(float_type != nullptr);
134 if (float_type->width() == 32) {
135 float fa = val->GetFloat();
136 return const_mgr->GetFloatConst(-fa);
137 } else if (float_type->width() == 64) {
138 double da = val->GetDouble();
139 return const_mgr->GetDoubleConst(-da);
140 }
141 return nullptr;
142 }
143
144 // Returns a constants with the value |-val| of the given type.
NegateIntConst(const analysis::Type * result_type,const analysis::Constant * val,analysis::ConstantManager * const_mgr)145 const analysis::Constant* NegateIntConst(const analysis::Type* result_type,
146 const analysis::Constant* val,
147 analysis::ConstantManager* const_mgr) {
148 const analysis::Integer* int_type = result_type->AsInteger();
149 assert(int_type != nullptr);
150
151 if (val->AsNullConstant()) {
152 return val;
153 }
154
155 uint64_t new_value = static_cast<uint64_t>(-val->GetSignExtendedValue());
156 return const_mgr->GetIntConst(new_value, int_type->width(),
157 int_type->IsSigned());
158 }
159
160 // Folds an OpcompositeExtract where input is a composite constant.
FoldExtractWithConstants()161 ConstantFoldingRule FoldExtractWithConstants() {
162 return [](IRContext* context, Instruction* inst,
163 const std::vector<const analysis::Constant*>& constants)
164 -> const analysis::Constant* {
165 const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
166 if (c == nullptr) {
167 return nullptr;
168 }
169
170 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
171 uint32_t element_index = inst->GetSingleWordInOperand(i);
172 if (c->AsNullConstant()) {
173 // Return Null for the return type.
174 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
175 analysis::TypeManager* type_mgr = context->get_type_mgr();
176 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
177 }
178
179 auto cc = c->AsCompositeConstant();
180 assert(cc != nullptr);
181 auto components = cc->GetComponents();
182 // Protect against invalid IR. Refuse to fold if the index is out
183 // of bounds.
184 if (element_index >= components.size()) return nullptr;
185 c = components[element_index];
186 }
187 return c;
188 };
189 }
190
191 // Folds an OpcompositeInsert where input is a composite constant.
FoldInsertWithConstants()192 ConstantFoldingRule FoldInsertWithConstants() {
193 return [](IRContext* context, Instruction* inst,
194 const std::vector<const analysis::Constant*>& constants)
195 -> const analysis::Constant* {
196 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
197 const analysis::Constant* object = constants[0];
198 const analysis::Constant* composite = constants[1];
199 if (object == nullptr || composite == nullptr) {
200 return nullptr;
201 }
202
203 // If there is more than 1 index, then each additional constant used by the
204 // index will need to be recreated to use the inserted object.
205 std::vector<const analysis::Constant*> chain;
206 std::vector<const analysis::Constant*> components;
207 const analysis::Type* type = nullptr;
208 const uint32_t final_index = (inst->NumInOperands() - 1);
209
210 // Work down hierarchy of all indexes
211 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
212 type = composite->type();
213
214 if (composite->AsNullConstant()) {
215 // Make new composite so it can be inserted in the index with the
216 // non-null value
217 if (const auto new_composite =
218 const_mgr->GetNullCompositeConstant(type)) {
219 // Keep track of any indexes along the way to last index
220 if (i != final_index) {
221 chain.push_back(new_composite);
222 }
223 components = new_composite->AsCompositeConstant()->GetComponents();
224 } else {
225 // Unsupported input type (such as structs)
226 return nullptr;
227 }
228 } else {
229 // Keep track of any indexes along the way to last index
230 if (i != final_index) {
231 chain.push_back(composite);
232 }
233 components = composite->AsCompositeConstant()->GetComponents();
234 }
235 const uint32_t index = inst->GetSingleWordInOperand(i);
236 composite = components[index];
237 }
238
239 // Final index in hierarchy is inserted with new object.
240 const uint32_t final_operand = inst->GetSingleWordInOperand(final_index);
241 std::vector<uint32_t> ids;
242 for (size_t i = 0; i < components.size(); i++) {
243 const analysis::Constant* constant =
244 (i == final_operand) ? object : components[i];
245 Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
246 ids.push_back(member_inst->result_id());
247 }
248 const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
249
250 // Work backwards up the chain and replace each index with new constant.
251 for (size_t i = chain.size(); i > 0; i--) {
252 // Need to insert any previous instruction into the module first.
253 // Can't just insert in types_values_begin() because it will move above
254 // where the types are declared.
255 // Can't compare with location of inst because not all new added
256 // instructions are added to types_values_
257 auto iter = context->types_values_end();
258 Module::inst_iterator* pos = &iter;
259 const_mgr->BuildInstructionAndAddToModule(new_constant, pos);
260
261 composite = chain[i - 1];
262 components = composite->AsCompositeConstant()->GetComponents();
263 type = composite->type();
264 ids.clear();
265 for (size_t k = 0; k < components.size(); k++) {
266 const uint32_t index =
267 inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
268 const analysis::Constant* constant =
269 (k == index) ? new_constant : components[k];
270 const uint32_t constant_id =
271 const_mgr->FindDeclaredConstant(constant, 0);
272 ids.push_back(constant_id);
273 }
274 new_constant = const_mgr->GetConstant(type, ids);
275 }
276
277 // If multiple constants were created, only need to return the top index.
278 return new_constant;
279 };
280 }
281
FoldVectorShuffleWithConstants()282 ConstantFoldingRule FoldVectorShuffleWithConstants() {
283 return [](IRContext* context, Instruction* inst,
284 const std::vector<const analysis::Constant*>& constants)
285 -> const analysis::Constant* {
286 assert(inst->opcode() == spv::Op::OpVectorShuffle);
287 const analysis::Constant* c1 = constants[0];
288 const analysis::Constant* c2 = constants[1];
289 if (c1 == nullptr || c2 == nullptr) {
290 return nullptr;
291 }
292
293 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
294 const analysis::Type* element_type = c1->type()->AsVector()->element_type();
295
296 std::vector<const analysis::Constant*> c1_components;
297 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
298 c1_components = vec_const->GetComponents();
299 } else {
300 assert(c1->AsNullConstant());
301 const analysis::Constant* element =
302 const_mgr->GetConstant(element_type, {});
303 c1_components.resize(c1->type()->AsVector()->element_count(), element);
304 }
305 std::vector<const analysis::Constant*> c2_components;
306 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
307 c2_components = vec_const->GetComponents();
308 } else {
309 assert(c2->AsNullConstant());
310 const analysis::Constant* element =
311 const_mgr->GetConstant(element_type, {});
312 c2_components.resize(c2->type()->AsVector()->element_count(), element);
313 }
314
315 std::vector<uint32_t> ids;
316 const uint32_t undef_literal_value = 0xffffffff;
317 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
318 uint32_t index = inst->GetSingleWordInOperand(i);
319 if (index == undef_literal_value) {
320 // Don't fold shuffle with undef literal value.
321 return nullptr;
322 } else if (index < c1_components.size()) {
323 Instruction* member_inst =
324 const_mgr->GetDefiningInstruction(c1_components[index]);
325 ids.push_back(member_inst->result_id());
326 } else {
327 Instruction* member_inst = const_mgr->GetDefiningInstruction(
328 c2_components[index - c1_components.size()]);
329 ids.push_back(member_inst->result_id());
330 }
331 }
332
333 analysis::TypeManager* type_mgr = context->get_type_mgr();
334 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
335 };
336 }
337
FoldVectorTimesScalar()338 ConstantFoldingRule FoldVectorTimesScalar() {
339 return [](IRContext* context, Instruction* inst,
340 const std::vector<const analysis::Constant*>& constants)
341 -> const analysis::Constant* {
342 assert(inst->opcode() == spv::Op::OpVectorTimesScalar);
343 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
344 analysis::TypeManager* type_mgr = context->get_type_mgr();
345
346 if (!inst->IsFloatingPointFoldingAllowed()) {
347 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
348 return nullptr;
349 }
350 }
351
352 const analysis::Constant* c1 = constants[0];
353 const analysis::Constant* c2 = constants[1];
354
355 if (c1 && c1->IsZero()) {
356 return c1;
357 }
358
359 if (c2 && c2->IsZero()) {
360 // Get or create the NullConstant for this type.
361 std::vector<uint32_t> ids;
362 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
363 }
364
365 if (c1 == nullptr || c2 == nullptr) {
366 return nullptr;
367 }
368
369 // Check result type.
370 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
371 const analysis::Vector* vector_type = result_type->AsVector();
372 assert(vector_type != nullptr);
373 const analysis::Type* element_type = vector_type->element_type();
374 assert(element_type != nullptr);
375 const analysis::Float* float_type = element_type->AsFloat();
376 assert(float_type != nullptr);
377
378 // Check types of c1 and c2.
379 assert(c1->type()->AsVector() == vector_type);
380 assert(c1->type()->AsVector()->element_type() == element_type &&
381 c2->type() == element_type);
382
383 // Get a float vector that is the result of vector-times-scalar.
384 std::vector<const analysis::Constant*> c1_components =
385 c1->GetVectorComponents(const_mgr);
386 std::vector<uint32_t> ids;
387 if (float_type->width() == 32) {
388 float scalar = c2->GetFloat();
389 for (uint32_t i = 0; i < c1_components.size(); ++i) {
390 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
391 std::vector<uint32_t> words = result.GetWords();
392 const analysis::Constant* new_elem =
393 const_mgr->GetConstant(float_type, words);
394 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
395 }
396 return const_mgr->GetConstant(vector_type, ids);
397 } else if (float_type->width() == 64) {
398 double scalar = c2->GetDouble();
399 for (uint32_t i = 0; i < c1_components.size(); ++i) {
400 utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
401 scalar);
402 std::vector<uint32_t> words = result.GetWords();
403 const analysis::Constant* new_elem =
404 const_mgr->GetConstant(float_type, words);
405 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
406 }
407 return const_mgr->GetConstant(vector_type, ids);
408 }
409 return nullptr;
410 };
411 }
412
413 // Returns to the constant that results from tranposing |matrix|. The result
414 // will have type |result_type|, and |matrix| must exist in |context|. The
415 // result constant will also exist in |context|.
TransposeMatrix(const analysis::Constant * matrix,analysis::Matrix * result_type,IRContext * context)416 const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix,
417 analysis::Matrix* result_type,
418 IRContext* context) {
419 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
420 if (matrix->AsNullConstant() != nullptr) {
421 return const_mgr->GetNullCompositeConstant(result_type);
422 }
423
424 const auto& columns = matrix->AsMatrixConstant()->GetComponents();
425 uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count();
426
427 // Collect the ids of the elements in their new positions.
428 std::vector<std::vector<uint32_t>> result_elements(number_of_rows);
429 for (const analysis::Constant* column : columns) {
430 if (column->AsNullConstant()) {
431 column = const_mgr->GetNullCompositeConstant(column->type());
432 }
433 const auto& column_components = column->AsVectorConstant()->GetComponents();
434
435 for (uint32_t row = 0; row < number_of_rows; ++row) {
436 result_elements[row].push_back(
437 const_mgr->GetDefiningInstruction(column_components[row])
438 ->result_id());
439 }
440 }
441
442 // Create the constant for each row in the result, and collect the ids.
443 std::vector<uint32_t> result_columns(number_of_rows);
444 for (uint32_t col = 0; col < number_of_rows; ++col) {
445 auto* element = const_mgr->GetConstant(result_type->element_type(),
446 result_elements[col]);
447 result_columns[col] =
448 const_mgr->GetDefiningInstruction(element)->result_id();
449 }
450
451 // Create the matrix constant from the row ids, and return it.
452 return const_mgr->GetConstant(result_type, result_columns);
453 }
454
FoldTranspose(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)455 const analysis::Constant* FoldTranspose(
456 IRContext* context, Instruction* inst,
457 const std::vector<const analysis::Constant*>& constants) {
458 assert(inst->opcode() == spv::Op::OpTranspose);
459
460 analysis::TypeManager* type_mgr = context->get_type_mgr();
461 if (!inst->IsFloatingPointFoldingAllowed()) {
462 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
463 return nullptr;
464 }
465 }
466
467 const analysis::Constant* matrix = constants[0];
468 if (matrix == nullptr) {
469 return nullptr;
470 }
471
472 auto* result_type = type_mgr->GetType(inst->type_id());
473 return TransposeMatrix(matrix, result_type->AsMatrix(), context);
474 }
475
FoldVectorTimesMatrix()476 ConstantFoldingRule FoldVectorTimesMatrix() {
477 return [](IRContext* context, Instruction* inst,
478 const std::vector<const analysis::Constant*>& constants)
479 -> const analysis::Constant* {
480 assert(inst->opcode() == spv::Op::OpVectorTimesMatrix);
481 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
482 analysis::TypeManager* type_mgr = context->get_type_mgr();
483
484 if (!inst->IsFloatingPointFoldingAllowed()) {
485 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
486 return nullptr;
487 }
488 }
489
490 const analysis::Constant* c1 = constants[0];
491 const analysis::Constant* c2 = constants[1];
492
493 if (c1 == nullptr || c2 == nullptr) {
494 return nullptr;
495 }
496
497 // Check result type.
498 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
499 const analysis::Vector* vector_type = result_type->AsVector();
500 assert(vector_type != nullptr);
501 const analysis::Type* element_type = vector_type->element_type();
502 assert(element_type != nullptr);
503 const analysis::Float* float_type = element_type->AsFloat();
504 assert(float_type != nullptr);
505
506 // Check types of c1 and c2.
507 assert(c1->type()->AsVector() == vector_type);
508 assert(c1->type()->AsVector()->element_type() == element_type &&
509 c2->type()->AsMatrix()->element_type() == vector_type);
510
511 uint32_t resultVectorSize = result_type->AsVector()->element_count();
512 std::vector<uint32_t> ids;
513
514 if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
515 std::vector<uint32_t> words(float_type->width() / 32, 0);
516 for (uint32_t i = 0; i < resultVectorSize; ++i) {
517 const analysis::Constant* new_elem =
518 const_mgr->GetConstant(float_type, words);
519 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
520 }
521 return const_mgr->GetConstant(vector_type, ids);
522 }
523
524 // Get a float vector that is the result of vector-times-matrix.
525 std::vector<const analysis::Constant*> c1_components =
526 c1->GetVectorComponents(const_mgr);
527 std::vector<const analysis::Constant*> c2_components =
528 c2->AsMatrixConstant()->GetComponents();
529
530 if (float_type->width() == 32) {
531 for (uint32_t i = 0; i < resultVectorSize; ++i) {
532 float result_scalar = 0.0f;
533 if (!c2_components[i]->AsNullConstant()) {
534 const analysis::VectorConstant* c2_vec =
535 c2_components[i]->AsVectorConstant();
536 for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
537 float c1_scalar = c1_components[j]->GetFloat();
538 float c2_scalar = c2_vec->GetComponents()[j]->GetFloat();
539 result_scalar += c1_scalar * c2_scalar;
540 }
541 }
542 utils::FloatProxy<float> result(result_scalar);
543 std::vector<uint32_t> words = result.GetWords();
544 const analysis::Constant* new_elem =
545 const_mgr->GetConstant(float_type, words);
546 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
547 }
548 return const_mgr->GetConstant(vector_type, ids);
549 } else if (float_type->width() == 64) {
550 for (uint32_t i = 0; i < c2_components.size(); ++i) {
551 double result_scalar = 0.0;
552 if (!c2_components[i]->AsNullConstant()) {
553 const analysis::VectorConstant* c2_vec =
554 c2_components[i]->AsVectorConstant();
555 for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
556 double c1_scalar = c1_components[j]->GetDouble();
557 double c2_scalar = c2_vec->GetComponents()[j]->GetDouble();
558 result_scalar += c1_scalar * c2_scalar;
559 }
560 }
561 utils::FloatProxy<double> result(result_scalar);
562 std::vector<uint32_t> words = result.GetWords();
563 const analysis::Constant* new_elem =
564 const_mgr->GetConstant(float_type, words);
565 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
566 }
567 return const_mgr->GetConstant(vector_type, ids);
568 }
569 return nullptr;
570 };
571 }
572
FoldMatrixTimesVector()573 ConstantFoldingRule FoldMatrixTimesVector() {
574 return [](IRContext* context, Instruction* inst,
575 const std::vector<const analysis::Constant*>& constants)
576 -> const analysis::Constant* {
577 assert(inst->opcode() == spv::Op::OpMatrixTimesVector);
578 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
579 analysis::TypeManager* type_mgr = context->get_type_mgr();
580
581 if (!inst->IsFloatingPointFoldingAllowed()) {
582 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
583 return nullptr;
584 }
585 }
586
587 const analysis::Constant* c1 = constants[0];
588 const analysis::Constant* c2 = constants[1];
589
590 if (c1 == nullptr || c2 == nullptr) {
591 return nullptr;
592 }
593
594 // Check result type.
595 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
596 const analysis::Vector* vector_type = result_type->AsVector();
597 assert(vector_type != nullptr);
598 const analysis::Type* element_type = vector_type->element_type();
599 assert(element_type != nullptr);
600 const analysis::Float* float_type = element_type->AsFloat();
601 assert(float_type != nullptr);
602
603 // Check types of c1 and c2.
604 assert(c1->type()->AsMatrix()->element_type() == vector_type);
605 assert(c2->type()->AsVector()->element_type() == element_type);
606
607 uint32_t resultVectorSize = result_type->AsVector()->element_count();
608 std::vector<uint32_t> ids;
609
610 if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
611 std::vector<uint32_t> words(float_type->width() / 32, 0);
612 for (uint32_t i = 0; i < resultVectorSize; ++i) {
613 const analysis::Constant* new_elem =
614 const_mgr->GetConstant(float_type, words);
615 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
616 }
617 return const_mgr->GetConstant(vector_type, ids);
618 }
619
620 // Get a float vector that is the result of matrix-times-vector.
621 std::vector<const analysis::Constant*> c1_components =
622 c1->AsMatrixConstant()->GetComponents();
623 std::vector<const analysis::Constant*> c2_components =
624 c2->GetVectorComponents(const_mgr);
625
626 if (float_type->width() == 32) {
627 for (uint32_t i = 0; i < resultVectorSize; ++i) {
628 float result_scalar = 0.0f;
629 for (uint32_t j = 0; j < c1_components.size(); ++j) {
630 if (!c1_components[j]->AsNullConstant()) {
631 float c1_scalar = c1_components[j]
632 ->AsVectorConstant()
633 ->GetComponents()[i]
634 ->GetFloat();
635 float c2_scalar = c2_components[j]->GetFloat();
636 result_scalar += c1_scalar * c2_scalar;
637 }
638 }
639 utils::FloatProxy<float> result(result_scalar);
640 std::vector<uint32_t> words = result.GetWords();
641 const analysis::Constant* new_elem =
642 const_mgr->GetConstant(float_type, words);
643 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
644 }
645 return const_mgr->GetConstant(vector_type, ids);
646 } else if (float_type->width() == 64) {
647 for (uint32_t i = 0; i < resultVectorSize; ++i) {
648 double result_scalar = 0.0;
649 for (uint32_t j = 0; j < c1_components.size(); ++j) {
650 if (!c1_components[j]->AsNullConstant()) {
651 double c1_scalar = c1_components[j]
652 ->AsVectorConstant()
653 ->GetComponents()[i]
654 ->GetDouble();
655 double c2_scalar = c2_components[j]->GetDouble();
656 result_scalar += c1_scalar * c2_scalar;
657 }
658 }
659 utils::FloatProxy<double> result(result_scalar);
660 std::vector<uint32_t> words = result.GetWords();
661 const analysis::Constant* new_elem =
662 const_mgr->GetConstant(float_type, words);
663 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
664 }
665 return const_mgr->GetConstant(vector_type, ids);
666 }
667 return nullptr;
668 };
669 }
670
FoldCompositeWithConstants()671 ConstantFoldingRule FoldCompositeWithConstants() {
672 // Folds an OpCompositeConstruct where all of the inputs are constants to a
673 // constant. A new constant is created if necessary.
674 return [](IRContext* context, Instruction* inst,
675 const std::vector<const analysis::Constant*>& constants)
676 -> const analysis::Constant* {
677 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
678 analysis::TypeManager* type_mgr = context->get_type_mgr();
679 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
680 Instruction* type_inst =
681 context->get_def_use_mgr()->GetDef(inst->type_id());
682
683 std::vector<uint32_t> ids;
684 for (uint32_t i = 0; i < constants.size(); ++i) {
685 const analysis::Constant* element_const = constants[i];
686 if (element_const == nullptr) {
687 return nullptr;
688 }
689
690 uint32_t component_type_id = 0;
691 if (type_inst->opcode() == spv::Op::OpTypeStruct) {
692 component_type_id = type_inst->GetSingleWordInOperand(i);
693 } else if (type_inst->opcode() == spv::Op::OpTypeArray) {
694 component_type_id = type_inst->GetSingleWordInOperand(0);
695 }
696
697 uint32_t element_id =
698 const_mgr->FindDeclaredConstant(element_const, component_type_id);
699 if (element_id == 0) {
700 return nullptr;
701 }
702 ids.push_back(element_id);
703 }
704 return const_mgr->GetConstant(new_type, ids);
705 };
706 }
707
708 // The interface for a function that returns the result of applying a scalar
709 // floating-point binary operation on |a| and |b|. The type of the return value
710 // will be |type|. The input constants must also be of type |type|.
711 using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
712 const analysis::Type* result_type, const analysis::Constant* a,
713 analysis::ConstantManager*)>;
714
715 // The interface for a function that returns the result of applying a scalar
716 // floating-point binary operation on |a| and |b|. The type of the return value
717 // will be |type|. The input constants must also be of type |type|.
718 using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
719 const analysis::Type* result_type, const analysis::Constant* a,
720 const analysis::Constant* b, analysis::ConstantManager*)>;
721
722 // Returns a |ConstantFoldingRule| that folds unary scalar ops
723 // using |scalar_rule| and unary vectors ops by applying
724 // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
725 // that is returned assumes that |constants| contains 1 entry. If they are
726 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
727 // whose element type is |Float| or |Integer|.
FoldUnaryOp(UnaryScalarFoldingRule scalar_rule)728 ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
729 return [scalar_rule](IRContext* context, Instruction* inst,
730 const std::vector<const analysis::Constant*>& constants)
731 -> const analysis::Constant* {
732 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
733 analysis::TypeManager* type_mgr = context->get_type_mgr();
734 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
735 const analysis::Vector* vector_type = result_type->AsVector();
736
737 const analysis::Constant* arg =
738 (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
739
740 if (arg == nullptr) {
741 return nullptr;
742 }
743
744 if (vector_type != nullptr) {
745 std::vector<const analysis::Constant*> a_components;
746 std::vector<const analysis::Constant*> results_components;
747
748 a_components = arg->GetVectorComponents(const_mgr);
749
750 // Fold each component of the vector.
751 for (uint32_t i = 0; i < a_components.size(); ++i) {
752 results_components.push_back(scalar_rule(vector_type->element_type(),
753 a_components[i], const_mgr));
754 if (results_components[i] == nullptr) {
755 return nullptr;
756 }
757 }
758
759 // Build the constant object and return it.
760 std::vector<uint32_t> ids;
761 for (const analysis::Constant* member : results_components) {
762 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
763 }
764 return const_mgr->GetConstant(vector_type, ids);
765 } else {
766 return scalar_rule(result_type, arg, const_mgr);
767 }
768 };
769 }
770
771 // Returns a |ConstantFoldingRule| that folds binary scalar ops
772 // using |scalar_rule| and binary vectors ops by applying
773 // |scalar_rule| to the elements of the vector. The folding rule assumes that op
774 // has two inputs. For regular instruction, those are in operands 0 and 1. For
775 // extended instruction, they are in operands 1 and 2. If an element in
776 // |constants| is not nullprt, then the constant's type is |Float|, |Integer|,
777 // or |Vector| whose element type is |Float| or |Integer|.
FoldBinaryOp(BinaryScalarFoldingRule scalar_rule)778 ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) {
779 return [scalar_rule](IRContext* context, Instruction* inst,
780 const std::vector<const analysis::Constant*>& constants)
781 -> const analysis::Constant* {
782 assert(constants.size() == inst->NumInOperands());
783 assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2));
784 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
785 analysis::TypeManager* type_mgr = context->get_type_mgr();
786 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
787 const analysis::Vector* vector_type = result_type->AsVector();
788
789 const analysis::Constant* arg1 =
790 (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
791 const analysis::Constant* arg2 =
792 (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1];
793
794 if (arg1 == nullptr || arg2 == nullptr) {
795 return nullptr;
796 }
797
798 if (vector_type == nullptr) {
799 return scalar_rule(result_type, arg1, arg2, const_mgr);
800 }
801
802 std::vector<const analysis::Constant*> a_components;
803 std::vector<const analysis::Constant*> b_components;
804 std::vector<const analysis::Constant*> results_components;
805
806 a_components = arg1->GetVectorComponents(const_mgr);
807 b_components = arg2->GetVectorComponents(const_mgr);
808 assert(a_components.size() == b_components.size());
809
810 // Fold each component of the vector.
811 for (uint32_t i = 0; i < a_components.size(); ++i) {
812 results_components.push_back(scalar_rule(vector_type->element_type(),
813 a_components[i], b_components[i],
814 const_mgr));
815 if (results_components[i] == nullptr) {
816 return nullptr;
817 }
818 }
819
820 // Build the constant object and return it.
821 std::vector<uint32_t> ids;
822 for (const analysis::Constant* member : results_components) {
823 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
824 }
825 return const_mgr->GetConstant(vector_type, ids);
826 };
827 }
828
829 // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
830 // using |scalar_rule| and unary float point vectors ops by applying
831 // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
832 // that is returned assumes that |constants| contains 1 entry. If they are
833 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
834 // whose element type is |Float| or |Integer|.
FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule)835 ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
836 auto folding_rule = FoldUnaryOp(scalar_rule);
837 return [folding_rule](IRContext* context, Instruction* inst,
838 const std::vector<const analysis::Constant*>& constants)
839 -> const analysis::Constant* {
840 if (!inst->IsFloatingPointFoldingAllowed()) {
841 return nullptr;
842 }
843
844 return folding_rule(context, inst, constants);
845 };
846 }
847
848 // Returns the result of folding the constants in |constants| according the
849 // |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
850 // per component.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule,uint32_t result_type_id,const std::vector<const analysis::Constant * > & constants,IRContext * context)851 const analysis::Constant* FoldFPBinaryOp(
852 BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
853 const std::vector<const analysis::Constant*>& constants,
854 IRContext* context) {
855 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
856 analysis::TypeManager* type_mgr = context->get_type_mgr();
857 const analysis::Type* result_type = type_mgr->GetType(result_type_id);
858 const analysis::Vector* vector_type = result_type->AsVector();
859
860 if (constants[0] == nullptr || constants[1] == nullptr) {
861 return nullptr;
862 }
863
864 if (vector_type != nullptr) {
865 std::vector<const analysis::Constant*> a_components;
866 std::vector<const analysis::Constant*> b_components;
867 std::vector<const analysis::Constant*> results_components;
868
869 a_components = constants[0]->GetVectorComponents(const_mgr);
870 b_components = constants[1]->GetVectorComponents(const_mgr);
871
872 // Fold each component of the vector.
873 for (uint32_t i = 0; i < a_components.size(); ++i) {
874 results_components.push_back(scalar_rule(vector_type->element_type(),
875 a_components[i], b_components[i],
876 const_mgr));
877 if (results_components[i] == nullptr) {
878 return nullptr;
879 }
880 }
881
882 // Build the constant object and return it.
883 std::vector<uint32_t> ids;
884 for (const analysis::Constant* member : results_components) {
885 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
886 }
887 return const_mgr->GetConstant(vector_type, ids);
888 } else {
889 return scalar_rule(result_type, constants[0], constants[1], const_mgr);
890 }
891 }
892
893 // Returns a |ConstantFoldingRule| that folds floating point scalars using
894 // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
895 // elements of the vector. The |ConstantFoldingRule| that is returned assumes
896 // that |constants| contains 2 entries. If they are not |nullptr|, then their
897 // type is either |Float| or a |Vector| whose element type is |Float|.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule)898 ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
899 return [scalar_rule](IRContext* context, Instruction* inst,
900 const std::vector<const analysis::Constant*>& constants)
901 -> const analysis::Constant* {
902 if (!inst->IsFloatingPointFoldingAllowed()) {
903 return nullptr;
904 }
905 if (inst->opcode() == spv::Op::OpExtInst) {
906 return FoldFPBinaryOp(scalar_rule, inst->type_id(),
907 {constants[1], constants[2]}, context);
908 }
909 return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
910 };
911 }
912
913 // This macro defines a |UnaryScalarFoldingRule| that performs float to
914 // integer conversion.
915 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldFToIOp()916 UnaryScalarFoldingRule FoldFToIOp() {
917 return [](const analysis::Type* result_type, const analysis::Constant* a,
918 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
919 assert(result_type != nullptr && a != nullptr);
920 const analysis::Integer* integer_type = result_type->AsInteger();
921 const analysis::Float* float_type = a->type()->AsFloat();
922 assert(float_type != nullptr);
923 assert(integer_type != nullptr);
924 if (integer_type->width() != 32) return nullptr;
925 if (float_type->width() == 32) {
926 float fa = a->GetFloat();
927 uint32_t result = integer_type->IsSigned()
928 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
929 : static_cast<uint32_t>(fa);
930 std::vector<uint32_t> words = {result};
931 return const_mgr->GetConstant(result_type, words);
932 } else if (float_type->width() == 64) {
933 double fa = a->GetDouble();
934 uint32_t result = integer_type->IsSigned()
935 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
936 : static_cast<uint32_t>(fa);
937 std::vector<uint32_t> words = {result};
938 return const_mgr->GetConstant(result_type, words);
939 }
940 return nullptr;
941 };
942 }
943
944 // This function defines a |UnaryScalarFoldingRule| that performs integer to
945 // float conversion.
946 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldIToFOp()947 UnaryScalarFoldingRule FoldIToFOp() {
948 return [](const analysis::Type* result_type, const analysis::Constant* a,
949 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
950 assert(result_type != nullptr && a != nullptr);
951 const analysis::Integer* integer_type = a->type()->AsInteger();
952 const analysis::Float* float_type = result_type->AsFloat();
953 assert(float_type != nullptr);
954 assert(integer_type != nullptr);
955 if (integer_type->width() != 32) return nullptr;
956 uint32_t ua = a->GetU32();
957 if (float_type->width() == 32) {
958 float result_val = integer_type->IsSigned()
959 ? static_cast<float>(static_cast<int32_t>(ua))
960 : static_cast<float>(ua);
961 utils::FloatProxy<float> result(result_val);
962 std::vector<uint32_t> words = {result.data()};
963 return const_mgr->GetConstant(result_type, words);
964 } else if (float_type->width() == 64) {
965 double result_val = integer_type->IsSigned()
966 ? static_cast<double>(static_cast<int32_t>(ua))
967 : static_cast<double>(ua);
968 utils::FloatProxy<double> result(result_val);
969 std::vector<uint32_t> words = result.GetWords();
970 return const_mgr->GetConstant(result_type, words);
971 }
972 return nullptr;
973 };
974 }
975
976 // This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
FoldQuantizeToF16Scalar()977 UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
978 return [](const analysis::Type* result_type, const analysis::Constant* a,
979 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
980 assert(result_type != nullptr && a != nullptr);
981 const analysis::Float* float_type = a->type()->AsFloat();
982 assert(float_type != nullptr);
983 if (float_type->width() != 32) {
984 return nullptr;
985 }
986
987 float fa = a->GetFloat();
988 utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
989 utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
990 utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
991 orignal.castTo(quantized, utils::round_direction::kToZero);
992 quantized.castTo(result, utils::round_direction::kToZero);
993 std::vector<uint32_t> words = {result.getBits()};
994 return const_mgr->GetConstant(result_type, words);
995 };
996 }
997
998 // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
999 // operator |op| must work for both float and double, and use syntax "f1 op f2".
1000 #define FOLD_FPARITH_OP(op) \
1001 [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
1002 const analysis::Constant* b, \
1003 analysis::ConstantManager* const_mgr_in_macro) \
1004 -> const analysis::Constant* { \
1005 assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
1006 assert(result_type_in_macro == a->type() && \
1007 result_type_in_macro == b->type()); \
1008 const analysis::Float* float_type_in_macro = \
1009 result_type_in_macro->AsFloat(); \
1010 assert(float_type_in_macro != nullptr); \
1011 if (float_type_in_macro->width() == 32) { \
1012 float fa = a->GetFloat(); \
1013 float fb = b->GetFloat(); \
1014 utils::FloatProxy<float> result_in_macro(fa op fb); \
1015 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
1016 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
1017 words_in_macro); \
1018 } else if (float_type_in_macro->width() == 64) { \
1019 double fa = a->GetDouble(); \
1020 double fb = b->GetDouble(); \
1021 utils::FloatProxy<double> result_in_macro(fa op fb); \
1022 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
1023 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
1024 words_in_macro); \
1025 } \
1026 return nullptr; \
1027 }
1028
1029 // Define the folding rule for conversion between floating point and integer
FoldFToI()1030 ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
FoldIToF()1031 ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
FoldQuantizeToF16()1032 ConstantFoldingRule FoldQuantizeToF16() {
1033 return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
1034 }
1035
1036 // Define the folding rules for subtraction, addition, multiplication, and
1037 // division for floating point values.
FoldFSub()1038 ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
FoldFAdd()1039 ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
FoldFMul()1040 ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
1041
1042 // Returns the constant that results from evaluating |numerator| / 0.0. Returns
1043 // |nullptr| if the result could not be evaluated.
FoldFPScalarDivideByZero(const analysis::Type * result_type,const analysis::Constant * numerator,analysis::ConstantManager * const_mgr)1044 const analysis::Constant* FoldFPScalarDivideByZero(
1045 const analysis::Type* result_type, const analysis::Constant* numerator,
1046 analysis::ConstantManager* const_mgr) {
1047 if (numerator == nullptr) {
1048 return nullptr;
1049 }
1050
1051 if (numerator->IsZero()) {
1052 return GetNan(result_type, const_mgr);
1053 }
1054
1055 const analysis::Constant* result = GetInf(result_type, const_mgr);
1056 if (result == nullptr) {
1057 return nullptr;
1058 }
1059
1060 if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
1061 result = NegateFPConst(result_type, result, const_mgr);
1062 }
1063 return result;
1064 }
1065
1066 // Returns the result of folding |numerator| / |denominator|. Returns |nullptr|
1067 // if it cannot be folded.
FoldScalarFPDivide(const analysis::Type * result_type,const analysis::Constant * numerator,const analysis::Constant * denominator,analysis::ConstantManager * const_mgr)1068 const analysis::Constant* FoldScalarFPDivide(
1069 const analysis::Type* result_type, const analysis::Constant* numerator,
1070 const analysis::Constant* denominator,
1071 analysis::ConstantManager* const_mgr) {
1072 if (denominator == nullptr) {
1073 return nullptr;
1074 }
1075
1076 if (denominator->IsZero()) {
1077 return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
1078 }
1079
1080 uint32_t width = denominator->type()->AsFloat()->width();
1081 if (width != 32 && width != 64) {
1082 return nullptr;
1083 }
1084
1085 const analysis::FloatConstant* denominator_float =
1086 denominator->AsFloatConstant();
1087 if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
1088 const analysis::Constant* result =
1089 FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
1090 if (result != nullptr)
1091 result = NegateFPConst(result_type, result, const_mgr);
1092 return result;
1093 } else {
1094 return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
1095 }
1096 }
1097
1098 // Returns the constant folding rule to fold |OpFDiv| with two constants.
FoldFDiv()1099 ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
1100
CompareFloatingPoint(bool op_result,bool op_unordered,bool need_ordered)1101 bool CompareFloatingPoint(bool op_result, bool op_unordered,
1102 bool need_ordered) {
1103 if (need_ordered) {
1104 // operands are ordered and Operand 1 is |op| Operand 2
1105 return !op_unordered && op_result;
1106 } else {
1107 // operands are unordered or Operand 1 is |op| Operand 2
1108 return op_unordered || op_result;
1109 }
1110 }
1111
1112 // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
1113 // operator |op| must work for both float and double, and use syntax "f1 op f2".
1114 #define FOLD_FPCMP_OP(op, ord) \
1115 [](const analysis::Type* result_type, const analysis::Constant* a, \
1116 const analysis::Constant* b, \
1117 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
1118 assert(result_type != nullptr && a != nullptr && b != nullptr); \
1119 assert(result_type->AsBool()); \
1120 assert(a->type() == b->type()); \
1121 const analysis::Float* float_type = a->type()->AsFloat(); \
1122 assert(float_type != nullptr); \
1123 if (float_type->width() == 32) { \
1124 float fa = a->GetFloat(); \
1125 float fb = b->GetFloat(); \
1126 bool result = CompareFloatingPoint( \
1127 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
1128 std::vector<uint32_t> words = {uint32_t(result)}; \
1129 return const_mgr->GetConstant(result_type, words); \
1130 } else if (float_type->width() == 64) { \
1131 double fa = a->GetDouble(); \
1132 double fb = b->GetDouble(); \
1133 bool result = CompareFloatingPoint( \
1134 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
1135 std::vector<uint32_t> words = {uint32_t(result)}; \
1136 return const_mgr->GetConstant(result_type, words); \
1137 } \
1138 return nullptr; \
1139 }
1140
1141 // Define the folding rules for ordered and unordered comparison for floating
1142 // point values.
FoldFOrdEqual()1143 ConstantFoldingRule FoldFOrdEqual() {
1144 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
1145 }
FoldFUnordEqual()1146 ConstantFoldingRule FoldFUnordEqual() {
1147 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
1148 }
FoldFOrdNotEqual()1149 ConstantFoldingRule FoldFOrdNotEqual() {
1150 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
1151 }
FoldFUnordNotEqual()1152 ConstantFoldingRule FoldFUnordNotEqual() {
1153 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
1154 }
FoldFOrdLessThan()1155 ConstantFoldingRule FoldFOrdLessThan() {
1156 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
1157 }
FoldFUnordLessThan()1158 ConstantFoldingRule FoldFUnordLessThan() {
1159 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
1160 }
FoldFOrdGreaterThan()1161 ConstantFoldingRule FoldFOrdGreaterThan() {
1162 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
1163 }
FoldFUnordGreaterThan()1164 ConstantFoldingRule FoldFUnordGreaterThan() {
1165 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
1166 }
FoldFOrdLessThanEqual()1167 ConstantFoldingRule FoldFOrdLessThanEqual() {
1168 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
1169 }
FoldFUnordLessThanEqual()1170 ConstantFoldingRule FoldFUnordLessThanEqual() {
1171 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
1172 }
FoldFOrdGreaterThanEqual()1173 ConstantFoldingRule FoldFOrdGreaterThanEqual() {
1174 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
1175 }
FoldFUnordGreaterThanEqual()1176 ConstantFoldingRule FoldFUnordGreaterThanEqual() {
1177 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
1178 }
1179
1180 // Folds an OpDot where all of the inputs are constants to a
1181 // constant. A new constant is created if necessary.
FoldOpDotWithConstants()1182 ConstantFoldingRule FoldOpDotWithConstants() {
1183 return [](IRContext* context, Instruction* inst,
1184 const std::vector<const analysis::Constant*>& constants)
1185 -> const analysis::Constant* {
1186 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1187 analysis::TypeManager* type_mgr = context->get_type_mgr();
1188 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
1189 assert(new_type->AsFloat() && "OpDot should have a float return type.");
1190 const analysis::Float* float_type = new_type->AsFloat();
1191
1192 if (!inst->IsFloatingPointFoldingAllowed()) {
1193 return nullptr;
1194 }
1195
1196 // If one of the operands is 0, then the result is 0.
1197 bool has_zero_operand = false;
1198
1199 for (int i = 0; i < 2; ++i) {
1200 if (constants[i]) {
1201 if (constants[i]->AsNullConstant() ||
1202 constants[i]->AsVectorConstant()->IsZero()) {
1203 has_zero_operand = true;
1204 break;
1205 }
1206 }
1207 }
1208
1209 if (has_zero_operand) {
1210 if (float_type->width() == 32) {
1211 utils::FloatProxy<float> result(0.0f);
1212 std::vector<uint32_t> words = result.GetWords();
1213 return const_mgr->GetConstant(float_type, words);
1214 }
1215 if (float_type->width() == 64) {
1216 utils::FloatProxy<double> result(0.0);
1217 std::vector<uint32_t> words = result.GetWords();
1218 return const_mgr->GetConstant(float_type, words);
1219 }
1220 return nullptr;
1221 }
1222
1223 if (constants[0] == nullptr || constants[1] == nullptr) {
1224 return nullptr;
1225 }
1226
1227 std::vector<const analysis::Constant*> a_components;
1228 std::vector<const analysis::Constant*> b_components;
1229
1230 a_components = constants[0]->GetVectorComponents(const_mgr);
1231 b_components = constants[1]->GetVectorComponents(const_mgr);
1232
1233 utils::FloatProxy<double> result(0.0);
1234 std::vector<uint32_t> words = result.GetWords();
1235 const analysis::Constant* result_const =
1236 const_mgr->GetConstant(float_type, words);
1237 for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
1238 ++i) {
1239 if (a_components[i] == nullptr || b_components[i] == nullptr) {
1240 return nullptr;
1241 }
1242
1243 const analysis::Constant* component = FOLD_FPARITH_OP(*)(
1244 new_type, a_components[i], b_components[i], const_mgr);
1245 if (component == nullptr) {
1246 return nullptr;
1247 }
1248 result_const =
1249 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
1250 }
1251 return result_const;
1252 };
1253 }
1254
FoldFNegate()1255 ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(NegateFPConst); }
FoldSNegate()1256 ConstantFoldingRule FoldSNegate() { return FoldUnaryOp(NegateIntConst); }
1257
FoldFClampFeedingCompare(spv::Op cmp_opcode)1258 ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
1259 return [cmp_opcode](IRContext* context, Instruction* inst,
1260 const std::vector<const analysis::Constant*>& constants)
1261 -> const analysis::Constant* {
1262 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1263 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1264
1265 if (!inst->IsFloatingPointFoldingAllowed()) {
1266 return nullptr;
1267 }
1268
1269 uint32_t non_const_idx = (constants[0] ? 1 : 0);
1270 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
1271 Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
1272
1273 analysis::TypeManager* type_mgr = context->get_type_mgr();
1274 const analysis::Type* operand_type =
1275 type_mgr->GetType(operand_inst->type_id());
1276
1277 if (!operand_type->AsFloat()) {
1278 return nullptr;
1279 }
1280
1281 if (operand_type->AsFloat()->width() != 32 &&
1282 operand_type->AsFloat()->width() != 64) {
1283 return nullptr;
1284 }
1285
1286 if (operand_inst->opcode() != spv::Op::OpExtInst) {
1287 return nullptr;
1288 }
1289
1290 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
1291 return nullptr;
1292 }
1293
1294 if (constants[1] == nullptr && constants[0] == nullptr) {
1295 return nullptr;
1296 }
1297
1298 uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
1299 const analysis::Constant* max_const =
1300 const_mgr->FindDeclaredConstant(max_id);
1301
1302 uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
1303 const analysis::Constant* min_const =
1304 const_mgr->FindDeclaredConstant(min_id);
1305
1306 bool found_result = false;
1307 bool result = false;
1308
1309 switch (cmp_opcode) {
1310 case spv::Op::OpFOrdLessThan:
1311 case spv::Op::OpFUnordLessThan:
1312 case spv::Op::OpFOrdGreaterThanEqual:
1313 case spv::Op::OpFUnordGreaterThanEqual:
1314 if (constants[0]) {
1315 if (min_const) {
1316 if (constants[0]->GetValueAsDouble() <
1317 min_const->GetValueAsDouble()) {
1318 found_result = true;
1319 result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1320 cmp_opcode == spv::Op::OpFUnordLessThan);
1321 }
1322 }
1323 if (max_const) {
1324 if (constants[0]->GetValueAsDouble() >=
1325 max_const->GetValueAsDouble()) {
1326 found_result = true;
1327 result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1328 cmp_opcode == spv::Op::OpFUnordLessThan);
1329 }
1330 }
1331 }
1332
1333 if (constants[1]) {
1334 if (max_const) {
1335 if (max_const->GetValueAsDouble() <
1336 constants[1]->GetValueAsDouble()) {
1337 found_result = true;
1338 result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1339 cmp_opcode == spv::Op::OpFUnordLessThan);
1340 }
1341 }
1342
1343 if (min_const) {
1344 if (min_const->GetValueAsDouble() >=
1345 constants[1]->GetValueAsDouble()) {
1346 found_result = true;
1347 result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1348 cmp_opcode == spv::Op::OpFUnordLessThan);
1349 }
1350 }
1351 }
1352 break;
1353 case spv::Op::OpFOrdGreaterThan:
1354 case spv::Op::OpFUnordGreaterThan:
1355 case spv::Op::OpFOrdLessThanEqual:
1356 case spv::Op::OpFUnordLessThanEqual:
1357 if (constants[0]) {
1358 if (min_const) {
1359 if (constants[0]->GetValueAsDouble() <=
1360 min_const->GetValueAsDouble()) {
1361 found_result = true;
1362 result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1363 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1364 }
1365 }
1366 if (max_const) {
1367 if (constants[0]->GetValueAsDouble() >
1368 max_const->GetValueAsDouble()) {
1369 found_result = true;
1370 result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1371 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1372 }
1373 }
1374 }
1375
1376 if (constants[1]) {
1377 if (max_const) {
1378 if (max_const->GetValueAsDouble() <=
1379 constants[1]->GetValueAsDouble()) {
1380 found_result = true;
1381 result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1382 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1383 }
1384 }
1385
1386 if (min_const) {
1387 if (min_const->GetValueAsDouble() >
1388 constants[1]->GetValueAsDouble()) {
1389 found_result = true;
1390 result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1391 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1392 }
1393 }
1394 }
1395 break;
1396 default:
1397 return nullptr;
1398 }
1399
1400 if (!found_result) {
1401 return nullptr;
1402 }
1403
1404 const analysis::Type* bool_type =
1405 context->get_type_mgr()->GetType(inst->type_id());
1406 const analysis::Constant* result_const =
1407 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
1408 assert(result_const);
1409 return result_const;
1410 };
1411 }
1412
FoldFMix()1413 ConstantFoldingRule FoldFMix() {
1414 return [](IRContext* context, Instruction* inst,
1415 const std::vector<const analysis::Constant*>& constants)
1416 -> const analysis::Constant* {
1417 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1418 assert(inst->opcode() == spv::Op::OpExtInst &&
1419 "Expecting an extended instruction.");
1420 assert(inst->GetSingleWordInOperand(0) ==
1421 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1422 "Expecting a GLSLstd450 extended instruction.");
1423 assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
1424 "Expecting and FMix instruction.");
1425
1426 if (!inst->IsFloatingPointFoldingAllowed()) {
1427 return nullptr;
1428 }
1429
1430 // Make sure all FMix operands are constants.
1431 for (uint32_t i = 1; i < 4; i++) {
1432 if (constants[i] == nullptr) {
1433 return nullptr;
1434 }
1435 }
1436
1437 const analysis::Constant* one;
1438 bool is_vector = false;
1439 const analysis::Type* result_type = constants[1]->type();
1440 const analysis::Type* base_type = result_type;
1441 if (base_type->AsVector()) {
1442 is_vector = true;
1443 base_type = base_type->AsVector()->element_type();
1444 }
1445 assert(base_type->AsFloat() != nullptr &&
1446 "FMix is suppose to act on floats or vectors of floats.");
1447
1448 if (base_type->AsFloat()->width() == 32) {
1449 one = const_mgr->GetConstant(base_type,
1450 utils::FloatProxy<float>(1.0f).GetWords());
1451 } else {
1452 one = const_mgr->GetConstant(base_type,
1453 utils::FloatProxy<double>(1.0).GetWords());
1454 }
1455
1456 if (is_vector) {
1457 uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
1458 one =
1459 const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
1460 }
1461
1462 const analysis::Constant* temp1 = FoldFPBinaryOp(
1463 FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
1464 if (temp1 == nullptr) {
1465 return nullptr;
1466 }
1467
1468 const analysis::Constant* temp2 = FoldFPBinaryOp(
1469 FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
1470 if (temp2 == nullptr) {
1471 return nullptr;
1472 }
1473 const analysis::Constant* temp3 =
1474 FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
1475 {constants[2], constants[3]}, context);
1476 if (temp3 == nullptr) {
1477 return nullptr;
1478 }
1479 return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
1480 context);
1481 };
1482 }
1483
FoldMin(const analysis::Type * result_type,const analysis::Constant * a,const analysis::Constant * b,analysis::ConstantManager *)1484 const analysis::Constant* FoldMin(const analysis::Type* result_type,
1485 const analysis::Constant* a,
1486 const analysis::Constant* b,
1487 analysis::ConstantManager*) {
1488 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1489 if (int_type->width() == 32) {
1490 if (int_type->IsSigned()) {
1491 int32_t va = a->GetS32();
1492 int32_t vb = b->GetS32();
1493 return (va < vb ? a : b);
1494 } else {
1495 uint32_t va = a->GetU32();
1496 uint32_t vb = b->GetU32();
1497 return (va < vb ? a : b);
1498 }
1499 } else if (int_type->width() == 64) {
1500 if (int_type->IsSigned()) {
1501 int64_t va = a->GetS64();
1502 int64_t vb = b->GetS64();
1503 return (va < vb ? a : b);
1504 } else {
1505 uint64_t va = a->GetU64();
1506 uint64_t vb = b->GetU64();
1507 return (va < vb ? a : b);
1508 }
1509 }
1510 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1511 if (float_type->width() == 32) {
1512 float va = a->GetFloat();
1513 float vb = b->GetFloat();
1514 return (va < vb ? a : b);
1515 } else if (float_type->width() == 64) {
1516 double va = a->GetDouble();
1517 double vb = b->GetDouble();
1518 return (va < vb ? a : b);
1519 }
1520 }
1521 return nullptr;
1522 }
1523
FoldMax(const analysis::Type * result_type,const analysis::Constant * a,const analysis::Constant * b,analysis::ConstantManager *)1524 const analysis::Constant* FoldMax(const analysis::Type* result_type,
1525 const analysis::Constant* a,
1526 const analysis::Constant* b,
1527 analysis::ConstantManager*) {
1528 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1529 if (int_type->width() == 32) {
1530 if (int_type->IsSigned()) {
1531 int32_t va = a->GetS32();
1532 int32_t vb = b->GetS32();
1533 return (va > vb ? a : b);
1534 } else {
1535 uint32_t va = a->GetU32();
1536 uint32_t vb = b->GetU32();
1537 return (va > vb ? a : b);
1538 }
1539 } else if (int_type->width() == 64) {
1540 if (int_type->IsSigned()) {
1541 int64_t va = a->GetS64();
1542 int64_t vb = b->GetS64();
1543 return (va > vb ? a : b);
1544 } else {
1545 uint64_t va = a->GetU64();
1546 uint64_t vb = b->GetU64();
1547 return (va > vb ? a : b);
1548 }
1549 }
1550 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1551 if (float_type->width() == 32) {
1552 float va = a->GetFloat();
1553 float vb = b->GetFloat();
1554 return (va > vb ? a : b);
1555 } else if (float_type->width() == 64) {
1556 double va = a->GetDouble();
1557 double vb = b->GetDouble();
1558 return (va > vb ? a : b);
1559 }
1560 }
1561 return nullptr;
1562 }
1563
1564 // Fold an clamp instruction when all three operands are constant.
FoldClamp1(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1565 const analysis::Constant* FoldClamp1(
1566 IRContext* context, Instruction* inst,
1567 const std::vector<const analysis::Constant*>& constants) {
1568 assert(inst->opcode() == spv::Op::OpExtInst &&
1569 "Expecting an extended instruction.");
1570 assert(inst->GetSingleWordInOperand(0) ==
1571 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1572 "Expecting a GLSLstd450 extended instruction.");
1573
1574 // Make sure all Clamp operands are constants.
1575 for (uint32_t i = 1; i < 4; i++) {
1576 if (constants[i] == nullptr) {
1577 return nullptr;
1578 }
1579 }
1580
1581 const analysis::Constant* temp = FoldFPBinaryOp(
1582 FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1583 if (temp == nullptr) {
1584 return nullptr;
1585 }
1586 return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1587 context);
1588 }
1589
1590 // Fold a clamp instruction when |x <= min_val|.
FoldClamp2(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1591 const analysis::Constant* FoldClamp2(
1592 IRContext* context, Instruction* inst,
1593 const std::vector<const analysis::Constant*>& constants) {
1594 assert(inst->opcode() == spv::Op::OpExtInst &&
1595 "Expecting an extended instruction.");
1596 assert(inst->GetSingleWordInOperand(0) ==
1597 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1598 "Expecting a GLSLstd450 extended instruction.");
1599
1600 const analysis::Constant* x = constants[1];
1601 const analysis::Constant* min_val = constants[2];
1602
1603 if (x == nullptr || min_val == nullptr) {
1604 return nullptr;
1605 }
1606
1607 const analysis::Constant* temp =
1608 FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1609 if (temp == min_val) {
1610 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1611 // result of the max operation is |min_val|, we know the result of the min
1612 // operation, even if |max_val| is not a constant.
1613 return min_val;
1614 }
1615 return nullptr;
1616 }
1617
1618 // Fold a clamp instruction when |x >= max_val|.
FoldClamp3(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1619 const analysis::Constant* FoldClamp3(
1620 IRContext* context, Instruction* inst,
1621 const std::vector<const analysis::Constant*>& constants) {
1622 assert(inst->opcode() == spv::Op::OpExtInst &&
1623 "Expecting an extended instruction.");
1624 assert(inst->GetSingleWordInOperand(0) ==
1625 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1626 "Expecting a GLSLstd450 extended instruction.");
1627
1628 const analysis::Constant* x = constants[1];
1629 const analysis::Constant* max_val = constants[3];
1630
1631 if (x == nullptr || max_val == nullptr) {
1632 return nullptr;
1633 }
1634
1635 const analysis::Constant* temp =
1636 FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1637 if (temp == max_val) {
1638 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1639 // result of the max operation is |min_val|, we know the result of the min
1640 // operation, even if |max_val| is not a constant.
1641 return max_val;
1642 }
1643 return nullptr;
1644 }
1645
FoldFTranscendentalUnary(double (* fp)(double))1646 UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
1647 return
1648 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1649 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1650 assert(result_type != nullptr && a != nullptr);
1651 const analysis::Float* float_type = a->type()->AsFloat();
1652 assert(float_type != nullptr);
1653 assert(float_type == result_type->AsFloat());
1654 if (float_type->width() == 32) {
1655 float fa = a->GetFloat();
1656 float res = static_cast<float>(fp(fa));
1657 utils::FloatProxy<float> result(res);
1658 std::vector<uint32_t> words = result.GetWords();
1659 return const_mgr->GetConstant(result_type, words);
1660 } else if (float_type->width() == 64) {
1661 double fa = a->GetDouble();
1662 double res = fp(fa);
1663 utils::FloatProxy<double> result(res);
1664 std::vector<uint32_t> words = result.GetWords();
1665 return const_mgr->GetConstant(result_type, words);
1666 }
1667 return nullptr;
1668 };
1669 }
1670
FoldFTranscendentalBinary(double (* fp)(double,double))1671 BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1672 double)) {
1673 return
1674 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1675 const analysis::Constant* b,
1676 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1677 assert(result_type != nullptr && a != nullptr);
1678 const analysis::Float* float_type = a->type()->AsFloat();
1679 assert(float_type != nullptr);
1680 assert(float_type == result_type->AsFloat());
1681 assert(float_type == b->type()->AsFloat());
1682 if (float_type->width() == 32) {
1683 float fa = a->GetFloat();
1684 float fb = b->GetFloat();
1685 float res = static_cast<float>(fp(fa, fb));
1686 utils::FloatProxy<float> result(res);
1687 std::vector<uint32_t> words = result.GetWords();
1688 return const_mgr->GetConstant(result_type, words);
1689 } else if (float_type->width() == 64) {
1690 double fa = a->GetDouble();
1691 double fb = b->GetDouble();
1692 double res = fp(fa, fb);
1693 utils::FloatProxy<double> result(res);
1694 std::vector<uint32_t> words = result.GetWords();
1695 return const_mgr->GetConstant(result_type, words);
1696 }
1697 return nullptr;
1698 };
1699 }
1700
1701 enum Sign { Signed, Unsigned };
1702
1703 // Returns a BinaryScalarFoldingRule that applies `op` to the scalars.
1704 // The `signedness` is used to determine if the operands should be interpreted
1705 // as signed or unsigned. If the operands are signed, the value will be sign
1706 // extended before the value is passed to `op`. Otherwise the values will be
1707 // zero extended.
1708 template <Sign signedness>
FoldBinaryIntegerOperation(uint64_t (* op)(uint64_t,uint64_t))1709 BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
1710 uint64_t)) {
1711 return
1712 [op](const analysis::Type* result_type, const analysis::Constant* a,
1713 const analysis::Constant* b,
1714 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1715 assert(result_type != nullptr && a != nullptr && b != nullptr);
1716 const analysis::Integer* integer_type = result_type->AsInteger();
1717 assert(integer_type != nullptr);
1718 assert(a->type()->kind() == analysis::Type::kInteger);
1719 assert(b->type()->kind() == analysis::Type::kInteger);
1720 assert(integer_type->width() == a->type()->AsInteger()->width());
1721 assert(integer_type->width() == b->type()->AsInteger()->width());
1722
1723 // In SPIR-V, all operations support unsigned types, but the way they
1724 // are interpreted depends on the opcode. This is why we use the
1725 // template argument to determine how to interpret the operands.
1726 uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
1727 : a->GetZeroExtendedValue());
1728 uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
1729 : b->GetZeroExtendedValue());
1730 uint64_t result = op(ia, ib);
1731
1732 const analysis::Constant* result_constant =
1733 GenerateIntegerConstant(integer_type, result, const_mgr);
1734 return result_constant;
1735 };
1736 }
1737
1738 // A scalar folding rule that folds OpSConvert.
FoldScalarSConvert(const analysis::Type * result_type,const analysis::Constant * a,analysis::ConstantManager * const_mgr)1739 const analysis::Constant* FoldScalarSConvert(
1740 const analysis::Type* result_type, const analysis::Constant* a,
1741 analysis::ConstantManager* const_mgr) {
1742 assert(result_type != nullptr);
1743 assert(a != nullptr);
1744 assert(const_mgr != nullptr);
1745 const analysis::Integer* integer_type = result_type->AsInteger();
1746 assert(integer_type && "The result type of an SConvert");
1747 int64_t value = a->GetSignExtendedValue();
1748 return GenerateIntegerConstant(integer_type, value, const_mgr);
1749 }
1750
1751 // A scalar folding rule that folds OpUConvert.
FoldScalarUConvert(const analysis::Type * result_type,const analysis::Constant * a,analysis::ConstantManager * const_mgr)1752 const analysis::Constant* FoldScalarUConvert(
1753 const analysis::Type* result_type, const analysis::Constant* a,
1754 analysis::ConstantManager* const_mgr) {
1755 assert(result_type != nullptr);
1756 assert(a != nullptr);
1757 assert(const_mgr != nullptr);
1758 const analysis::Integer* integer_type = result_type->AsInteger();
1759 assert(integer_type && "The result type of an UConvert");
1760 uint64_t value = a->GetZeroExtendedValue();
1761
1762 // If the operand was an unsigned value with less than 32-bit, it would have
1763 // been sign extended earlier, and we need to clear those bits.
1764 auto* operand_type = a->type()->AsInteger();
1765 value = ZeroExtendValue(value, operand_type->width());
1766 return GenerateIntegerConstant(integer_type, value, const_mgr);
1767 }
1768 } // namespace
1769
AddFoldingRules()1770 void ConstantFoldingRules::AddFoldingRules() {
1771 // Add all folding rules to the list for the opcodes to which they apply.
1772 // Note that the order in which rules are added to the list matters. If a rule
1773 // applies to the instruction, the rest of the rules will not be attempted.
1774 // Take that into consideration.
1775
1776 rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
1777
1778 rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
1779 rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
1780
1781 rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
1782 rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
1783 rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
1784 rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
1785 rules_[spv::Op::OpSConvert].push_back(FoldUnaryOp(FoldScalarSConvert));
1786 rules_[spv::Op::OpUConvert].push_back(FoldUnaryOp(FoldScalarUConvert));
1787
1788 rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
1789 rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
1790 rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
1791 rules_[spv::Op::OpFMul].push_back(FoldFMul());
1792 rules_[spv::Op::OpFSub].push_back(FoldFSub());
1793
1794 rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual());
1795
1796 rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual());
1797
1798 rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual());
1799
1800 rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual());
1801
1802 rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan());
1803 rules_[spv::Op::OpFOrdLessThan].push_back(
1804 FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan));
1805
1806 rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan());
1807 rules_[spv::Op::OpFUnordLessThan].push_back(
1808 FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan));
1809
1810 rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
1811 rules_[spv::Op::OpFOrdGreaterThan].push_back(
1812 FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan));
1813
1814 rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
1815 rules_[spv::Op::OpFUnordGreaterThan].push_back(
1816 FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan));
1817
1818 rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
1819 rules_[spv::Op::OpFOrdLessThanEqual].push_back(
1820 FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual));
1821
1822 rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
1823 rules_[spv::Op::OpFUnordLessThanEqual].push_back(
1824 FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual));
1825
1826 rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
1827 rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(
1828 FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual));
1829
1830 rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1831 FoldFUnordGreaterThanEqual());
1832 rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1833 FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual));
1834
1835 rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
1836 rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
1837 rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
1838 rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
1839 rules_[spv::Op::OpTranspose].push_back(FoldTranspose);
1840
1841 rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
1842 rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
1843 rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
1844
1845 rules_[spv::Op::OpIAdd].push_back(
1846 FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1847 [](uint64_t a, uint64_t b) { return a + b; })));
1848 rules_[spv::Op::OpISub].push_back(
1849 FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1850 [](uint64_t a, uint64_t b) { return a - b; })));
1851 rules_[spv::Op::OpIMul].push_back(
1852 FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1853 [](uint64_t a, uint64_t b) { return a * b; })));
1854 rules_[spv::Op::OpUDiv].push_back(
1855 FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1856 [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); })));
1857 rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp(
1858 FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1859 return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) /
1860 static_cast<int64_t>(b))
1861 : 0);
1862 })));
1863 rules_[spv::Op::OpUMod].push_back(
1864 FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1865 [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); })));
1866
1867 rules_[spv::Op::OpSRem].push_back(FoldBinaryOp(
1868 FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1869 return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) %
1870 static_cast<int64_t>(b))
1871 : 0);
1872 })));
1873
1874 rules_[spv::Op::OpSMod].push_back(FoldBinaryOp(
1875 FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1876 if (b == 0) return static_cast<uint64_t>(0ull);
1877
1878 int64_t signed_a = static_cast<int64_t>(a);
1879 int64_t signed_b = static_cast<int64_t>(b);
1880 int64_t result = signed_a % signed_b;
1881 if ((signed_b < 0) != (result < 0)) result += signed_b;
1882 return static_cast<uint64_t>(result);
1883 })));
1884
1885 // Add rules for GLSLstd450
1886 FeatureManager* feature_manager = context_->get_feature_mgr();
1887 uint32_t ext_inst_glslstd450_id =
1888 feature_manager->GetExtInstImportId_GLSLstd450();
1889 if (ext_inst_glslstd450_id != 0) {
1890 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
1891 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
1892 FoldFPBinaryOp(FoldMin));
1893 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
1894 FoldFPBinaryOp(FoldMin));
1895 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
1896 FoldFPBinaryOp(FoldMin));
1897 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
1898 FoldFPBinaryOp(FoldMax));
1899 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
1900 FoldFPBinaryOp(FoldMax));
1901 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
1902 FoldFPBinaryOp(FoldMax));
1903 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1904 FoldClamp1);
1905 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1906 FoldClamp2);
1907 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1908 FoldClamp3);
1909 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1910 FoldClamp1);
1911 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1912 FoldClamp2);
1913 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1914 FoldClamp3);
1915 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1916 FoldClamp1);
1917 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1918 FoldClamp2);
1919 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1920 FoldClamp3);
1921 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
1922 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
1923 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
1924 FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
1925 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
1926 FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
1927 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
1928 FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
1929 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
1930 FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
1931 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
1932 FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
1933 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
1934 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
1935 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
1936 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
1937
1938 #ifdef __ANDROID__
1939 // Android NDK r15c targeting ABI 15 doesn't have full support for C++11
1940 // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
1941 // available up until ABI 18 so we use a shim
1942 auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
1943 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1944 FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
1945 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1946 FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
1947 #else
1948 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1949 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
1950 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1951 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
1952 #endif
1953
1954 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
1955 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
1956 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
1957 FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
1958 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
1959 FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
1960 }
1961 }
1962 } // namespace opt
1963 } // namespace spvtools
1964