1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/folding_rules.h"
16
17 #include <limits>
18 #include <memory>
19 #include <utility>
20
21 #include "ir_builder.h"
22 #include "source/latest_version_glsl_std_450_header.h"
23 #include "source/opt/ir_context.h"
24
25 namespace spvtools {
26 namespace opt {
27 namespace {
28
29 const uint32_t kExtractCompositeIdInIdx = 0;
30 const uint32_t kInsertObjectIdInIdx = 0;
31 const uint32_t kInsertCompositeIdInIdx = 1;
32 const uint32_t kExtInstSetIdInIdx = 0;
33 const uint32_t kExtInstInstructionInIdx = 1;
34 const uint32_t kFMixXIdInIdx = 2;
35 const uint32_t kFMixYIdInIdx = 3;
36 const uint32_t kFMixAIdInIdx = 4;
37 const uint32_t kStoreObjectInIdx = 1;
38
39 // Some image instructions may contain an "image operands" argument.
40 // Returns the operand index for the "image operands".
41 // Returns -1 if the instruction does not have image operands.
ImageOperandsMaskInOperandIndex(Instruction * inst)42 int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
43 const auto opcode = inst->opcode();
44 switch (opcode) {
45 case SpvOpImageSampleImplicitLod:
46 case SpvOpImageSampleExplicitLod:
47 case SpvOpImageSampleProjImplicitLod:
48 case SpvOpImageSampleProjExplicitLod:
49 case SpvOpImageFetch:
50 case SpvOpImageRead:
51 case SpvOpImageSparseSampleImplicitLod:
52 case SpvOpImageSparseSampleExplicitLod:
53 case SpvOpImageSparseSampleProjImplicitLod:
54 case SpvOpImageSparseSampleProjExplicitLod:
55 case SpvOpImageSparseFetch:
56 case SpvOpImageSparseRead:
57 return inst->NumOperands() > 4 ? 2 : -1;
58 case SpvOpImageSampleDrefImplicitLod:
59 case SpvOpImageSampleDrefExplicitLod:
60 case SpvOpImageSampleProjDrefImplicitLod:
61 case SpvOpImageSampleProjDrefExplicitLod:
62 case SpvOpImageGather:
63 case SpvOpImageDrefGather:
64 case SpvOpImageSparseSampleDrefImplicitLod:
65 case SpvOpImageSparseSampleDrefExplicitLod:
66 case SpvOpImageSparseSampleProjDrefImplicitLod:
67 case SpvOpImageSparseSampleProjDrefExplicitLod:
68 case SpvOpImageSparseGather:
69 case SpvOpImageSparseDrefGather:
70 return inst->NumOperands() > 5 ? 3 : -1;
71 case SpvOpImageWrite:
72 return inst->NumOperands() > 3 ? 3 : -1;
73 default:
74 return -1;
75 }
76 }
77
78 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)79 uint32_t ElementWidth(const analysis::Type* type) {
80 if (const analysis::Vector* vec_type = type->AsVector()) {
81 return ElementWidth(vec_type->element_type());
82 } else if (const analysis::Float* float_type = type->AsFloat()) {
83 return float_type->width();
84 } else {
85 assert(type->AsInteger());
86 return type->AsInteger()->width();
87 }
88 }
89
90 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)91 bool HasFloatingPoint(const analysis::Type* type) {
92 if (type->AsFloat()) {
93 return true;
94 } else if (const analysis::Vector* vec_type = type->AsVector()) {
95 return vec_type->element_type()->AsFloat() != nullptr;
96 }
97
98 return false;
99 }
100
101 // Returns false if |val| is NaN, infinite or subnormal.
102 template <typename T>
IsValidResult(T val)103 bool IsValidResult(T val) {
104 int classified = std::fpclassify(val);
105 switch (classified) {
106 case FP_NAN:
107 case FP_INFINITE:
108 case FP_SUBNORMAL:
109 return false;
110 default:
111 return true;
112 }
113 }
114
ConstInput(const std::vector<const analysis::Constant * > & constants)115 const analysis::Constant* ConstInput(
116 const std::vector<const analysis::Constant*>& constants) {
117 return constants[0] ? constants[0] : constants[1];
118 }
119
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)120 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
121 Instruction* inst) {
122 uint32_t in_op = c ? 1u : 0u;
123 return context->get_def_use_mgr()->GetDef(
124 inst->GetSingleWordInOperand(in_op));
125 }
126
ExtractInts(uint64_t val)127 std::vector<uint32_t> ExtractInts(uint64_t val) {
128 std::vector<uint32_t> words;
129 words.push_back(static_cast<uint32_t>(val));
130 words.push_back(static_cast<uint32_t>(val >> 32));
131 return words;
132 }
133
GetWordsFromScalarIntConstant(const analysis::IntConstant * c)134 std::vector<uint32_t> GetWordsFromScalarIntConstant(
135 const analysis::IntConstant* c) {
136 assert(c != nullptr);
137 uint32_t width = c->type()->AsInteger()->width();
138 assert(width == 32 || width == 64);
139 if (width == 64) {
140 uint64_t uval = static_cast<uint64_t>(c->GetU64());
141 return ExtractInts(uval);
142 }
143 return {c->GetU32()};
144 }
145
GetWordsFromScalarFloatConstant(const analysis::FloatConstant * c)146 std::vector<uint32_t> GetWordsFromScalarFloatConstant(
147 const analysis::FloatConstant* c) {
148 assert(c != nullptr);
149 uint32_t width = c->type()->AsFloat()->width();
150 assert(width == 32 || width == 64);
151 if (width == 64) {
152 utils::FloatProxy<double> result(c->GetDouble());
153 return result.GetWords();
154 }
155 utils::FloatProxy<float> result(c->GetFloat());
156 return result.GetWords();
157 }
158
GetWordsFromNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)159 std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
160 analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
161 if (const auto* float_constant = c->AsFloatConstant()) {
162 return GetWordsFromScalarFloatConstant(float_constant);
163 } else if (const auto* int_constant = c->AsIntConstant()) {
164 return GetWordsFromScalarIntConstant(int_constant);
165 } else if (const auto* vec_constant = c->AsVectorConstant()) {
166 std::vector<uint32_t> words;
167 for (const auto* comp : vec_constant->GetComponents()) {
168 auto comp_in_words =
169 GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
170 words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
171 }
172 return words;
173 }
174 return {};
175 }
176
ConvertWordsToNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const std::vector<uint32_t> & words,const analysis::Type * type)177 const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
178 analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
179 const analysis::Type* type) {
180 if (type->AsInteger() || type->AsFloat())
181 return const_mgr->GetConstant(type, words);
182 if (const auto* vec_type = type->AsVector())
183 return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
184 return nullptr;
185 }
186
187 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
188 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)189 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
190 const analysis::Constant* c) {
191 assert(c);
192 assert(c->type()->AsFloat());
193 uint32_t width = c->type()->AsFloat()->width();
194 assert(width == 32 || width == 64);
195 std::vector<uint32_t> words;
196 if (width == 64) {
197 utils::FloatProxy<double> result(c->GetDouble() * -1.0);
198 words = result.GetWords();
199 } else {
200 utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
201 words = result.GetWords();
202 }
203
204 const analysis::Constant* negated_const =
205 const_mgr->GetConstant(c->type(), std::move(words));
206 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
207 }
208
209 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)210 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
211 const analysis::Constant* c) {
212 assert(c);
213 assert(c->type()->AsInteger());
214 uint32_t width = c->type()->AsInteger()->width();
215 assert(width == 32 || width == 64);
216 std::vector<uint32_t> words;
217 if (width == 64) {
218 uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
219 words = ExtractInts(uval);
220 } else {
221 words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
222 }
223
224 const analysis::Constant* negated_const =
225 const_mgr->GetConstant(c->type(), std::move(words));
226 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
227 }
228
229 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)230 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
231 const analysis::Constant* c) {
232 assert(const_mgr && c);
233 assert(c->type()->AsVector());
234 if (c->AsNullConstant()) {
235 // 0.0 vs -0.0 shouldn't matter.
236 return const_mgr->GetDefiningInstruction(c)->result_id();
237 } else {
238 const analysis::Type* component_type =
239 c->AsVectorConstant()->component_type();
240 std::vector<uint32_t> words;
241 for (auto& comp : c->AsVectorConstant()->GetComponents()) {
242 if (component_type->AsFloat()) {
243 words.push_back(NegateFloatingPointConstant(const_mgr, comp));
244 } else {
245 assert(component_type->AsInteger());
246 words.push_back(NegateIntegerConstant(const_mgr, comp));
247 }
248 }
249
250 const analysis::Constant* negated_const =
251 const_mgr->GetConstant(c->type(), std::move(words));
252 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
253 }
254 }
255
256 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)257 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
258 const analysis::Constant* c) {
259 if (c->type()->AsVector()) {
260 return NegateVectorConstant(const_mgr, c);
261 } else if (c->type()->AsFloat()) {
262 return NegateFloatingPointConstant(const_mgr, c);
263 } else {
264 assert(c->type()->AsInteger());
265 return NegateIntegerConstant(const_mgr, c);
266 }
267 }
268
269 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
270 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)271 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
272 const analysis::Constant* c) {
273 assert(const_mgr && c);
274 assert(c->type()->AsFloat());
275
276 uint32_t width = c->type()->AsFloat()->width();
277 assert(width == 32 || width == 64);
278 std::vector<uint32_t> words;
279 if (width == 64) {
280 spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
281 if (!IsValidResult(result.getAsFloat())) return 0;
282 words = result.GetWords();
283 } else {
284 spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
285 if (!IsValidResult(result.getAsFloat())) return 0;
286 words = result.GetWords();
287 }
288
289 const analysis::Constant* negated_const =
290 const_mgr->GetConstant(c->type(), std::move(words));
291 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
292 }
293
294 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()295 FoldingRule ReciprocalFDiv() {
296 return [](IRContext* context, Instruction* inst,
297 const std::vector<const analysis::Constant*>& constants) {
298 assert(inst->opcode() == SpvOpFDiv);
299 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
300 const analysis::Type* type =
301 context->get_type_mgr()->GetType(inst->type_id());
302 if (!inst->IsFloatingPointFoldingAllowed()) return false;
303
304 uint32_t width = ElementWidth(type);
305 if (width != 32 && width != 64) return false;
306
307 if (constants[1] != nullptr) {
308 uint32_t id = 0;
309 if (const analysis::VectorConstant* vector_const =
310 constants[1]->AsVectorConstant()) {
311 std::vector<uint32_t> neg_ids;
312 for (auto& comp : vector_const->GetComponents()) {
313 id = Reciprocal(const_mgr, comp);
314 if (id == 0) return false;
315 neg_ids.push_back(id);
316 }
317 const analysis::Constant* negated_const =
318 const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
319 id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
320 } else if (constants[1]->AsFloatConstant()) {
321 id = Reciprocal(const_mgr, constants[1]);
322 if (id == 0) return false;
323 } else {
324 // Don't fold a null constant.
325 return false;
326 }
327 inst->SetOpcode(SpvOpFMul);
328 inst->SetInOperands(
329 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
330 {SPV_OPERAND_TYPE_ID, {id}}});
331 return true;
332 }
333
334 return false;
335 };
336 }
337
338 // Elides consecutive negate instructions.
MergeNegateArithmetic()339 FoldingRule MergeNegateArithmetic() {
340 return [](IRContext* context, Instruction* inst,
341 const std::vector<const analysis::Constant*>& constants) {
342 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
343 (void)constants;
344 const analysis::Type* type =
345 context->get_type_mgr()->GetType(inst->type_id());
346 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
347 return false;
348
349 Instruction* op_inst =
350 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
351 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
352 return false;
353
354 if (op_inst->opcode() == inst->opcode()) {
355 // Elide negates.
356 inst->SetOpcode(SpvOpCopyObject);
357 inst->SetInOperands(
358 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
359 return true;
360 }
361
362 return false;
363 };
364 }
365
366 // Merges negate into a mul or div operation if that operation contains a
367 // constant operand.
368 // Cases:
369 // -(x * 2) = x * -2
370 // -(2 * x) = x * -2
371 // -(x / 2) = x / -2
372 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()373 FoldingRule MergeNegateMulDivArithmetic() {
374 return [](IRContext* context, Instruction* inst,
375 const std::vector<const analysis::Constant*>& constants) {
376 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
377 (void)constants;
378 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
379 const analysis::Type* type =
380 context->get_type_mgr()->GetType(inst->type_id());
381 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
382 return false;
383
384 Instruction* op_inst =
385 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
386 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
387 return false;
388
389 uint32_t width = ElementWidth(type);
390 if (width != 32 && width != 64) return false;
391
392 SpvOp opcode = op_inst->opcode();
393 if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
394 opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
395 std::vector<const analysis::Constant*> op_constants =
396 const_mgr->GetOperandConstants(op_inst);
397 // Merge negate into mul or div if one operand is constant.
398 if (op_constants[0] || op_constants[1]) {
399 bool zero_is_variable = op_constants[0] == nullptr;
400 const analysis::Constant* c = ConstInput(op_constants);
401 uint32_t neg_id = NegateConstant(const_mgr, c);
402 uint32_t non_const_id = zero_is_variable
403 ? op_inst->GetSingleWordInOperand(0u)
404 : op_inst->GetSingleWordInOperand(1u);
405 // Change this instruction to a mul/div.
406 inst->SetOpcode(op_inst->opcode());
407 if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
408 uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
409 uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
410 inst->SetInOperands(
411 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
412 } else {
413 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
414 {SPV_OPERAND_TYPE_ID, {neg_id}}});
415 }
416 return true;
417 }
418 }
419
420 return false;
421 };
422 }
423
424 // Merges negate into a add or sub operation if that operation contains a
425 // constant operand.
426 // Cases:
427 // -(x + 2) = -2 - x
428 // -(2 + x) = -2 - x
429 // -(x - 2) = 2 - x
430 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()431 FoldingRule MergeNegateAddSubArithmetic() {
432 return [](IRContext* context, Instruction* inst,
433 const std::vector<const analysis::Constant*>& constants) {
434 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
435 (void)constants;
436 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
437 const analysis::Type* type =
438 context->get_type_mgr()->GetType(inst->type_id());
439 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
440 return false;
441
442 Instruction* op_inst =
443 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
444 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
445 return false;
446
447 uint32_t width = ElementWidth(type);
448 if (width != 32 && width != 64) return false;
449
450 if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
451 op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
452 std::vector<const analysis::Constant*> op_constants =
453 const_mgr->GetOperandConstants(op_inst);
454 if (op_constants[0] || op_constants[1]) {
455 bool zero_is_variable = op_constants[0] == nullptr;
456 bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
457 (op_inst->opcode() == SpvOpIAdd);
458 bool swap_operands = !is_add || zero_is_variable;
459 bool negate_const = is_add;
460 const analysis::Constant* c = ConstInput(op_constants);
461 uint32_t const_id = 0;
462 if (negate_const) {
463 const_id = NegateConstant(const_mgr, c);
464 } else {
465 const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
466 : op_inst->GetSingleWordInOperand(0u);
467 }
468
469 // Swap operands if necessary and make the instruction a subtraction.
470 uint32_t op0 =
471 zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
472 uint32_t op1 =
473 zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
474 if (swap_operands) std::swap(op0, op1);
475 inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
476 inst->SetInOperands(
477 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
478 return true;
479 }
480 }
481
482 return false;
483 };
484 }
485
486 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)487 bool HasZero(const analysis::Constant* c) {
488 if (c->AsNullConstant()) {
489 return true;
490 }
491 if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
492 for (auto& comp : vec_const->GetComponents())
493 if (HasZero(comp)) return true;
494 } else {
495 assert(c->AsScalarConstant());
496 return c->AsScalarConstant()->IsZero();
497 }
498
499 return false;
500 }
501
502 // Performs |input1| |opcode| |input2| and returns the merged constant result
503 // id. Returns 0 if the result is not a valid value. The input types must be
504 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)505 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
506 SpvOp opcode,
507 const analysis::Constant* input1,
508 const analysis::Constant* input2) {
509 const analysis::Type* type = input1->type();
510 assert(type->AsFloat());
511 uint32_t width = type->AsFloat()->width();
512 assert(width == 32 || width == 64);
513 std::vector<uint32_t> words;
514 #define FOLD_OP(op) \
515 if (width == 64) { \
516 utils::FloatProxy<double> val = \
517 input1->GetDouble() op input2->GetDouble(); \
518 double dval = val.getAsFloat(); \
519 if (!IsValidResult(dval)) return 0; \
520 words = val.GetWords(); \
521 } else { \
522 utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
523 float fval = val.getAsFloat(); \
524 if (!IsValidResult(fval)) return 0; \
525 words = val.GetWords(); \
526 } static_assert(true, "require extra semicolon")
527 switch (opcode) {
528 case SpvOpFMul:
529 FOLD_OP(*);
530 break;
531 case SpvOpFDiv:
532 if (HasZero(input2)) return 0;
533 FOLD_OP(/);
534 break;
535 case SpvOpFAdd:
536 FOLD_OP(+);
537 break;
538 case SpvOpFSub:
539 FOLD_OP(-);
540 break;
541 default:
542 assert(false && "Unexpected operation");
543 break;
544 }
545 #undef FOLD_OP
546 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
547 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
548 }
549
550 // Performs |input1| |opcode| |input2| and returns the merged constant result
551 // id. Returns 0 if the result is not a valid value. The input types must be
552 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)553 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
554 SpvOp opcode, const analysis::Constant* input1,
555 const analysis::Constant* input2) {
556 assert(input1->type()->AsInteger());
557 const analysis::Integer* type = input1->type()->AsInteger();
558 uint32_t width = type->AsInteger()->width();
559 assert(width == 32 || width == 64);
560 std::vector<uint32_t> words;
561 #define FOLD_OP(op) \
562 if (width == 64) { \
563 if (type->IsSigned()) { \
564 int64_t val = input1->GetS64() op input2->GetS64(); \
565 words = ExtractInts(static_cast<uint64_t>(val)); \
566 } else { \
567 uint64_t val = input1->GetU64() op input2->GetU64(); \
568 words = ExtractInts(val); \
569 } \
570 } else { \
571 if (type->IsSigned()) { \
572 int32_t val = input1->GetS32() op input2->GetS32(); \
573 words.push_back(static_cast<uint32_t>(val)); \
574 } else { \
575 uint32_t val = input1->GetU32() op input2->GetU32(); \
576 words.push_back(val); \
577 } \
578 } static_assert(true, "require extra semicalon")
579 switch (opcode) {
580 case SpvOpIMul:
581 FOLD_OP(*);
582 break;
583 case SpvOpSDiv:
584 case SpvOpUDiv:
585 assert(false && "Should not merge integer division");
586 break;
587 case SpvOpIAdd:
588 FOLD_OP(+);
589 break;
590 case SpvOpISub:
591 FOLD_OP(-);
592 break;
593 default:
594 assert(false && "Unexpected operation");
595 break;
596 }
597 #undef FOLD_OP
598 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
599 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
600 }
601
602 // Performs |input1| |opcode| |input2| and returns the merged constant result
603 // id. Returns 0 if the result is not a valid value. The input types must be
604 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)605 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
606 const analysis::Constant* input1,
607 const analysis::Constant* input2) {
608 assert(input1 && input2);
609 const analysis::Type* type = input1->type();
610 std::vector<uint32_t> words;
611 if (const analysis::Vector* vector_type = type->AsVector()) {
612 const analysis::Type* ele_type = vector_type->element_type();
613 for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
614 uint32_t id = 0;
615
616 const analysis::Constant* input1_comp = nullptr;
617 if (const analysis::VectorConstant* input1_vector =
618 input1->AsVectorConstant()) {
619 input1_comp = input1_vector->GetComponents()[i];
620 } else {
621 assert(input1->AsNullConstant());
622 input1_comp = const_mgr->GetConstant(ele_type, {});
623 }
624
625 const analysis::Constant* input2_comp = nullptr;
626 if (const analysis::VectorConstant* input2_vector =
627 input2->AsVectorConstant()) {
628 input2_comp = input2_vector->GetComponents()[i];
629 } else {
630 assert(input2->AsNullConstant());
631 input2_comp = const_mgr->GetConstant(ele_type, {});
632 }
633
634 if (ele_type->AsFloat()) {
635 id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
636 input2_comp);
637 } else {
638 assert(ele_type->AsInteger());
639 id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
640 input2_comp);
641 }
642 if (id == 0) return 0;
643 words.push_back(id);
644 }
645 const analysis::Constant* merged_const =
646 const_mgr->GetConstant(type, words);
647 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
648 } else if (type->AsFloat()) {
649 return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
650 } else {
651 assert(type->AsInteger());
652 return PerformIntegerOperation(const_mgr, opcode, input1, input2);
653 }
654 }
655
656 // Merges consecutive multiplies where each contains one constant operand.
657 // Cases:
658 // 2 * (x * 2) = x * 4
659 // 2 * (2 * x) = x * 4
660 // (x * 2) * 2 = x * 4
661 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()662 FoldingRule MergeMulMulArithmetic() {
663 return [](IRContext* context, Instruction* inst,
664 const std::vector<const analysis::Constant*>& constants) {
665 assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
666 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
667 const analysis::Type* type =
668 context->get_type_mgr()->GetType(inst->type_id());
669 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
670 return false;
671
672 uint32_t width = ElementWidth(type);
673 if (width != 32 && width != 64) return false;
674
675 // Determine the constant input and the variable input in |inst|.
676 const analysis::Constant* const_input1 = ConstInput(constants);
677 if (!const_input1) return false;
678 Instruction* other_inst = NonConstInput(context, constants[0], inst);
679 if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
680 return false;
681
682 if (other_inst->opcode() == inst->opcode()) {
683 std::vector<const analysis::Constant*> other_constants =
684 const_mgr->GetOperandConstants(other_inst);
685 const analysis::Constant* const_input2 = ConstInput(other_constants);
686 if (!const_input2) return false;
687
688 bool other_first_is_variable = other_constants[0] == nullptr;
689 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
690 const_input1, const_input2);
691 if (merged_id == 0) return false;
692
693 uint32_t non_const_id = other_first_is_variable
694 ? other_inst->GetSingleWordInOperand(0u)
695 : other_inst->GetSingleWordInOperand(1u);
696 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
697 {SPV_OPERAND_TYPE_ID, {merged_id}}});
698 return true;
699 }
700
701 return false;
702 };
703 }
704
705 // Merges divides into subsequent multiplies if each instruction contains one
706 // constant operand. Does not support integer operations.
707 // Cases:
708 // 2 * (x / 2) = x * 1
709 // 2 * (2 / x) = 4 / x
710 // (x / 2) * 2 = x * 1
711 // (2 / x) * 2 = 4 / x
712 // (y / x) * x = y
713 // x * (y / x) = y
MergeMulDivArithmetic()714 FoldingRule MergeMulDivArithmetic() {
715 return [](IRContext* context, Instruction* inst,
716 const std::vector<const analysis::Constant*>& constants) {
717 assert(inst->opcode() == SpvOpFMul);
718 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
719 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
720
721 const analysis::Type* type =
722 context->get_type_mgr()->GetType(inst->type_id());
723 if (!inst->IsFloatingPointFoldingAllowed()) return false;
724
725 uint32_t width = ElementWidth(type);
726 if (width != 32 && width != 64) return false;
727
728 for (uint32_t i = 0; i < 2; i++) {
729 uint32_t op_id = inst->GetSingleWordInOperand(i);
730 Instruction* op_inst = def_use_mgr->GetDef(op_id);
731 if (op_inst->opcode() == SpvOpFDiv) {
732 if (op_inst->GetSingleWordInOperand(1) ==
733 inst->GetSingleWordInOperand(1 - i)) {
734 inst->SetOpcode(SpvOpCopyObject);
735 inst->SetInOperands(
736 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
737 return true;
738 }
739 }
740 }
741
742 const analysis::Constant* const_input1 = ConstInput(constants);
743 if (!const_input1) return false;
744 Instruction* other_inst = NonConstInput(context, constants[0], inst);
745 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
746
747 if (other_inst->opcode() == SpvOpFDiv) {
748 std::vector<const analysis::Constant*> other_constants =
749 const_mgr->GetOperandConstants(other_inst);
750 const analysis::Constant* const_input2 = ConstInput(other_constants);
751 if (!const_input2 || HasZero(const_input2)) return false;
752
753 bool other_first_is_variable = other_constants[0] == nullptr;
754 // If the variable value is the second operand of the divide, multiply
755 // the constants together. Otherwise divide the constants.
756 uint32_t merged_id = PerformOperation(
757 const_mgr,
758 other_first_is_variable ? other_inst->opcode() : inst->opcode(),
759 const_input1, const_input2);
760 if (merged_id == 0) return false;
761
762 uint32_t non_const_id = other_first_is_variable
763 ? other_inst->GetSingleWordInOperand(0u)
764 : other_inst->GetSingleWordInOperand(1u);
765
766 // If the variable value is on the second operand of the div, then this
767 // operation is a div. Otherwise it should be a multiply.
768 inst->SetOpcode(other_first_is_variable ? inst->opcode()
769 : other_inst->opcode());
770 if (other_first_is_variable) {
771 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
772 {SPV_OPERAND_TYPE_ID, {merged_id}}});
773 } else {
774 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
775 {SPV_OPERAND_TYPE_ID, {non_const_id}}});
776 }
777 return true;
778 }
779
780 return false;
781 };
782 }
783
784 // Merges multiply of constant and negation.
785 // Cases:
786 // (-x) * 2 = x * -2
787 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()788 FoldingRule MergeMulNegateArithmetic() {
789 return [](IRContext* context, Instruction* inst,
790 const std::vector<const analysis::Constant*>& constants) {
791 assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
792 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
793 const analysis::Type* type =
794 context->get_type_mgr()->GetType(inst->type_id());
795 bool uses_float = HasFloatingPoint(type);
796 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
797
798 uint32_t width = ElementWidth(type);
799 if (width != 32 && width != 64) return false;
800
801 const analysis::Constant* const_input1 = ConstInput(constants);
802 if (!const_input1) return false;
803 Instruction* other_inst = NonConstInput(context, constants[0], inst);
804 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
805 return false;
806
807 if (other_inst->opcode() == SpvOpFNegate ||
808 other_inst->opcode() == SpvOpSNegate) {
809 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
810
811 inst->SetInOperands(
812 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
813 {SPV_OPERAND_TYPE_ID, {neg_id}}});
814 return true;
815 }
816
817 return false;
818 };
819 }
820
821 // Merges consecutive divides if each instruction contains one constant operand.
822 // Does not support integer division.
823 // Cases:
824 // 2 / (x / 2) = 4 / x
825 // 4 / (2 / x) = 2 * x
826 // (4 / x) / 2 = 2 / x
827 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()828 FoldingRule MergeDivDivArithmetic() {
829 return [](IRContext* context, Instruction* inst,
830 const std::vector<const analysis::Constant*>& constants) {
831 assert(inst->opcode() == SpvOpFDiv);
832 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
833 const analysis::Type* type =
834 context->get_type_mgr()->GetType(inst->type_id());
835 if (!inst->IsFloatingPointFoldingAllowed()) return false;
836
837 uint32_t width = ElementWidth(type);
838 if (width != 32 && width != 64) return false;
839
840 const analysis::Constant* const_input1 = ConstInput(constants);
841 if (!const_input1 || HasZero(const_input1)) return false;
842 Instruction* other_inst = NonConstInput(context, constants[0], inst);
843 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
844
845 bool first_is_variable = constants[0] == nullptr;
846 if (other_inst->opcode() == inst->opcode()) {
847 std::vector<const analysis::Constant*> other_constants =
848 const_mgr->GetOperandConstants(other_inst);
849 const analysis::Constant* const_input2 = ConstInput(other_constants);
850 if (!const_input2 || HasZero(const_input2)) return false;
851
852 bool other_first_is_variable = other_constants[0] == nullptr;
853
854 SpvOp merge_op = inst->opcode();
855 if (other_first_is_variable) {
856 // Constants magnify.
857 merge_op = SpvOpFMul;
858 }
859
860 // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
861 // because it is commutative.
862 if (first_is_variable) std::swap(const_input1, const_input2);
863 uint32_t merged_id =
864 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
865 if (merged_id == 0) return false;
866
867 uint32_t non_const_id = other_first_is_variable
868 ? other_inst->GetSingleWordInOperand(0u)
869 : other_inst->GetSingleWordInOperand(1u);
870
871 SpvOp op = inst->opcode();
872 if (!first_is_variable && !other_first_is_variable) {
873 // Effectively div of 1/x, so change to multiply.
874 op = SpvOpFMul;
875 }
876
877 uint32_t op1 = merged_id;
878 uint32_t op2 = non_const_id;
879 if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
880 inst->SetOpcode(op);
881 inst->SetInOperands(
882 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
883 return true;
884 }
885
886 return false;
887 };
888 }
889
890 // Fold multiplies succeeded by divides where each instruction contains a
891 // constant operand. Does not support integer divide.
892 // Cases:
893 // 4 / (x * 2) = 2 / x
894 // 4 / (2 * x) = 2 / x
895 // (x * 4) / 2 = x * 2
896 // (4 * x) / 2 = x * 2
897 // (x * y) / x = y
898 // (y * x) / x = y
MergeDivMulArithmetic()899 FoldingRule MergeDivMulArithmetic() {
900 return [](IRContext* context, Instruction* inst,
901 const std::vector<const analysis::Constant*>& constants) {
902 assert(inst->opcode() == SpvOpFDiv);
903 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
904 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
905
906 const analysis::Type* type =
907 context->get_type_mgr()->GetType(inst->type_id());
908 if (!inst->IsFloatingPointFoldingAllowed()) return false;
909
910 uint32_t width = ElementWidth(type);
911 if (width != 32 && width != 64) return false;
912
913 uint32_t op_id = inst->GetSingleWordInOperand(0);
914 Instruction* op_inst = def_use_mgr->GetDef(op_id);
915
916 if (op_inst->opcode() == SpvOpFMul) {
917 for (uint32_t i = 0; i < 2; i++) {
918 if (op_inst->GetSingleWordInOperand(i) ==
919 inst->GetSingleWordInOperand(1)) {
920 inst->SetOpcode(SpvOpCopyObject);
921 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
922 {op_inst->GetSingleWordInOperand(1 - i)}}});
923 return true;
924 }
925 }
926 }
927
928 const analysis::Constant* const_input1 = ConstInput(constants);
929 if (!const_input1 || HasZero(const_input1)) return false;
930 Instruction* other_inst = NonConstInput(context, constants[0], inst);
931 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
932
933 bool first_is_variable = constants[0] == nullptr;
934 if (other_inst->opcode() == SpvOpFMul) {
935 std::vector<const analysis::Constant*> other_constants =
936 const_mgr->GetOperandConstants(other_inst);
937 const analysis::Constant* const_input2 = ConstInput(other_constants);
938 if (!const_input2) return false;
939
940 bool other_first_is_variable = other_constants[0] == nullptr;
941
942 // This is an x / (*) case. Swap the inputs.
943 if (first_is_variable) std::swap(const_input1, const_input2);
944 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
945 const_input1, const_input2);
946 if (merged_id == 0) return false;
947
948 uint32_t non_const_id = other_first_is_variable
949 ? other_inst->GetSingleWordInOperand(0u)
950 : other_inst->GetSingleWordInOperand(1u);
951
952 uint32_t op1 = merged_id;
953 uint32_t op2 = non_const_id;
954 if (first_is_variable) std::swap(op1, op2);
955
956 // Convert to multiply
957 if (first_is_variable) inst->SetOpcode(other_inst->opcode());
958 inst->SetInOperands(
959 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
960 return true;
961 }
962
963 return false;
964 };
965 }
966
967 // Fold divides of a constant and a negation.
968 // Cases:
969 // (-x) / 2 = x / -2
970 // 2 / (-x) = 2 / -x
MergeDivNegateArithmetic()971 FoldingRule MergeDivNegateArithmetic() {
972 return [](IRContext* context, Instruction* inst,
973 const std::vector<const analysis::Constant*>& constants) {
974 assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
975 inst->opcode() == SpvOpUDiv);
976 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
977 const analysis::Type* type =
978 context->get_type_mgr()->GetType(inst->type_id());
979 bool uses_float = HasFloatingPoint(type);
980 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
981
982 uint32_t width = ElementWidth(type);
983 if (width != 32 && width != 64) return false;
984
985 const analysis::Constant* const_input1 = ConstInput(constants);
986 if (!const_input1) return false;
987 Instruction* other_inst = NonConstInput(context, constants[0], inst);
988 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
989 return false;
990
991 bool first_is_variable = constants[0] == nullptr;
992 if (other_inst->opcode() == SpvOpFNegate ||
993 other_inst->opcode() == SpvOpSNegate) {
994 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
995
996 if (first_is_variable) {
997 inst->SetInOperands(
998 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
999 {SPV_OPERAND_TYPE_ID, {neg_id}}});
1000 } else {
1001 inst->SetInOperands(
1002 {{SPV_OPERAND_TYPE_ID, {neg_id}},
1003 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1004 }
1005 return true;
1006 }
1007
1008 return false;
1009 };
1010 }
1011
1012 // Folds addition of a constant and a negation.
1013 // Cases:
1014 // (-x) + 2 = 2 - x
1015 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()1016 FoldingRule MergeAddNegateArithmetic() {
1017 return [](IRContext* context, Instruction* inst,
1018 const std::vector<const analysis::Constant*>& constants) {
1019 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1020 const analysis::Type* type =
1021 context->get_type_mgr()->GetType(inst->type_id());
1022 bool uses_float = HasFloatingPoint(type);
1023 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1024
1025 const analysis::Constant* const_input1 = ConstInput(constants);
1026 if (!const_input1) return false;
1027 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1028 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1029 return false;
1030
1031 if (other_inst->opcode() == SpvOpSNegate ||
1032 other_inst->opcode() == SpvOpFNegate) {
1033 inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
1034 uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
1035 : inst->GetSingleWordInOperand(1u);
1036 inst->SetInOperands(
1037 {{SPV_OPERAND_TYPE_ID, {const_id}},
1038 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1039 return true;
1040 }
1041 return false;
1042 };
1043 }
1044
1045 // Folds subtraction of a constant and a negation.
1046 // Cases:
1047 // (-x) - 2 = -2 - x
1048 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()1049 FoldingRule MergeSubNegateArithmetic() {
1050 return [](IRContext* context, Instruction* inst,
1051 const std::vector<const analysis::Constant*>& constants) {
1052 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1053 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1054 const analysis::Type* type =
1055 context->get_type_mgr()->GetType(inst->type_id());
1056 bool uses_float = HasFloatingPoint(type);
1057 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1058
1059 uint32_t width = ElementWidth(type);
1060 if (width != 32 && width != 64) return false;
1061
1062 const analysis::Constant* const_input1 = ConstInput(constants);
1063 if (!const_input1) return false;
1064 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1065 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1066 return false;
1067
1068 if (other_inst->opcode() == SpvOpSNegate ||
1069 other_inst->opcode() == SpvOpFNegate) {
1070 uint32_t op1 = 0;
1071 uint32_t op2 = 0;
1072 SpvOp opcode = inst->opcode();
1073 if (constants[0] != nullptr) {
1074 op1 = other_inst->GetSingleWordInOperand(0u);
1075 op2 = inst->GetSingleWordInOperand(0u);
1076 opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
1077 } else {
1078 op1 = NegateConstant(const_mgr, const_input1);
1079 op2 = other_inst->GetSingleWordInOperand(0u);
1080 }
1081
1082 inst->SetOpcode(opcode);
1083 inst->SetInOperands(
1084 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1085 return true;
1086 }
1087 return false;
1088 };
1089 }
1090
1091 // Folds addition of an addition where each operation has a constant operand.
1092 // Cases:
1093 // (x + 2) + 2 = x + 4
1094 // (2 + x) + 2 = x + 4
1095 // 2 + (x + 2) = x + 4
1096 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1097 FoldingRule MergeAddAddArithmetic() {
1098 return [](IRContext* context, Instruction* inst,
1099 const std::vector<const analysis::Constant*>& constants) {
1100 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1101 const analysis::Type* type =
1102 context->get_type_mgr()->GetType(inst->type_id());
1103 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1104 bool uses_float = HasFloatingPoint(type);
1105 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1106
1107 uint32_t width = ElementWidth(type);
1108 if (width != 32 && width != 64) return false;
1109
1110 const analysis::Constant* const_input1 = ConstInput(constants);
1111 if (!const_input1) return false;
1112 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1113 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1114 return false;
1115
1116 if (other_inst->opcode() == SpvOpFAdd ||
1117 other_inst->opcode() == SpvOpIAdd) {
1118 std::vector<const analysis::Constant*> other_constants =
1119 const_mgr->GetOperandConstants(other_inst);
1120 const analysis::Constant* const_input2 = ConstInput(other_constants);
1121 if (!const_input2) return false;
1122
1123 Instruction* non_const_input =
1124 NonConstInput(context, other_constants[0], other_inst);
1125 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1126 const_input1, const_input2);
1127 if (merged_id == 0) return false;
1128
1129 inst->SetInOperands(
1130 {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1131 {SPV_OPERAND_TYPE_ID, {merged_id}}});
1132 return true;
1133 }
1134 return false;
1135 };
1136 }
1137
1138 // Folds addition of a subtraction where each operation has a constant operand.
1139 // Cases:
1140 // (x - 2) + 2 = x + 0
1141 // (2 - x) + 2 = 4 - x
1142 // 2 + (x - 2) = x + 0
1143 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1144 FoldingRule MergeAddSubArithmetic() {
1145 return [](IRContext* context, Instruction* inst,
1146 const std::vector<const analysis::Constant*>& constants) {
1147 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1148 const analysis::Type* type =
1149 context->get_type_mgr()->GetType(inst->type_id());
1150 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1151 bool uses_float = HasFloatingPoint(type);
1152 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1153
1154 uint32_t width = ElementWidth(type);
1155 if (width != 32 && width != 64) return false;
1156
1157 const analysis::Constant* const_input1 = ConstInput(constants);
1158 if (!const_input1) return false;
1159 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1160 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1161 return false;
1162
1163 if (other_inst->opcode() == SpvOpFSub ||
1164 other_inst->opcode() == SpvOpISub) {
1165 std::vector<const analysis::Constant*> other_constants =
1166 const_mgr->GetOperandConstants(other_inst);
1167 const analysis::Constant* const_input2 = ConstInput(other_constants);
1168 if (!const_input2) return false;
1169
1170 bool first_is_variable = other_constants[0] == nullptr;
1171 SpvOp op = inst->opcode();
1172 uint32_t op1 = 0;
1173 uint32_t op2 = 0;
1174 if (first_is_variable) {
1175 // Subtract constants. Non-constant operand is first.
1176 op1 = other_inst->GetSingleWordInOperand(0u);
1177 op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1178 const_input2);
1179 } else {
1180 // Add constants. Constant operand is first. Change the opcode.
1181 op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1182 const_input2);
1183 op2 = other_inst->GetSingleWordInOperand(1u);
1184 op = other_inst->opcode();
1185 }
1186 if (op1 == 0 || op2 == 0) return false;
1187
1188 inst->SetOpcode(op);
1189 inst->SetInOperands(
1190 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1191 return true;
1192 }
1193 return false;
1194 };
1195 }
1196
1197 // Folds subtraction of an addition where each operand has a constant operand.
1198 // Cases:
1199 // (x + 2) - 2 = x + 0
1200 // (2 + x) - 2 = x + 0
1201 // 2 - (x + 2) = 0 - x
1202 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1203 FoldingRule MergeSubAddArithmetic() {
1204 return [](IRContext* context, Instruction* inst,
1205 const std::vector<const analysis::Constant*>& constants) {
1206 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1207 const analysis::Type* type =
1208 context->get_type_mgr()->GetType(inst->type_id());
1209 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1210 bool uses_float = HasFloatingPoint(type);
1211 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1212
1213 uint32_t width = ElementWidth(type);
1214 if (width != 32 && width != 64) return false;
1215
1216 const analysis::Constant* const_input1 = ConstInput(constants);
1217 if (!const_input1) return false;
1218 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1219 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1220 return false;
1221
1222 if (other_inst->opcode() == SpvOpFAdd ||
1223 other_inst->opcode() == SpvOpIAdd) {
1224 std::vector<const analysis::Constant*> other_constants =
1225 const_mgr->GetOperandConstants(other_inst);
1226 const analysis::Constant* const_input2 = ConstInput(other_constants);
1227 if (!const_input2) return false;
1228
1229 Instruction* non_const_input =
1230 NonConstInput(context, other_constants[0], other_inst);
1231
1232 // If the first operand of the sub is not a constant, swap the constants
1233 // so the subtraction has the correct operands.
1234 if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1235 // Subtract the constants.
1236 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1237 const_input1, const_input2);
1238 SpvOp op = inst->opcode();
1239 uint32_t op1 = 0;
1240 uint32_t op2 = 0;
1241 if (constants[0] == nullptr) {
1242 // Non-constant operand is first. Change the opcode.
1243 op1 = non_const_input->result_id();
1244 op2 = merged_id;
1245 op = other_inst->opcode();
1246 } else {
1247 // Constant operand is first.
1248 op1 = merged_id;
1249 op2 = non_const_input->result_id();
1250 }
1251 if (op1 == 0 || op2 == 0) return false;
1252
1253 inst->SetOpcode(op);
1254 inst->SetInOperands(
1255 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1256 return true;
1257 }
1258 return false;
1259 };
1260 }
1261
1262 // Folds subtraction of a subtraction where each operand has a constant operand.
1263 // Cases:
1264 // (x - 2) - 2 = x - 4
1265 // (2 - x) - 2 = 0 - x
1266 // 2 - (x - 2) = 4 - x
1267 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1268 FoldingRule MergeSubSubArithmetic() {
1269 return [](IRContext* context, Instruction* inst,
1270 const std::vector<const analysis::Constant*>& constants) {
1271 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1272 const analysis::Type* type =
1273 context->get_type_mgr()->GetType(inst->type_id());
1274 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1275 bool uses_float = HasFloatingPoint(type);
1276 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1277
1278 uint32_t width = ElementWidth(type);
1279 if (width != 32 && width != 64) return false;
1280
1281 const analysis::Constant* const_input1 = ConstInput(constants);
1282 if (!const_input1) return false;
1283 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1284 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1285 return false;
1286
1287 if (other_inst->opcode() == SpvOpFSub ||
1288 other_inst->opcode() == SpvOpISub) {
1289 std::vector<const analysis::Constant*> other_constants =
1290 const_mgr->GetOperandConstants(other_inst);
1291 const analysis::Constant* const_input2 = ConstInput(other_constants);
1292 if (!const_input2) return false;
1293
1294 Instruction* non_const_input =
1295 NonConstInput(context, other_constants[0], other_inst);
1296
1297 // Merge the constants.
1298 uint32_t merged_id = 0;
1299 SpvOp merge_op = inst->opcode();
1300 if (other_constants[0] == nullptr) {
1301 merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1302 } else if (constants[0] == nullptr) {
1303 std::swap(const_input1, const_input2);
1304 }
1305 merged_id =
1306 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1307 if (merged_id == 0) return false;
1308
1309 SpvOp op = inst->opcode();
1310 if (constants[0] != nullptr && other_constants[0] != nullptr) {
1311 // Change the operation.
1312 op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1313 }
1314
1315 uint32_t op1 = 0;
1316 uint32_t op2 = 0;
1317 if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1318 op1 = merged_id;
1319 op2 = non_const_input->result_id();
1320 } else {
1321 op1 = non_const_input->result_id();
1322 op2 = merged_id;
1323 }
1324
1325 inst->SetOpcode(op);
1326 inst->SetInOperands(
1327 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1328 return true;
1329 }
1330 return false;
1331 };
1332 }
1333
1334 // Helper function for MergeGenericAddSubArithmetic. If |addend| and
1335 // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
MergeGenericAddendSub(uint32_t addend,uint32_t sub,Instruction * inst)1336 bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1337 IRContext* context = inst->context();
1338 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1339 Instruction* sub_inst = def_use_mgr->GetDef(sub);
1340 if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
1341 return false;
1342 if (sub_inst->opcode() == SpvOpFSub &&
1343 !sub_inst->IsFloatingPointFoldingAllowed())
1344 return false;
1345 if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1346 inst->SetOpcode(SpvOpCopyObject);
1347 inst->SetInOperands(
1348 {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1349 context->UpdateDefUse(inst);
1350 return true;
1351 }
1352
1353 // Folds addition of a subtraction where the subtrahend is equal to the
1354 // other addend. Return a copy of the minuend. Accepts generic (const and
1355 // non-const) operands.
1356 // Cases:
1357 // (a - b) + b = a
1358 // b + (a - b) = a
MergeGenericAddSubArithmetic()1359 FoldingRule MergeGenericAddSubArithmetic() {
1360 return [](IRContext* context, Instruction* inst,
1361 const std::vector<const analysis::Constant*>&) {
1362 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1363 const analysis::Type* type =
1364 context->get_type_mgr()->GetType(inst->type_id());
1365 bool uses_float = HasFloatingPoint(type);
1366 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1367
1368 uint32_t width = ElementWidth(type);
1369 if (width != 32 && width != 64) return false;
1370
1371 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1372 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1373 if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1374 return MergeGenericAddendSub(add_op1, add_op0, inst);
1375 };
1376 }
1377
1378 // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1379 // generate |factor0_0| * (|factor0_1| + |factor1_1|).
FactorAddMulsOpnds(uint32_t factor0_0,uint32_t factor0_1,uint32_t factor1_0,uint32_t factor1_1,Instruction * inst)1380 bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1381 uint32_t factor1_0, uint32_t factor1_1,
1382 Instruction* inst) {
1383 IRContext* context = inst->context();
1384 if (factor0_0 != factor1_0) return false;
1385 InstructionBuilder ir_builder(
1386 context, inst,
1387 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1388 Instruction* new_add_inst = ir_builder.AddBinaryOp(
1389 inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1390 inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
1391 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1392 {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1393 context->UpdateDefUse(inst);
1394 return true;
1395 }
1396
1397 // Perform the following factoring identity, handling all operand order
1398 // combinations: (a * b) + (a * c) = a * (b + c)
FactorAddMuls()1399 FoldingRule FactorAddMuls() {
1400 return [](IRContext* context, Instruction* inst,
1401 const std::vector<const analysis::Constant*>&) {
1402 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1403 const analysis::Type* type =
1404 context->get_type_mgr()->GetType(inst->type_id());
1405 bool uses_float = HasFloatingPoint(type);
1406 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1407
1408 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1409 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1410 Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1411 if (add_op0_inst->opcode() != SpvOpFMul &&
1412 add_op0_inst->opcode() != SpvOpIMul)
1413 return false;
1414 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1415 Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1416 if (add_op1_inst->opcode() != SpvOpFMul &&
1417 add_op1_inst->opcode() != SpvOpIMul)
1418 return false;
1419
1420 // Only perform this optimization if both of the muls only have one use.
1421 // Otherwise this is a deoptimization in size and performance.
1422 if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1423 if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1424
1425 if (add_op0_inst->opcode() == SpvOpFMul &&
1426 (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1427 !add_op1_inst->IsFloatingPointFoldingAllowed()))
1428 return false;
1429
1430 for (int i = 0; i < 2; i++) {
1431 for (int j = 0; j < 2; j++) {
1432 // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1433 if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1434 add_op0_inst->GetSingleWordInOperand(1 - i),
1435 add_op1_inst->GetSingleWordInOperand(j),
1436 add_op1_inst->GetSingleWordInOperand(1 - j),
1437 inst))
1438 return true;
1439 }
1440 }
1441 return false;
1442 };
1443 }
1444
IntMultipleBy1()1445 FoldingRule IntMultipleBy1() {
1446 return [](IRContext*, Instruction* inst,
1447 const std::vector<const analysis::Constant*>& constants) {
1448 assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
1449 for (uint32_t i = 0; i < 2; i++) {
1450 if (constants[i] == nullptr) {
1451 continue;
1452 }
1453 const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1454 if (int_constant) {
1455 uint32_t width = ElementWidth(int_constant->type());
1456 if (width != 32 && width != 64) return false;
1457 bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1458 : int_constant->GetU64BitValue() == 1ull;
1459 if (is_one) {
1460 inst->SetOpcode(SpvOpCopyObject);
1461 inst->SetInOperands(
1462 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1463 return true;
1464 }
1465 }
1466 }
1467 return false;
1468 };
1469 }
1470
CompositeConstructFeedingExtract()1471 FoldingRule CompositeConstructFeedingExtract() {
1472 return [](IRContext* context, Instruction* inst,
1473 const std::vector<const analysis::Constant*>&) {
1474 // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1475 // then we can simply use the appropriate element in the construction.
1476 assert(inst->opcode() == SpvOpCompositeExtract &&
1477 "Wrong opcode. Should be OpCompositeExtract.");
1478 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1479 analysis::TypeManager* type_mgr = context->get_type_mgr();
1480
1481 // If there are no index operands, then this rule cannot do anything.
1482 if (inst->NumInOperands() <= 1) {
1483 return false;
1484 }
1485
1486 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1487 Instruction* cinst = def_use_mgr->GetDef(cid);
1488
1489 if (cinst->opcode() != SpvOpCompositeConstruct) {
1490 return false;
1491 }
1492
1493 std::vector<Operand> operands;
1494 analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
1495 if (composite_type->AsVector() == nullptr) {
1496 // Get the element being extracted from the OpCompositeConstruct
1497 // Since it is not a vector, it is simple to extract the single element.
1498 uint32_t element_index = inst->GetSingleWordInOperand(1);
1499 uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
1500 operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1501
1502 // Add the remaining indices for extraction.
1503 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1504 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1505 {inst->GetSingleWordInOperand(i)}});
1506 }
1507
1508 } else {
1509 // With vectors we have to handle the case where it is concatenating
1510 // vectors.
1511 assert(inst->NumInOperands() == 2 &&
1512 "Expecting a vector of scalar values.");
1513
1514 uint32_t element_index = inst->GetSingleWordInOperand(1);
1515 for (uint32_t construct_index = 0;
1516 construct_index < cinst->NumInOperands(); ++construct_index) {
1517 uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
1518 Instruction* element_def = def_use_mgr->GetDef(element_id);
1519 analysis::Vector* element_type =
1520 type_mgr->GetType(element_def->type_id())->AsVector();
1521 if (element_type) {
1522 uint32_t vector_size = element_type->element_count();
1523 if (vector_size <= element_index) {
1524 // The element we want comes after this vector.
1525 element_index -= vector_size;
1526 } else {
1527 // We want an element of this vector.
1528 operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1529 operands.push_back(
1530 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
1531 break;
1532 }
1533 } else {
1534 if (element_index == 0) {
1535 // This is a scalar, and we this is the element we are extracting.
1536 operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1537 break;
1538 } else {
1539 // Skip over this scalar value.
1540 --element_index;
1541 }
1542 }
1543 }
1544 }
1545
1546 // If there were no extra indices, then we have the final object. No need
1547 // to extract even more.
1548 if (operands.size() == 1) {
1549 inst->SetOpcode(SpvOpCopyObject);
1550 }
1551
1552 inst->SetInOperands(std::move(operands));
1553 return true;
1554 };
1555 }
1556
1557 // If the OpCompositeConstruct is simply putting back together elements that
1558 // where extracted from the same source, we can simply reuse the source.
1559 //
1560 // This is a common code pattern because of the way that scalar replacement
1561 // works.
CompositeExtractFeedingConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1562 bool CompositeExtractFeedingConstruct(
1563 IRContext* context, Instruction* inst,
1564 const std::vector<const analysis::Constant*>&) {
1565 assert(inst->opcode() == SpvOpCompositeConstruct &&
1566 "Wrong opcode. Should be OpCompositeConstruct.");
1567 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1568 uint32_t original_id = 0;
1569
1570 if (inst->NumInOperands() == 0) {
1571 // The struct being constructed has no members.
1572 return false;
1573 }
1574
1575 // Check each element to make sure they are:
1576 // - extractions
1577 // - extracting the same position they are inserting
1578 // - all extract from the same id.
1579 for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1580 const uint32_t element_id = inst->GetSingleWordInOperand(i);
1581 Instruction* element_inst = def_use_mgr->GetDef(element_id);
1582
1583 if (element_inst->opcode() != SpvOpCompositeExtract) {
1584 return false;
1585 }
1586
1587 if (element_inst->NumInOperands() != 2) {
1588 return false;
1589 }
1590
1591 if (element_inst->GetSingleWordInOperand(1) != i) {
1592 return false;
1593 }
1594
1595 if (i == 0) {
1596 original_id =
1597 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1598 } else if (original_id !=
1599 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1600 return false;
1601 }
1602 }
1603
1604 // The last check it to see that the object being extracted from is the
1605 // correct type.
1606 Instruction* original_inst = def_use_mgr->GetDef(original_id);
1607 if (original_inst->type_id() != inst->type_id()) {
1608 return false;
1609 }
1610
1611 // Simplify by using the original object.
1612 inst->SetOpcode(SpvOpCopyObject);
1613 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1614 return true;
1615 }
1616
InsertFeedingExtract()1617 FoldingRule InsertFeedingExtract() {
1618 return [](IRContext* context, Instruction* inst,
1619 const std::vector<const analysis::Constant*>&) {
1620 assert(inst->opcode() == SpvOpCompositeExtract &&
1621 "Wrong opcode. Should be OpCompositeExtract.");
1622 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1623 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1624 Instruction* cinst = def_use_mgr->GetDef(cid);
1625
1626 if (cinst->opcode() != SpvOpCompositeInsert) {
1627 return false;
1628 }
1629
1630 // Find the first position where the list of insert and extract indicies
1631 // differ, if at all.
1632 uint32_t i;
1633 for (i = 1; i < inst->NumInOperands(); ++i) {
1634 if (i + 1 >= cinst->NumInOperands()) {
1635 break;
1636 }
1637
1638 if (inst->GetSingleWordInOperand(i) !=
1639 cinst->GetSingleWordInOperand(i + 1)) {
1640 break;
1641 }
1642 }
1643
1644 // We are extracting the element that was inserted.
1645 if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1646 inst->SetOpcode(SpvOpCopyObject);
1647 inst->SetInOperands(
1648 {{SPV_OPERAND_TYPE_ID,
1649 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1650 return true;
1651 }
1652
1653 // Extracting the value that was inserted along with values for the base
1654 // composite. Cannot do anything.
1655 if (i == inst->NumInOperands()) {
1656 return false;
1657 }
1658
1659 // Extracting an element of the value that was inserted. Extract from
1660 // that value directly.
1661 if (i + 1 == cinst->NumInOperands()) {
1662 std::vector<Operand> operands;
1663 operands.push_back(
1664 {SPV_OPERAND_TYPE_ID,
1665 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1666 for (; i < inst->NumInOperands(); ++i) {
1667 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1668 {inst->GetSingleWordInOperand(i)}});
1669 }
1670 inst->SetInOperands(std::move(operands));
1671 return true;
1672 }
1673
1674 // Extracting a value that is disjoint from the element being inserted.
1675 // Rewrite the extract to use the composite input to the insert.
1676 std::vector<Operand> operands;
1677 operands.push_back(
1678 {SPV_OPERAND_TYPE_ID,
1679 {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1680 for (i = 1; i < inst->NumInOperands(); ++i) {
1681 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1682 {inst->GetSingleWordInOperand(i)}});
1683 }
1684 inst->SetInOperands(std::move(operands));
1685 return true;
1686 };
1687 }
1688
1689 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1690 // operands of the VectorShuffle. We just need to adjust the index in the
1691 // extract instruction.
VectorShuffleFeedingExtract()1692 FoldingRule VectorShuffleFeedingExtract() {
1693 return [](IRContext* context, Instruction* inst,
1694 const std::vector<const analysis::Constant*>&) {
1695 assert(inst->opcode() == SpvOpCompositeExtract &&
1696 "Wrong opcode. Should be OpCompositeExtract.");
1697 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1698 analysis::TypeManager* type_mgr = context->get_type_mgr();
1699 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1700 Instruction* cinst = def_use_mgr->GetDef(cid);
1701
1702 if (cinst->opcode() != SpvOpVectorShuffle) {
1703 return false;
1704 }
1705
1706 // Find the size of the first vector operand of the VectorShuffle
1707 Instruction* first_input =
1708 def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1709 analysis::Type* first_input_type =
1710 type_mgr->GetType(first_input->type_id());
1711 assert(first_input_type->AsVector() &&
1712 "Input to vector shuffle should be vectors.");
1713 uint32_t first_input_size = first_input_type->AsVector()->element_count();
1714
1715 // Get index of the element the vector shuffle is placing in the position
1716 // being extracted.
1717 uint32_t new_index =
1718 cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1719
1720 // Extracting an undefined value so fold this extract into an undef.
1721 const uint32_t undef_literal_value = 0xffffffff;
1722 if (new_index == undef_literal_value) {
1723 inst->SetOpcode(SpvOpUndef);
1724 inst->SetInOperands({});
1725 return true;
1726 }
1727
1728 // Get the id of the of the vector the elemtent comes from, and update the
1729 // index if needed.
1730 uint32_t new_vector = 0;
1731 if (new_index < first_input_size) {
1732 new_vector = cinst->GetSingleWordInOperand(0);
1733 } else {
1734 new_vector = cinst->GetSingleWordInOperand(1);
1735 new_index -= first_input_size;
1736 }
1737
1738 // Update the extract instruction.
1739 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1740 inst->SetInOperand(1, {new_index});
1741 return true;
1742 };
1743 }
1744
1745 // When an FMix with is feeding an Extract that extracts an element whose
1746 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1747 // operands of the FMix.
FMixFeedingExtract()1748 FoldingRule FMixFeedingExtract() {
1749 return [](IRContext* context, Instruction* inst,
1750 const std::vector<const analysis::Constant*>&) {
1751 assert(inst->opcode() == SpvOpCompositeExtract &&
1752 "Wrong opcode. Should be OpCompositeExtract.");
1753 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1754 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1755
1756 uint32_t composite_id =
1757 inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1758 Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1759
1760 if (composite_inst->opcode() != SpvOpExtInst) {
1761 return false;
1762 }
1763
1764 uint32_t inst_set_id =
1765 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1766
1767 if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1768 inst_set_id ||
1769 composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1770 GLSLstd450FMix) {
1771 return false;
1772 }
1773
1774 // Get the |a| for the FMix instruction.
1775 uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1776 std::unique_ptr<Instruction> a(inst->Clone(context));
1777 a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1778 context->get_instruction_folder().FoldInstruction(a.get());
1779
1780 if (a->opcode() != SpvOpCopyObject) {
1781 return false;
1782 }
1783
1784 const analysis::Constant* a_const =
1785 const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1786
1787 if (!a_const) {
1788 return false;
1789 }
1790
1791 bool use_x = false;
1792
1793 assert(a_const->type()->AsFloat());
1794 double element_value = a_const->GetValueAsDouble();
1795 if (element_value == 0.0) {
1796 use_x = true;
1797 } else if (element_value == 1.0) {
1798 use_x = false;
1799 } else {
1800 return false;
1801 }
1802
1803 // Get the id of the of the vector the element comes from.
1804 uint32_t new_vector = 0;
1805 if (use_x) {
1806 new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1807 } else {
1808 new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1809 }
1810
1811 // Update the extract instruction.
1812 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1813 return true;
1814 };
1815 }
1816
RedundantPhi()1817 FoldingRule RedundantPhi() {
1818 // An OpPhi instruction where all values are the same or the result of the phi
1819 // itself, can be replaced by the value itself.
1820 return [](IRContext*, Instruction* inst,
1821 const std::vector<const analysis::Constant*>&) {
1822 assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
1823
1824 uint32_t incoming_value = 0;
1825
1826 for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
1827 uint32_t op_id = inst->GetSingleWordInOperand(i);
1828 if (op_id == inst->result_id()) {
1829 continue;
1830 }
1831
1832 if (incoming_value == 0) {
1833 incoming_value = op_id;
1834 } else if (op_id != incoming_value) {
1835 // Found two possible value. Can't simplify.
1836 return false;
1837 }
1838 }
1839
1840 if (incoming_value == 0) {
1841 // Code looks invalid. Don't do anything.
1842 return false;
1843 }
1844
1845 // We have a single incoming value. Simplify using that value.
1846 inst->SetOpcode(SpvOpCopyObject);
1847 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
1848 return true;
1849 };
1850 }
1851
BitCastScalarOrVector()1852 FoldingRule BitCastScalarOrVector() {
1853 return [](IRContext* context, Instruction* inst,
1854 const std::vector<const analysis::Constant*>& constants) {
1855 assert(inst->opcode() == SpvOpBitcast && constants.size() == 1);
1856 if (constants[0] == nullptr) return false;
1857
1858 const analysis::Type* type =
1859 context->get_type_mgr()->GetType(inst->type_id());
1860 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
1861 return false;
1862
1863 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1864 std::vector<uint32_t> words =
1865 GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
1866 if (words.size() == 0) return false;
1867
1868 const analysis::Constant* bitcasted_constant =
1869 ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
1870 if (!bitcasted_constant) return false;
1871
1872 auto new_feeder_id =
1873 const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
1874 ->result_id();
1875 inst->SetOpcode(SpvOpCopyObject);
1876 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
1877 return true;
1878 };
1879 }
1880
RedundantSelect()1881 FoldingRule RedundantSelect() {
1882 // An OpSelect instruction where both values are the same or the condition is
1883 // constant can be replaced by one of the values
1884 return [](IRContext*, Instruction* inst,
1885 const std::vector<const analysis::Constant*>& constants) {
1886 assert(inst->opcode() == SpvOpSelect &&
1887 "Wrong opcode. Should be OpSelect.");
1888 assert(inst->NumInOperands() == 3);
1889 assert(constants.size() == 3);
1890
1891 uint32_t true_id = inst->GetSingleWordInOperand(1);
1892 uint32_t false_id = inst->GetSingleWordInOperand(2);
1893
1894 if (true_id == false_id) {
1895 // Both results are the same, condition doesn't matter
1896 inst->SetOpcode(SpvOpCopyObject);
1897 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1898 return true;
1899 } else if (constants[0]) {
1900 const analysis::Type* type = constants[0]->type();
1901 if (type->AsBool()) {
1902 // Scalar constant value, select the corresponding value.
1903 inst->SetOpcode(SpvOpCopyObject);
1904 if (constants[0]->AsNullConstant() ||
1905 !constants[0]->AsBoolConstant()->value()) {
1906 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1907 } else {
1908 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1909 }
1910 return true;
1911 } else {
1912 assert(type->AsVector());
1913 if (constants[0]->AsNullConstant()) {
1914 // All values come from false id.
1915 inst->SetOpcode(SpvOpCopyObject);
1916 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1917 return true;
1918 } else {
1919 // Convert to a vector shuffle.
1920 std::vector<Operand> ops;
1921 ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
1922 ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
1923 const analysis::VectorConstant* vector_const =
1924 constants[0]->AsVectorConstant();
1925 uint32_t size =
1926 static_cast<uint32_t>(vector_const->GetComponents().size());
1927 for (uint32_t i = 0; i != size; ++i) {
1928 const analysis::Constant* component =
1929 vector_const->GetComponents()[i];
1930 if (component->AsNullConstant() ||
1931 !component->AsBoolConstant()->value()) {
1932 // Selecting from the false vector which is the second input
1933 // vector to the shuffle. Offset the index by |size|.
1934 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
1935 } else {
1936 // Selecting from true vector which is the first input vector to
1937 // the shuffle.
1938 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
1939 }
1940 }
1941
1942 inst->SetOpcode(SpvOpVectorShuffle);
1943 inst->SetInOperands(std::move(ops));
1944 return true;
1945 }
1946 }
1947 }
1948
1949 return false;
1950 };
1951 }
1952
1953 enum class FloatConstantKind { Unknown, Zero, One };
1954
getFloatConstantKind(const analysis::Constant * constant)1955 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
1956 if (constant == nullptr) {
1957 return FloatConstantKind::Unknown;
1958 }
1959
1960 assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
1961
1962 if (constant->AsNullConstant()) {
1963 return FloatConstantKind::Zero;
1964 } else if (const analysis::VectorConstant* vc =
1965 constant->AsVectorConstant()) {
1966 const std::vector<const analysis::Constant*>& components =
1967 vc->GetComponents();
1968 assert(!components.empty());
1969
1970 FloatConstantKind kind = getFloatConstantKind(components[0]);
1971
1972 for (size_t i = 1; i < components.size(); ++i) {
1973 if (getFloatConstantKind(components[i]) != kind) {
1974 return FloatConstantKind::Unknown;
1975 }
1976 }
1977
1978 return kind;
1979 } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
1980 if (fc->IsZero()) return FloatConstantKind::Zero;
1981
1982 uint32_t width = fc->type()->AsFloat()->width();
1983 if (width != 32 && width != 64) return FloatConstantKind::Unknown;
1984
1985 double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
1986
1987 if (value == 0.0) {
1988 return FloatConstantKind::Zero;
1989 } else if (value == 1.0) {
1990 return FloatConstantKind::One;
1991 } else {
1992 return FloatConstantKind::Unknown;
1993 }
1994 } else {
1995 return FloatConstantKind::Unknown;
1996 }
1997 }
1998
RedundantFAdd()1999 FoldingRule RedundantFAdd() {
2000 return [](IRContext*, Instruction* inst,
2001 const std::vector<const analysis::Constant*>& constants) {
2002 assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
2003 assert(constants.size() == 2);
2004
2005 if (!inst->IsFloatingPointFoldingAllowed()) {
2006 return false;
2007 }
2008
2009 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2010 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2011
2012 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2013 inst->SetOpcode(SpvOpCopyObject);
2014 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2015 {inst->GetSingleWordInOperand(
2016 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
2017 return true;
2018 }
2019
2020 return false;
2021 };
2022 }
2023
RedundantFSub()2024 FoldingRule RedundantFSub() {
2025 return [](IRContext*, Instruction* inst,
2026 const std::vector<const analysis::Constant*>& constants) {
2027 assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
2028 assert(constants.size() == 2);
2029
2030 if (!inst->IsFloatingPointFoldingAllowed()) {
2031 return false;
2032 }
2033
2034 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2035 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2036
2037 if (kind0 == FloatConstantKind::Zero) {
2038 inst->SetOpcode(SpvOpFNegate);
2039 inst->SetInOperands(
2040 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
2041 return true;
2042 }
2043
2044 if (kind1 == FloatConstantKind::Zero) {
2045 inst->SetOpcode(SpvOpCopyObject);
2046 inst->SetInOperands(
2047 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2048 return true;
2049 }
2050
2051 return false;
2052 };
2053 }
2054
RedundantFMul()2055 FoldingRule RedundantFMul() {
2056 return [](IRContext*, Instruction* inst,
2057 const std::vector<const analysis::Constant*>& constants) {
2058 assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
2059 assert(constants.size() == 2);
2060
2061 if (!inst->IsFloatingPointFoldingAllowed()) {
2062 return false;
2063 }
2064
2065 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2066 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2067
2068 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2069 inst->SetOpcode(SpvOpCopyObject);
2070 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2071 {inst->GetSingleWordInOperand(
2072 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
2073 return true;
2074 }
2075
2076 if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
2077 inst->SetOpcode(SpvOpCopyObject);
2078 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2079 {inst->GetSingleWordInOperand(
2080 kind0 == FloatConstantKind::One ? 1 : 0)}}});
2081 return true;
2082 }
2083
2084 return false;
2085 };
2086 }
2087
RedundantFDiv()2088 FoldingRule RedundantFDiv() {
2089 return [](IRContext*, Instruction* inst,
2090 const std::vector<const analysis::Constant*>& constants) {
2091 assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
2092 assert(constants.size() == 2);
2093
2094 if (!inst->IsFloatingPointFoldingAllowed()) {
2095 return false;
2096 }
2097
2098 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2099 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2100
2101 if (kind0 == FloatConstantKind::Zero) {
2102 inst->SetOpcode(SpvOpCopyObject);
2103 inst->SetInOperands(
2104 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2105 return true;
2106 }
2107
2108 if (kind1 == FloatConstantKind::One) {
2109 inst->SetOpcode(SpvOpCopyObject);
2110 inst->SetInOperands(
2111 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2112 return true;
2113 }
2114
2115 return false;
2116 };
2117 }
2118
RedundantFMix()2119 FoldingRule RedundantFMix() {
2120 return [](IRContext* context, Instruction* inst,
2121 const std::vector<const analysis::Constant*>& constants) {
2122 assert(inst->opcode() == SpvOpExtInst &&
2123 "Wrong opcode. Should be OpExtInst.");
2124
2125 if (!inst->IsFloatingPointFoldingAllowed()) {
2126 return false;
2127 }
2128
2129 uint32_t instSetId =
2130 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2131
2132 if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2133 inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2134 GLSLstd450FMix) {
2135 assert(constants.size() == 5);
2136
2137 FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2138
2139 if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2140 inst->SetOpcode(SpvOpCopyObject);
2141 inst->SetInOperands(
2142 {{SPV_OPERAND_TYPE_ID,
2143 {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2144 ? kFMixXIdInIdx
2145 : kFMixYIdInIdx)}}});
2146 return true;
2147 }
2148 }
2149
2150 return false;
2151 };
2152 }
2153
2154 // This rule handles addition of zero for integers.
RedundantIAdd()2155 FoldingRule RedundantIAdd() {
2156 return [](IRContext* context, Instruction* inst,
2157 const std::vector<const analysis::Constant*>& constants) {
2158 assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
2159
2160 uint32_t operand = std::numeric_limits<uint32_t>::max();
2161 const analysis::Type* operand_type = nullptr;
2162 if (constants[0] && constants[0]->IsZero()) {
2163 operand = inst->GetSingleWordInOperand(1);
2164 operand_type = constants[0]->type();
2165 } else if (constants[1] && constants[1]->IsZero()) {
2166 operand = inst->GetSingleWordInOperand(0);
2167 operand_type = constants[1]->type();
2168 }
2169
2170 if (operand != std::numeric_limits<uint32_t>::max()) {
2171 const analysis::Type* inst_type =
2172 context->get_type_mgr()->GetType(inst->type_id());
2173 if (inst_type->IsSame(operand_type)) {
2174 inst->SetOpcode(SpvOpCopyObject);
2175 } else {
2176 inst->SetOpcode(SpvOpBitcast);
2177 }
2178 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2179 return true;
2180 }
2181 return false;
2182 };
2183 }
2184
2185 // This rule look for a dot with a constant vector containing a single 1 and
2186 // the rest 0s. This is the same as doing an extract.
DotProductDoingExtract()2187 FoldingRule DotProductDoingExtract() {
2188 return [](IRContext* context, Instruction* inst,
2189 const std::vector<const analysis::Constant*>& constants) {
2190 assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
2191
2192 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2193
2194 if (!inst->IsFloatingPointFoldingAllowed()) {
2195 return false;
2196 }
2197
2198 for (int i = 0; i < 2; ++i) {
2199 if (!constants[i]) {
2200 continue;
2201 }
2202
2203 const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2204 assert(vector_type && "Inputs to OpDot must be vectors.");
2205 const analysis::Float* element_type =
2206 vector_type->element_type()->AsFloat();
2207 assert(element_type && "Inputs to OpDot must be vectors of floats.");
2208 uint32_t element_width = element_type->width();
2209 if (element_width != 32 && element_width != 64) {
2210 return false;
2211 }
2212
2213 std::vector<const analysis::Constant*> components;
2214 components = constants[i]->GetVectorComponents(const_mgr);
2215
2216 const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2217
2218 uint32_t component_with_one = kNotFound;
2219 bool all_others_zero = true;
2220 for (uint32_t j = 0; j < components.size(); ++j) {
2221 const analysis::Constant* element = components[j];
2222 double value =
2223 (element_width == 32 ? element->GetFloat() : element->GetDouble());
2224 if (value == 0.0) {
2225 continue;
2226 } else if (value == 1.0) {
2227 if (component_with_one == kNotFound) {
2228 component_with_one = j;
2229 } else {
2230 component_with_one = kNotFound;
2231 break;
2232 }
2233 } else {
2234 all_others_zero = false;
2235 break;
2236 }
2237 }
2238
2239 if (!all_others_zero || component_with_one == kNotFound) {
2240 continue;
2241 }
2242
2243 std::vector<Operand> operands;
2244 operands.push_back(
2245 {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2246 operands.push_back(
2247 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2248
2249 inst->SetOpcode(SpvOpCompositeExtract);
2250 inst->SetInOperands(std::move(operands));
2251 return true;
2252 }
2253 return false;
2254 };
2255 }
2256
2257 // If we are storing an undef, then we can remove the store.
2258 //
2259 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2260 // is complicated. Waiting to see if it is needed.
StoringUndef()2261 FoldingRule StoringUndef() {
2262 return [](IRContext* context, Instruction* inst,
2263 const std::vector<const analysis::Constant*>&) {
2264 assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
2265
2266 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2267
2268 // If this is a volatile store, the store cannot be removed.
2269 if (inst->NumInOperands() == 3) {
2270 if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
2271 return false;
2272 }
2273 }
2274
2275 uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2276 Instruction* object_inst = def_use_mgr->GetDef(object_id);
2277 if (object_inst->opcode() == SpvOpUndef) {
2278 inst->ToNop();
2279 return true;
2280 }
2281 return false;
2282 };
2283 }
2284
VectorShuffleFeedingShuffle()2285 FoldingRule VectorShuffleFeedingShuffle() {
2286 return [](IRContext* context, Instruction* inst,
2287 const std::vector<const analysis::Constant*>&) {
2288 assert(inst->opcode() == SpvOpVectorShuffle &&
2289 "Wrong opcode. Should be OpVectorShuffle.");
2290
2291 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2292 analysis::TypeManager* type_mgr = context->get_type_mgr();
2293
2294 Instruction* feeding_shuffle_inst =
2295 def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2296 analysis::Vector* op0_type =
2297 type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2298 uint32_t op0_length = op0_type->element_count();
2299
2300 bool feeder_is_op0 = true;
2301 if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2302 feeding_shuffle_inst =
2303 def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2304 feeder_is_op0 = false;
2305 }
2306
2307 if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2308 return false;
2309 }
2310
2311 Instruction* feeder2 =
2312 def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2313 analysis::Vector* feeder_op0_type =
2314 type_mgr->GetType(feeder2->type_id())->AsVector();
2315 uint32_t feeder_op0_length = feeder_op0_type->element_count();
2316
2317 uint32_t new_feeder_id = 0;
2318 std::vector<Operand> new_operands;
2319 new_operands.resize(
2320 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
2321 const uint32_t undef_literal = 0xffffffff;
2322 for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2323 uint32_t component_index = inst->GetSingleWordInOperand(op);
2324
2325 // Do not interpret the undefined value literal as coming from operand 1.
2326 if (component_index != undef_literal &&
2327 feeder_is_op0 == (component_index < op0_length)) {
2328 // This component comes from the feeding_shuffle_inst. Update
2329 // |component_index| to be the index into the operand of the feeder.
2330
2331 // Adjust component_index to get the index into the operands of the
2332 // feeding_shuffle_inst.
2333 if (component_index >= op0_length) {
2334 component_index -= op0_length;
2335 }
2336 component_index =
2337 feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2338
2339 // Check if we are using a component from the first or second operand of
2340 // the feeding instruction.
2341 if (component_index < feeder_op0_length) {
2342 if (new_feeder_id == 0) {
2343 // First time through, save the id of the operand the element comes
2344 // from.
2345 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2346 } else if (new_feeder_id !=
2347 feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2348 // We need both elements of the feeding_shuffle_inst, so we cannot
2349 // fold.
2350 return false;
2351 }
2352 } else {
2353 if (new_feeder_id == 0) {
2354 // First time through, save the id of the operand the element comes
2355 // from.
2356 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2357 } else if (new_feeder_id !=
2358 feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2359 // We need both elements of the feeding_shuffle_inst, so we cannot
2360 // fold.
2361 return false;
2362 }
2363 component_index -= feeder_op0_length;
2364 }
2365
2366 if (!feeder_is_op0) {
2367 component_index += op0_length;
2368 }
2369 }
2370 new_operands.push_back(
2371 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2372 }
2373
2374 if (new_feeder_id == 0) {
2375 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2376 const analysis::Type* type =
2377 type_mgr->GetType(feeding_shuffle_inst->type_id());
2378 const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2379 new_feeder_id =
2380 const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2381 }
2382
2383 if (feeder_is_op0) {
2384 // If the size of the first vector operand changed then the indices
2385 // referring to the second operand need to be adjusted.
2386 Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2387 analysis::Type* new_feeder_type =
2388 type_mgr->GetType(new_feeder_inst->type_id());
2389 uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2390 int32_t adjustment = op0_length - new_op0_size;
2391
2392 if (adjustment != 0) {
2393 for (uint32_t i = 2; i < new_operands.size(); i++) {
2394 if (inst->GetSingleWordInOperand(i) >= op0_length) {
2395 new_operands[i].words[0] -= adjustment;
2396 }
2397 }
2398 }
2399
2400 new_operands[0].words[0] = new_feeder_id;
2401 new_operands[1] = inst->GetInOperand(1);
2402 } else {
2403 new_operands[1].words[0] = new_feeder_id;
2404 new_operands[0] = inst->GetInOperand(0);
2405 }
2406
2407 inst->SetInOperands(std::move(new_operands));
2408 return true;
2409 };
2410 }
2411
2412 // Removes duplicate ids from the interface list of an OpEntryPoint
2413 // instruction.
RemoveRedundantOperands()2414 FoldingRule RemoveRedundantOperands() {
2415 return [](IRContext*, Instruction* inst,
2416 const std::vector<const analysis::Constant*>&) {
2417 assert(inst->opcode() == SpvOpEntryPoint &&
2418 "Wrong opcode. Should be OpEntryPoint.");
2419 bool has_redundant_operand = false;
2420 std::unordered_set<uint32_t> seen_operands;
2421 std::vector<Operand> new_operands;
2422
2423 new_operands.emplace_back(inst->GetOperand(0));
2424 new_operands.emplace_back(inst->GetOperand(1));
2425 new_operands.emplace_back(inst->GetOperand(2));
2426 for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2427 if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2428 new_operands.emplace_back(inst->GetOperand(i));
2429 } else {
2430 has_redundant_operand = true;
2431 }
2432 }
2433
2434 if (!has_redundant_operand) {
2435 return false;
2436 }
2437
2438 inst->SetInOperands(std::move(new_operands));
2439 return true;
2440 };
2441 }
2442
2443 // If an image instruction's operand is a constant, updates the image operand
2444 // flag from Offset to ConstOffset.
UpdateImageOperands()2445 FoldingRule UpdateImageOperands() {
2446 return [](IRContext*, Instruction* inst,
2447 const std::vector<const analysis::Constant*>& constants) {
2448 const auto opcode = inst->opcode();
2449 (void)opcode;
2450 assert((opcode == SpvOpImageSampleImplicitLod ||
2451 opcode == SpvOpImageSampleExplicitLod ||
2452 opcode == SpvOpImageSampleDrefImplicitLod ||
2453 opcode == SpvOpImageSampleDrefExplicitLod ||
2454 opcode == SpvOpImageSampleProjImplicitLod ||
2455 opcode == SpvOpImageSampleProjExplicitLod ||
2456 opcode == SpvOpImageSampleProjDrefImplicitLod ||
2457 opcode == SpvOpImageSampleProjDrefExplicitLod ||
2458 opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
2459 opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
2460 opcode == SpvOpImageWrite ||
2461 opcode == SpvOpImageSparseSampleImplicitLod ||
2462 opcode == SpvOpImageSparseSampleExplicitLod ||
2463 opcode == SpvOpImageSparseSampleDrefImplicitLod ||
2464 opcode == SpvOpImageSparseSampleDrefExplicitLod ||
2465 opcode == SpvOpImageSparseSampleProjImplicitLod ||
2466 opcode == SpvOpImageSparseSampleProjExplicitLod ||
2467 opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
2468 opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
2469 opcode == SpvOpImageSparseFetch ||
2470 opcode == SpvOpImageSparseGather ||
2471 opcode == SpvOpImageSparseDrefGather ||
2472 opcode == SpvOpImageSparseRead) &&
2473 "Wrong opcode. Should be an image instruction.");
2474
2475 int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2476 if (operand_index >= 0) {
2477 auto image_operands = inst->GetSingleWordInOperand(operand_index);
2478 if (image_operands & SpvImageOperandsOffsetMask) {
2479 uint32_t offset_operand_index = operand_index + 1;
2480 if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
2481 if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
2482 if (image_operands & SpvImageOperandsGradMask)
2483 offset_operand_index += 2;
2484 assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
2485 "Offset and ConstOffset may not be used together");
2486 if (offset_operand_index < inst->NumOperands()) {
2487 if (constants[offset_operand_index]) {
2488 image_operands = image_operands | SpvImageOperandsConstOffsetMask;
2489 image_operands = image_operands & ~SpvImageOperandsOffsetMask;
2490 inst->SetInOperand(operand_index, {image_operands});
2491 return true;
2492 }
2493 }
2494 }
2495 }
2496
2497 return false;
2498 };
2499 }
2500
2501 } // namespace
2502
AddFoldingRules()2503 void FoldingRules::AddFoldingRules() {
2504 // Add all folding rules to the list for the opcodes to which they apply.
2505 // Note that the order in which rules are added to the list matters. If a rule
2506 // applies to the instruction, the rest of the rules will not be attempted.
2507 // Take that into consideration.
2508 rules_[SpvOpBitcast].push_back(BitCastScalarOrVector());
2509
2510 rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
2511
2512 rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
2513 rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
2514 rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2515 rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
2516
2517 rules_[SpvOpDot].push_back(DotProductDoingExtract());
2518
2519 rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
2520
2521 rules_[SpvOpFAdd].push_back(RedundantFAdd());
2522 rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
2523 rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
2524 rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
2525 rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
2526 rules_[SpvOpFAdd].push_back(FactorAddMuls());
2527
2528 rules_[SpvOpFDiv].push_back(RedundantFDiv());
2529 rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
2530 rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
2531 rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
2532 rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
2533
2534 rules_[SpvOpFMul].push_back(RedundantFMul());
2535 rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
2536 rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
2537 rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
2538
2539 rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
2540 rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
2541 rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
2542
2543 rules_[SpvOpFSub].push_back(RedundantFSub());
2544 rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
2545 rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
2546 rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
2547
2548 rules_[SpvOpIAdd].push_back(RedundantIAdd());
2549 rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
2550 rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
2551 rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
2552 rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
2553 rules_[SpvOpIAdd].push_back(FactorAddMuls());
2554
2555 rules_[SpvOpIMul].push_back(IntMultipleBy1());
2556 rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
2557 rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
2558
2559 rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
2560 rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
2561 rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
2562
2563 rules_[SpvOpPhi].push_back(RedundantPhi());
2564
2565 rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
2566
2567 rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
2568 rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
2569 rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
2570
2571 rules_[SpvOpSelect].push_back(RedundantSelect());
2572
2573 rules_[SpvOpStore].push_back(StoringUndef());
2574
2575 rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
2576
2577 rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2578
2579 rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
2580 rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
2581 rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
2582 rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
2583 rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
2584 rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
2585 rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
2586 rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
2587 rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
2588 rules_[SpvOpImageGather].push_back(UpdateImageOperands());
2589 rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
2590 rules_[SpvOpImageRead].push_back(UpdateImageOperands());
2591 rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
2592 rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
2593 rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
2594 rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
2595 UpdateImageOperands());
2596 rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
2597 UpdateImageOperands());
2598 rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
2599 UpdateImageOperands());
2600 rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
2601 UpdateImageOperands());
2602 rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
2603 UpdateImageOperands());
2604 rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
2605 UpdateImageOperands());
2606 rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
2607 rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
2608 rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
2609 rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
2610
2611 FeatureManager* feature_manager = context_->get_feature_mgr();
2612 // Add rules for GLSLstd450
2613 uint32_t ext_inst_glslstd450_id =
2614 feature_manager->GetExtInstImportId_GLSLstd450();
2615 if (ext_inst_glslstd450_id != 0) {
2616 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
2617 RedundantFMix());
2618 }
2619 }
2620 } // namespace opt
2621 } // namespace spvtools
2622