1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/fold.h"
16
17 #include <cassert>
18 #include <cstdint>
19 #include <vector>
20
21 #include "source/opt/const_folding_rules.h"
22 #include "source/opt/def_use_manager.h"
23 #include "source/opt/folding_rules.h"
24 #include "source/opt/ir_context.h"
25
26 namespace spvtools {
27 namespace opt {
28 namespace {
29
30 #ifndef INT32_MIN
31 #define INT32_MIN (-2147483648)
32 #endif
33
34 #ifndef INT32_MAX
35 #define INT32_MAX 2147483647
36 #endif
37
38 #ifndef UINT32_MAX
39 #define UINT32_MAX 0xffffffff /* 4294967295U */
40 #endif
41
42 } // namespace
43
UnaryOperate(spv::Op opcode,uint32_t operand) const44 uint32_t InstructionFolder::UnaryOperate(spv::Op opcode,
45 uint32_t operand) const {
46 switch (opcode) {
47 // Arthimetics
48 case spv::Op::OpSNegate: {
49 int32_t s_operand = static_cast<int32_t>(operand);
50 if (s_operand == std::numeric_limits<int32_t>::min()) {
51 return s_operand;
52 }
53 return -s_operand;
54 }
55 case spv::Op::OpNot:
56 return ~operand;
57 case spv::Op::OpLogicalNot:
58 return !static_cast<bool>(operand);
59 case spv::Op::OpUConvert:
60 return operand;
61 case spv::Op::OpSConvert:
62 return operand;
63 default:
64 assert(false &&
65 "Unsupported unary operation for OpSpecConstantOp instruction");
66 return 0u;
67 }
68 }
69
BinaryOperate(spv::Op opcode,uint32_t a,uint32_t b) const70 uint32_t InstructionFolder::BinaryOperate(spv::Op opcode, uint32_t a,
71 uint32_t b) const {
72 switch (opcode) {
73 // Shifting
74 case spv::Op::OpShiftRightLogical:
75 if (b >= 32) {
76 // This is undefined behaviour when |b| > 32. Choose 0 for consistency.
77 // When |b| == 32, doing the shift in C++ in undefined, but the result
78 // will be 0, so just return that value.
79 return 0;
80 }
81 return a >> b;
82 case spv::Op::OpShiftRightArithmetic:
83 if (b > 32) {
84 // This is undefined behaviour. Choose 0 for consistency.
85 return 0;
86 }
87 if (b == 32) {
88 // Doing the shift in C++ is undefined, but the result is defined in the
89 // spir-v spec. Find that value another way.
90 if (static_cast<int32_t>(a) >= 0) {
91 return 0;
92 } else {
93 return static_cast<uint32_t>(-1);
94 }
95 }
96 return (static_cast<int32_t>(a)) >> b;
97 case spv::Op::OpShiftLeftLogical:
98 if (b >= 32) {
99 // This is undefined behaviour when |b| > 32. Choose 0 for consistency.
100 // When |b| == 32, doing the shift in C++ in undefined, but the result
101 // will be 0, so just return that value.
102 return 0;
103 }
104 return a << b;
105
106 // Bitwise operations
107 case spv::Op::OpBitwiseOr:
108 return a | b;
109 case spv::Op::OpBitwiseAnd:
110 return a & b;
111 case spv::Op::OpBitwiseXor:
112 return a ^ b;
113
114 // Logical
115 case spv::Op::OpLogicalEqual:
116 return (static_cast<bool>(a)) == (static_cast<bool>(b));
117 case spv::Op::OpLogicalNotEqual:
118 return (static_cast<bool>(a)) != (static_cast<bool>(b));
119 case spv::Op::OpLogicalOr:
120 return (static_cast<bool>(a)) || (static_cast<bool>(b));
121 case spv::Op::OpLogicalAnd:
122 return (static_cast<bool>(a)) && (static_cast<bool>(b));
123
124 // Comparison
125 case spv::Op::OpIEqual:
126 return a == b;
127 case spv::Op::OpINotEqual:
128 return a != b;
129 case spv::Op::OpULessThan:
130 return a < b;
131 case spv::Op::OpSLessThan:
132 return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
133 case spv::Op::OpUGreaterThan:
134 return a > b;
135 case spv::Op::OpSGreaterThan:
136 return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
137 case spv::Op::OpULessThanEqual:
138 return a <= b;
139 case spv::Op::OpSLessThanEqual:
140 return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
141 case spv::Op::OpUGreaterThanEqual:
142 return a >= b;
143 case spv::Op::OpSGreaterThanEqual:
144 return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
145 default:
146 assert(false &&
147 "Unsupported binary operation for OpSpecConstantOp instruction");
148 return 0u;
149 }
150 }
151
TernaryOperate(spv::Op opcode,uint32_t a,uint32_t b,uint32_t c) const152 uint32_t InstructionFolder::TernaryOperate(spv::Op opcode, uint32_t a,
153 uint32_t b, uint32_t c) const {
154 switch (opcode) {
155 case spv::Op::OpSelect:
156 return (static_cast<bool>(a)) ? b : c;
157 default:
158 assert(false &&
159 "Unsupported ternary operation for OpSpecConstantOp instruction");
160 return 0u;
161 }
162 }
163
OperateWords(spv::Op opcode,const std::vector<uint32_t> & operand_words) const164 uint32_t InstructionFolder::OperateWords(
165 spv::Op opcode, const std::vector<uint32_t>& operand_words) const {
166 switch (operand_words.size()) {
167 case 1:
168 return UnaryOperate(opcode, operand_words.front());
169 case 2:
170 return BinaryOperate(opcode, operand_words.front(), operand_words.back());
171 case 3:
172 return TernaryOperate(opcode, operand_words[0], operand_words[1],
173 operand_words[2]);
174 default:
175 assert(false && "Invalid number of operands");
176 return 0;
177 }
178 }
179
FoldInstructionInternal(Instruction * inst) const180 bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
181 auto identity_map = [](uint32_t id) { return id; };
182 Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
183 if (folded_inst != nullptr) {
184 inst->SetOpcode(spv::Op::OpCopyObject);
185 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
186 return true;
187 }
188
189 analysis::ConstantManager* const_manager = context_->get_constant_mgr();
190 std::vector<const analysis::Constant*> constants =
191 const_manager->GetOperandConstants(inst);
192
193 for (const FoldingRule& rule :
194 GetFoldingRules().GetRulesForInstruction(inst)) {
195 if (rule(context_, inst, constants)) {
196 return true;
197 }
198 }
199 return false;
200 }
201
202 // Returns the result of performing an operation on scalar constant operands.
203 // This function extracts the operand values as 32 bit words and returns the
204 // result in 32 bit word. Scalar constants with longer than 32-bit width are
205 // not accepted in this function.
FoldScalars(spv::Op opcode,const std::vector<const analysis::Constant * > & operands) const206 uint32_t InstructionFolder::FoldScalars(
207 spv::Op opcode,
208 const std::vector<const analysis::Constant*>& operands) const {
209 assert(IsFoldableOpcode(opcode) &&
210 "Unhandled instruction opcode in FoldScalars");
211 std::vector<uint32_t> operand_values_in_raw_words;
212 for (const auto& operand : operands) {
213 if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
214 const auto& scalar_words = scalar->words();
215 assert(scalar_words.size() == 1 &&
216 "Scalar constants with longer than 32-bit width are not allowed "
217 "in FoldScalars()");
218 operand_values_in_raw_words.push_back(scalar_words.front());
219 } else if (operand->AsNullConstant()) {
220 operand_values_in_raw_words.push_back(0u);
221 } else {
222 assert(false &&
223 "FoldScalars() only accepts ScalarConst or NullConst type of "
224 "constant");
225 }
226 }
227 return OperateWords(opcode, operand_values_in_raw_words);
228 }
229
FoldBinaryIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const230 bool InstructionFolder::FoldBinaryIntegerOpToConstant(
231 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
232 uint32_t* result) const {
233 spv::Op opcode = inst->opcode();
234 analysis::ConstantManager* const_manger = context_->get_constant_mgr();
235
236 uint32_t ids[2];
237 const analysis::IntConstant* constants[2];
238 for (uint32_t i = 0; i < 2; i++) {
239 const Operand* operand = &inst->GetInOperand(i);
240 if (operand->type != SPV_OPERAND_TYPE_ID) {
241 return false;
242 }
243 ids[i] = id_map(operand->words[0]);
244 const analysis::Constant* constant =
245 const_manger->FindDeclaredConstant(ids[i]);
246 constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
247 }
248
249 switch (opcode) {
250 // Arthimetics
251 case spv::Op::OpIMul:
252 for (uint32_t i = 0; i < 2; i++) {
253 if (constants[i] != nullptr && constants[i]->IsZero()) {
254 *result = 0;
255 return true;
256 }
257 }
258 break;
259 case spv::Op::OpUDiv:
260 case spv::Op::OpSDiv:
261 case spv::Op::OpSRem:
262 case spv::Op::OpSMod:
263 case spv::Op::OpUMod:
264 // This changes undefined behaviour (ie divide by 0) into a 0.
265 for (uint32_t i = 0; i < 2; i++) {
266 if (constants[i] != nullptr && constants[i]->IsZero()) {
267 *result = 0;
268 return true;
269 }
270 }
271 break;
272
273 // Shifting
274 case spv::Op::OpShiftRightLogical:
275 case spv::Op::OpShiftLeftLogical:
276 if (constants[1] != nullptr) {
277 // When shifting by a value larger than the size of the result, the
278 // result is undefined. We are setting the undefined behaviour to a
279 // result of 0. If the shift amount is the same as the size of the
280 // result, then the result is defined, and it 0.
281 uint32_t shift_amount = constants[1]->GetU32BitValue();
282 if (shift_amount >= 32) {
283 *result = 0;
284 return true;
285 }
286 }
287 break;
288
289 // Bitwise operations
290 case spv::Op::OpBitwiseOr:
291 for (uint32_t i = 0; i < 2; i++) {
292 if (constants[i] != nullptr) {
293 // TODO: Change the mask against a value based on the bit width of the
294 // instruction result type. This way we can handle say 16-bit values
295 // as well.
296 uint32_t mask = constants[i]->GetU32BitValue();
297 if (mask == 0xFFFFFFFF) {
298 *result = 0xFFFFFFFF;
299 return true;
300 }
301 }
302 }
303 break;
304 case spv::Op::OpBitwiseAnd:
305 for (uint32_t i = 0; i < 2; i++) {
306 if (constants[i] != nullptr) {
307 if (constants[i]->IsZero()) {
308 *result = 0;
309 return true;
310 }
311 }
312 }
313 break;
314
315 // Comparison
316 case spv::Op::OpULessThan:
317 if (constants[0] != nullptr &&
318 constants[0]->GetU32BitValue() == UINT32_MAX) {
319 *result = false;
320 return true;
321 }
322 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
323 *result = false;
324 return true;
325 }
326 break;
327 case spv::Op::OpSLessThan:
328 if (constants[0] != nullptr &&
329 constants[0]->GetS32BitValue() == INT32_MAX) {
330 *result = false;
331 return true;
332 }
333 if (constants[1] != nullptr &&
334 constants[1]->GetS32BitValue() == INT32_MIN) {
335 *result = false;
336 return true;
337 }
338 break;
339 case spv::Op::OpUGreaterThan:
340 if (constants[0] != nullptr && constants[0]->IsZero()) {
341 *result = false;
342 return true;
343 }
344 if (constants[1] != nullptr &&
345 constants[1]->GetU32BitValue() == UINT32_MAX) {
346 *result = false;
347 return true;
348 }
349 break;
350 case spv::Op::OpSGreaterThan:
351 if (constants[0] != nullptr &&
352 constants[0]->GetS32BitValue() == INT32_MIN) {
353 *result = false;
354 return true;
355 }
356 if (constants[1] != nullptr &&
357 constants[1]->GetS32BitValue() == INT32_MAX) {
358 *result = false;
359 return true;
360 }
361 break;
362 case spv::Op::OpULessThanEqual:
363 if (constants[0] != nullptr && constants[0]->IsZero()) {
364 *result = true;
365 return true;
366 }
367 if (constants[1] != nullptr &&
368 constants[1]->GetU32BitValue() == UINT32_MAX) {
369 *result = true;
370 return true;
371 }
372 break;
373 case spv::Op::OpSLessThanEqual:
374 if (constants[0] != nullptr &&
375 constants[0]->GetS32BitValue() == INT32_MIN) {
376 *result = true;
377 return true;
378 }
379 if (constants[1] != nullptr &&
380 constants[1]->GetS32BitValue() == INT32_MAX) {
381 *result = true;
382 return true;
383 }
384 break;
385 case spv::Op::OpUGreaterThanEqual:
386 if (constants[0] != nullptr &&
387 constants[0]->GetU32BitValue() == UINT32_MAX) {
388 *result = true;
389 return true;
390 }
391 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
392 *result = true;
393 return true;
394 }
395 break;
396 case spv::Op::OpSGreaterThanEqual:
397 if (constants[0] != nullptr &&
398 constants[0]->GetS32BitValue() == INT32_MAX) {
399 *result = true;
400 return true;
401 }
402 if (constants[1] != nullptr &&
403 constants[1]->GetS32BitValue() == INT32_MIN) {
404 *result = true;
405 return true;
406 }
407 break;
408 default:
409 break;
410 }
411 return false;
412 }
413
FoldBinaryBooleanOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const414 bool InstructionFolder::FoldBinaryBooleanOpToConstant(
415 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
416 uint32_t* result) const {
417 spv::Op opcode = inst->opcode();
418 analysis::ConstantManager* const_manger = context_->get_constant_mgr();
419
420 uint32_t ids[2];
421 const analysis::BoolConstant* constants[2];
422 for (uint32_t i = 0; i < 2; i++) {
423 const Operand* operand = &inst->GetInOperand(i);
424 if (operand->type != SPV_OPERAND_TYPE_ID) {
425 return false;
426 }
427 ids[i] = id_map(operand->words[0]);
428 const analysis::Constant* constant =
429 const_manger->FindDeclaredConstant(ids[i]);
430 constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
431 }
432
433 switch (opcode) {
434 // Logical
435 case spv::Op::OpLogicalOr:
436 for (uint32_t i = 0; i < 2; i++) {
437 if (constants[i] != nullptr) {
438 if (constants[i]->value()) {
439 *result = true;
440 return true;
441 }
442 }
443 }
444 break;
445 case spv::Op::OpLogicalAnd:
446 for (uint32_t i = 0; i < 2; i++) {
447 if (constants[i] != nullptr) {
448 if (!constants[i]->value()) {
449 *result = false;
450 return true;
451 }
452 }
453 }
454 break;
455
456 default:
457 break;
458 }
459 return false;
460 }
461
FoldIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const462 bool InstructionFolder::FoldIntegerOpToConstant(
463 Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
464 uint32_t* result) const {
465 assert(IsFoldableOpcode(inst->opcode()) &&
466 "Unhandled instruction opcode in FoldScalars");
467 switch (inst->NumInOperands()) {
468 case 2:
469 return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
470 FoldBinaryBooleanOpToConstant(inst, id_map, result);
471 default:
472 return false;
473 }
474 }
475
FoldVectors(spv::Op opcode,uint32_t num_dims,const std::vector<const analysis::Constant * > & operands) const476 std::vector<uint32_t> InstructionFolder::FoldVectors(
477 spv::Op opcode, uint32_t num_dims,
478 const std::vector<const analysis::Constant*>& operands) const {
479 assert(IsFoldableOpcode(opcode) &&
480 "Unhandled instruction opcode in FoldVectors");
481 std::vector<uint32_t> result;
482 for (uint32_t d = 0; d < num_dims; d++) {
483 std::vector<uint32_t> operand_values_for_one_dimension;
484 for (const auto& operand : operands) {
485 if (const analysis::VectorConstant* vector_operand =
486 operand->AsVectorConstant()) {
487 // Extract the raw value of the scalar component constants
488 // in 32-bit words here. The reason of not using FoldScalars() here
489 // is that we do not create temporary null constants as components
490 // when the vector operand is a NullConstant because Constant creation
491 // may need extra checks for the validity and that is not managed in
492 // here.
493 if (const analysis::ScalarConstant* scalar_component =
494 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
495 const auto& scalar_words = scalar_component->words();
496 assert(
497 scalar_words.size() == 1 &&
498 "Vector components with longer than 32-bit width are not allowed "
499 "in FoldVectors()");
500 operand_values_for_one_dimension.push_back(scalar_words.front());
501 } else if (operand->AsNullConstant()) {
502 operand_values_for_one_dimension.push_back(0u);
503 } else {
504 assert(false &&
505 "VectorConst should only has ScalarConst or NullConst as "
506 "components");
507 }
508 } else if (operand->AsNullConstant()) {
509 operand_values_for_one_dimension.push_back(0u);
510 } else {
511 assert(false &&
512 "FoldVectors() only accepts VectorConst or NullConst type of "
513 "constant");
514 }
515 }
516 result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
517 }
518 return result;
519 }
520
IsFoldableOpcode(spv::Op opcode) const521 bool InstructionFolder::IsFoldableOpcode(spv::Op opcode) const {
522 // NOTE: Extend to more opcodes as new cases are handled in the folder
523 // functions.
524 switch (opcode) {
525 case spv::Op::OpBitwiseAnd:
526 case spv::Op::OpBitwiseOr:
527 case spv::Op::OpBitwiseXor:
528 case spv::Op::OpIAdd:
529 case spv::Op::OpIEqual:
530 case spv::Op::OpIMul:
531 case spv::Op::OpINotEqual:
532 case spv::Op::OpISub:
533 case spv::Op::OpLogicalAnd:
534 case spv::Op::OpLogicalEqual:
535 case spv::Op::OpLogicalNot:
536 case spv::Op::OpLogicalNotEqual:
537 case spv::Op::OpLogicalOr:
538 case spv::Op::OpNot:
539 case spv::Op::OpSDiv:
540 case spv::Op::OpSelect:
541 case spv::Op::OpSGreaterThan:
542 case spv::Op::OpSGreaterThanEqual:
543 case spv::Op::OpShiftLeftLogical:
544 case spv::Op::OpShiftRightArithmetic:
545 case spv::Op::OpShiftRightLogical:
546 case spv::Op::OpSLessThan:
547 case spv::Op::OpSLessThanEqual:
548 case spv::Op::OpSMod:
549 case spv::Op::OpSNegate:
550 case spv::Op::OpSRem:
551 case spv::Op::OpSConvert:
552 case spv::Op::OpUConvert:
553 case spv::Op::OpUDiv:
554 case spv::Op::OpUGreaterThan:
555 case spv::Op::OpUGreaterThanEqual:
556 case spv::Op::OpULessThan:
557 case spv::Op::OpULessThanEqual:
558 case spv::Op::OpUMod:
559 return true;
560 default:
561 return false;
562 }
563 }
564
IsFoldableConstant(const analysis::Constant * cst) const565 bool InstructionFolder::IsFoldableConstant(
566 const analysis::Constant* cst) const {
567 // Currently supported constants are 32-bit values or null constants.
568 if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
569 return scalar->words().size() == 1;
570 else
571 return cst->AsNullConstant() != nullptr;
572 }
573
FoldInstructionToConstant(Instruction * inst,std::function<uint32_t (uint32_t)> id_map) const574 Instruction* InstructionFolder::FoldInstructionToConstant(
575 Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
576 analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
577
578 if (!inst->IsFoldableByFoldScalar() && !inst->IsFoldableByFoldVector() &&
579 !GetConstantFoldingRules().HasFoldingRule(inst)) {
580 return nullptr;
581 }
582 // Collect the values of the constant parameters.
583 std::vector<const analysis::Constant*> constants;
584 bool missing_constants = false;
585 inst->ForEachInId([&constants, &missing_constants, const_mgr,
586 &id_map](uint32_t* op_id) {
587 uint32_t id = id_map(*op_id);
588 const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
589 if (!const_op) {
590 constants.push_back(nullptr);
591 missing_constants = true;
592 } else {
593 constants.push_back(const_op);
594 }
595 });
596
597 const analysis::Constant* folded_const = nullptr;
598 for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) {
599 folded_const = rule(context_, inst, constants);
600 if (folded_const != nullptr) {
601 Instruction* const_inst =
602 const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
603 if (const_inst == nullptr) {
604 return nullptr;
605 }
606 assert(const_inst->type_id() == inst->type_id());
607 // May be a new instruction that needs to be analysed.
608 context_->UpdateDefUse(const_inst);
609 return const_inst;
610 }
611 }
612
613 bool successful = false;
614
615 // If all parameters are constant, fold the instruction to a constant.
616 if (inst->IsFoldableByFoldScalar()) {
617 uint32_t result_val = 0;
618
619 if (!missing_constants) {
620 result_val = FoldScalars(inst->opcode(), constants);
621 successful = true;
622 }
623
624 if (!successful) {
625 successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
626 }
627
628 if (successful) {
629 const analysis::Constant* result_const =
630 const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
631 Instruction* folded_inst =
632 const_mgr->GetDefiningInstruction(result_const, inst->type_id());
633 return folded_inst;
634 }
635 } else if (inst->IsFoldableByFoldVector()) {
636 std::vector<uint32_t> result_val;
637
638 if (!missing_constants) {
639 if (Instruction* inst_type =
640 context_->get_def_use_mgr()->GetDef(inst->type_id())) {
641 result_val = FoldVectors(
642 inst->opcode(), inst_type->GetSingleWordInOperand(1), constants);
643 successful = true;
644 }
645 }
646
647 if (successful) {
648 const analysis::Constant* result_const =
649 const_mgr->GetNumericVectorConstantWithWords(
650 const_mgr->GetType(inst)->AsVector(), result_val);
651 Instruction* folded_inst =
652 const_mgr->GetDefiningInstruction(result_const, inst->type_id());
653 return folded_inst;
654 }
655 }
656
657 return nullptr;
658 }
659
IsFoldableType(Instruction * type_inst) const660 bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
661 return IsFoldableScalarType(type_inst) || IsFoldableVectorType(type_inst);
662 }
663
IsFoldableScalarType(Instruction * type_inst) const664 bool InstructionFolder::IsFoldableScalarType(Instruction* type_inst) const {
665 // Support 32-bit integers.
666 if (type_inst->opcode() == spv::Op::OpTypeInt) {
667 return type_inst->GetSingleWordInOperand(0) == 32;
668 }
669 // Support booleans.
670 if (type_inst->opcode() == spv::Op::OpTypeBool) {
671 return true;
672 }
673 // Nothing else yet.
674 return false;
675 }
676
IsFoldableVectorType(Instruction * type_inst) const677 bool InstructionFolder::IsFoldableVectorType(Instruction* type_inst) const {
678 // Support vectors with foldable components
679 if (type_inst->opcode() == spv::Op::OpTypeVector) {
680 uint32_t component_type_id = type_inst->GetSingleWordInOperand(0);
681 Instruction* def_component_type =
682 context_->get_def_use_mgr()->GetDef(component_type_id);
683 return def_component_type != nullptr &&
684 IsFoldableScalarType(def_component_type);
685 }
686 // Nothing else yet.
687 return false;
688 }
689
FoldInstruction(Instruction * inst) const690 bool InstructionFolder::FoldInstruction(Instruction* inst) const {
691 bool modified = false;
692 Instruction* folded_inst(inst);
693 while (folded_inst->opcode() != spv::Op::OpCopyObject &&
694 FoldInstructionInternal(&*folded_inst)) {
695 modified = true;
696 }
697 return modified;
698 }
699
700 } // namespace opt
701 } // namespace spvtools
702