1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/folding_rules.h"
16
17 #include <limits>
18 #include <memory>
19 #include <utility>
20
21 #include "source/latest_version_glsl_std_450_header.h"
22 #include "source/opt/ir_context.h"
23
24 namespace spvtools {
25 namespace opt {
26 namespace {
27
28 const uint32_t kExtractCompositeIdInIdx = 0;
29 const uint32_t kInsertObjectIdInIdx = 0;
30 const uint32_t kInsertCompositeIdInIdx = 1;
31 const uint32_t kExtInstSetIdInIdx = 0;
32 const uint32_t kExtInstInstructionInIdx = 1;
33 const uint32_t kFMixXIdInIdx = 2;
34 const uint32_t kFMixYIdInIdx = 3;
35 const uint32_t kFMixAIdInIdx = 4;
36 const uint32_t kStoreObjectInIdx = 1;
37
38 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)39 uint32_t ElementWidth(const analysis::Type* type) {
40 if (const analysis::Vector* vec_type = type->AsVector()) {
41 return ElementWidth(vec_type->element_type());
42 } else if (const analysis::Float* float_type = type->AsFloat()) {
43 return float_type->width();
44 } else {
45 assert(type->AsInteger());
46 return type->AsInteger()->width();
47 }
48 }
49
50 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)51 bool HasFloatingPoint(const analysis::Type* type) {
52 if (type->AsFloat()) {
53 return true;
54 } else if (const analysis::Vector* vec_type = type->AsVector()) {
55 return vec_type->element_type()->AsFloat() != nullptr;
56 }
57
58 return false;
59 }
60
61 // Returns false if |val| is NaN, infinite or subnormal.
62 template <typename T>
IsValidResult(T val)63 bool IsValidResult(T val) {
64 int classified = std::fpclassify(val);
65 switch (classified) {
66 case FP_NAN:
67 case FP_INFINITE:
68 case FP_SUBNORMAL:
69 return false;
70 default:
71 return true;
72 }
73 }
74
ConstInput(const std::vector<const analysis::Constant * > & constants)75 const analysis::Constant* ConstInput(
76 const std::vector<const analysis::Constant*>& constants) {
77 return constants[0] ? constants[0] : constants[1];
78 }
79
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)80 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
81 Instruction* inst) {
82 uint32_t in_op = c ? 1u : 0u;
83 return context->get_def_use_mgr()->GetDef(
84 inst->GetSingleWordInOperand(in_op));
85 }
86
87 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
88 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)89 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
90 const analysis::Constant* c) {
91 assert(c);
92 assert(c->type()->AsFloat());
93 uint32_t width = c->type()->AsFloat()->width();
94 assert(width == 32 || width == 64);
95 std::vector<uint32_t> words;
96 if (width == 64) {
97 utils::FloatProxy<double> result(c->GetDouble() * -1.0);
98 words = result.GetWords();
99 } else {
100 utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
101 words = result.GetWords();
102 }
103
104 const analysis::Constant* negated_const =
105 const_mgr->GetConstant(c->type(), std::move(words));
106 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
107 }
108
ExtractInts(uint64_t val)109 std::vector<uint32_t> ExtractInts(uint64_t val) {
110 std::vector<uint32_t> words;
111 words.push_back(static_cast<uint32_t>(val));
112 words.push_back(static_cast<uint32_t>(val >> 32));
113 return words;
114 }
115
116 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)117 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
118 const analysis::Constant* c) {
119 assert(c);
120 assert(c->type()->AsInteger());
121 uint32_t width = c->type()->AsInteger()->width();
122 assert(width == 32 || width == 64);
123 std::vector<uint32_t> words;
124 if (width == 64) {
125 uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
126 words = ExtractInts(uval);
127 } else {
128 words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
129 }
130
131 const analysis::Constant* negated_const =
132 const_mgr->GetConstant(c->type(), std::move(words));
133 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
134 }
135
136 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)137 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
138 const analysis::Constant* c) {
139 assert(const_mgr && c);
140 assert(c->type()->AsVector());
141 if (c->AsNullConstant()) {
142 // 0.0 vs -0.0 shouldn't matter.
143 return const_mgr->GetDefiningInstruction(c)->result_id();
144 } else {
145 const analysis::Type* component_type =
146 c->AsVectorConstant()->component_type();
147 std::vector<uint32_t> words;
148 for (auto& comp : c->AsVectorConstant()->GetComponents()) {
149 if (component_type->AsFloat()) {
150 words.push_back(NegateFloatingPointConstant(const_mgr, comp));
151 } else {
152 assert(component_type->AsInteger());
153 words.push_back(NegateIntegerConstant(const_mgr, comp));
154 }
155 }
156
157 const analysis::Constant* negated_const =
158 const_mgr->GetConstant(c->type(), std::move(words));
159 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
160 }
161 }
162
163 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)164 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
165 const analysis::Constant* c) {
166 if (c->type()->AsVector()) {
167 return NegateVectorConstant(const_mgr, c);
168 } else if (c->type()->AsFloat()) {
169 return NegateFloatingPointConstant(const_mgr, c);
170 } else {
171 assert(c->type()->AsInteger());
172 return NegateIntegerConstant(const_mgr, c);
173 }
174 }
175
176 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
177 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)178 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
179 const analysis::Constant* c) {
180 assert(const_mgr && c);
181 assert(c->type()->AsFloat());
182
183 uint32_t width = c->type()->AsFloat()->width();
184 assert(width == 32 || width == 64);
185 std::vector<uint32_t> words;
186 if (width == 64) {
187 spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
188 if (!IsValidResult(result.getAsFloat())) return 0;
189 words = result.GetWords();
190 } else {
191 spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
192 if (!IsValidResult(result.getAsFloat())) return 0;
193 words = result.GetWords();
194 }
195
196 const analysis::Constant* negated_const =
197 const_mgr->GetConstant(c->type(), std::move(words));
198 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
199 }
200
201 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()202 FoldingRule ReciprocalFDiv() {
203 return [](IRContext* context, Instruction* inst,
204 const std::vector<const analysis::Constant*>& constants) {
205 assert(inst->opcode() == SpvOpFDiv);
206 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
207 const analysis::Type* type =
208 context->get_type_mgr()->GetType(inst->type_id());
209 if (!inst->IsFloatingPointFoldingAllowed()) return false;
210
211 uint32_t width = ElementWidth(type);
212 if (width != 32 && width != 64) return false;
213
214 if (constants[1] != nullptr) {
215 uint32_t id = 0;
216 if (const analysis::VectorConstant* vector_const =
217 constants[1]->AsVectorConstant()) {
218 std::vector<uint32_t> neg_ids;
219 for (auto& comp : vector_const->GetComponents()) {
220 id = Reciprocal(const_mgr, comp);
221 if (id == 0) return false;
222 neg_ids.push_back(id);
223 }
224 const analysis::Constant* negated_const =
225 const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
226 id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
227 } else if (constants[1]->AsFloatConstant()) {
228 id = Reciprocal(const_mgr, constants[1]);
229 if (id == 0) return false;
230 } else {
231 // Don't fold a null constant.
232 return false;
233 }
234 inst->SetOpcode(SpvOpFMul);
235 inst->SetInOperands(
236 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
237 {SPV_OPERAND_TYPE_ID, {id}}});
238 return true;
239 }
240
241 return false;
242 };
243 }
244
245 // Elides consecutive negate instructions.
MergeNegateArithmetic()246 FoldingRule MergeNegateArithmetic() {
247 return [](IRContext* context, Instruction* inst,
248 const std::vector<const analysis::Constant*>& constants) {
249 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
250 (void)constants;
251 const analysis::Type* type =
252 context->get_type_mgr()->GetType(inst->type_id());
253 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
254 return false;
255
256 Instruction* op_inst =
257 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
258 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
259 return false;
260
261 if (op_inst->opcode() == inst->opcode()) {
262 // Elide negates.
263 inst->SetOpcode(SpvOpCopyObject);
264 inst->SetInOperands(
265 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
266 return true;
267 }
268
269 return false;
270 };
271 }
272
273 // Merges negate into a mul or div operation if that operation contains a
274 // constant operand.
275 // Cases:
276 // -(x * 2) = x * -2
277 // -(2 * x) = x * -2
278 // -(x / 2) = x / -2
279 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()280 FoldingRule MergeNegateMulDivArithmetic() {
281 return [](IRContext* context, Instruction* inst,
282 const std::vector<const analysis::Constant*>& constants) {
283 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
284 (void)constants;
285 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
286 const analysis::Type* type =
287 context->get_type_mgr()->GetType(inst->type_id());
288 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
289 return false;
290
291 Instruction* op_inst =
292 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
293 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
294 return false;
295
296 uint32_t width = ElementWidth(type);
297 if (width != 32 && width != 64) return false;
298
299 SpvOp opcode = op_inst->opcode();
300 if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
301 opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
302 std::vector<const analysis::Constant*> op_constants =
303 const_mgr->GetOperandConstants(op_inst);
304 // Merge negate into mul or div if one operand is constant.
305 if (op_constants[0] || op_constants[1]) {
306 bool zero_is_variable = op_constants[0] == nullptr;
307 const analysis::Constant* c = ConstInput(op_constants);
308 uint32_t neg_id = NegateConstant(const_mgr, c);
309 uint32_t non_const_id = zero_is_variable
310 ? op_inst->GetSingleWordInOperand(0u)
311 : op_inst->GetSingleWordInOperand(1u);
312 // Change this instruction to a mul/div.
313 inst->SetOpcode(op_inst->opcode());
314 if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
315 uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
316 uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
317 inst->SetInOperands(
318 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
319 } else {
320 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
321 {SPV_OPERAND_TYPE_ID, {neg_id}}});
322 }
323 return true;
324 }
325 }
326
327 return false;
328 };
329 }
330
331 // Merges negate into a add or sub operation if that operation contains a
332 // constant operand.
333 // Cases:
334 // -(x + 2) = -2 - x
335 // -(2 + x) = -2 - x
336 // -(x - 2) = 2 - x
337 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()338 FoldingRule MergeNegateAddSubArithmetic() {
339 return [](IRContext* context, Instruction* inst,
340 const std::vector<const analysis::Constant*>& constants) {
341 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
342 (void)constants;
343 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
344 const analysis::Type* type =
345 context->get_type_mgr()->GetType(inst->type_id());
346 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
347 return false;
348
349 Instruction* op_inst =
350 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
351 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
352 return false;
353
354 uint32_t width = ElementWidth(type);
355 if (width != 32 && width != 64) return false;
356
357 if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
358 op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
359 std::vector<const analysis::Constant*> op_constants =
360 const_mgr->GetOperandConstants(op_inst);
361 if (op_constants[0] || op_constants[1]) {
362 bool zero_is_variable = op_constants[0] == nullptr;
363 bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
364 (op_inst->opcode() == SpvOpIAdd);
365 bool swap_operands = !is_add || zero_is_variable;
366 bool negate_const = is_add;
367 const analysis::Constant* c = ConstInput(op_constants);
368 uint32_t const_id = 0;
369 if (negate_const) {
370 const_id = NegateConstant(const_mgr, c);
371 } else {
372 const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
373 : op_inst->GetSingleWordInOperand(0u);
374 }
375
376 // Swap operands if necessary and make the instruction a subtraction.
377 uint32_t op0 =
378 zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
379 uint32_t op1 =
380 zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
381 if (swap_operands) std::swap(op0, op1);
382 inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
383 inst->SetInOperands(
384 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
385 return true;
386 }
387 }
388
389 return false;
390 };
391 }
392
393 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)394 bool HasZero(const analysis::Constant* c) {
395 if (c->AsNullConstant()) {
396 return true;
397 }
398 if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
399 for (auto& comp : vec_const->GetComponents())
400 if (HasZero(comp)) return true;
401 } else {
402 assert(c->AsScalarConstant());
403 return c->AsScalarConstant()->IsZero();
404 }
405
406 return false;
407 }
408
409 // Performs |input1| |opcode| |input2| and returns the merged constant result
410 // id. Returns 0 if the result is not a valid value. The input types must be
411 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)412 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
413 SpvOp opcode,
414 const analysis::Constant* input1,
415 const analysis::Constant* input2) {
416 const analysis::Type* type = input1->type();
417 assert(type->AsFloat());
418 uint32_t width = type->AsFloat()->width();
419 assert(width == 32 || width == 64);
420 std::vector<uint32_t> words;
421 #define FOLD_OP(op) \
422 if (width == 64) { \
423 utils::FloatProxy<double> val = \
424 input1->GetDouble() op input2->GetDouble(); \
425 double dval = val.getAsFloat(); \
426 if (!IsValidResult(dval)) return 0; \
427 words = val.GetWords(); \
428 } else { \
429 utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
430 float fval = val.getAsFloat(); \
431 if (!IsValidResult(fval)) return 0; \
432 words = val.GetWords(); \
433 }
434 switch (opcode) {
435 case SpvOpFMul:
436 FOLD_OP(*);
437 break;
438 case SpvOpFDiv:
439 if (HasZero(input2)) return 0;
440 FOLD_OP(/);
441 break;
442 case SpvOpFAdd:
443 FOLD_OP(+);
444 break;
445 case SpvOpFSub:
446 FOLD_OP(-);
447 break;
448 default:
449 assert(false && "Unexpected operation");
450 break;
451 }
452 #undef FOLD_OP
453 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
454 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
455 }
456
457 // Performs |input1| |opcode| |input2| and returns the merged constant result
458 // id. Returns 0 if the result is not a valid value. The input types must be
459 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)460 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
461 SpvOp opcode, const analysis::Constant* input1,
462 const analysis::Constant* input2) {
463 assert(input1->type()->AsInteger());
464 const analysis::Integer* type = input1->type()->AsInteger();
465 uint32_t width = type->AsInteger()->width();
466 assert(width == 32 || width == 64);
467 std::vector<uint32_t> words;
468 #define FOLD_OP(op) \
469 if (width == 64) { \
470 if (type->IsSigned()) { \
471 int64_t val = input1->GetS64() op input2->GetS64(); \
472 words = ExtractInts(static_cast<uint64_t>(val)); \
473 } else { \
474 uint64_t val = input1->GetU64() op input2->GetU64(); \
475 words = ExtractInts(val); \
476 } \
477 } else { \
478 if (type->IsSigned()) { \
479 int32_t val = input1->GetS32() op input2->GetS32(); \
480 words.push_back(static_cast<uint32_t>(val)); \
481 } else { \
482 uint32_t val = input1->GetU32() op input2->GetU32(); \
483 words.push_back(val); \
484 } \
485 }
486 switch (opcode) {
487 case SpvOpIMul:
488 FOLD_OP(*);
489 break;
490 case SpvOpSDiv:
491 case SpvOpUDiv:
492 assert(false && "Should not merge integer division");
493 break;
494 case SpvOpIAdd:
495 FOLD_OP(+);
496 break;
497 case SpvOpISub:
498 FOLD_OP(-);
499 break;
500 default:
501 assert(false && "Unexpected operation");
502 break;
503 }
504 #undef FOLD_OP
505 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
506 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
507 }
508
509 // Performs |input1| |opcode| |input2| and returns the merged constant result
510 // id. Returns 0 if the result is not a valid value. The input types must be
511 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)512 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
513 const analysis::Constant* input1,
514 const analysis::Constant* input2) {
515 assert(input1 && input2);
516 assert(input1->type() == input2->type());
517 const analysis::Type* type = input1->type();
518 std::vector<uint32_t> words;
519 if (const analysis::Vector* vector_type = type->AsVector()) {
520 const analysis::Type* ele_type = vector_type->element_type();
521 for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
522 uint32_t id = 0;
523
524 const analysis::Constant* input1_comp = nullptr;
525 if (const analysis::VectorConstant* input1_vector =
526 input1->AsVectorConstant()) {
527 input1_comp = input1_vector->GetComponents()[i];
528 } else {
529 assert(input1->AsNullConstant());
530 input1_comp = const_mgr->GetConstant(ele_type, {});
531 }
532
533 const analysis::Constant* input2_comp = nullptr;
534 if (const analysis::VectorConstant* input2_vector =
535 input2->AsVectorConstant()) {
536 input2_comp = input2_vector->GetComponents()[i];
537 } else {
538 assert(input2->AsNullConstant());
539 input2_comp = const_mgr->GetConstant(ele_type, {});
540 }
541
542 if (ele_type->AsFloat()) {
543 id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
544 input2_comp);
545 } else {
546 assert(ele_type->AsInteger());
547 id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
548 input2_comp);
549 }
550 if (id == 0) return 0;
551 words.push_back(id);
552 }
553 const analysis::Constant* merged_const =
554 const_mgr->GetConstant(type, words);
555 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
556 } else if (type->AsFloat()) {
557 return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
558 } else {
559 assert(type->AsInteger());
560 return PerformIntegerOperation(const_mgr, opcode, input1, input2);
561 }
562 }
563
564 // Merges consecutive multiplies where each contains one constant operand.
565 // Cases:
566 // 2 * (x * 2) = x * 4
567 // 2 * (2 * x) = x * 4
568 // (x * 2) * 2 = x * 4
569 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()570 FoldingRule MergeMulMulArithmetic() {
571 return [](IRContext* context, Instruction* inst,
572 const std::vector<const analysis::Constant*>& constants) {
573 assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
574 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
575 const analysis::Type* type =
576 context->get_type_mgr()->GetType(inst->type_id());
577 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
578 return false;
579
580 uint32_t width = ElementWidth(type);
581 if (width != 32 && width != 64) return false;
582
583 // Determine the constant input and the variable input in |inst|.
584 const analysis::Constant* const_input1 = ConstInput(constants);
585 if (!const_input1) return false;
586 Instruction* other_inst = NonConstInput(context, constants[0], inst);
587 if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
588 return false;
589
590 if (other_inst->opcode() == inst->opcode()) {
591 std::vector<const analysis::Constant*> other_constants =
592 const_mgr->GetOperandConstants(other_inst);
593 const analysis::Constant* const_input2 = ConstInput(other_constants);
594 if (!const_input2) return false;
595
596 bool other_first_is_variable = other_constants[0] == nullptr;
597 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
598 const_input1, const_input2);
599 if (merged_id == 0) return false;
600
601 uint32_t non_const_id = other_first_is_variable
602 ? other_inst->GetSingleWordInOperand(0u)
603 : other_inst->GetSingleWordInOperand(1u);
604 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
605 {SPV_OPERAND_TYPE_ID, {merged_id}}});
606 return true;
607 }
608
609 return false;
610 };
611 }
612
613 // Merges divides into subsequent multiplies if each instruction contains one
614 // constant operand. Does not support integer operations.
615 // Cases:
616 // 2 * (x / 2) = x * 1
617 // 2 * (2 / x) = 4 / x
618 // (x / 2) * 2 = x * 1
619 // (2 / x) * 2 = 4 / x
620 // (y / x) * x = y
621 // x * (y / x) = y
MergeMulDivArithmetic()622 FoldingRule MergeMulDivArithmetic() {
623 return [](IRContext* context, Instruction* inst,
624 const std::vector<const analysis::Constant*>& constants) {
625 assert(inst->opcode() == SpvOpFMul);
626 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
627 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
628
629 const analysis::Type* type =
630 context->get_type_mgr()->GetType(inst->type_id());
631 if (!inst->IsFloatingPointFoldingAllowed()) return false;
632
633 uint32_t width = ElementWidth(type);
634 if (width != 32 && width != 64) return false;
635
636 for (uint32_t i = 0; i < 2; i++) {
637 uint32_t op_id = inst->GetSingleWordInOperand(i);
638 Instruction* op_inst = def_use_mgr->GetDef(op_id);
639 if (op_inst->opcode() == SpvOpFDiv) {
640 if (op_inst->GetSingleWordInOperand(1) ==
641 inst->GetSingleWordInOperand(1 - i)) {
642 inst->SetOpcode(SpvOpCopyObject);
643 inst->SetInOperands(
644 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
645 return true;
646 }
647 }
648 }
649
650 const analysis::Constant* const_input1 = ConstInput(constants);
651 if (!const_input1) return false;
652 Instruction* other_inst = NonConstInput(context, constants[0], inst);
653 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
654
655 if (other_inst->opcode() == SpvOpFDiv) {
656 std::vector<const analysis::Constant*> other_constants =
657 const_mgr->GetOperandConstants(other_inst);
658 const analysis::Constant* const_input2 = ConstInput(other_constants);
659 if (!const_input2 || HasZero(const_input2)) return false;
660
661 bool other_first_is_variable = other_constants[0] == nullptr;
662 // If the variable value is the second operand of the divide, multiply
663 // the constants together. Otherwise divide the constants.
664 uint32_t merged_id = PerformOperation(
665 const_mgr,
666 other_first_is_variable ? other_inst->opcode() : inst->opcode(),
667 const_input1, const_input2);
668 if (merged_id == 0) return false;
669
670 uint32_t non_const_id = other_first_is_variable
671 ? other_inst->GetSingleWordInOperand(0u)
672 : other_inst->GetSingleWordInOperand(1u);
673
674 // If the variable value is on the second operand of the div, then this
675 // operation is a div. Otherwise it should be a multiply.
676 inst->SetOpcode(other_first_is_variable ? inst->opcode()
677 : other_inst->opcode());
678 if (other_first_is_variable) {
679 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
680 {SPV_OPERAND_TYPE_ID, {merged_id}}});
681 } else {
682 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
683 {SPV_OPERAND_TYPE_ID, {non_const_id}}});
684 }
685 return true;
686 }
687
688 return false;
689 };
690 }
691
692 // Merges multiply of constant and negation.
693 // Cases:
694 // (-x) * 2 = x * -2
695 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()696 FoldingRule MergeMulNegateArithmetic() {
697 return [](IRContext* context, Instruction* inst,
698 const std::vector<const analysis::Constant*>& constants) {
699 assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
700 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
701 const analysis::Type* type =
702 context->get_type_mgr()->GetType(inst->type_id());
703 bool uses_float = HasFloatingPoint(type);
704 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
705
706 uint32_t width = ElementWidth(type);
707 if (width != 32 && width != 64) return false;
708
709 const analysis::Constant* const_input1 = ConstInput(constants);
710 if (!const_input1) return false;
711 Instruction* other_inst = NonConstInput(context, constants[0], inst);
712 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
713 return false;
714
715 if (other_inst->opcode() == SpvOpFNegate ||
716 other_inst->opcode() == SpvOpSNegate) {
717 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
718
719 inst->SetInOperands(
720 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
721 {SPV_OPERAND_TYPE_ID, {neg_id}}});
722 return true;
723 }
724
725 return false;
726 };
727 }
728
729 // Merges consecutive divides if each instruction contains one constant operand.
730 // Does not support integer division.
731 // Cases:
732 // 2 / (x / 2) = 4 / x
733 // 4 / (2 / x) = 2 * x
734 // (4 / x) / 2 = 2 / x
735 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()736 FoldingRule MergeDivDivArithmetic() {
737 return [](IRContext* context, Instruction* inst,
738 const std::vector<const analysis::Constant*>& constants) {
739 assert(inst->opcode() == SpvOpFDiv);
740 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
741 const analysis::Type* type =
742 context->get_type_mgr()->GetType(inst->type_id());
743 if (!inst->IsFloatingPointFoldingAllowed()) return false;
744
745 uint32_t width = ElementWidth(type);
746 if (width != 32 && width != 64) return false;
747
748 const analysis::Constant* const_input1 = ConstInput(constants);
749 if (!const_input1 || HasZero(const_input1)) return false;
750 Instruction* other_inst = NonConstInput(context, constants[0], inst);
751 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
752
753 bool first_is_variable = constants[0] == nullptr;
754 if (other_inst->opcode() == inst->opcode()) {
755 std::vector<const analysis::Constant*> other_constants =
756 const_mgr->GetOperandConstants(other_inst);
757 const analysis::Constant* const_input2 = ConstInput(other_constants);
758 if (!const_input2 || HasZero(const_input2)) return false;
759
760 bool other_first_is_variable = other_constants[0] == nullptr;
761
762 SpvOp merge_op = inst->opcode();
763 if (other_first_is_variable) {
764 // Constants magnify.
765 merge_op = SpvOpFMul;
766 }
767
768 // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
769 // because it is commutative.
770 if (first_is_variable) std::swap(const_input1, const_input2);
771 uint32_t merged_id =
772 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
773 if (merged_id == 0) return false;
774
775 uint32_t non_const_id = other_first_is_variable
776 ? other_inst->GetSingleWordInOperand(0u)
777 : other_inst->GetSingleWordInOperand(1u);
778
779 SpvOp op = inst->opcode();
780 if (!first_is_variable && !other_first_is_variable) {
781 // Effectively div of 1/x, so change to multiply.
782 op = SpvOpFMul;
783 }
784
785 uint32_t op1 = merged_id;
786 uint32_t op2 = non_const_id;
787 if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
788 inst->SetOpcode(op);
789 inst->SetInOperands(
790 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
791 return true;
792 }
793
794 return false;
795 };
796 }
797
798 // Fold multiplies succeeded by divides where each instruction contains a
799 // constant operand. Does not support integer divide.
800 // Cases:
801 // 4 / (x * 2) = 2 / x
802 // 4 / (2 * x) = 2 / x
803 // (x * 4) / 2 = x * 2
804 // (4 * x) / 2 = x * 2
805 // (x * y) / x = y
806 // (y * x) / x = y
MergeDivMulArithmetic()807 FoldingRule MergeDivMulArithmetic() {
808 return [](IRContext* context, Instruction* inst,
809 const std::vector<const analysis::Constant*>& constants) {
810 assert(inst->opcode() == SpvOpFDiv);
811 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
812 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
813
814 const analysis::Type* type =
815 context->get_type_mgr()->GetType(inst->type_id());
816 if (!inst->IsFloatingPointFoldingAllowed()) return false;
817
818 uint32_t width = ElementWidth(type);
819 if (width != 32 && width != 64) return false;
820
821 uint32_t op_id = inst->GetSingleWordInOperand(0);
822 Instruction* op_inst = def_use_mgr->GetDef(op_id);
823
824 if (op_inst->opcode() == SpvOpFMul) {
825 for (uint32_t i = 0; i < 2; i++) {
826 if (op_inst->GetSingleWordInOperand(i) ==
827 inst->GetSingleWordInOperand(1)) {
828 inst->SetOpcode(SpvOpCopyObject);
829 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
830 {op_inst->GetSingleWordInOperand(1 - i)}}});
831 return true;
832 }
833 }
834 }
835
836 const analysis::Constant* const_input1 = ConstInput(constants);
837 if (!const_input1 || HasZero(const_input1)) return false;
838 Instruction* other_inst = NonConstInput(context, constants[0], inst);
839 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
840
841 bool first_is_variable = constants[0] == nullptr;
842 if (other_inst->opcode() == SpvOpFMul) {
843 std::vector<const analysis::Constant*> other_constants =
844 const_mgr->GetOperandConstants(other_inst);
845 const analysis::Constant* const_input2 = ConstInput(other_constants);
846 if (!const_input2) return false;
847
848 bool other_first_is_variable = other_constants[0] == nullptr;
849
850 // This is an x / (*) case. Swap the inputs.
851 if (first_is_variable) std::swap(const_input1, const_input2);
852 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
853 const_input1, const_input2);
854 if (merged_id == 0) return false;
855
856 uint32_t non_const_id = other_first_is_variable
857 ? other_inst->GetSingleWordInOperand(0u)
858 : other_inst->GetSingleWordInOperand(1u);
859
860 uint32_t op1 = merged_id;
861 uint32_t op2 = non_const_id;
862 if (first_is_variable) std::swap(op1, op2);
863
864 // Convert to multiply
865 if (first_is_variable) inst->SetOpcode(other_inst->opcode());
866 inst->SetInOperands(
867 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
868 return true;
869 }
870
871 return false;
872 };
873 }
874
875 // Fold divides of a constant and a negation.
876 // Cases:
877 // (-x) / 2 = x / -2
878 // 2 / (-x) = 2 / -x
MergeDivNegateArithmetic()879 FoldingRule MergeDivNegateArithmetic() {
880 return [](IRContext* context, Instruction* inst,
881 const std::vector<const analysis::Constant*>& constants) {
882 assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
883 inst->opcode() == SpvOpUDiv);
884 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
885 const analysis::Type* type =
886 context->get_type_mgr()->GetType(inst->type_id());
887 bool uses_float = HasFloatingPoint(type);
888 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
889
890 uint32_t width = ElementWidth(type);
891 if (width != 32 && width != 64) return false;
892
893 const analysis::Constant* const_input1 = ConstInput(constants);
894 if (!const_input1) return false;
895 Instruction* other_inst = NonConstInput(context, constants[0], inst);
896 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
897 return false;
898
899 bool first_is_variable = constants[0] == nullptr;
900 if (other_inst->opcode() == SpvOpFNegate ||
901 other_inst->opcode() == SpvOpSNegate) {
902 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
903
904 if (first_is_variable) {
905 inst->SetInOperands(
906 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
907 {SPV_OPERAND_TYPE_ID, {neg_id}}});
908 } else {
909 inst->SetInOperands(
910 {{SPV_OPERAND_TYPE_ID, {neg_id}},
911 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
912 }
913 return true;
914 }
915
916 return false;
917 };
918 }
919
920 // Folds addition of a constant and a negation.
921 // Cases:
922 // (-x) + 2 = 2 - x
923 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()924 FoldingRule MergeAddNegateArithmetic() {
925 return [](IRContext* context, Instruction* inst,
926 const std::vector<const analysis::Constant*>& constants) {
927 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
928 const analysis::Type* type =
929 context->get_type_mgr()->GetType(inst->type_id());
930 bool uses_float = HasFloatingPoint(type);
931 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
932
933 const analysis::Constant* const_input1 = ConstInput(constants);
934 if (!const_input1) return false;
935 Instruction* other_inst = NonConstInput(context, constants[0], inst);
936 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
937 return false;
938
939 if (other_inst->opcode() == SpvOpSNegate ||
940 other_inst->opcode() == SpvOpFNegate) {
941 inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
942 uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
943 : inst->GetSingleWordInOperand(1u);
944 inst->SetInOperands(
945 {{SPV_OPERAND_TYPE_ID, {const_id}},
946 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
947 return true;
948 }
949 return false;
950 };
951 }
952
953 // Folds subtraction of a constant and a negation.
954 // Cases:
955 // (-x) - 2 = -2 - x
956 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()957 FoldingRule MergeSubNegateArithmetic() {
958 return [](IRContext* context, Instruction* inst,
959 const std::vector<const analysis::Constant*>& constants) {
960 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
961 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
962 const analysis::Type* type =
963 context->get_type_mgr()->GetType(inst->type_id());
964 bool uses_float = HasFloatingPoint(type);
965 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
966
967 uint32_t width = ElementWidth(type);
968 if (width != 32 && width != 64) return false;
969
970 const analysis::Constant* const_input1 = ConstInput(constants);
971 if (!const_input1) return false;
972 Instruction* other_inst = NonConstInput(context, constants[0], inst);
973 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
974 return false;
975
976 if (other_inst->opcode() == SpvOpSNegate ||
977 other_inst->opcode() == SpvOpFNegate) {
978 uint32_t op1 = 0;
979 uint32_t op2 = 0;
980 SpvOp opcode = inst->opcode();
981 if (constants[0] != nullptr) {
982 op1 = other_inst->GetSingleWordInOperand(0u);
983 op2 = inst->GetSingleWordInOperand(0u);
984 opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
985 } else {
986 op1 = NegateConstant(const_mgr, const_input1);
987 op2 = other_inst->GetSingleWordInOperand(0u);
988 }
989
990 inst->SetOpcode(opcode);
991 inst->SetInOperands(
992 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
993 return true;
994 }
995 return false;
996 };
997 }
998
999 // Folds addition of an addition where each operation has a constant operand.
1000 // Cases:
1001 // (x + 2) + 2 = x + 4
1002 // (2 + x) + 2 = x + 4
1003 // 2 + (x + 2) = x + 4
1004 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1005 FoldingRule MergeAddAddArithmetic() {
1006 return [](IRContext* context, Instruction* inst,
1007 const std::vector<const analysis::Constant*>& constants) {
1008 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1009 const analysis::Type* type =
1010 context->get_type_mgr()->GetType(inst->type_id());
1011 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1012 bool uses_float = HasFloatingPoint(type);
1013 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1014
1015 uint32_t width = ElementWidth(type);
1016 if (width != 32 && width != 64) return false;
1017
1018 const analysis::Constant* const_input1 = ConstInput(constants);
1019 if (!const_input1) return false;
1020 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1021 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1022 return false;
1023
1024 if (other_inst->opcode() == SpvOpFAdd ||
1025 other_inst->opcode() == SpvOpIAdd) {
1026 std::vector<const analysis::Constant*> other_constants =
1027 const_mgr->GetOperandConstants(other_inst);
1028 const analysis::Constant* const_input2 = ConstInput(other_constants);
1029 if (!const_input2) return false;
1030
1031 Instruction* non_const_input =
1032 NonConstInput(context, other_constants[0], other_inst);
1033 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1034 const_input1, const_input2);
1035 if (merged_id == 0) return false;
1036
1037 inst->SetInOperands(
1038 {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1039 {SPV_OPERAND_TYPE_ID, {merged_id}}});
1040 return true;
1041 }
1042 return false;
1043 };
1044 }
1045
1046 // Folds addition of a subtraction where each operation has a constant operand.
1047 // Cases:
1048 // (x - 2) + 2 = x + 0
1049 // (2 - x) + 2 = 4 - x
1050 // 2 + (x - 2) = x + 0
1051 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1052 FoldingRule MergeAddSubArithmetic() {
1053 return [](IRContext* context, Instruction* inst,
1054 const std::vector<const analysis::Constant*>& constants) {
1055 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1056 const analysis::Type* type =
1057 context->get_type_mgr()->GetType(inst->type_id());
1058 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1059 bool uses_float = HasFloatingPoint(type);
1060 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1061
1062 uint32_t width = ElementWidth(type);
1063 if (width != 32 && width != 64) return false;
1064
1065 const analysis::Constant* const_input1 = ConstInput(constants);
1066 if (!const_input1) return false;
1067 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1068 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1069 return false;
1070
1071 if (other_inst->opcode() == SpvOpFSub ||
1072 other_inst->opcode() == SpvOpISub) {
1073 std::vector<const analysis::Constant*> other_constants =
1074 const_mgr->GetOperandConstants(other_inst);
1075 const analysis::Constant* const_input2 = ConstInput(other_constants);
1076 if (!const_input2) return false;
1077
1078 bool first_is_variable = other_constants[0] == nullptr;
1079 SpvOp op = inst->opcode();
1080 uint32_t op1 = 0;
1081 uint32_t op2 = 0;
1082 if (first_is_variable) {
1083 // Subtract constants. Non-constant operand is first.
1084 op1 = other_inst->GetSingleWordInOperand(0u);
1085 op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1086 const_input2);
1087 } else {
1088 // Add constants. Constant operand is first. Change the opcode.
1089 op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1090 const_input2);
1091 op2 = other_inst->GetSingleWordInOperand(1u);
1092 op = other_inst->opcode();
1093 }
1094 if (op1 == 0 || op2 == 0) return false;
1095
1096 inst->SetOpcode(op);
1097 inst->SetInOperands(
1098 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1099 return true;
1100 }
1101 return false;
1102 };
1103 }
1104
1105 // Folds subtraction of an addition where each operand has a constant operand.
1106 // Cases:
1107 // (x + 2) - 2 = x + 0
1108 // (2 + x) - 2 = x + 0
1109 // 2 - (x + 2) = 0 - x
1110 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1111 FoldingRule MergeSubAddArithmetic() {
1112 return [](IRContext* context, Instruction* inst,
1113 const std::vector<const analysis::Constant*>& constants) {
1114 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1115 const analysis::Type* type =
1116 context->get_type_mgr()->GetType(inst->type_id());
1117 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1118 bool uses_float = HasFloatingPoint(type);
1119 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1120
1121 uint32_t width = ElementWidth(type);
1122 if (width != 32 && width != 64) return false;
1123
1124 const analysis::Constant* const_input1 = ConstInput(constants);
1125 if (!const_input1) return false;
1126 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1127 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1128 return false;
1129
1130 if (other_inst->opcode() == SpvOpFAdd ||
1131 other_inst->opcode() == SpvOpIAdd) {
1132 std::vector<const analysis::Constant*> other_constants =
1133 const_mgr->GetOperandConstants(other_inst);
1134 const analysis::Constant* const_input2 = ConstInput(other_constants);
1135 if (!const_input2) return false;
1136
1137 Instruction* non_const_input =
1138 NonConstInput(context, other_constants[0], other_inst);
1139
1140 // If the first operand of the sub is not a constant, swap the constants
1141 // so the subtraction has the correct operands.
1142 if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1143 // Subtract the constants.
1144 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1145 const_input1, const_input2);
1146 SpvOp op = inst->opcode();
1147 uint32_t op1 = 0;
1148 uint32_t op2 = 0;
1149 if (constants[0] == nullptr) {
1150 // Non-constant operand is first. Change the opcode.
1151 op1 = non_const_input->result_id();
1152 op2 = merged_id;
1153 op = other_inst->opcode();
1154 } else {
1155 // Constant operand is first.
1156 op1 = merged_id;
1157 op2 = non_const_input->result_id();
1158 }
1159 if (op1 == 0 || op2 == 0) return false;
1160
1161 inst->SetOpcode(op);
1162 inst->SetInOperands(
1163 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1164 return true;
1165 }
1166 return false;
1167 };
1168 }
1169
1170 // Folds subtraction of a subtraction where each operand has a constant operand.
1171 // Cases:
1172 // (x - 2) - 2 = x - 4
1173 // (2 - x) - 2 = 0 - x
1174 // 2 - (x - 2) = 4 - x
1175 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1176 FoldingRule MergeSubSubArithmetic() {
1177 return [](IRContext* context, Instruction* inst,
1178 const std::vector<const analysis::Constant*>& constants) {
1179 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1180 const analysis::Type* type =
1181 context->get_type_mgr()->GetType(inst->type_id());
1182 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1183 bool uses_float = HasFloatingPoint(type);
1184 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1185
1186 uint32_t width = ElementWidth(type);
1187 if (width != 32 && width != 64) return false;
1188
1189 const analysis::Constant* const_input1 = ConstInput(constants);
1190 if (!const_input1) return false;
1191 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1192 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1193 return false;
1194
1195 if (other_inst->opcode() == SpvOpFSub ||
1196 other_inst->opcode() == SpvOpISub) {
1197 std::vector<const analysis::Constant*> other_constants =
1198 const_mgr->GetOperandConstants(other_inst);
1199 const analysis::Constant* const_input2 = ConstInput(other_constants);
1200 if (!const_input2) return false;
1201
1202 Instruction* non_const_input =
1203 NonConstInput(context, other_constants[0], other_inst);
1204
1205 // Merge the constants.
1206 uint32_t merged_id = 0;
1207 SpvOp merge_op = inst->opcode();
1208 if (other_constants[0] == nullptr) {
1209 merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1210 } else if (constants[0] == nullptr) {
1211 std::swap(const_input1, const_input2);
1212 }
1213 merged_id =
1214 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1215 if (merged_id == 0) return false;
1216
1217 SpvOp op = inst->opcode();
1218 if (constants[0] != nullptr && other_constants[0] != nullptr) {
1219 // Change the operation.
1220 op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1221 }
1222
1223 uint32_t op1 = 0;
1224 uint32_t op2 = 0;
1225 if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1226 op1 = merged_id;
1227 op2 = non_const_input->result_id();
1228 } else {
1229 op1 = non_const_input->result_id();
1230 op2 = merged_id;
1231 }
1232
1233 inst->SetOpcode(op);
1234 inst->SetInOperands(
1235 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1236 return true;
1237 }
1238 return false;
1239 };
1240 }
1241
IntMultipleBy1()1242 FoldingRule IntMultipleBy1() {
1243 return [](IRContext*, Instruction* inst,
1244 const std::vector<const analysis::Constant*>& constants) {
1245 assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
1246 for (uint32_t i = 0; i < 2; i++) {
1247 if (constants[i] == nullptr) {
1248 continue;
1249 }
1250 const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1251 if (int_constant) {
1252 uint32_t width = ElementWidth(int_constant->type());
1253 if (width != 32 && width != 64) return false;
1254 bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1255 : int_constant->GetU64BitValue() == 1ull;
1256 if (is_one) {
1257 inst->SetOpcode(SpvOpCopyObject);
1258 inst->SetInOperands(
1259 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1260 return true;
1261 }
1262 }
1263 }
1264 return false;
1265 };
1266 }
1267
CompositeConstructFeedingExtract()1268 FoldingRule CompositeConstructFeedingExtract() {
1269 return [](IRContext* context, Instruction* inst,
1270 const std::vector<const analysis::Constant*>&) {
1271 // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1272 // then we can simply use the appropriate element in the construction.
1273 assert(inst->opcode() == SpvOpCompositeExtract &&
1274 "Wrong opcode. Should be OpCompositeExtract.");
1275 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1276 analysis::TypeManager* type_mgr = context->get_type_mgr();
1277 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1278 Instruction* cinst = def_use_mgr->GetDef(cid);
1279
1280 if (cinst->opcode() != SpvOpCompositeConstruct) {
1281 return false;
1282 }
1283
1284 std::vector<Operand> operands;
1285 analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
1286 if (composite_type->AsVector() == nullptr) {
1287 // Get the element being extracted from the OpCompositeConstruct
1288 // Since it is not a vector, it is simple to extract the single element.
1289 uint32_t element_index = inst->GetSingleWordInOperand(1);
1290 uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
1291 operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1292
1293 // Add the remaining indices for extraction.
1294 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1295 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1296 {inst->GetSingleWordInOperand(i)}});
1297 }
1298
1299 } else {
1300 // With vectors we have to handle the case where it is concatenating
1301 // vectors.
1302 assert(inst->NumInOperands() == 2 &&
1303 "Expecting a vector of scalar values.");
1304
1305 uint32_t element_index = inst->GetSingleWordInOperand(1);
1306 for (uint32_t construct_index = 0;
1307 construct_index < cinst->NumInOperands(); ++construct_index) {
1308 uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
1309 Instruction* element_def = def_use_mgr->GetDef(element_id);
1310 analysis::Vector* element_type =
1311 type_mgr->GetType(element_def->type_id())->AsVector();
1312 if (element_type) {
1313 uint32_t vector_size = element_type->element_count();
1314 if (vector_size < element_index) {
1315 // The element we want comes after this vector.
1316 element_index -= vector_size;
1317 } else {
1318 // We want an element of this vector.
1319 operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1320 operands.push_back(
1321 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
1322 break;
1323 }
1324 } else {
1325 if (element_index == 0) {
1326 // This is a scalar, and we this is the element we are extracting.
1327 operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1328 break;
1329 } else {
1330 // Skip over this scalar value.
1331 --element_index;
1332 }
1333 }
1334 }
1335 }
1336
1337 // If there were no extra indices, then we have the final object. No need
1338 // to extract even more.
1339 if (operands.size() == 1) {
1340 inst->SetOpcode(SpvOpCopyObject);
1341 }
1342
1343 inst->SetInOperands(std::move(operands));
1344 return true;
1345 };
1346 }
1347
CompositeExtractFeedingConstruct()1348 FoldingRule CompositeExtractFeedingConstruct() {
1349 // If the OpCompositeConstruct is simply putting back together elements that
1350 // where extracted from the same souce, we can simlpy reuse the source.
1351 //
1352 // This is a common code pattern because of the way that scalar replacement
1353 // works.
1354 return [](IRContext* context, Instruction* inst,
1355 const std::vector<const analysis::Constant*>&) {
1356 assert(inst->opcode() == SpvOpCompositeConstruct &&
1357 "Wrong opcode. Should be OpCompositeConstruct.");
1358 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1359 uint32_t original_id = 0;
1360
1361 // Check each element to make sure they are:
1362 // - extractions
1363 // - extracting the same position they are inserting
1364 // - all extract from the same id.
1365 for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1366 uint32_t element_id = inst->GetSingleWordInOperand(i);
1367 Instruction* element_inst = def_use_mgr->GetDef(element_id);
1368
1369 if (element_inst->opcode() != SpvOpCompositeExtract) {
1370 return false;
1371 }
1372
1373 if (element_inst->NumInOperands() != 2) {
1374 return false;
1375 }
1376
1377 if (element_inst->GetSingleWordInOperand(1) != i) {
1378 return false;
1379 }
1380
1381 if (i == 0) {
1382 original_id =
1383 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1384 } else if (original_id != element_inst->GetSingleWordInOperand(
1385 kExtractCompositeIdInIdx)) {
1386 return false;
1387 }
1388 }
1389
1390 // The last check it to see that the object being extracted from is the
1391 // correct type.
1392 Instruction* original_inst = def_use_mgr->GetDef(original_id);
1393 if (original_inst->type_id() != inst->type_id()) {
1394 return false;
1395 }
1396
1397 // Simplify by using the original object.
1398 inst->SetOpcode(SpvOpCopyObject);
1399 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1400 return true;
1401 };
1402 }
1403
InsertFeedingExtract()1404 FoldingRule InsertFeedingExtract() {
1405 return [](IRContext* context, Instruction* inst,
1406 const std::vector<const analysis::Constant*>&) {
1407 assert(inst->opcode() == SpvOpCompositeExtract &&
1408 "Wrong opcode. Should be OpCompositeExtract.");
1409 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1410 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1411 Instruction* cinst = def_use_mgr->GetDef(cid);
1412
1413 if (cinst->opcode() != SpvOpCompositeInsert) {
1414 return false;
1415 }
1416
1417 // Find the first position where the list of insert and extract indicies
1418 // differ, if at all.
1419 uint32_t i;
1420 for (i = 1; i < inst->NumInOperands(); ++i) {
1421 if (i + 1 >= cinst->NumInOperands()) {
1422 break;
1423 }
1424
1425 if (inst->GetSingleWordInOperand(i) !=
1426 cinst->GetSingleWordInOperand(i + 1)) {
1427 break;
1428 }
1429 }
1430
1431 // We are extracting the element that was inserted.
1432 if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1433 inst->SetOpcode(SpvOpCopyObject);
1434 inst->SetInOperands(
1435 {{SPV_OPERAND_TYPE_ID,
1436 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1437 return true;
1438 }
1439
1440 // Extracting the value that was inserted along with values for the base
1441 // composite. Cannot do anything.
1442 if (i == inst->NumInOperands()) {
1443 return false;
1444 }
1445
1446 // Extracting an element of the value that was inserted. Extract from
1447 // that value directly.
1448 if (i + 1 == cinst->NumInOperands()) {
1449 std::vector<Operand> operands;
1450 operands.push_back(
1451 {SPV_OPERAND_TYPE_ID,
1452 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1453 for (; i < inst->NumInOperands(); ++i) {
1454 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1455 {inst->GetSingleWordInOperand(i)}});
1456 }
1457 inst->SetInOperands(std::move(operands));
1458 return true;
1459 }
1460
1461 // Extracting a value that is disjoint from the element being inserted.
1462 // Rewrite the extract to use the composite input to the insert.
1463 std::vector<Operand> operands;
1464 operands.push_back(
1465 {SPV_OPERAND_TYPE_ID,
1466 {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1467 for (i = 1; i < inst->NumInOperands(); ++i) {
1468 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1469 {inst->GetSingleWordInOperand(i)}});
1470 }
1471 inst->SetInOperands(std::move(operands));
1472 return true;
1473 };
1474 }
1475
1476 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1477 // operands of the VectorShuffle. We just need to adjust the index in the
1478 // extract instruction.
VectorShuffleFeedingExtract()1479 FoldingRule VectorShuffleFeedingExtract() {
1480 return [](IRContext* context, Instruction* inst,
1481 const std::vector<const analysis::Constant*>&) {
1482 assert(inst->opcode() == SpvOpCompositeExtract &&
1483 "Wrong opcode. Should be OpCompositeExtract.");
1484 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1485 analysis::TypeManager* type_mgr = context->get_type_mgr();
1486 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1487 Instruction* cinst = def_use_mgr->GetDef(cid);
1488
1489 if (cinst->opcode() != SpvOpVectorShuffle) {
1490 return false;
1491 }
1492
1493 // Find the size of the first vector operand of the VectorShuffle
1494 Instruction* first_input =
1495 def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1496 analysis::Type* first_input_type =
1497 type_mgr->GetType(first_input->type_id());
1498 assert(first_input_type->AsVector() &&
1499 "Input to vector shuffle should be vectors.");
1500 uint32_t first_input_size = first_input_type->AsVector()->element_count();
1501
1502 // Get index of the element the vector shuffle is placing in the position
1503 // being extracted.
1504 uint32_t new_index =
1505 cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1506
1507 // Extracting an undefined value so fold this extract into an undef.
1508 const uint32_t undef_literal_value = 0xffffffff;
1509 if (new_index == undef_literal_value) {
1510 inst->SetOpcode(SpvOpUndef);
1511 inst->SetInOperands({});
1512 return true;
1513 }
1514
1515 // Get the id of the of the vector the elemtent comes from, and update the
1516 // index if needed.
1517 uint32_t new_vector = 0;
1518 if (new_index < first_input_size) {
1519 new_vector = cinst->GetSingleWordInOperand(0);
1520 } else {
1521 new_vector = cinst->GetSingleWordInOperand(1);
1522 new_index -= first_input_size;
1523 }
1524
1525 // Update the extract instruction.
1526 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1527 inst->SetInOperand(1, {new_index});
1528 return true;
1529 };
1530 }
1531
1532 // When an FMix with is feeding an Extract that extracts an element whose
1533 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1534 // operands of the FMix.
FMixFeedingExtract()1535 FoldingRule FMixFeedingExtract() {
1536 return [](IRContext* context, Instruction* inst,
1537 const std::vector<const analysis::Constant*>&) {
1538 assert(inst->opcode() == SpvOpCompositeExtract &&
1539 "Wrong opcode. Should be OpCompositeExtract.");
1540 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1541 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1542
1543 uint32_t composite_id =
1544 inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1545 Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1546
1547 if (composite_inst->opcode() != SpvOpExtInst) {
1548 return false;
1549 }
1550
1551 uint32_t inst_set_id =
1552 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1553
1554 if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1555 inst_set_id ||
1556 composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1557 GLSLstd450FMix) {
1558 return false;
1559 }
1560
1561 // Get the |a| for the FMix instruction.
1562 uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1563 std::unique_ptr<Instruction> a(inst->Clone(context));
1564 a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1565 context->get_instruction_folder().FoldInstruction(a.get());
1566
1567 if (a->opcode() != SpvOpCopyObject) {
1568 return false;
1569 }
1570
1571 const analysis::Constant* a_const =
1572 const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1573
1574 if (!a_const) {
1575 return false;
1576 }
1577
1578 bool use_x = false;
1579
1580 assert(a_const->type()->AsFloat());
1581 double element_value = a_const->GetValueAsDouble();
1582 if (element_value == 0.0) {
1583 use_x = true;
1584 } else if (element_value == 1.0) {
1585 use_x = false;
1586 } else {
1587 return false;
1588 }
1589
1590 // Get the id of the of the vector the element comes from.
1591 uint32_t new_vector = 0;
1592 if (use_x) {
1593 new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1594 } else {
1595 new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1596 }
1597
1598 // Update the extract instruction.
1599 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1600 return true;
1601 };
1602 }
1603
RedundantPhi()1604 FoldingRule RedundantPhi() {
1605 // An OpPhi instruction where all values are the same or the result of the phi
1606 // itself, can be replaced by the value itself.
1607 return [](IRContext*, Instruction* inst,
1608 const std::vector<const analysis::Constant*>&) {
1609 assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
1610
1611 uint32_t incoming_value = 0;
1612
1613 for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
1614 uint32_t op_id = inst->GetSingleWordInOperand(i);
1615 if (op_id == inst->result_id()) {
1616 continue;
1617 }
1618
1619 if (incoming_value == 0) {
1620 incoming_value = op_id;
1621 } else if (op_id != incoming_value) {
1622 // Found two possible value. Can't simplify.
1623 return false;
1624 }
1625 }
1626
1627 if (incoming_value == 0) {
1628 // Code looks invalid. Don't do anything.
1629 return false;
1630 }
1631
1632 // We have a single incoming value. Simplify using that value.
1633 inst->SetOpcode(SpvOpCopyObject);
1634 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
1635 return true;
1636 };
1637 }
1638
RedundantSelect()1639 FoldingRule RedundantSelect() {
1640 // An OpSelect instruction where both values are the same or the condition is
1641 // constant can be replaced by one of the values
1642 return [](IRContext*, Instruction* inst,
1643 const std::vector<const analysis::Constant*>& constants) {
1644 assert(inst->opcode() == SpvOpSelect &&
1645 "Wrong opcode. Should be OpSelect.");
1646 assert(inst->NumInOperands() == 3);
1647 assert(constants.size() == 3);
1648
1649 uint32_t true_id = inst->GetSingleWordInOperand(1);
1650 uint32_t false_id = inst->GetSingleWordInOperand(2);
1651
1652 if (true_id == false_id) {
1653 // Both results are the same, condition doesn't matter
1654 inst->SetOpcode(SpvOpCopyObject);
1655 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1656 return true;
1657 } else if (constants[0]) {
1658 const analysis::Type* type = constants[0]->type();
1659 if (type->AsBool()) {
1660 // Scalar constant value, select the corresponding value.
1661 inst->SetOpcode(SpvOpCopyObject);
1662 if (constants[0]->AsNullConstant() ||
1663 !constants[0]->AsBoolConstant()->value()) {
1664 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1665 } else {
1666 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1667 }
1668 return true;
1669 } else {
1670 assert(type->AsVector());
1671 if (constants[0]->AsNullConstant()) {
1672 // All values come from false id.
1673 inst->SetOpcode(SpvOpCopyObject);
1674 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1675 return true;
1676 } else {
1677 // Convert to a vector shuffle.
1678 std::vector<Operand> ops;
1679 ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
1680 ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
1681 const analysis::VectorConstant* vector_const =
1682 constants[0]->AsVectorConstant();
1683 uint32_t size =
1684 static_cast<uint32_t>(vector_const->GetComponents().size());
1685 for (uint32_t i = 0; i != size; ++i) {
1686 const analysis::Constant* component =
1687 vector_const->GetComponents()[i];
1688 if (component->AsNullConstant() ||
1689 !component->AsBoolConstant()->value()) {
1690 // Selecting from the false vector which is the second input
1691 // vector to the shuffle. Offset the index by |size|.
1692 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
1693 } else {
1694 // Selecting from true vector which is the first input vector to
1695 // the shuffle.
1696 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
1697 }
1698 }
1699
1700 inst->SetOpcode(SpvOpVectorShuffle);
1701 inst->SetInOperands(std::move(ops));
1702 return true;
1703 }
1704 }
1705 }
1706
1707 return false;
1708 };
1709 }
1710
1711 enum class FloatConstantKind { Unknown, Zero, One };
1712
getFloatConstantKind(const analysis::Constant * constant)1713 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
1714 if (constant == nullptr) {
1715 return FloatConstantKind::Unknown;
1716 }
1717
1718 assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
1719
1720 if (constant->AsNullConstant()) {
1721 return FloatConstantKind::Zero;
1722 } else if (const analysis::VectorConstant* vc =
1723 constant->AsVectorConstant()) {
1724 const std::vector<const analysis::Constant*>& components =
1725 vc->GetComponents();
1726 assert(!components.empty());
1727
1728 FloatConstantKind kind = getFloatConstantKind(components[0]);
1729
1730 for (size_t i = 1; i < components.size(); ++i) {
1731 if (getFloatConstantKind(components[i]) != kind) {
1732 return FloatConstantKind::Unknown;
1733 }
1734 }
1735
1736 return kind;
1737 } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
1738 if (fc->IsZero()) return FloatConstantKind::Zero;
1739
1740 uint32_t width = fc->type()->AsFloat()->width();
1741 if (width != 32 && width != 64) return FloatConstantKind::Unknown;
1742
1743 double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
1744
1745 if (value == 0.0) {
1746 return FloatConstantKind::Zero;
1747 } else if (value == 1.0) {
1748 return FloatConstantKind::One;
1749 } else {
1750 return FloatConstantKind::Unknown;
1751 }
1752 } else {
1753 return FloatConstantKind::Unknown;
1754 }
1755 }
1756
RedundantFAdd()1757 FoldingRule RedundantFAdd() {
1758 return [](IRContext*, Instruction* inst,
1759 const std::vector<const analysis::Constant*>& constants) {
1760 assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
1761 assert(constants.size() == 2);
1762
1763 if (!inst->IsFloatingPointFoldingAllowed()) {
1764 return false;
1765 }
1766
1767 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1768 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1769
1770 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
1771 inst->SetOpcode(SpvOpCopyObject);
1772 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1773 {inst->GetSingleWordInOperand(
1774 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
1775 return true;
1776 }
1777
1778 return false;
1779 };
1780 }
1781
RedundantFSub()1782 FoldingRule RedundantFSub() {
1783 return [](IRContext*, Instruction* inst,
1784 const std::vector<const analysis::Constant*>& constants) {
1785 assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
1786 assert(constants.size() == 2);
1787
1788 if (!inst->IsFloatingPointFoldingAllowed()) {
1789 return false;
1790 }
1791
1792 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1793 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1794
1795 if (kind0 == FloatConstantKind::Zero) {
1796 inst->SetOpcode(SpvOpFNegate);
1797 inst->SetInOperands(
1798 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
1799 return true;
1800 }
1801
1802 if (kind1 == FloatConstantKind::Zero) {
1803 inst->SetOpcode(SpvOpCopyObject);
1804 inst->SetInOperands(
1805 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
1806 return true;
1807 }
1808
1809 return false;
1810 };
1811 }
1812
RedundantFMul()1813 FoldingRule RedundantFMul() {
1814 return [](IRContext*, Instruction* inst,
1815 const std::vector<const analysis::Constant*>& constants) {
1816 assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
1817 assert(constants.size() == 2);
1818
1819 if (!inst->IsFloatingPointFoldingAllowed()) {
1820 return false;
1821 }
1822
1823 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1824 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1825
1826 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
1827 inst->SetOpcode(SpvOpCopyObject);
1828 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1829 {inst->GetSingleWordInOperand(
1830 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
1831 return true;
1832 }
1833
1834 if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
1835 inst->SetOpcode(SpvOpCopyObject);
1836 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1837 {inst->GetSingleWordInOperand(
1838 kind0 == FloatConstantKind::One ? 1 : 0)}}});
1839 return true;
1840 }
1841
1842 return false;
1843 };
1844 }
1845
RedundantFDiv()1846 FoldingRule RedundantFDiv() {
1847 return [](IRContext*, Instruction* inst,
1848 const std::vector<const analysis::Constant*>& constants) {
1849 assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
1850 assert(constants.size() == 2);
1851
1852 if (!inst->IsFloatingPointFoldingAllowed()) {
1853 return false;
1854 }
1855
1856 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1857 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1858
1859 if (kind0 == FloatConstantKind::Zero) {
1860 inst->SetOpcode(SpvOpCopyObject);
1861 inst->SetInOperands(
1862 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
1863 return true;
1864 }
1865
1866 if (kind1 == FloatConstantKind::One) {
1867 inst->SetOpcode(SpvOpCopyObject);
1868 inst->SetInOperands(
1869 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
1870 return true;
1871 }
1872
1873 return false;
1874 };
1875 }
1876
RedundantFMix()1877 FoldingRule RedundantFMix() {
1878 return [](IRContext* context, Instruction* inst,
1879 const std::vector<const analysis::Constant*>& constants) {
1880 assert(inst->opcode() == SpvOpExtInst &&
1881 "Wrong opcode. Should be OpExtInst.");
1882
1883 if (!inst->IsFloatingPointFoldingAllowed()) {
1884 return false;
1885 }
1886
1887 uint32_t instSetId =
1888 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1889
1890 if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
1891 inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
1892 GLSLstd450FMix) {
1893 assert(constants.size() == 5);
1894
1895 FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
1896
1897 if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
1898 inst->SetOpcode(SpvOpCopyObject);
1899 inst->SetInOperands(
1900 {{SPV_OPERAND_TYPE_ID,
1901 {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
1902 ? kFMixXIdInIdx
1903 : kFMixYIdInIdx)}}});
1904 return true;
1905 }
1906 }
1907
1908 return false;
1909 };
1910 }
1911
1912 // This rule handles addition of zero for integers.
RedundantIAdd()1913 FoldingRule RedundantIAdd() {
1914 return [](IRContext* context, Instruction* inst,
1915 const std::vector<const analysis::Constant*>& constants) {
1916 assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
1917
1918 uint32_t operand = std::numeric_limits<uint32_t>::max();
1919 const analysis::Type* operand_type = nullptr;
1920 if (constants[0] && constants[0]->IsZero()) {
1921 operand = inst->GetSingleWordInOperand(1);
1922 operand_type = constants[0]->type();
1923 } else if (constants[1] && constants[1]->IsZero()) {
1924 operand = inst->GetSingleWordInOperand(0);
1925 operand_type = constants[1]->type();
1926 }
1927
1928 if (operand != std::numeric_limits<uint32_t>::max()) {
1929 const analysis::Type* inst_type =
1930 context->get_type_mgr()->GetType(inst->type_id());
1931 if (inst_type->IsSame(operand_type)) {
1932 inst->SetOpcode(SpvOpCopyObject);
1933 } else {
1934 inst->SetOpcode(SpvOpBitcast);
1935 }
1936 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
1937 return true;
1938 }
1939 return false;
1940 };
1941 }
1942
1943 // This rule look for a dot with a constant vector containing a single 1 and
1944 // the rest 0s. This is the same as doing an extract.
DotProductDoingExtract()1945 FoldingRule DotProductDoingExtract() {
1946 return [](IRContext* context, Instruction* inst,
1947 const std::vector<const analysis::Constant*>& constants) {
1948 assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
1949
1950 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1951
1952 if (!inst->IsFloatingPointFoldingAllowed()) {
1953 return false;
1954 }
1955
1956 for (int i = 0; i < 2; ++i) {
1957 if (!constants[i]) {
1958 continue;
1959 }
1960
1961 const analysis::Vector* vector_type = constants[i]->type()->AsVector();
1962 assert(vector_type && "Inputs to OpDot must be vectors.");
1963 const analysis::Float* element_type =
1964 vector_type->element_type()->AsFloat();
1965 assert(element_type && "Inputs to OpDot must be vectors of floats.");
1966 uint32_t element_width = element_type->width();
1967 if (element_width != 32 && element_width != 64) {
1968 return false;
1969 }
1970
1971 std::vector<const analysis::Constant*> components;
1972 components = constants[i]->GetVectorComponents(const_mgr);
1973
1974 const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
1975
1976 uint32_t component_with_one = kNotFound;
1977 bool all_others_zero = true;
1978 for (uint32_t j = 0; j < components.size(); ++j) {
1979 const analysis::Constant* element = components[j];
1980 double value =
1981 (element_width == 32 ? element->GetFloat() : element->GetDouble());
1982 if (value == 0.0) {
1983 continue;
1984 } else if (value == 1.0) {
1985 if (component_with_one == kNotFound) {
1986 component_with_one = j;
1987 } else {
1988 component_with_one = kNotFound;
1989 break;
1990 }
1991 } else {
1992 all_others_zero = false;
1993 break;
1994 }
1995 }
1996
1997 if (!all_others_zero || component_with_one == kNotFound) {
1998 continue;
1999 }
2000
2001 std::vector<Operand> operands;
2002 operands.push_back(
2003 {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2004 operands.push_back(
2005 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2006
2007 inst->SetOpcode(SpvOpCompositeExtract);
2008 inst->SetInOperands(std::move(operands));
2009 return true;
2010 }
2011 return false;
2012 };
2013 }
2014
2015 // If we are storing an undef, then we can remove the store.
2016 //
2017 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2018 // is complicated. Waiting to see if it is needed.
StoringUndef()2019 FoldingRule StoringUndef() {
2020 return [](IRContext* context, Instruction* inst,
2021 const std::vector<const analysis::Constant*>&) {
2022 assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
2023
2024 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2025
2026 // If this is a volatile store, the store cannot be removed.
2027 if (inst->NumInOperands() == 3) {
2028 if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
2029 return false;
2030 }
2031 }
2032
2033 uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2034 Instruction* object_inst = def_use_mgr->GetDef(object_id);
2035 if (object_inst->opcode() == SpvOpUndef) {
2036 inst->ToNop();
2037 return true;
2038 }
2039 return false;
2040 };
2041 }
2042
VectorShuffleFeedingShuffle()2043 FoldingRule VectorShuffleFeedingShuffle() {
2044 return [](IRContext* context, Instruction* inst,
2045 const std::vector<const analysis::Constant*>&) {
2046 assert(inst->opcode() == SpvOpVectorShuffle &&
2047 "Wrong opcode. Should be OpVectorShuffle.");
2048
2049 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2050 analysis::TypeManager* type_mgr = context->get_type_mgr();
2051
2052 Instruction* feeding_shuffle_inst =
2053 def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2054 analysis::Vector* op0_type =
2055 type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2056 uint32_t op0_length = op0_type->element_count();
2057
2058 bool feeder_is_op0 = true;
2059 if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2060 feeding_shuffle_inst =
2061 def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2062 feeder_is_op0 = false;
2063 }
2064
2065 if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2066 return false;
2067 }
2068
2069 Instruction* feeder2 =
2070 def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2071 analysis::Vector* feeder_op0_type =
2072 type_mgr->GetType(feeder2->type_id())->AsVector();
2073 uint32_t feeder_op0_length = feeder_op0_type->element_count();
2074
2075 uint32_t new_feeder_id = 0;
2076 std::vector<Operand> new_operands;
2077 new_operands.resize(
2078 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
2079 const uint32_t undef_literal = 0xffffffff;
2080 for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2081 uint32_t component_index = inst->GetSingleWordInOperand(op);
2082
2083 // Do not interpret the undefined value literal as coming from operand 1.
2084 if (component_index != undef_literal &&
2085 feeder_is_op0 == (component_index < op0_length)) {
2086 // This component comes from the feeding_shuffle_inst. Update
2087 // |component_index| to be the index into the operand of the feeder.
2088
2089 // Adjust component_index to get the index into the operands of the
2090 // feeding_shuffle_inst.
2091 if (component_index >= op0_length) {
2092 component_index -= op0_length;
2093 }
2094 component_index =
2095 feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2096
2097 // Check if we are using a component from the first or second operand of
2098 // the feeding instruction.
2099 if (component_index < feeder_op0_length) {
2100 if (new_feeder_id == 0) {
2101 // First time through, save the id of the operand the element comes
2102 // from.
2103 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2104 } else if (new_feeder_id !=
2105 feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2106 // We need both elements of the feeding_shuffle_inst, so we cannot
2107 // fold.
2108 return false;
2109 }
2110 } else {
2111 if (new_feeder_id == 0) {
2112 // First time through, save the id of the operand the element comes
2113 // from.
2114 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2115 } else if (new_feeder_id !=
2116 feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2117 // We need both elements of the feeding_shuffle_inst, so we cannot
2118 // fold.
2119 return false;
2120 }
2121 component_index -= feeder_op0_length;
2122 }
2123
2124 if (!feeder_is_op0) {
2125 component_index += op0_length;
2126 }
2127 }
2128 new_operands.push_back(
2129 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2130 }
2131
2132 if (new_feeder_id == 0) {
2133 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2134 const analysis::Type* type =
2135 type_mgr->GetType(feeding_shuffle_inst->type_id());
2136 const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2137 new_feeder_id =
2138 const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2139 }
2140
2141 if (feeder_is_op0) {
2142 // If the size of the first vector operand changed then the indices
2143 // referring to the second operand need to be adjusted.
2144 Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2145 analysis::Type* new_feeder_type =
2146 type_mgr->GetType(new_feeder_inst->type_id());
2147 uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2148 int32_t adjustment = op0_length - new_op0_size;
2149
2150 if (adjustment != 0) {
2151 for (uint32_t i = 2; i < new_operands.size(); i++) {
2152 if (inst->GetSingleWordInOperand(i) >= op0_length) {
2153 new_operands[i].words[0] -= adjustment;
2154 }
2155 }
2156 }
2157
2158 new_operands[0].words[0] = new_feeder_id;
2159 new_operands[1] = inst->GetInOperand(1);
2160 } else {
2161 new_operands[1].words[0] = new_feeder_id;
2162 new_operands[0] = inst->GetInOperand(0);
2163 }
2164
2165 inst->SetInOperands(std::move(new_operands));
2166 return true;
2167 };
2168 }
2169
2170 } // namespace
2171
FoldingRules()2172 FoldingRules::FoldingRules() {
2173 // Add all folding rules to the list for the opcodes to which they apply.
2174 // Note that the order in which rules are added to the list matters. If a rule
2175 // applies to the instruction, the rest of the rules will not be attempted.
2176 // Take that into consideration.
2177 rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
2178
2179 rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
2180 rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
2181 rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2182 rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
2183
2184 rules_[SpvOpDot].push_back(DotProductDoingExtract());
2185
2186 rules_[SpvOpExtInst].push_back(RedundantFMix());
2187
2188 rules_[SpvOpFAdd].push_back(RedundantFAdd());
2189 rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
2190 rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
2191 rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
2192
2193 rules_[SpvOpFDiv].push_back(RedundantFDiv());
2194 rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
2195 rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
2196 rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
2197 rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
2198
2199 rules_[SpvOpFMul].push_back(RedundantFMul());
2200 rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
2201 rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
2202 rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
2203
2204 rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
2205 rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
2206 rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
2207
2208 rules_[SpvOpFSub].push_back(RedundantFSub());
2209 rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
2210 rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
2211 rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
2212
2213 rules_[SpvOpIAdd].push_back(RedundantIAdd());
2214 rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
2215 rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
2216 rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
2217
2218 rules_[SpvOpIMul].push_back(IntMultipleBy1());
2219 rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
2220 rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
2221
2222 rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
2223 rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
2224 rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
2225
2226 rules_[SpvOpPhi].push_back(RedundantPhi());
2227
2228 rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
2229
2230 rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
2231 rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
2232 rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
2233
2234 rules_[SpvOpSelect].push_back(RedundantSelect());
2235
2236 rules_[SpvOpStore].push_back(StoringUndef());
2237
2238 rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
2239
2240 rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2241 }
2242 } // namespace opt
2243 } // namespace spvtools
2244