1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/fold.h"
16
17 #include <cassert>
18 #include <cstdint>
19 #include <vector>
20
21 #include "source/opt/const_folding_rules.h"
22 #include "source/opt/def_use_manager.h"
23 #include "source/opt/folding_rules.h"
24 #include "source/opt/ir_context.h"
25
26 namespace spvtools {
27 namespace opt {
28 namespace {
29
30 #ifndef INT32_MIN
31 #define INT32_MIN (-2147483648)
32 #endif
33
34 #ifndef INT32_MAX
35 #define INT32_MAX 2147483647
36 #endif
37
38 #ifndef UINT32_MAX
39 #define UINT32_MAX 0xffffffff /* 4294967295U */
40 #endif
41
42 } // namespace
43
UnaryOperate(spv::Op opcode,uint32_t operand) const44 uint32_t InstructionFolder::UnaryOperate(spv::Op opcode,
45 uint32_t operand) const {
46 switch (opcode) {
47 // Arthimetics
48 case spv::Op::OpSNegate: {
49 int32_t s_operand = static_cast<int32_t>(operand);
50 if (s_operand == std::numeric_limits<int32_t>::min()) {
51 return s_operand;
52 }
53 return -s_operand;
54 }
55 case spv::Op::OpNot:
56 return ~operand;
57 case spv::Op::OpLogicalNot:
58 return !static_cast<bool>(operand);
59 case spv::Op::OpUConvert:
60 return operand;
61 case spv::Op::OpSConvert:
62 return operand;
63 default:
64 assert(false &&
65 "Unsupported unary operation for OpSpecConstantOp instruction");
66 return 0u;
67 }
68 }
69
BinaryOperate(spv::Op opcode,uint32_t a,uint32_t b) const70 uint32_t InstructionFolder::BinaryOperate(spv::Op opcode, uint32_t a,
71 uint32_t b) const {
72 switch (opcode) {
73 // Arthimetics
74 case spv::Op::OpIAdd:
75 return a + b;
76 case spv::Op::OpISub:
77 return a - b;
78 case spv::Op::OpIMul:
79 return a * b;
80 case spv::Op::OpUDiv:
81 if (b != 0) {
82 return a / b;
83 } else {
84 // Dividing by 0 is undefined, so we will just pick 0.
85 return 0;
86 }
87 case spv::Op::OpSDiv:
88 if (b != 0u) {
89 return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
90 } else {
91 // Dividing by 0 is undefined, so we will just pick 0.
92 return 0;
93 }
94 case spv::Op::OpSRem: {
95 // The sign of non-zero result comes from the first operand: a. This is
96 // guaranteed by C++11 rules for integer division operator. The division
97 // result is rounded toward zero, so the result of '%' has the sign of
98 // the first operand.
99 if (b != 0u) {
100 return static_cast<int32_t>(a) % static_cast<int32_t>(b);
101 } else {
102 // Remainder when dividing with 0 is undefined, so we will just pick 0.
103 return 0;
104 }
105 }
106 case spv::Op::OpSMod: {
107 // The sign of non-zero result comes from the second operand: b
108 if (b != 0u) {
109 int32_t rem = BinaryOperate(spv::Op::OpSRem, a, b);
110 int32_t b_prim = static_cast<int32_t>(b);
111 return (rem + b_prim) % b_prim;
112 } else {
113 // Mod with 0 is undefined, so we will just pick 0.
114 return 0;
115 }
116 }
117 case spv::Op::OpUMod:
118 if (b != 0u) {
119 return (a % b);
120 } else {
121 // Mod with 0 is undefined, so we will just pick 0.
122 return 0;
123 }
124
125 // Shifting
126 case spv::Op::OpShiftRightLogical:
127 if (b >= 32) {
128 // This is undefined behaviour when |b| > 32. Choose 0 for consistency.
129 // When |b| == 32, doing the shift in C++ in undefined, but the result
130 // will be 0, so just return that value.
131 return 0;
132 }
133 return a >> b;
134 case spv::Op::OpShiftRightArithmetic:
135 if (b > 32) {
136 // This is undefined behaviour. Choose 0 for consistency.
137 return 0;
138 }
139 if (b == 32) {
140 // Doing the shift in C++ is undefined, but the result is defined in the
141 // spir-v spec. Find that value another way.
142 if (static_cast<int32_t>(a) >= 0) {
143 return 0;
144 } else {
145 return static_cast<uint32_t>(-1);
146 }
147 }
148 return (static_cast<int32_t>(a)) >> b;
149 case spv::Op::OpShiftLeftLogical:
150 if (b >= 32) {
151 // This is undefined behaviour when |b| > 32. Choose 0 for consistency.
152 // When |b| == 32, doing the shift in C++ in undefined, but the result
153 // will be 0, so just return that value.
154 return 0;
155 }
156 return a << b;
157
158 // Bitwise operations
159 case spv::Op::OpBitwiseOr:
160 return a | b;
161 case spv::Op::OpBitwiseAnd:
162 return a & b;
163 case spv::Op::OpBitwiseXor:
164 return a ^ b;
165
166 // Logical
167 case spv::Op::OpLogicalEqual:
168 return (static_cast<bool>(a)) == (static_cast<bool>(b));
169 case spv::Op::OpLogicalNotEqual:
170 return (static_cast<bool>(a)) != (static_cast<bool>(b));
171 case spv::Op::OpLogicalOr:
172 return (static_cast<bool>(a)) || (static_cast<bool>(b));
173 case spv::Op::OpLogicalAnd:
174 return (static_cast<bool>(a)) && (static_cast<bool>(b));
175
176 // Comparison
177 case spv::Op::OpIEqual:
178 return a == b;
179 case spv::Op::OpINotEqual:
180 return a != b;
181 case spv::Op::OpULessThan:
182 return a < b;
183 case spv::Op::OpSLessThan:
184 return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
185 case spv::Op::OpUGreaterThan:
186 return a > b;
187 case spv::Op::OpSGreaterThan:
188 return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
189 case spv::Op::OpULessThanEqual:
190 return a <= b;
191 case spv::Op::OpSLessThanEqual:
192 return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
193 case spv::Op::OpUGreaterThanEqual:
194 return a >= b;
195 case spv::Op::OpSGreaterThanEqual:
196 return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
197 default:
198 assert(false &&
199 "Unsupported binary operation for OpSpecConstantOp instruction");
200 return 0u;
201 }
202 }
203
TernaryOperate(spv::Op opcode,uint32_t a,uint32_t b,uint32_t c) const204 uint32_t InstructionFolder::TernaryOperate(spv::Op opcode, uint32_t a,
205 uint32_t b, uint32_t c) const {
206 switch (opcode) {
207 case spv::Op::OpSelect:
208 return (static_cast<bool>(a)) ? b : c;
209 default:
210 assert(false &&
211 "Unsupported ternary operation for OpSpecConstantOp instruction");
212 return 0u;
213 }
214 }
215
OperateWords(spv::Op opcode,const std::vector<uint32_t> & operand_words) const216 uint32_t InstructionFolder::OperateWords(
217 spv::Op opcode, const std::vector<uint32_t>& operand_words) const {
218 switch (operand_words.size()) {
219 case 1:
220 return UnaryOperate(opcode, operand_words.front());
221 case 2:
222 return BinaryOperate(opcode, operand_words.front(), operand_words.back());
223 case 3:
224 return TernaryOperate(opcode, operand_words[0], operand_words[1],
225 operand_words[2]);
226 default:
227 assert(false && "Invalid number of operands");
228 return 0;
229 }
230 }
231
FoldInstructionInternal(Instruction * inst) const232 bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
233 auto identity_map = [](uint32_t id) { return id; };
234 Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
235 if (folded_inst != nullptr) {
236 inst->SetOpcode(spv::Op::OpCopyObject);
237 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
238 return true;
239 }
240
241 analysis::ConstantManager* const_manager = context_->get_constant_mgr();
242 std::vector<const analysis::Constant*> constants =
243 const_manager->GetOperandConstants(inst);
244
245 for (const FoldingRule& rule :
246 GetFoldingRules().GetRulesForInstruction(inst)) {
247 if (rule(context_, inst, constants)) {
248 return true;
249 }
250 }
251 return false;
252 }
253
254 // Returns the result of performing an operation on scalar constant operands.
255 // This function extracts the operand values as 32 bit words and returns the
256 // result in 32 bit word. Scalar constants with longer than 32-bit width are
257 // not accepted in this function.
FoldScalars(spv::Op opcode,const std::vector<const analysis::Constant * > & operands) const258 uint32_t InstructionFolder::FoldScalars(
259 spv::Op opcode,
260 const std::vector<const analysis::Constant*>& operands) const {
261 assert(IsFoldableOpcode(opcode) &&
262 "Unhandled instruction opcode in FoldScalars");
263 std::vector<uint32_t> operand_values_in_raw_words;
264 for (const auto& operand : operands) {
265 if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
266 const auto& scalar_words = scalar->words();
267 assert(scalar_words.size() == 1 &&
268 "Scalar constants with longer than 32-bit width are not allowed "
269 "in FoldScalars()");
270 operand_values_in_raw_words.push_back(scalar_words.front());
271 } else if (operand->AsNullConstant()) {
272 operand_values_in_raw_words.push_back(0u);
273 } else {
274 assert(false &&
275 "FoldScalars() only accepts ScalarConst or NullConst type of "
276 "constant");
277 }
278 }
279 return OperateWords(opcode, operand_values_in_raw_words);
280 }
281
FoldBinaryIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const282 bool InstructionFolder::FoldBinaryIntegerOpToConstant(
283 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
284 uint32_t* result) const {
285 spv::Op opcode = inst->opcode();
286 analysis::ConstantManager* const_manger = context_->get_constant_mgr();
287
288 uint32_t ids[2];
289 const analysis::IntConstant* constants[2];
290 for (uint32_t i = 0; i < 2; i++) {
291 const Operand* operand = &inst->GetInOperand(i);
292 if (operand->type != SPV_OPERAND_TYPE_ID) {
293 return false;
294 }
295 ids[i] = id_map(operand->words[0]);
296 const analysis::Constant* constant =
297 const_manger->FindDeclaredConstant(ids[i]);
298 constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
299 }
300
301 switch (opcode) {
302 // Arthimetics
303 case spv::Op::OpIMul:
304 for (uint32_t i = 0; i < 2; i++) {
305 if (constants[i] != nullptr && constants[i]->IsZero()) {
306 *result = 0;
307 return true;
308 }
309 }
310 break;
311 case spv::Op::OpUDiv:
312 case spv::Op::OpSDiv:
313 case spv::Op::OpSRem:
314 case spv::Op::OpSMod:
315 case spv::Op::OpUMod:
316 // This changes undefined behaviour (ie divide by 0) into a 0.
317 for (uint32_t i = 0; i < 2; i++) {
318 if (constants[i] != nullptr && constants[i]->IsZero()) {
319 *result = 0;
320 return true;
321 }
322 }
323 break;
324
325 // Shifting
326 case spv::Op::OpShiftRightLogical:
327 case spv::Op::OpShiftLeftLogical:
328 if (constants[1] != nullptr) {
329 // When shifting by a value larger than the size of the result, the
330 // result is undefined. We are setting the undefined behaviour to a
331 // result of 0. If the shift amount is the same as the size of the
332 // result, then the result is defined, and it 0.
333 uint32_t shift_amount = constants[1]->GetU32BitValue();
334 if (shift_amount >= 32) {
335 *result = 0;
336 return true;
337 }
338 }
339 break;
340
341 // Bitwise operations
342 case spv::Op::OpBitwiseOr:
343 for (uint32_t i = 0; i < 2; i++) {
344 if (constants[i] != nullptr) {
345 // TODO: Change the mask against a value based on the bit width of the
346 // instruction result type. This way we can handle say 16-bit values
347 // as well.
348 uint32_t mask = constants[i]->GetU32BitValue();
349 if (mask == 0xFFFFFFFF) {
350 *result = 0xFFFFFFFF;
351 return true;
352 }
353 }
354 }
355 break;
356 case spv::Op::OpBitwiseAnd:
357 for (uint32_t i = 0; i < 2; i++) {
358 if (constants[i] != nullptr) {
359 if (constants[i]->IsZero()) {
360 *result = 0;
361 return true;
362 }
363 }
364 }
365 break;
366
367 // Comparison
368 case spv::Op::OpULessThan:
369 if (constants[0] != nullptr &&
370 constants[0]->GetU32BitValue() == UINT32_MAX) {
371 *result = false;
372 return true;
373 }
374 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
375 *result = false;
376 return true;
377 }
378 break;
379 case spv::Op::OpSLessThan:
380 if (constants[0] != nullptr &&
381 constants[0]->GetS32BitValue() == INT32_MAX) {
382 *result = false;
383 return true;
384 }
385 if (constants[1] != nullptr &&
386 constants[1]->GetS32BitValue() == INT32_MIN) {
387 *result = false;
388 return true;
389 }
390 break;
391 case spv::Op::OpUGreaterThan:
392 if (constants[0] != nullptr && constants[0]->IsZero()) {
393 *result = false;
394 return true;
395 }
396 if (constants[1] != nullptr &&
397 constants[1]->GetU32BitValue() == UINT32_MAX) {
398 *result = false;
399 return true;
400 }
401 break;
402 case spv::Op::OpSGreaterThan:
403 if (constants[0] != nullptr &&
404 constants[0]->GetS32BitValue() == INT32_MIN) {
405 *result = false;
406 return true;
407 }
408 if (constants[1] != nullptr &&
409 constants[1]->GetS32BitValue() == INT32_MAX) {
410 *result = false;
411 return true;
412 }
413 break;
414 case spv::Op::OpULessThanEqual:
415 if (constants[0] != nullptr && constants[0]->IsZero()) {
416 *result = true;
417 return true;
418 }
419 if (constants[1] != nullptr &&
420 constants[1]->GetU32BitValue() == UINT32_MAX) {
421 *result = true;
422 return true;
423 }
424 break;
425 case spv::Op::OpSLessThanEqual:
426 if (constants[0] != nullptr &&
427 constants[0]->GetS32BitValue() == INT32_MIN) {
428 *result = true;
429 return true;
430 }
431 if (constants[1] != nullptr &&
432 constants[1]->GetS32BitValue() == INT32_MAX) {
433 *result = true;
434 return true;
435 }
436 break;
437 case spv::Op::OpUGreaterThanEqual:
438 if (constants[0] != nullptr &&
439 constants[0]->GetU32BitValue() == UINT32_MAX) {
440 *result = true;
441 return true;
442 }
443 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
444 *result = true;
445 return true;
446 }
447 break;
448 case spv::Op::OpSGreaterThanEqual:
449 if (constants[0] != nullptr &&
450 constants[0]->GetS32BitValue() == INT32_MAX) {
451 *result = true;
452 return true;
453 }
454 if (constants[1] != nullptr &&
455 constants[1]->GetS32BitValue() == INT32_MIN) {
456 *result = true;
457 return true;
458 }
459 break;
460 default:
461 break;
462 }
463 return false;
464 }
465
FoldBinaryBooleanOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const466 bool InstructionFolder::FoldBinaryBooleanOpToConstant(
467 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
468 uint32_t* result) const {
469 spv::Op opcode = inst->opcode();
470 analysis::ConstantManager* const_manger = context_->get_constant_mgr();
471
472 uint32_t ids[2];
473 const analysis::BoolConstant* constants[2];
474 for (uint32_t i = 0; i < 2; i++) {
475 const Operand* operand = &inst->GetInOperand(i);
476 if (operand->type != SPV_OPERAND_TYPE_ID) {
477 return false;
478 }
479 ids[i] = id_map(operand->words[0]);
480 const analysis::Constant* constant =
481 const_manger->FindDeclaredConstant(ids[i]);
482 constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
483 }
484
485 switch (opcode) {
486 // Logical
487 case spv::Op::OpLogicalOr:
488 for (uint32_t i = 0; i < 2; i++) {
489 if (constants[i] != nullptr) {
490 if (constants[i]->value()) {
491 *result = true;
492 return true;
493 }
494 }
495 }
496 break;
497 case spv::Op::OpLogicalAnd:
498 for (uint32_t i = 0; i < 2; i++) {
499 if (constants[i] != nullptr) {
500 if (!constants[i]->value()) {
501 *result = false;
502 return true;
503 }
504 }
505 }
506 break;
507
508 default:
509 break;
510 }
511 return false;
512 }
513
FoldIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const514 bool InstructionFolder::FoldIntegerOpToConstant(
515 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
516 uint32_t* result) const {
517 assert(IsFoldableOpcode(inst->opcode()) &&
518 "Unhandled instruction opcode in FoldScalars");
519 switch (inst->NumInOperands()) {
520 case 2:
521 return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
522 FoldBinaryBooleanOpToConstant(inst, id_map, result);
523 default:
524 return false;
525 }
526 }
527
FoldVectors(spv::Op opcode,uint32_t num_dims,const std::vector<const analysis::Constant * > & operands) const528 std::vector<uint32_t> InstructionFolder::FoldVectors(
529 spv::Op opcode, uint32_t num_dims,
530 const std::vector<const analysis::Constant*>& operands) const {
531 assert(IsFoldableOpcode(opcode) &&
532 "Unhandled instruction opcode in FoldVectors");
533 std::vector<uint32_t> result;
534 for (uint32_t d = 0; d < num_dims; d++) {
535 std::vector<uint32_t> operand_values_for_one_dimension;
536 for (const auto& operand : operands) {
537 if (const analysis::VectorConstant* vector_operand =
538 operand->AsVectorConstant()) {
539 // Extract the raw value of the scalar component constants
540 // in 32-bit words here. The reason of not using FoldScalars() here
541 // is that we do not create temporary null constants as components
542 // when the vector operand is a NullConstant because Constant creation
543 // may need extra checks for the validity and that is not managed in
544 // here.
545 if (const analysis::ScalarConstant* scalar_component =
546 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
547 const auto& scalar_words = scalar_component->words();
548 assert(
549 scalar_words.size() == 1 &&
550 "Vector components with longer than 32-bit width are not allowed "
551 "in FoldVectors()");
552 operand_values_for_one_dimension.push_back(scalar_words.front());
553 } else if (operand->AsNullConstant()) {
554 operand_values_for_one_dimension.push_back(0u);
555 } else {
556 assert(false &&
557 "VectorConst should only has ScalarConst or NullConst as "
558 "components");
559 }
560 } else if (operand->AsNullConstant()) {
561 operand_values_for_one_dimension.push_back(0u);
562 } else {
563 assert(false &&
564 "FoldVectors() only accepts VectorConst or NullConst type of "
565 "constant");
566 }
567 }
568 result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
569 }
570 return result;
571 }
572
IsFoldableOpcode(spv::Op opcode) const573 bool InstructionFolder::IsFoldableOpcode(spv::Op opcode) const {
574 // NOTE: Extend to more opcodes as new cases are handled in the folder
575 // functions.
576 switch (opcode) {
577 case spv::Op::OpBitwiseAnd:
578 case spv::Op::OpBitwiseOr:
579 case spv::Op::OpBitwiseXor:
580 case spv::Op::OpIAdd:
581 case spv::Op::OpIEqual:
582 case spv::Op::OpIMul:
583 case spv::Op::OpINotEqual:
584 case spv::Op::OpISub:
585 case spv::Op::OpLogicalAnd:
586 case spv::Op::OpLogicalEqual:
587 case spv::Op::OpLogicalNot:
588 case spv::Op::OpLogicalNotEqual:
589 case spv::Op::OpLogicalOr:
590 case spv::Op::OpNot:
591 case spv::Op::OpSDiv:
592 case spv::Op::OpSelect:
593 case spv::Op::OpSGreaterThan:
594 case spv::Op::OpSGreaterThanEqual:
595 case spv::Op::OpShiftLeftLogical:
596 case spv::Op::OpShiftRightArithmetic:
597 case spv::Op::OpShiftRightLogical:
598 case spv::Op::OpSLessThan:
599 case spv::Op::OpSLessThanEqual:
600 case spv::Op::OpSMod:
601 case spv::Op::OpSNegate:
602 case spv::Op::OpSRem:
603 case spv::Op::OpSConvert:
604 case spv::Op::OpUConvert:
605 case spv::Op::OpUDiv:
606 case spv::Op::OpUGreaterThan:
607 case spv::Op::OpUGreaterThanEqual:
608 case spv::Op::OpULessThan:
609 case spv::Op::OpULessThanEqual:
610 case spv::Op::OpUMod:
611 return true;
612 default:
613 return false;
614 }
615 }
616
IsFoldableConstant(const analysis::Constant * cst) const617 bool InstructionFolder::IsFoldableConstant(
618 const analysis::Constant* cst) const {
619 // Currently supported constants are 32-bit values or null constants.
620 if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
621 return scalar->words().size() == 1;
622 else
623 return cst->AsNullConstant() != nullptr;
624 }
625
FoldInstructionToConstant(Instruction * inst,std::function<uint32_t (uint32_t)> id_map) const626 Instruction* InstructionFolder::FoldInstructionToConstant(
627 Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
628 analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
629
630 if (!inst->IsFoldableByFoldScalar() && !inst->IsFoldableByFoldVector() &&
631 !GetConstantFoldingRules().HasFoldingRule(inst)) {
632 return nullptr;
633 }
634 // Collect the values of the constant parameters.
635 std::vector<const analysis::Constant*> constants;
636 bool missing_constants = false;
637 inst->ForEachInId([&constants, &missing_constants, const_mgr,
638 &id_map](uint32_t* op_id) {
639 uint32_t id = id_map(*op_id);
640 const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
641 if (!const_op) {
642 constants.push_back(nullptr);
643 missing_constants = true;
644 } else {
645 constants.push_back(const_op);
646 }
647 });
648
649 const analysis::Constant* folded_const = nullptr;
650 for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) {
651 folded_const = rule(context_, inst, constants);
652 if (folded_const != nullptr) {
653 Instruction* const_inst =
654 const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
655 if (const_inst == nullptr) {
656 return nullptr;
657 }
658 assert(const_inst->type_id() == inst->type_id());
659 // May be a new instruction that needs to be analysed.
660 context_->UpdateDefUse(const_inst);
661 return const_inst;
662 }
663 }
664
665 bool successful = false;
666
667 // If all parameters are constant, fold the instruction to a constant.
668 if (inst->IsFoldableByFoldScalar()) {
669 uint32_t result_val = 0;
670
671 if (!missing_constants) {
672 result_val = FoldScalars(inst->opcode(), constants);
673 successful = true;
674 }
675
676 if (!successful) {
677 successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
678 }
679
680 if (successful) {
681 const analysis::Constant* result_const =
682 const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
683 Instruction* folded_inst =
684 const_mgr->GetDefiningInstruction(result_const, inst->type_id());
685 return folded_inst;
686 }
687 } else if (inst->IsFoldableByFoldVector()) {
688 std::vector<uint32_t> result_val;
689
690 if (!missing_constants) {
691 if (Instruction* inst_type =
692 context_->get_def_use_mgr()->GetDef(inst->type_id())) {
693 result_val = FoldVectors(
694 inst->opcode(), inst_type->GetSingleWordInOperand(1), constants);
695 successful = true;
696 }
697 }
698
699 if (successful) {
700 const analysis::Constant* result_const =
701 const_mgr->GetNumericVectorConstantWithWords(
702 const_mgr->GetType(inst)->AsVector(), result_val);
703 Instruction* folded_inst =
704 const_mgr->GetDefiningInstruction(result_const, inst->type_id());
705 return folded_inst;
706 }
707 }
708
709 return nullptr;
710 }
711
IsFoldableType(Instruction * type_inst) const712 bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
713 return IsFoldableScalarType(type_inst) || IsFoldableVectorType(type_inst);
714 }
715
IsFoldableScalarType(Instruction * type_inst) const716 bool InstructionFolder::IsFoldableScalarType(Instruction* type_inst) const {
717 // Support 32-bit integers.
718 if (type_inst->opcode() == spv::Op::OpTypeInt) {
719 return type_inst->GetSingleWordInOperand(0) == 32;
720 }
721 // Support booleans.
722 if (type_inst->opcode() == spv::Op::OpTypeBool) {
723 return true;
724 }
725 // Nothing else yet.
726 return false;
727 }
728
IsFoldableVectorType(Instruction * type_inst) const729 bool InstructionFolder::IsFoldableVectorType(Instruction* type_inst) const {
730 // Support vectors with foldable components
731 if (type_inst->opcode() == spv::Op::OpTypeVector) {
732 uint32_t component_type_id = type_inst->GetSingleWordInOperand(0);
733 Instruction* def_component_type =
734 context_->get_def_use_mgr()->GetDef(component_type_id);
735 return def_component_type != nullptr &&
736 IsFoldableScalarType(def_component_type);
737 }
738 // Nothing else yet.
739 return false;
740 }
741
FoldInstruction(Instruction * inst) const742 bool InstructionFolder::FoldInstruction(Instruction* inst) const {
743 bool modified = false;
744 Instruction* folded_inst(inst);
745 while (folded_inst->opcode() != spv::Op::OpCopyObject &&
746 FoldInstructionInternal(&*folded_inst)) {
747 modified = true;
748 }
749 return modified;
750 }
751
752 } // namespace opt
753 } // namespace spvtools
754