1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/fold.h"
16
17 #include <cassert>
18 #include <cstdint>
19 #include <vector>
20
21 #include "source/opt/const_folding_rules.h"
22 #include "source/opt/def_use_manager.h"
23 #include "source/opt/folding_rules.h"
24 #include "source/opt/ir_builder.h"
25 #include "source/opt/ir_context.h"
26
27 namespace spvtools {
28 namespace opt {
29 namespace {
30
31 #ifndef INT32_MIN
32 #define INT32_MIN (-2147483648)
33 #endif
34
35 #ifndef INT32_MAX
36 #define INT32_MAX 2147483647
37 #endif
38
39 #ifndef UINT32_MAX
40 #define UINT32_MAX 0xffffffff /* 4294967295U */
41 #endif
42
43 } // namespace
44
UnaryOperate(SpvOp opcode,uint32_t operand) const45 uint32_t InstructionFolder::UnaryOperate(SpvOp opcode, uint32_t operand) const {
46 switch (opcode) {
47 // Arthimetics
48 case SpvOp::SpvOpSNegate: {
49 int32_t s_operand = static_cast<int32_t>(operand);
50 if (s_operand == std::numeric_limits<int32_t>::min()) {
51 return s_operand;
52 }
53 return -s_operand;
54 }
55 case SpvOp::SpvOpNot:
56 return ~operand;
57 case SpvOp::SpvOpLogicalNot:
58 return !static_cast<bool>(operand);
59 case SpvOp::SpvOpUConvert:
60 return operand;
61 case SpvOp::SpvOpSConvert:
62 return operand;
63 default:
64 assert(false &&
65 "Unsupported unary operation for OpSpecConstantOp instruction");
66 return 0u;
67 }
68 }
69
BinaryOperate(SpvOp opcode,uint32_t a,uint32_t b) const70 uint32_t InstructionFolder::BinaryOperate(SpvOp opcode, uint32_t a,
71 uint32_t b) const {
72 switch (opcode) {
73 // Arthimetics
74 case SpvOp::SpvOpIAdd:
75 return a + b;
76 case SpvOp::SpvOpISub:
77 return a - b;
78 case SpvOp::SpvOpIMul:
79 return a * b;
80 case SpvOp::SpvOpUDiv:
81 if (b != 0) {
82 return a / b;
83 } else {
84 // Dividing by 0 is undefined, so we will just pick 0.
85 return 0;
86 }
87 case SpvOp::SpvOpSDiv:
88 if (b != 0u) {
89 return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
90 } else {
91 // Dividing by 0 is undefined, so we will just pick 0.
92 return 0;
93 }
94 case SpvOp::SpvOpSRem: {
95 // The sign of non-zero result comes from the first operand: a. This is
96 // guaranteed by C++11 rules for integer division operator. The division
97 // result is rounded toward zero, so the result of '%' has the sign of
98 // the first operand.
99 if (b != 0u) {
100 return static_cast<int32_t>(a) % static_cast<int32_t>(b);
101 } else {
102 // Remainder when dividing with 0 is undefined, so we will just pick 0.
103 return 0;
104 }
105 }
106 case SpvOp::SpvOpSMod: {
107 // The sign of non-zero result comes from the second operand: b
108 if (b != 0u) {
109 int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
110 int32_t b_prim = static_cast<int32_t>(b);
111 return (rem + b_prim) % b_prim;
112 } else {
113 // Mod with 0 is undefined, so we will just pick 0.
114 return 0;
115 }
116 }
117 case SpvOp::SpvOpUMod:
118 if (b != 0u) {
119 return (a % b);
120 } else {
121 // Mod with 0 is undefined, so we will just pick 0.
122 return 0;
123 }
124
125 // Shifting
126 case SpvOp::SpvOpShiftRightLogical:
127 if (b >= 32) {
128 // This is undefined behaviour when |b| > 32. Choose 0 for consistency.
129 // When |b| == 32, doing the shift in C++ in undefined, but the result
130 // will be 0, so just return that value.
131 return 0;
132 }
133 return a >> b;
134 case SpvOp::SpvOpShiftRightArithmetic:
135 if (b > 32) {
136 // This is undefined behaviour. Choose 0 for consistency.
137 return 0;
138 }
139 if (b == 32) {
140 // Doing the shift in C++ is undefined, but the result is defined in the
141 // spir-v spec. Find that value another way.
142 if (static_cast<int32_t>(a) >= 0) {
143 return 0;
144 } else {
145 return static_cast<uint32_t>(-1);
146 }
147 }
148 return (static_cast<int32_t>(a)) >> b;
149 case SpvOp::SpvOpShiftLeftLogical:
150 if (b >= 32) {
151 // This is undefined behaviour when |b| > 32. Choose 0 for consistency.
152 // When |b| == 32, doing the shift in C++ in undefined, but the result
153 // will be 0, so just return that value.
154 return 0;
155 }
156 return a << b;
157
158 // Bitwise operations
159 case SpvOp::SpvOpBitwiseOr:
160 return a | b;
161 case SpvOp::SpvOpBitwiseAnd:
162 return a & b;
163 case SpvOp::SpvOpBitwiseXor:
164 return a ^ b;
165
166 // Logical
167 case SpvOp::SpvOpLogicalEqual:
168 return (static_cast<bool>(a)) == (static_cast<bool>(b));
169 case SpvOp::SpvOpLogicalNotEqual:
170 return (static_cast<bool>(a)) != (static_cast<bool>(b));
171 case SpvOp::SpvOpLogicalOr:
172 return (static_cast<bool>(a)) || (static_cast<bool>(b));
173 case SpvOp::SpvOpLogicalAnd:
174 return (static_cast<bool>(a)) && (static_cast<bool>(b));
175
176 // Comparison
177 case SpvOp::SpvOpIEqual:
178 return a == b;
179 case SpvOp::SpvOpINotEqual:
180 return a != b;
181 case SpvOp::SpvOpULessThan:
182 return a < b;
183 case SpvOp::SpvOpSLessThan:
184 return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
185 case SpvOp::SpvOpUGreaterThan:
186 return a > b;
187 case SpvOp::SpvOpSGreaterThan:
188 return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
189 case SpvOp::SpvOpULessThanEqual:
190 return a <= b;
191 case SpvOp::SpvOpSLessThanEqual:
192 return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
193 case SpvOp::SpvOpUGreaterThanEqual:
194 return a >= b;
195 case SpvOp::SpvOpSGreaterThanEqual:
196 return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
197 default:
198 assert(false &&
199 "Unsupported binary operation for OpSpecConstantOp instruction");
200 return 0u;
201 }
202 }
203
TernaryOperate(SpvOp opcode,uint32_t a,uint32_t b,uint32_t c) const204 uint32_t InstructionFolder::TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b,
205 uint32_t c) const {
206 switch (opcode) {
207 case SpvOp::SpvOpSelect:
208 return (static_cast<bool>(a)) ? b : c;
209 default:
210 assert(false &&
211 "Unsupported ternary operation for OpSpecConstantOp instruction");
212 return 0u;
213 }
214 }
215
OperateWords(SpvOp opcode,const std::vector<uint32_t> & operand_words) const216 uint32_t InstructionFolder::OperateWords(
217 SpvOp opcode, const std::vector<uint32_t>& operand_words) const {
218 switch (operand_words.size()) {
219 case 1:
220 return UnaryOperate(opcode, operand_words.front());
221 case 2:
222 return BinaryOperate(opcode, operand_words.front(), operand_words.back());
223 case 3:
224 return TernaryOperate(opcode, operand_words[0], operand_words[1],
225 operand_words[2]);
226 default:
227 assert(false && "Invalid number of operands");
228 return 0;
229 }
230 }
231
FoldInstructionInternal(Instruction * inst) const232 bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
233 auto identity_map = [](uint32_t id) { return id; };
234 Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
235 if (folded_inst != nullptr) {
236 inst->SetOpcode(SpvOpCopyObject);
237 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
238 return true;
239 }
240
241 analysis::ConstantManager* const_manager = context_->get_constant_mgr();
242 std::vector<const analysis::Constant*> constants =
243 const_manager->GetOperandConstants(inst);
244
245 for (const FoldingRule& rule :
246 GetFoldingRules().GetRulesForInstruction(inst)) {
247 if (rule(context_, inst, constants)) {
248 return true;
249 }
250 }
251 return false;
252 }
253
254 // Returns the result of performing an operation on scalar constant operands.
255 // This function extracts the operand values as 32 bit words and returns the
256 // result in 32 bit word. Scalar constants with longer than 32-bit width are
257 // not accepted in this function.
FoldScalars(SpvOp opcode,const std::vector<const analysis::Constant * > & operands) const258 uint32_t InstructionFolder::FoldScalars(
259 SpvOp opcode,
260 const std::vector<const analysis::Constant*>& operands) const {
261 assert(IsFoldableOpcode(opcode) &&
262 "Unhandled instruction opcode in FoldScalars");
263 std::vector<uint32_t> operand_values_in_raw_words;
264 for (const auto& operand : operands) {
265 if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
266 const auto& scalar_words = scalar->words();
267 assert(scalar_words.size() == 1 &&
268 "Scalar constants with longer than 32-bit width are not allowed "
269 "in FoldScalars()");
270 operand_values_in_raw_words.push_back(scalar_words.front());
271 } else if (operand->AsNullConstant()) {
272 operand_values_in_raw_words.push_back(0u);
273 } else {
274 assert(false &&
275 "FoldScalars() only accepts ScalarConst or NullConst type of "
276 "constant");
277 }
278 }
279 return OperateWords(opcode, operand_values_in_raw_words);
280 }
281
FoldBinaryIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const282 bool InstructionFolder::FoldBinaryIntegerOpToConstant(
283 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
284 uint32_t* result) const {
285 SpvOp opcode = inst->opcode();
286 analysis::ConstantManager* const_manger = context_->get_constant_mgr();
287
288 uint32_t ids[2];
289 const analysis::IntConstant* constants[2];
290 for (uint32_t i = 0; i < 2; i++) {
291 const Operand* operand = &inst->GetInOperand(i);
292 if (operand->type != SPV_OPERAND_TYPE_ID) {
293 return false;
294 }
295 ids[i] = id_map(operand->words[0]);
296 const analysis::Constant* constant =
297 const_manger->FindDeclaredConstant(ids[i]);
298 constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
299 }
300
301 switch (opcode) {
302 // Arthimetics
303 case SpvOp::SpvOpIMul:
304 for (uint32_t i = 0; i < 2; i++) {
305 if (constants[i] != nullptr && constants[i]->IsZero()) {
306 *result = 0;
307 return true;
308 }
309 }
310 break;
311 case SpvOp::SpvOpUDiv:
312 case SpvOp::SpvOpSDiv:
313 case SpvOp::SpvOpSRem:
314 case SpvOp::SpvOpSMod:
315 case SpvOp::SpvOpUMod:
316 // This changes undefined behaviour (ie divide by 0) into a 0.
317 for (uint32_t i = 0; i < 2; i++) {
318 if (constants[i] != nullptr && constants[i]->IsZero()) {
319 *result = 0;
320 return true;
321 }
322 }
323 break;
324
325 // Shifting
326 case SpvOp::SpvOpShiftRightLogical:
327 case SpvOp::SpvOpShiftLeftLogical:
328 if (constants[1] != nullptr) {
329 // When shifting by a value larger than the size of the result, the
330 // result is undefined. We are setting the undefined behaviour to a
331 // result of 0. If the shift amount is the same as the size of the
332 // result, then the result is defined, and it 0.
333 uint32_t shift_amount = constants[1]->GetU32BitValue();
334 if (shift_amount >= 32) {
335 *result = 0;
336 return true;
337 }
338 }
339 break;
340
341 // Bitwise operations
342 case SpvOp::SpvOpBitwiseOr:
343 for (uint32_t i = 0; i < 2; i++) {
344 if (constants[i] != nullptr) {
345 // TODO: Change the mask against a value based on the bit width of the
346 // instruction result type. This way we can handle say 16-bit values
347 // as well.
348 uint32_t mask = constants[i]->GetU32BitValue();
349 if (mask == 0xFFFFFFFF) {
350 *result = 0xFFFFFFFF;
351 return true;
352 }
353 }
354 }
355 break;
356 case SpvOp::SpvOpBitwiseAnd:
357 for (uint32_t i = 0; i < 2; i++) {
358 if (constants[i] != nullptr) {
359 if (constants[i]->IsZero()) {
360 *result = 0;
361 return true;
362 }
363 }
364 }
365 break;
366
367 // Comparison
368 case SpvOp::SpvOpULessThan:
369 if (constants[0] != nullptr &&
370 constants[0]->GetU32BitValue() == UINT32_MAX) {
371 *result = false;
372 return true;
373 }
374 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
375 *result = false;
376 return true;
377 }
378 break;
379 case SpvOp::SpvOpSLessThan:
380 if (constants[0] != nullptr &&
381 constants[0]->GetS32BitValue() == INT32_MAX) {
382 *result = false;
383 return true;
384 }
385 if (constants[1] != nullptr &&
386 constants[1]->GetS32BitValue() == INT32_MIN) {
387 *result = false;
388 return true;
389 }
390 break;
391 case SpvOp::SpvOpUGreaterThan:
392 if (constants[0] != nullptr && constants[0]->IsZero()) {
393 *result = false;
394 return true;
395 }
396 if (constants[1] != nullptr &&
397 constants[1]->GetU32BitValue() == UINT32_MAX) {
398 *result = false;
399 return true;
400 }
401 break;
402 case SpvOp::SpvOpSGreaterThan:
403 if (constants[0] != nullptr &&
404 constants[0]->GetS32BitValue() == INT32_MIN) {
405 *result = false;
406 return true;
407 }
408 if (constants[1] != nullptr &&
409 constants[1]->GetS32BitValue() == INT32_MAX) {
410 *result = false;
411 return true;
412 }
413 break;
414 case SpvOp::SpvOpULessThanEqual:
415 if (constants[0] != nullptr && constants[0]->IsZero()) {
416 *result = true;
417 return true;
418 }
419 if (constants[1] != nullptr &&
420 constants[1]->GetU32BitValue() == UINT32_MAX) {
421 *result = true;
422 return true;
423 }
424 break;
425 case SpvOp::SpvOpSLessThanEqual:
426 if (constants[0] != nullptr &&
427 constants[0]->GetS32BitValue() == INT32_MIN) {
428 *result = true;
429 return true;
430 }
431 if (constants[1] != nullptr &&
432 constants[1]->GetS32BitValue() == INT32_MAX) {
433 *result = true;
434 return true;
435 }
436 break;
437 case SpvOp::SpvOpUGreaterThanEqual:
438 if (constants[0] != nullptr &&
439 constants[0]->GetU32BitValue() == UINT32_MAX) {
440 *result = true;
441 return true;
442 }
443 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
444 *result = true;
445 return true;
446 }
447 break;
448 case SpvOp::SpvOpSGreaterThanEqual:
449 if (constants[0] != nullptr &&
450 constants[0]->GetS32BitValue() == INT32_MAX) {
451 *result = true;
452 return true;
453 }
454 if (constants[1] != nullptr &&
455 constants[1]->GetS32BitValue() == INT32_MIN) {
456 *result = true;
457 return true;
458 }
459 break;
460 default:
461 break;
462 }
463 return false;
464 }
465
FoldBinaryBooleanOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const466 bool InstructionFolder::FoldBinaryBooleanOpToConstant(
467 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
468 uint32_t* result) const {
469 SpvOp opcode = inst->opcode();
470 analysis::ConstantManager* const_manger = context_->get_constant_mgr();
471
472 uint32_t ids[2];
473 const analysis::BoolConstant* constants[2];
474 for (uint32_t i = 0; i < 2; i++) {
475 const Operand* operand = &inst->GetInOperand(i);
476 if (operand->type != SPV_OPERAND_TYPE_ID) {
477 return false;
478 }
479 ids[i] = id_map(operand->words[0]);
480 const analysis::Constant* constant =
481 const_manger->FindDeclaredConstant(ids[i]);
482 constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
483 }
484
485 switch (opcode) {
486 // Logical
487 case SpvOp::SpvOpLogicalOr:
488 for (uint32_t i = 0; i < 2; i++) {
489 if (constants[i] != nullptr) {
490 if (constants[i]->value()) {
491 *result = true;
492 return true;
493 }
494 }
495 }
496 break;
497 case SpvOp::SpvOpLogicalAnd:
498 for (uint32_t i = 0; i < 2; i++) {
499 if (constants[i] != nullptr) {
500 if (!constants[i]->value()) {
501 *result = false;
502 return true;
503 }
504 }
505 }
506 break;
507
508 default:
509 break;
510 }
511 return false;
512 }
513
FoldIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const514 bool InstructionFolder::FoldIntegerOpToConstant(
515 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
516 uint32_t* result) const {
517 assert(IsFoldableOpcode(inst->opcode()) &&
518 "Unhandled instruction opcode in FoldScalars");
519 switch (inst->NumInOperands()) {
520 case 2:
521 return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
522 FoldBinaryBooleanOpToConstant(inst, id_map, result);
523 default:
524 return false;
525 }
526 }
527
FoldVectors(SpvOp opcode,uint32_t num_dims,const std::vector<const analysis::Constant * > & operands) const528 std::vector<uint32_t> InstructionFolder::FoldVectors(
529 SpvOp opcode, uint32_t num_dims,
530 const std::vector<const analysis::Constant*>& operands) const {
531 assert(IsFoldableOpcode(opcode) &&
532 "Unhandled instruction opcode in FoldVectors");
533 std::vector<uint32_t> result;
534 for (uint32_t d = 0; d < num_dims; d++) {
535 std::vector<uint32_t> operand_values_for_one_dimension;
536 for (const auto& operand : operands) {
537 if (const analysis::VectorConstant* vector_operand =
538 operand->AsVectorConstant()) {
539 // Extract the raw value of the scalar component constants
540 // in 32-bit words here. The reason of not using FoldScalars() here
541 // is that we do not create temporary null constants as components
542 // when the vector operand is a NullConstant because Constant creation
543 // may need extra checks for the validity and that is not manageed in
544 // here.
545 if (const analysis::ScalarConstant* scalar_component =
546 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
547 const auto& scalar_words = scalar_component->words();
548 assert(
549 scalar_words.size() == 1 &&
550 "Vector components with longer than 32-bit width are not allowed "
551 "in FoldVectors()");
552 operand_values_for_one_dimension.push_back(scalar_words.front());
553 } else if (operand->AsNullConstant()) {
554 operand_values_for_one_dimension.push_back(0u);
555 } else {
556 assert(false &&
557 "VectorConst should only has ScalarConst or NullConst as "
558 "components");
559 }
560 } else if (operand->AsNullConstant()) {
561 operand_values_for_one_dimension.push_back(0u);
562 } else {
563 assert(false &&
564 "FoldVectors() only accepts VectorConst or NullConst type of "
565 "constant");
566 }
567 }
568 result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
569 }
570 return result;
571 }
572
IsFoldableOpcode(SpvOp opcode) const573 bool InstructionFolder::IsFoldableOpcode(SpvOp opcode) const {
574 // NOTE: Extend to more opcodes as new cases are handled in the folder
575 // functions.
576 switch (opcode) {
577 case SpvOp::SpvOpBitwiseAnd:
578 case SpvOp::SpvOpBitwiseOr:
579 case SpvOp::SpvOpBitwiseXor:
580 case SpvOp::SpvOpIAdd:
581 case SpvOp::SpvOpIEqual:
582 case SpvOp::SpvOpIMul:
583 case SpvOp::SpvOpINotEqual:
584 case SpvOp::SpvOpISub:
585 case SpvOp::SpvOpLogicalAnd:
586 case SpvOp::SpvOpLogicalEqual:
587 case SpvOp::SpvOpLogicalNot:
588 case SpvOp::SpvOpLogicalNotEqual:
589 case SpvOp::SpvOpLogicalOr:
590 case SpvOp::SpvOpNot:
591 case SpvOp::SpvOpSDiv:
592 case SpvOp::SpvOpSelect:
593 case SpvOp::SpvOpSGreaterThan:
594 case SpvOp::SpvOpSGreaterThanEqual:
595 case SpvOp::SpvOpShiftLeftLogical:
596 case SpvOp::SpvOpShiftRightArithmetic:
597 case SpvOp::SpvOpShiftRightLogical:
598 case SpvOp::SpvOpSLessThan:
599 case SpvOp::SpvOpSLessThanEqual:
600 case SpvOp::SpvOpSMod:
601 case SpvOp::SpvOpSNegate:
602 case SpvOp::SpvOpSRem:
603 case SpvOp::SpvOpSConvert:
604 case SpvOp::SpvOpUConvert:
605 case SpvOp::SpvOpUDiv:
606 case SpvOp::SpvOpUGreaterThan:
607 case SpvOp::SpvOpUGreaterThanEqual:
608 case SpvOp::SpvOpULessThan:
609 case SpvOp::SpvOpULessThanEqual:
610 case SpvOp::SpvOpUMod:
611 return true;
612 default:
613 return false;
614 }
615 }
616
IsFoldableConstant(const analysis::Constant * cst) const617 bool InstructionFolder::IsFoldableConstant(
618 const analysis::Constant* cst) const {
619 // Currently supported constants are 32-bit values or null constants.
620 if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
621 return scalar->words().size() == 1;
622 else
623 return cst->AsNullConstant() != nullptr;
624 }
625
FoldInstructionToConstant(Instruction * inst,std::function<uint32_t (uint32_t)> id_map) const626 Instruction* InstructionFolder::FoldInstructionToConstant(
627 Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
628 analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
629
630 if (!inst->IsFoldableByFoldScalar() &&
631 !GetConstantFoldingRules().HasFoldingRule(inst)) {
632 return nullptr;
633 }
634 // Collect the values of the constant parameters.
635 std::vector<const analysis::Constant*> constants;
636 bool missing_constants = false;
637 inst->ForEachInId([&constants, &missing_constants, const_mgr,
638 &id_map](uint32_t* op_id) {
639 uint32_t id = id_map(*op_id);
640 const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
641 if (!const_op) {
642 constants.push_back(nullptr);
643 missing_constants = true;
644 } else {
645 constants.push_back(const_op);
646 }
647 });
648
649 const analysis::Constant* folded_const = nullptr;
650 for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) {
651 folded_const = rule(context_, inst, constants);
652 if (folded_const != nullptr) {
653 Instruction* const_inst =
654 const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
655 if (const_inst == nullptr) {
656 return nullptr;
657 }
658 assert(const_inst->type_id() == inst->type_id());
659 // May be a new instruction that needs to be analysed.
660 context_->UpdateDefUse(const_inst);
661 return const_inst;
662 }
663 }
664
665 uint32_t result_val = 0;
666 bool successful = false;
667 // If all parameters are constant, fold the instruction to a constant.
668 if (!missing_constants && inst->IsFoldableByFoldScalar()) {
669 result_val = FoldScalars(inst->opcode(), constants);
670 successful = true;
671 }
672
673 if (!successful && inst->IsFoldableByFoldScalar()) {
674 successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
675 }
676
677 if (successful) {
678 const analysis::Constant* result_const =
679 const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
680 Instruction* folded_inst =
681 const_mgr->GetDefiningInstruction(result_const, inst->type_id());
682 return folded_inst;
683 }
684 return nullptr;
685 }
686
IsFoldableType(Instruction * type_inst) const687 bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
688 // Support 32-bit integers.
689 if (type_inst->opcode() == SpvOpTypeInt) {
690 return type_inst->GetSingleWordInOperand(0) == 32;
691 }
692 // Support booleans.
693 if (type_inst->opcode() == SpvOpTypeBool) {
694 return true;
695 }
696 // Nothing else yet.
697 return false;
698 }
699
FoldInstruction(Instruction * inst) const700 bool InstructionFolder::FoldInstruction(Instruction* inst) const {
701 bool modified = false;
702 Instruction* folded_inst(inst);
703 while (folded_inst->opcode() != SpvOpCopyObject &&
704 FoldInstructionInternal(&*folded_inst)) {
705 modified = true;
706 }
707 return modified;
708 }
709
710 } // namespace opt
711 } // namespace spvtools
712