1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/folding_rules.h"
16
17 #include <climits>
18 #include <limits>
19 #include <memory>
20 #include <utility>
21
22 #include "ir_builder.h"
23 #include "source/latest_version_glsl_std_450_header.h"
24 #include "source/opt/ir_context.h"
25
26 namespace spvtools {
27 namespace opt {
28 namespace {
29
30 const uint32_t kExtractCompositeIdInIdx = 0;
31 const uint32_t kInsertObjectIdInIdx = 0;
32 const uint32_t kInsertCompositeIdInIdx = 1;
33 const uint32_t kExtInstSetIdInIdx = 0;
34 const uint32_t kExtInstInstructionInIdx = 1;
35 const uint32_t kFMixXIdInIdx = 2;
36 const uint32_t kFMixYIdInIdx = 3;
37 const uint32_t kFMixAIdInIdx = 4;
38 const uint32_t kStoreObjectInIdx = 1;
39
40 // Some image instructions may contain an "image operands" argument.
41 // Returns the operand index for the "image operands".
42 // Returns -1 if the instruction does not have image operands.
ImageOperandsMaskInOperandIndex(Instruction * inst)43 int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
44 const auto opcode = inst->opcode();
45 switch (opcode) {
46 case SpvOpImageSampleImplicitLod:
47 case SpvOpImageSampleExplicitLod:
48 case SpvOpImageSampleProjImplicitLod:
49 case SpvOpImageSampleProjExplicitLod:
50 case SpvOpImageFetch:
51 case SpvOpImageRead:
52 case SpvOpImageSparseSampleImplicitLod:
53 case SpvOpImageSparseSampleExplicitLod:
54 case SpvOpImageSparseSampleProjImplicitLod:
55 case SpvOpImageSparseSampleProjExplicitLod:
56 case SpvOpImageSparseFetch:
57 case SpvOpImageSparseRead:
58 return inst->NumOperands() > 4 ? 2 : -1;
59 case SpvOpImageSampleDrefImplicitLod:
60 case SpvOpImageSampleDrefExplicitLod:
61 case SpvOpImageSampleProjDrefImplicitLod:
62 case SpvOpImageSampleProjDrefExplicitLod:
63 case SpvOpImageGather:
64 case SpvOpImageDrefGather:
65 case SpvOpImageSparseSampleDrefImplicitLod:
66 case SpvOpImageSparseSampleDrefExplicitLod:
67 case SpvOpImageSparseSampleProjDrefImplicitLod:
68 case SpvOpImageSparseSampleProjDrefExplicitLod:
69 case SpvOpImageSparseGather:
70 case SpvOpImageSparseDrefGather:
71 return inst->NumOperands() > 5 ? 3 : -1;
72 case SpvOpImageWrite:
73 return inst->NumOperands() > 3 ? 3 : -1;
74 default:
75 return -1;
76 }
77 }
78
79 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)80 uint32_t ElementWidth(const analysis::Type* type) {
81 if (const analysis::Vector* vec_type = type->AsVector()) {
82 return ElementWidth(vec_type->element_type());
83 } else if (const analysis::Float* float_type = type->AsFloat()) {
84 return float_type->width();
85 } else {
86 assert(type->AsInteger());
87 return type->AsInteger()->width();
88 }
89 }
90
91 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)92 bool HasFloatingPoint(const analysis::Type* type) {
93 if (type->AsFloat()) {
94 return true;
95 } else if (const analysis::Vector* vec_type = type->AsVector()) {
96 return vec_type->element_type()->AsFloat() != nullptr;
97 }
98
99 return false;
100 }
101
102 // Returns false if |val| is NaN, infinite or subnormal.
103 template <typename T>
IsValidResult(T val)104 bool IsValidResult(T val) {
105 int classified = std::fpclassify(val);
106 switch (classified) {
107 case FP_NAN:
108 case FP_INFINITE:
109 case FP_SUBNORMAL:
110 return false;
111 default:
112 return true;
113 }
114 }
115
ConstInput(const std::vector<const analysis::Constant * > & constants)116 const analysis::Constant* ConstInput(
117 const std::vector<const analysis::Constant*>& constants) {
118 return constants[0] ? constants[0] : constants[1];
119 }
120
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)121 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
122 Instruction* inst) {
123 uint32_t in_op = c ? 1u : 0u;
124 return context->get_def_use_mgr()->GetDef(
125 inst->GetSingleWordInOperand(in_op));
126 }
127
ExtractInts(uint64_t val)128 std::vector<uint32_t> ExtractInts(uint64_t val) {
129 std::vector<uint32_t> words;
130 words.push_back(static_cast<uint32_t>(val));
131 words.push_back(static_cast<uint32_t>(val >> 32));
132 return words;
133 }
134
GetWordsFromScalarIntConstant(const analysis::IntConstant * c)135 std::vector<uint32_t> GetWordsFromScalarIntConstant(
136 const analysis::IntConstant* c) {
137 assert(c != nullptr);
138 uint32_t width = c->type()->AsInteger()->width();
139 assert(width == 32 || width == 64);
140 if (width == 64) {
141 uint64_t uval = static_cast<uint64_t>(c->GetU64());
142 return ExtractInts(uval);
143 }
144 return {c->GetU32()};
145 }
146
GetWordsFromScalarFloatConstant(const analysis::FloatConstant * c)147 std::vector<uint32_t> GetWordsFromScalarFloatConstant(
148 const analysis::FloatConstant* c) {
149 assert(c != nullptr);
150 uint32_t width = c->type()->AsFloat()->width();
151 assert(width == 32 || width == 64);
152 if (width == 64) {
153 utils::FloatProxy<double> result(c->GetDouble());
154 return result.GetWords();
155 }
156 utils::FloatProxy<float> result(c->GetFloat());
157 return result.GetWords();
158 }
159
GetWordsFromNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)160 std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
161 analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
162 if (const auto* float_constant = c->AsFloatConstant()) {
163 return GetWordsFromScalarFloatConstant(float_constant);
164 } else if (const auto* int_constant = c->AsIntConstant()) {
165 return GetWordsFromScalarIntConstant(int_constant);
166 } else if (const auto* vec_constant = c->AsVectorConstant()) {
167 std::vector<uint32_t> words;
168 for (const auto* comp : vec_constant->GetComponents()) {
169 auto comp_in_words =
170 GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
171 words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
172 }
173 return words;
174 }
175 return {};
176 }
177
ConvertWordsToNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const std::vector<uint32_t> & words,const analysis::Type * type)178 const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
179 analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
180 const analysis::Type* type) {
181 if (type->AsInteger() || type->AsFloat())
182 return const_mgr->GetConstant(type, words);
183 if (const auto* vec_type = type->AsVector())
184 return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
185 return nullptr;
186 }
187
188 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
189 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)190 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
191 const analysis::Constant* c) {
192 assert(c);
193 assert(c->type()->AsFloat());
194 uint32_t width = c->type()->AsFloat()->width();
195 assert(width == 32 || width == 64);
196 std::vector<uint32_t> words;
197 if (width == 64) {
198 utils::FloatProxy<double> result(c->GetDouble() * -1.0);
199 words = result.GetWords();
200 } else {
201 utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
202 words = result.GetWords();
203 }
204
205 const analysis::Constant* negated_const =
206 const_mgr->GetConstant(c->type(), std::move(words));
207 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
208 }
209
210 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)211 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
212 const analysis::Constant* c) {
213 assert(c);
214 assert(c->type()->AsInteger());
215 uint32_t width = c->type()->AsInteger()->width();
216 assert(width == 32 || width == 64);
217 std::vector<uint32_t> words;
218 if (width == 64) {
219 uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
220 words = ExtractInts(uval);
221 } else {
222 words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
223 }
224
225 const analysis::Constant* negated_const =
226 const_mgr->GetConstant(c->type(), std::move(words));
227 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
228 }
229
230 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)231 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
232 const analysis::Constant* c) {
233 assert(const_mgr && c);
234 assert(c->type()->AsVector());
235 if (c->AsNullConstant()) {
236 // 0.0 vs -0.0 shouldn't matter.
237 return const_mgr->GetDefiningInstruction(c)->result_id();
238 } else {
239 const analysis::Type* component_type =
240 c->AsVectorConstant()->component_type();
241 std::vector<uint32_t> words;
242 for (auto& comp : c->AsVectorConstant()->GetComponents()) {
243 if (component_type->AsFloat()) {
244 words.push_back(NegateFloatingPointConstant(const_mgr, comp));
245 } else {
246 assert(component_type->AsInteger());
247 words.push_back(NegateIntegerConstant(const_mgr, comp));
248 }
249 }
250
251 const analysis::Constant* negated_const =
252 const_mgr->GetConstant(c->type(), std::move(words));
253 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
254 }
255 }
256
257 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)258 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
259 const analysis::Constant* c) {
260 if (c->type()->AsVector()) {
261 return NegateVectorConstant(const_mgr, c);
262 } else if (c->type()->AsFloat()) {
263 return NegateFloatingPointConstant(const_mgr, c);
264 } else {
265 assert(c->type()->AsInteger());
266 return NegateIntegerConstant(const_mgr, c);
267 }
268 }
269
270 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
271 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)272 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
273 const analysis::Constant* c) {
274 assert(const_mgr && c);
275 assert(c->type()->AsFloat());
276
277 uint32_t width = c->type()->AsFloat()->width();
278 assert(width == 32 || width == 64);
279 std::vector<uint32_t> words;
280 if (width == 64) {
281 spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
282 if (!IsValidResult(result.getAsFloat())) return 0;
283 words = result.GetWords();
284 } else {
285 spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
286 if (!IsValidResult(result.getAsFloat())) return 0;
287 words = result.GetWords();
288 }
289
290 const analysis::Constant* negated_const =
291 const_mgr->GetConstant(c->type(), std::move(words));
292 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
293 }
294
295 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()296 FoldingRule ReciprocalFDiv() {
297 return [](IRContext* context, Instruction* inst,
298 const std::vector<const analysis::Constant*>& constants) {
299 assert(inst->opcode() == SpvOpFDiv);
300 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
301 const analysis::Type* type =
302 context->get_type_mgr()->GetType(inst->type_id());
303 if (!inst->IsFloatingPointFoldingAllowed()) return false;
304
305 uint32_t width = ElementWidth(type);
306 if (width != 32 && width != 64) return false;
307
308 if (constants[1] != nullptr) {
309 uint32_t id = 0;
310 if (const analysis::VectorConstant* vector_const =
311 constants[1]->AsVectorConstant()) {
312 std::vector<uint32_t> neg_ids;
313 for (auto& comp : vector_const->GetComponents()) {
314 id = Reciprocal(const_mgr, comp);
315 if (id == 0) return false;
316 neg_ids.push_back(id);
317 }
318 const analysis::Constant* negated_const =
319 const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
320 id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
321 } else if (constants[1]->AsFloatConstant()) {
322 id = Reciprocal(const_mgr, constants[1]);
323 if (id == 0) return false;
324 } else {
325 // Don't fold a null constant.
326 return false;
327 }
328 inst->SetOpcode(SpvOpFMul);
329 inst->SetInOperands(
330 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
331 {SPV_OPERAND_TYPE_ID, {id}}});
332 return true;
333 }
334
335 return false;
336 };
337 }
338
339 // Elides consecutive negate instructions.
MergeNegateArithmetic()340 FoldingRule MergeNegateArithmetic() {
341 return [](IRContext* context, Instruction* inst,
342 const std::vector<const analysis::Constant*>& constants) {
343 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
344 (void)constants;
345 const analysis::Type* type =
346 context->get_type_mgr()->GetType(inst->type_id());
347 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
348 return false;
349
350 Instruction* op_inst =
351 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
352 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
353 return false;
354
355 if (op_inst->opcode() == inst->opcode()) {
356 // Elide negates.
357 inst->SetOpcode(SpvOpCopyObject);
358 inst->SetInOperands(
359 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
360 return true;
361 }
362
363 return false;
364 };
365 }
366
367 // Merges negate into a mul or div operation if that operation contains a
368 // constant operand.
369 // Cases:
370 // -(x * 2) = x * -2
371 // -(2 * x) = x * -2
372 // -(x / 2) = x / -2
373 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()374 FoldingRule MergeNegateMulDivArithmetic() {
375 return [](IRContext* context, Instruction* inst,
376 const std::vector<const analysis::Constant*>& constants) {
377 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
378 (void)constants;
379 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
380 const analysis::Type* type =
381 context->get_type_mgr()->GetType(inst->type_id());
382 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
383 return false;
384
385 Instruction* op_inst =
386 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
387 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
388 return false;
389
390 uint32_t width = ElementWidth(type);
391 if (width != 32 && width != 64) return false;
392
393 SpvOp opcode = op_inst->opcode();
394 if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
395 opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
396 std::vector<const analysis::Constant*> op_constants =
397 const_mgr->GetOperandConstants(op_inst);
398 // Merge negate into mul or div if one operand is constant.
399 if (op_constants[0] || op_constants[1]) {
400 bool zero_is_variable = op_constants[0] == nullptr;
401 const analysis::Constant* c = ConstInput(op_constants);
402 uint32_t neg_id = NegateConstant(const_mgr, c);
403 uint32_t non_const_id = zero_is_variable
404 ? op_inst->GetSingleWordInOperand(0u)
405 : op_inst->GetSingleWordInOperand(1u);
406 // Change this instruction to a mul/div.
407 inst->SetOpcode(op_inst->opcode());
408 if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
409 uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
410 uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
411 inst->SetInOperands(
412 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
413 } else {
414 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
415 {SPV_OPERAND_TYPE_ID, {neg_id}}});
416 }
417 return true;
418 }
419 }
420
421 return false;
422 };
423 }
424
425 // Merges negate into a add or sub operation if that operation contains a
426 // constant operand.
427 // Cases:
428 // -(x + 2) = -2 - x
429 // -(2 + x) = -2 - x
430 // -(x - 2) = 2 - x
431 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()432 FoldingRule MergeNegateAddSubArithmetic() {
433 return [](IRContext* context, Instruction* inst,
434 const std::vector<const analysis::Constant*>& constants) {
435 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
436 (void)constants;
437 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
438 const analysis::Type* type =
439 context->get_type_mgr()->GetType(inst->type_id());
440 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
441 return false;
442
443 Instruction* op_inst =
444 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
445 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
446 return false;
447
448 uint32_t width = ElementWidth(type);
449 if (width != 32 && width != 64) return false;
450
451 if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
452 op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
453 std::vector<const analysis::Constant*> op_constants =
454 const_mgr->GetOperandConstants(op_inst);
455 if (op_constants[0] || op_constants[1]) {
456 bool zero_is_variable = op_constants[0] == nullptr;
457 bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
458 (op_inst->opcode() == SpvOpIAdd);
459 bool swap_operands = !is_add || zero_is_variable;
460 bool negate_const = is_add;
461 const analysis::Constant* c = ConstInput(op_constants);
462 uint32_t const_id = 0;
463 if (negate_const) {
464 const_id = NegateConstant(const_mgr, c);
465 } else {
466 const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
467 : op_inst->GetSingleWordInOperand(0u);
468 }
469
470 // Swap operands if necessary and make the instruction a subtraction.
471 uint32_t op0 =
472 zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
473 uint32_t op1 =
474 zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
475 if (swap_operands) std::swap(op0, op1);
476 inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
477 inst->SetInOperands(
478 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
479 return true;
480 }
481 }
482
483 return false;
484 };
485 }
486
487 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)488 bool HasZero(const analysis::Constant* c) {
489 if (c->AsNullConstant()) {
490 return true;
491 }
492 if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
493 for (auto& comp : vec_const->GetComponents())
494 if (HasZero(comp)) return true;
495 } else {
496 assert(c->AsScalarConstant());
497 return c->AsScalarConstant()->IsZero();
498 }
499
500 return false;
501 }
502
503 // Performs |input1| |opcode| |input2| and returns the merged constant result
504 // id. Returns 0 if the result is not a valid value. The input types must be
505 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)506 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
507 SpvOp opcode,
508 const analysis::Constant* input1,
509 const analysis::Constant* input2) {
510 const analysis::Type* type = input1->type();
511 assert(type->AsFloat());
512 uint32_t width = type->AsFloat()->width();
513 assert(width == 32 || width == 64);
514 std::vector<uint32_t> words;
515 #define FOLD_OP(op) \
516 if (width == 64) { \
517 utils::FloatProxy<double> val = \
518 input1->GetDouble() op input2->GetDouble(); \
519 double dval = val.getAsFloat(); \
520 if (!IsValidResult(dval)) return 0; \
521 words = val.GetWords(); \
522 } else { \
523 utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
524 float fval = val.getAsFloat(); \
525 if (!IsValidResult(fval)) return 0; \
526 words = val.GetWords(); \
527 } \
528 static_assert(true, "require extra semicolon")
529 switch (opcode) {
530 case SpvOpFMul:
531 FOLD_OP(*);
532 break;
533 case SpvOpFDiv:
534 if (HasZero(input2)) return 0;
535 FOLD_OP(/);
536 break;
537 case SpvOpFAdd:
538 FOLD_OP(+);
539 break;
540 case SpvOpFSub:
541 FOLD_OP(-);
542 break;
543 default:
544 assert(false && "Unexpected operation");
545 break;
546 }
547 #undef FOLD_OP
548 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
549 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
550 }
551
552 // Performs |input1| |opcode| |input2| and returns the merged constant result
553 // id. Returns 0 if the result is not a valid value. The input types must be
554 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)555 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
556 SpvOp opcode, const analysis::Constant* input1,
557 const analysis::Constant* input2) {
558 assert(input1->type()->AsInteger());
559 const analysis::Integer* type = input1->type()->AsInteger();
560 uint32_t width = type->AsInteger()->width();
561 assert(width == 32 || width == 64);
562 std::vector<uint32_t> words;
563 // Regardless of the sign of the constant, folding is performed on an unsigned
564 // interpretation of the constant data. This avoids signed integer overflow
565 // while folding, and works because sign is irrelevant for the IAdd, ISub and
566 // IMul instructions.
567 #define FOLD_OP(op) \
568 if (width == 64) { \
569 uint64_t val = input1->GetU64() op input2->GetU64(); \
570 words = ExtractInts(val); \
571 } else { \
572 uint32_t val = input1->GetU32() op input2->GetU32(); \
573 words.push_back(val); \
574 } \
575 static_assert(true, "require extra semicolon")
576 switch (opcode) {
577 case SpvOpIMul:
578 FOLD_OP(*);
579 break;
580 case SpvOpSDiv:
581 case SpvOpUDiv:
582 assert(false && "Should not merge integer division");
583 break;
584 case SpvOpIAdd:
585 FOLD_OP(+);
586 break;
587 case SpvOpISub:
588 FOLD_OP(-);
589 break;
590 default:
591 assert(false && "Unexpected operation");
592 break;
593 }
594 #undef FOLD_OP
595 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
596 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
597 }
598
599 // Performs |input1| |opcode| |input2| and returns the merged constant result
600 // id. Returns 0 if the result is not a valid value. The input types must be
601 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,SpvOp opcode,const analysis::Constant * input1,const analysis::Constant * input2)602 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
603 const analysis::Constant* input1,
604 const analysis::Constant* input2) {
605 assert(input1 && input2);
606 const analysis::Type* type = input1->type();
607 std::vector<uint32_t> words;
608 if (const analysis::Vector* vector_type = type->AsVector()) {
609 const analysis::Type* ele_type = vector_type->element_type();
610 for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
611 uint32_t id = 0;
612
613 const analysis::Constant* input1_comp = nullptr;
614 if (const analysis::VectorConstant* input1_vector =
615 input1->AsVectorConstant()) {
616 input1_comp = input1_vector->GetComponents()[i];
617 } else {
618 assert(input1->AsNullConstant());
619 input1_comp = const_mgr->GetConstant(ele_type, {});
620 }
621
622 const analysis::Constant* input2_comp = nullptr;
623 if (const analysis::VectorConstant* input2_vector =
624 input2->AsVectorConstant()) {
625 input2_comp = input2_vector->GetComponents()[i];
626 } else {
627 assert(input2->AsNullConstant());
628 input2_comp = const_mgr->GetConstant(ele_type, {});
629 }
630
631 if (ele_type->AsFloat()) {
632 id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
633 input2_comp);
634 } else {
635 assert(ele_type->AsInteger());
636 id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
637 input2_comp);
638 }
639 if (id == 0) return 0;
640 words.push_back(id);
641 }
642 const analysis::Constant* merged_const =
643 const_mgr->GetConstant(type, words);
644 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
645 } else if (type->AsFloat()) {
646 return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
647 } else {
648 assert(type->AsInteger());
649 return PerformIntegerOperation(const_mgr, opcode, input1, input2);
650 }
651 }
652
653 // Merges consecutive multiplies where each contains one constant operand.
654 // Cases:
655 // 2 * (x * 2) = x * 4
656 // 2 * (2 * x) = x * 4
657 // (x * 2) * 2 = x * 4
658 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()659 FoldingRule MergeMulMulArithmetic() {
660 return [](IRContext* context, Instruction* inst,
661 const std::vector<const analysis::Constant*>& constants) {
662 assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
663 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
664 const analysis::Type* type =
665 context->get_type_mgr()->GetType(inst->type_id());
666 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
667 return false;
668
669 uint32_t width = ElementWidth(type);
670 if (width != 32 && width != 64) return false;
671
672 // Determine the constant input and the variable input in |inst|.
673 const analysis::Constant* const_input1 = ConstInput(constants);
674 if (!const_input1) return false;
675 Instruction* other_inst = NonConstInput(context, constants[0], inst);
676 if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
677 return false;
678
679 if (other_inst->opcode() == inst->opcode()) {
680 std::vector<const analysis::Constant*> other_constants =
681 const_mgr->GetOperandConstants(other_inst);
682 const analysis::Constant* const_input2 = ConstInput(other_constants);
683 if (!const_input2) return false;
684
685 bool other_first_is_variable = other_constants[0] == nullptr;
686 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
687 const_input1, const_input2);
688 if (merged_id == 0) return false;
689
690 uint32_t non_const_id = other_first_is_variable
691 ? other_inst->GetSingleWordInOperand(0u)
692 : other_inst->GetSingleWordInOperand(1u);
693 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
694 {SPV_OPERAND_TYPE_ID, {merged_id}}});
695 return true;
696 }
697
698 return false;
699 };
700 }
701
702 // Merges divides into subsequent multiplies if each instruction contains one
703 // constant operand. Does not support integer operations.
704 // Cases:
705 // 2 * (x / 2) = x * 1
706 // 2 * (2 / x) = 4 / x
707 // (x / 2) * 2 = x * 1
708 // (2 / x) * 2 = 4 / x
709 // (y / x) * x = y
710 // x * (y / x) = y
MergeMulDivArithmetic()711 FoldingRule MergeMulDivArithmetic() {
712 return [](IRContext* context, Instruction* inst,
713 const std::vector<const analysis::Constant*>& constants) {
714 assert(inst->opcode() == SpvOpFMul);
715 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
716 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
717
718 const analysis::Type* type =
719 context->get_type_mgr()->GetType(inst->type_id());
720 if (!inst->IsFloatingPointFoldingAllowed()) return false;
721
722 uint32_t width = ElementWidth(type);
723 if (width != 32 && width != 64) return false;
724
725 for (uint32_t i = 0; i < 2; i++) {
726 uint32_t op_id = inst->GetSingleWordInOperand(i);
727 Instruction* op_inst = def_use_mgr->GetDef(op_id);
728 if (op_inst->opcode() == SpvOpFDiv) {
729 if (op_inst->GetSingleWordInOperand(1) ==
730 inst->GetSingleWordInOperand(1 - i)) {
731 inst->SetOpcode(SpvOpCopyObject);
732 inst->SetInOperands(
733 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
734 return true;
735 }
736 }
737 }
738
739 const analysis::Constant* const_input1 = ConstInput(constants);
740 if (!const_input1) return false;
741 Instruction* other_inst = NonConstInput(context, constants[0], inst);
742 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
743
744 if (other_inst->opcode() == SpvOpFDiv) {
745 std::vector<const analysis::Constant*> other_constants =
746 const_mgr->GetOperandConstants(other_inst);
747 const analysis::Constant* const_input2 = ConstInput(other_constants);
748 if (!const_input2 || HasZero(const_input2)) return false;
749
750 bool other_first_is_variable = other_constants[0] == nullptr;
751 // If the variable value is the second operand of the divide, multiply
752 // the constants together. Otherwise divide the constants.
753 uint32_t merged_id = PerformOperation(
754 const_mgr,
755 other_first_is_variable ? other_inst->opcode() : inst->opcode(),
756 const_input1, const_input2);
757 if (merged_id == 0) return false;
758
759 uint32_t non_const_id = other_first_is_variable
760 ? other_inst->GetSingleWordInOperand(0u)
761 : other_inst->GetSingleWordInOperand(1u);
762
763 // If the variable value is on the second operand of the div, then this
764 // operation is a div. Otherwise it should be a multiply.
765 inst->SetOpcode(other_first_is_variable ? inst->opcode()
766 : other_inst->opcode());
767 if (other_first_is_variable) {
768 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
769 {SPV_OPERAND_TYPE_ID, {merged_id}}});
770 } else {
771 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
772 {SPV_OPERAND_TYPE_ID, {non_const_id}}});
773 }
774 return true;
775 }
776
777 return false;
778 };
779 }
780
781 // Merges multiply of constant and negation.
782 // Cases:
783 // (-x) * 2 = x * -2
784 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()785 FoldingRule MergeMulNegateArithmetic() {
786 return [](IRContext* context, Instruction* inst,
787 const std::vector<const analysis::Constant*>& constants) {
788 assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
789 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
790 const analysis::Type* type =
791 context->get_type_mgr()->GetType(inst->type_id());
792 bool uses_float = HasFloatingPoint(type);
793 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
794
795 uint32_t width = ElementWidth(type);
796 if (width != 32 && width != 64) return false;
797
798 const analysis::Constant* const_input1 = ConstInput(constants);
799 if (!const_input1) return false;
800 Instruction* other_inst = NonConstInput(context, constants[0], inst);
801 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
802 return false;
803
804 if (other_inst->opcode() == SpvOpFNegate ||
805 other_inst->opcode() == SpvOpSNegate) {
806 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
807
808 inst->SetInOperands(
809 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
810 {SPV_OPERAND_TYPE_ID, {neg_id}}});
811 return true;
812 }
813
814 return false;
815 };
816 }
817
818 // Merges consecutive divides if each instruction contains one constant operand.
819 // Does not support integer division.
820 // Cases:
821 // 2 / (x / 2) = 4 / x
822 // 4 / (2 / x) = 2 * x
823 // (4 / x) / 2 = 2 / x
824 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()825 FoldingRule MergeDivDivArithmetic() {
826 return [](IRContext* context, Instruction* inst,
827 const std::vector<const analysis::Constant*>& constants) {
828 assert(inst->opcode() == SpvOpFDiv);
829 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
830 const analysis::Type* type =
831 context->get_type_mgr()->GetType(inst->type_id());
832 if (!inst->IsFloatingPointFoldingAllowed()) return false;
833
834 uint32_t width = ElementWidth(type);
835 if (width != 32 && width != 64) return false;
836
837 const analysis::Constant* const_input1 = ConstInput(constants);
838 if (!const_input1 || HasZero(const_input1)) return false;
839 Instruction* other_inst = NonConstInput(context, constants[0], inst);
840 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
841
842 bool first_is_variable = constants[0] == nullptr;
843 if (other_inst->opcode() == inst->opcode()) {
844 std::vector<const analysis::Constant*> other_constants =
845 const_mgr->GetOperandConstants(other_inst);
846 const analysis::Constant* const_input2 = ConstInput(other_constants);
847 if (!const_input2 || HasZero(const_input2)) return false;
848
849 bool other_first_is_variable = other_constants[0] == nullptr;
850
851 SpvOp merge_op = inst->opcode();
852 if (other_first_is_variable) {
853 // Constants magnify.
854 merge_op = SpvOpFMul;
855 }
856
857 // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
858 // because it is commutative.
859 if (first_is_variable) std::swap(const_input1, const_input2);
860 uint32_t merged_id =
861 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
862 if (merged_id == 0) return false;
863
864 uint32_t non_const_id = other_first_is_variable
865 ? other_inst->GetSingleWordInOperand(0u)
866 : other_inst->GetSingleWordInOperand(1u);
867
868 SpvOp op = inst->opcode();
869 if (!first_is_variable && !other_first_is_variable) {
870 // Effectively div of 1/x, so change to multiply.
871 op = SpvOpFMul;
872 }
873
874 uint32_t op1 = merged_id;
875 uint32_t op2 = non_const_id;
876 if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
877 inst->SetOpcode(op);
878 inst->SetInOperands(
879 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
880 return true;
881 }
882
883 return false;
884 };
885 }
886
887 // Fold multiplies succeeded by divides where each instruction contains a
888 // constant operand. Does not support integer divide.
889 // Cases:
890 // 4 / (x * 2) = 2 / x
891 // 4 / (2 * x) = 2 / x
892 // (x * 4) / 2 = x * 2
893 // (4 * x) / 2 = x * 2
894 // (x * y) / x = y
895 // (y * x) / x = y
MergeDivMulArithmetic()896 FoldingRule MergeDivMulArithmetic() {
897 return [](IRContext* context, Instruction* inst,
898 const std::vector<const analysis::Constant*>& constants) {
899 assert(inst->opcode() == SpvOpFDiv);
900 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
901 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
902
903 const analysis::Type* type =
904 context->get_type_mgr()->GetType(inst->type_id());
905 if (!inst->IsFloatingPointFoldingAllowed()) return false;
906
907 uint32_t width = ElementWidth(type);
908 if (width != 32 && width != 64) return false;
909
910 uint32_t op_id = inst->GetSingleWordInOperand(0);
911 Instruction* op_inst = def_use_mgr->GetDef(op_id);
912
913 if (op_inst->opcode() == SpvOpFMul) {
914 for (uint32_t i = 0; i < 2; i++) {
915 if (op_inst->GetSingleWordInOperand(i) ==
916 inst->GetSingleWordInOperand(1)) {
917 inst->SetOpcode(SpvOpCopyObject);
918 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
919 {op_inst->GetSingleWordInOperand(1 - i)}}});
920 return true;
921 }
922 }
923 }
924
925 const analysis::Constant* const_input1 = ConstInput(constants);
926 if (!const_input1 || HasZero(const_input1)) return false;
927 Instruction* other_inst = NonConstInput(context, constants[0], inst);
928 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
929
930 bool first_is_variable = constants[0] == nullptr;
931 if (other_inst->opcode() == SpvOpFMul) {
932 std::vector<const analysis::Constant*> other_constants =
933 const_mgr->GetOperandConstants(other_inst);
934 const analysis::Constant* const_input2 = ConstInput(other_constants);
935 if (!const_input2) return false;
936
937 bool other_first_is_variable = other_constants[0] == nullptr;
938
939 // This is an x / (*) case. Swap the inputs.
940 if (first_is_variable) std::swap(const_input1, const_input2);
941 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
942 const_input1, const_input2);
943 if (merged_id == 0) return false;
944
945 uint32_t non_const_id = other_first_is_variable
946 ? other_inst->GetSingleWordInOperand(0u)
947 : other_inst->GetSingleWordInOperand(1u);
948
949 uint32_t op1 = merged_id;
950 uint32_t op2 = non_const_id;
951 if (first_is_variable) std::swap(op1, op2);
952
953 // Convert to multiply
954 if (first_is_variable) inst->SetOpcode(other_inst->opcode());
955 inst->SetInOperands(
956 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
957 return true;
958 }
959
960 return false;
961 };
962 }
963
964 // Fold divides of a constant and a negation.
965 // Cases:
966 // (-x) / 2 = x / -2
967 // 2 / (-x) = -2 / x
MergeDivNegateArithmetic()968 FoldingRule MergeDivNegateArithmetic() {
969 return [](IRContext* context, Instruction* inst,
970 const std::vector<const analysis::Constant*>& constants) {
971 assert(inst->opcode() == SpvOpFDiv);
972 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
973 if (!inst->IsFloatingPointFoldingAllowed()) return false;
974
975 const analysis::Constant* const_input1 = ConstInput(constants);
976 if (!const_input1) return false;
977 Instruction* other_inst = NonConstInput(context, constants[0], inst);
978 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
979
980 bool first_is_variable = constants[0] == nullptr;
981 if (other_inst->opcode() == SpvOpFNegate) {
982 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
983
984 if (first_is_variable) {
985 inst->SetInOperands(
986 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
987 {SPV_OPERAND_TYPE_ID, {neg_id}}});
988 } else {
989 inst->SetInOperands(
990 {{SPV_OPERAND_TYPE_ID, {neg_id}},
991 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
992 }
993 return true;
994 }
995
996 return false;
997 };
998 }
999
1000 // Folds addition of a constant and a negation.
1001 // Cases:
1002 // (-x) + 2 = 2 - x
1003 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()1004 FoldingRule MergeAddNegateArithmetic() {
1005 return [](IRContext* context, Instruction* inst,
1006 const std::vector<const analysis::Constant*>& constants) {
1007 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1008 const analysis::Type* type =
1009 context->get_type_mgr()->GetType(inst->type_id());
1010 bool uses_float = HasFloatingPoint(type);
1011 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1012
1013 const analysis::Constant* const_input1 = ConstInput(constants);
1014 if (!const_input1) return false;
1015 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1016 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1017 return false;
1018
1019 if (other_inst->opcode() == SpvOpSNegate ||
1020 other_inst->opcode() == SpvOpFNegate) {
1021 inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
1022 uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
1023 : inst->GetSingleWordInOperand(1u);
1024 inst->SetInOperands(
1025 {{SPV_OPERAND_TYPE_ID, {const_id}},
1026 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1027 return true;
1028 }
1029 return false;
1030 };
1031 }
1032
1033 // Folds subtraction of a constant and a negation.
1034 // Cases:
1035 // (-x) - 2 = -2 - x
1036 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()1037 FoldingRule MergeSubNegateArithmetic() {
1038 return [](IRContext* context, Instruction* inst,
1039 const std::vector<const analysis::Constant*>& constants) {
1040 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1041 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1042 const analysis::Type* type =
1043 context->get_type_mgr()->GetType(inst->type_id());
1044 bool uses_float = HasFloatingPoint(type);
1045 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1046
1047 uint32_t width = ElementWidth(type);
1048 if (width != 32 && width != 64) return false;
1049
1050 const analysis::Constant* const_input1 = ConstInput(constants);
1051 if (!const_input1) return false;
1052 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1053 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1054 return false;
1055
1056 if (other_inst->opcode() == SpvOpSNegate ||
1057 other_inst->opcode() == SpvOpFNegate) {
1058 uint32_t op1 = 0;
1059 uint32_t op2 = 0;
1060 SpvOp opcode = inst->opcode();
1061 if (constants[0] != nullptr) {
1062 op1 = other_inst->GetSingleWordInOperand(0u);
1063 op2 = inst->GetSingleWordInOperand(0u);
1064 opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
1065 } else {
1066 op1 = NegateConstant(const_mgr, const_input1);
1067 op2 = other_inst->GetSingleWordInOperand(0u);
1068 }
1069
1070 inst->SetOpcode(opcode);
1071 inst->SetInOperands(
1072 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1073 return true;
1074 }
1075 return false;
1076 };
1077 }
1078
1079 // Folds addition of an addition where each operation has a constant operand.
1080 // Cases:
1081 // (x + 2) + 2 = x + 4
1082 // (2 + x) + 2 = x + 4
1083 // 2 + (x + 2) = x + 4
1084 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1085 FoldingRule MergeAddAddArithmetic() {
1086 return [](IRContext* context, Instruction* inst,
1087 const std::vector<const analysis::Constant*>& constants) {
1088 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1089 const analysis::Type* type =
1090 context->get_type_mgr()->GetType(inst->type_id());
1091 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1092 bool uses_float = HasFloatingPoint(type);
1093 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1094
1095 uint32_t width = ElementWidth(type);
1096 if (width != 32 && width != 64) return false;
1097
1098 const analysis::Constant* const_input1 = ConstInput(constants);
1099 if (!const_input1) return false;
1100 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1101 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1102 return false;
1103
1104 if (other_inst->opcode() == SpvOpFAdd ||
1105 other_inst->opcode() == SpvOpIAdd) {
1106 std::vector<const analysis::Constant*> other_constants =
1107 const_mgr->GetOperandConstants(other_inst);
1108 const analysis::Constant* const_input2 = ConstInput(other_constants);
1109 if (!const_input2) return false;
1110
1111 Instruction* non_const_input =
1112 NonConstInput(context, other_constants[0], other_inst);
1113 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1114 const_input1, const_input2);
1115 if (merged_id == 0) return false;
1116
1117 inst->SetInOperands(
1118 {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1119 {SPV_OPERAND_TYPE_ID, {merged_id}}});
1120 return true;
1121 }
1122 return false;
1123 };
1124 }
1125
1126 // Folds addition of a subtraction where each operation has a constant operand.
1127 // Cases:
1128 // (x - 2) + 2 = x + 0
1129 // (2 - x) + 2 = 4 - x
1130 // 2 + (x - 2) = x + 0
1131 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1132 FoldingRule MergeAddSubArithmetic() {
1133 return [](IRContext* context, Instruction* inst,
1134 const std::vector<const analysis::Constant*>& constants) {
1135 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1136 const analysis::Type* type =
1137 context->get_type_mgr()->GetType(inst->type_id());
1138 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1139 bool uses_float = HasFloatingPoint(type);
1140 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1141
1142 uint32_t width = ElementWidth(type);
1143 if (width != 32 && width != 64) return false;
1144
1145 const analysis::Constant* const_input1 = ConstInput(constants);
1146 if (!const_input1) return false;
1147 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1148 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1149 return false;
1150
1151 if (other_inst->opcode() == SpvOpFSub ||
1152 other_inst->opcode() == SpvOpISub) {
1153 std::vector<const analysis::Constant*> other_constants =
1154 const_mgr->GetOperandConstants(other_inst);
1155 const analysis::Constant* const_input2 = ConstInput(other_constants);
1156 if (!const_input2) return false;
1157
1158 bool first_is_variable = other_constants[0] == nullptr;
1159 SpvOp op = inst->opcode();
1160 uint32_t op1 = 0;
1161 uint32_t op2 = 0;
1162 if (first_is_variable) {
1163 // Subtract constants. Non-constant operand is first.
1164 op1 = other_inst->GetSingleWordInOperand(0u);
1165 op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1166 const_input2);
1167 } else {
1168 // Add constants. Constant operand is first. Change the opcode.
1169 op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1170 const_input2);
1171 op2 = other_inst->GetSingleWordInOperand(1u);
1172 op = other_inst->opcode();
1173 }
1174 if (op1 == 0 || op2 == 0) return false;
1175
1176 inst->SetOpcode(op);
1177 inst->SetInOperands(
1178 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1179 return true;
1180 }
1181 return false;
1182 };
1183 }
1184
1185 // Folds subtraction of an addition where each operand has a constant operand.
1186 // Cases:
1187 // (x + 2) - 2 = x + 0
1188 // (2 + x) - 2 = x + 0
1189 // 2 - (x + 2) = 0 - x
1190 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1191 FoldingRule MergeSubAddArithmetic() {
1192 return [](IRContext* context, Instruction* inst,
1193 const std::vector<const analysis::Constant*>& constants) {
1194 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1195 const analysis::Type* type =
1196 context->get_type_mgr()->GetType(inst->type_id());
1197 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1198 bool uses_float = HasFloatingPoint(type);
1199 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1200
1201 uint32_t width = ElementWidth(type);
1202 if (width != 32 && width != 64) return false;
1203
1204 const analysis::Constant* const_input1 = ConstInput(constants);
1205 if (!const_input1) return false;
1206 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1207 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1208 return false;
1209
1210 if (other_inst->opcode() == SpvOpFAdd ||
1211 other_inst->opcode() == SpvOpIAdd) {
1212 std::vector<const analysis::Constant*> other_constants =
1213 const_mgr->GetOperandConstants(other_inst);
1214 const analysis::Constant* const_input2 = ConstInput(other_constants);
1215 if (!const_input2) return false;
1216
1217 Instruction* non_const_input =
1218 NonConstInput(context, other_constants[0], other_inst);
1219
1220 // If the first operand of the sub is not a constant, swap the constants
1221 // so the subtraction has the correct operands.
1222 if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1223 // Subtract the constants.
1224 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1225 const_input1, const_input2);
1226 SpvOp op = inst->opcode();
1227 uint32_t op1 = 0;
1228 uint32_t op2 = 0;
1229 if (constants[0] == nullptr) {
1230 // Non-constant operand is first. Change the opcode.
1231 op1 = non_const_input->result_id();
1232 op2 = merged_id;
1233 op = other_inst->opcode();
1234 } else {
1235 // Constant operand is first.
1236 op1 = merged_id;
1237 op2 = non_const_input->result_id();
1238 }
1239 if (op1 == 0 || op2 == 0) return false;
1240
1241 inst->SetOpcode(op);
1242 inst->SetInOperands(
1243 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1244 return true;
1245 }
1246 return false;
1247 };
1248 }
1249
1250 // Folds subtraction of a subtraction where each operand has a constant operand.
1251 // Cases:
1252 // (x - 2) - 2 = x - 4
1253 // (2 - x) - 2 = 0 - x
1254 // 2 - (x - 2) = 4 - x
1255 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1256 FoldingRule MergeSubSubArithmetic() {
1257 return [](IRContext* context, Instruction* inst,
1258 const std::vector<const analysis::Constant*>& constants) {
1259 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1260 const analysis::Type* type =
1261 context->get_type_mgr()->GetType(inst->type_id());
1262 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1263 bool uses_float = HasFloatingPoint(type);
1264 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1265
1266 uint32_t width = ElementWidth(type);
1267 if (width != 32 && width != 64) return false;
1268
1269 const analysis::Constant* const_input1 = ConstInput(constants);
1270 if (!const_input1) return false;
1271 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1272 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1273 return false;
1274
1275 if (other_inst->opcode() == SpvOpFSub ||
1276 other_inst->opcode() == SpvOpISub) {
1277 std::vector<const analysis::Constant*> other_constants =
1278 const_mgr->GetOperandConstants(other_inst);
1279 const analysis::Constant* const_input2 = ConstInput(other_constants);
1280 if (!const_input2) return false;
1281
1282 Instruction* non_const_input =
1283 NonConstInput(context, other_constants[0], other_inst);
1284
1285 // Merge the constants.
1286 uint32_t merged_id = 0;
1287 SpvOp merge_op = inst->opcode();
1288 if (other_constants[0] == nullptr) {
1289 merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1290 } else if (constants[0] == nullptr) {
1291 std::swap(const_input1, const_input2);
1292 }
1293 merged_id =
1294 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1295 if (merged_id == 0) return false;
1296
1297 SpvOp op = inst->opcode();
1298 if (constants[0] != nullptr && other_constants[0] != nullptr) {
1299 // Change the operation.
1300 op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1301 }
1302
1303 uint32_t op1 = 0;
1304 uint32_t op2 = 0;
1305 if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1306 op1 = merged_id;
1307 op2 = non_const_input->result_id();
1308 } else {
1309 op1 = non_const_input->result_id();
1310 op2 = merged_id;
1311 }
1312
1313 inst->SetOpcode(op);
1314 inst->SetInOperands(
1315 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1316 return true;
1317 }
1318 return false;
1319 };
1320 }
1321
1322 // Helper function for MergeGenericAddSubArithmetic. If |addend| and
1323 // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
MergeGenericAddendSub(uint32_t addend,uint32_t sub,Instruction * inst)1324 bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1325 IRContext* context = inst->context();
1326 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1327 Instruction* sub_inst = def_use_mgr->GetDef(sub);
1328 if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
1329 return false;
1330 if (sub_inst->opcode() == SpvOpFSub &&
1331 !sub_inst->IsFloatingPointFoldingAllowed())
1332 return false;
1333 if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1334 inst->SetOpcode(SpvOpCopyObject);
1335 inst->SetInOperands(
1336 {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1337 context->UpdateDefUse(inst);
1338 return true;
1339 }
1340
1341 // Folds addition of a subtraction where the subtrahend is equal to the
1342 // other addend. Return a copy of the minuend. Accepts generic (const and
1343 // non-const) operands.
1344 // Cases:
1345 // (a - b) + b = a
1346 // b + (a - b) = a
MergeGenericAddSubArithmetic()1347 FoldingRule MergeGenericAddSubArithmetic() {
1348 return [](IRContext* context, Instruction* inst,
1349 const std::vector<const analysis::Constant*>&) {
1350 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1351 const analysis::Type* type =
1352 context->get_type_mgr()->GetType(inst->type_id());
1353 bool uses_float = HasFloatingPoint(type);
1354 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1355
1356 uint32_t width = ElementWidth(type);
1357 if (width != 32 && width != 64) return false;
1358
1359 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1360 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1361 if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1362 return MergeGenericAddendSub(add_op1, add_op0, inst);
1363 };
1364 }
1365
1366 // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1367 // generate |factor0_0| * (|factor0_1| + |factor1_1|).
FactorAddMulsOpnds(uint32_t factor0_0,uint32_t factor0_1,uint32_t factor1_0,uint32_t factor1_1,Instruction * inst)1368 bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1369 uint32_t factor1_0, uint32_t factor1_1,
1370 Instruction* inst) {
1371 IRContext* context = inst->context();
1372 if (factor0_0 != factor1_0) return false;
1373 InstructionBuilder ir_builder(
1374 context, inst,
1375 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1376 Instruction* new_add_inst = ir_builder.AddBinaryOp(
1377 inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1378 inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
1379 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1380 {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1381 context->UpdateDefUse(inst);
1382 return true;
1383 }
1384
1385 // Perform the following factoring identity, handling all operand order
1386 // combinations: (a * b) + (a * c) = a * (b + c)
FactorAddMuls()1387 FoldingRule FactorAddMuls() {
1388 return [](IRContext* context, Instruction* inst,
1389 const std::vector<const analysis::Constant*>&) {
1390 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1391 const analysis::Type* type =
1392 context->get_type_mgr()->GetType(inst->type_id());
1393 bool uses_float = HasFloatingPoint(type);
1394 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1395
1396 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1397 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1398 Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1399 if (add_op0_inst->opcode() != SpvOpFMul &&
1400 add_op0_inst->opcode() != SpvOpIMul)
1401 return false;
1402 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1403 Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1404 if (add_op1_inst->opcode() != SpvOpFMul &&
1405 add_op1_inst->opcode() != SpvOpIMul)
1406 return false;
1407
1408 // Only perform this optimization if both of the muls only have one use.
1409 // Otherwise this is a deoptimization in size and performance.
1410 if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1411 if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1412
1413 if (add_op0_inst->opcode() == SpvOpFMul &&
1414 (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1415 !add_op1_inst->IsFloatingPointFoldingAllowed()))
1416 return false;
1417
1418 for (int i = 0; i < 2; i++) {
1419 for (int j = 0; j < 2; j++) {
1420 // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1421 if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1422 add_op0_inst->GetSingleWordInOperand(1 - i),
1423 add_op1_inst->GetSingleWordInOperand(j),
1424 add_op1_inst->GetSingleWordInOperand(1 - j),
1425 inst))
1426 return true;
1427 }
1428 }
1429 return false;
1430 };
1431 }
1432
IntMultipleBy1()1433 FoldingRule IntMultipleBy1() {
1434 return [](IRContext*, Instruction* inst,
1435 const std::vector<const analysis::Constant*>& constants) {
1436 assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
1437 for (uint32_t i = 0; i < 2; i++) {
1438 if (constants[i] == nullptr) {
1439 continue;
1440 }
1441 const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1442 if (int_constant) {
1443 uint32_t width = ElementWidth(int_constant->type());
1444 if (width != 32 && width != 64) return false;
1445 bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1446 : int_constant->GetU64BitValue() == 1ull;
1447 if (is_one) {
1448 inst->SetOpcode(SpvOpCopyObject);
1449 inst->SetInOperands(
1450 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1451 return true;
1452 }
1453 }
1454 }
1455 return false;
1456 };
1457 }
1458
1459 // Returns the number of elements that the |index|th in operand in |inst|
1460 // contributes to the result of |inst|. |inst| must be an
1461 // OpCompositeConstructInstruction.
GetNumOfElementsContributedByOperand(IRContext * context,const Instruction * inst,uint32_t index)1462 uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
1463 const Instruction* inst,
1464 uint32_t index) {
1465 assert(inst->opcode() == SpvOpCompositeConstruct);
1466 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1467 analysis::TypeManager* type_mgr = context->get_type_mgr();
1468
1469 analysis::Vector* result_type =
1470 type_mgr->GetType(inst->type_id())->AsVector();
1471 if (result_type == nullptr) {
1472 // If the result of the OpCompositeConstruct is not a vector then every
1473 // operands corresponds to a single element in the result.
1474 return 1;
1475 }
1476
1477 // If the result type is a vector then the operands are either scalars or
1478 // vectors. If it is a scalar, then it corresponds to a single element. If it
1479 // is a vector, then each element in the vector will be an element in the
1480 // result.
1481 uint32_t id = inst->GetSingleWordInOperand(index);
1482 Instruction* def = def_use_mgr->GetDef(id);
1483 analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
1484 if (type == nullptr) {
1485 return 1;
1486 }
1487 return type->element_count();
1488 }
1489
1490 // Returns the in-operands for an OpCompositeExtract instruction that are needed
1491 // to extract the |result_index|th element in the result of |inst| without using
1492 // the result of |inst|. Returns the empty vector if |result_index| is
1493 // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
GetExtractOperandsForElementOfCompositeConstruct(IRContext * context,const Instruction * inst,uint32_t result_index)1494 std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
1495 IRContext* context, const Instruction* inst, uint32_t result_index) {
1496 assert(inst->opcode() == SpvOpCompositeConstruct);
1497 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1498 analysis::TypeManager* type_mgr = context->get_type_mgr();
1499
1500 analysis::Type* result_type = type_mgr->GetType(inst->type_id());
1501 if (result_type->AsVector() == nullptr) {
1502 uint32_t id = inst->GetSingleWordInOperand(result_index);
1503 return {Operand(SPV_OPERAND_TYPE_ID, {id})};
1504 }
1505
1506 // If the result type is a vector, then vector operands are concatenated.
1507 uint32_t total_element_count = 0;
1508 for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
1509 uint32_t element_count =
1510 GetNumOfElementsContributedByOperand(context, inst, idx);
1511 total_element_count += element_count;
1512 if (result_index < total_element_count) {
1513 std::vector<Operand> operands;
1514 uint32_t id = inst->GetSingleWordInOperand(idx);
1515 Instruction* operand_def = def_use_mgr->GetDef(id);
1516 analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
1517
1518 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1519 if (operand_type->AsVector()) {
1520 uint32_t start_index_of_id = total_element_count - element_count;
1521 uint32_t index_into_id = result_index - start_index_of_id;
1522 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
1523 }
1524 return operands;
1525 }
1526 }
1527 return {};
1528 }
1529
CompositeConstructFeedingExtract(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1530 bool CompositeConstructFeedingExtract(
1531 IRContext* context, Instruction* inst,
1532 const std::vector<const analysis::Constant*>&) {
1533 // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1534 // then we can simply use the appropriate element in the construction.
1535 assert(inst->opcode() == SpvOpCompositeExtract &&
1536 "Wrong opcode. Should be OpCompositeExtract.");
1537 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1538
1539 // If there are no index operands, then this rule cannot do anything.
1540 if (inst->NumInOperands() <= 1) {
1541 return false;
1542 }
1543
1544 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1545 Instruction* cinst = def_use_mgr->GetDef(cid);
1546
1547 if (cinst->opcode() != SpvOpCompositeConstruct) {
1548 return false;
1549 }
1550
1551 uint32_t index_into_result = inst->GetSingleWordInOperand(1);
1552 std::vector<Operand> operands =
1553 GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
1554 index_into_result);
1555
1556 if (operands.empty()) {
1557 return false;
1558 }
1559
1560 // Add the remaining indices for extraction.
1561 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1562 operands.push_back(
1563 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
1564 }
1565
1566 if (operands.size() == 1) {
1567 // If there were no extra indices, then we have the final object. No need
1568 // to extract any more.
1569 inst->SetOpcode(SpvOpCopyObject);
1570 }
1571
1572 inst->SetInOperands(std::move(operands));
1573 return true;
1574 }
1575
1576 // If the OpCompositeConstruct is simply putting back together elements that
1577 // where extracted from the same source, we can simply reuse the source.
1578 //
1579 // This is a common code pattern because of the way that scalar replacement
1580 // works.
CompositeExtractFeedingConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1581 bool CompositeExtractFeedingConstruct(
1582 IRContext* context, Instruction* inst,
1583 const std::vector<const analysis::Constant*>&) {
1584 assert(inst->opcode() == SpvOpCompositeConstruct &&
1585 "Wrong opcode. Should be OpCompositeConstruct.");
1586 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1587 uint32_t original_id = 0;
1588
1589 if (inst->NumInOperands() == 0) {
1590 // The struct being constructed has no members.
1591 return false;
1592 }
1593
1594 // Check each element to make sure they are:
1595 // - extractions
1596 // - extracting the same position they are inserting
1597 // - all extract from the same id.
1598 for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1599 const uint32_t element_id = inst->GetSingleWordInOperand(i);
1600 Instruction* element_inst = def_use_mgr->GetDef(element_id);
1601
1602 if (element_inst->opcode() != SpvOpCompositeExtract) {
1603 return false;
1604 }
1605
1606 if (element_inst->NumInOperands() != 2) {
1607 return false;
1608 }
1609
1610 if (element_inst->GetSingleWordInOperand(1) != i) {
1611 return false;
1612 }
1613
1614 if (i == 0) {
1615 original_id =
1616 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1617 } else if (original_id !=
1618 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1619 return false;
1620 }
1621 }
1622
1623 // The last check it to see that the object being extracted from is the
1624 // correct type.
1625 Instruction* original_inst = def_use_mgr->GetDef(original_id);
1626 if (original_inst->type_id() != inst->type_id()) {
1627 return false;
1628 }
1629
1630 // Simplify by using the original object.
1631 inst->SetOpcode(SpvOpCopyObject);
1632 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1633 return true;
1634 }
1635
InsertFeedingExtract()1636 FoldingRule InsertFeedingExtract() {
1637 return [](IRContext* context, Instruction* inst,
1638 const std::vector<const analysis::Constant*>&) {
1639 assert(inst->opcode() == SpvOpCompositeExtract &&
1640 "Wrong opcode. Should be OpCompositeExtract.");
1641 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1642 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1643 Instruction* cinst = def_use_mgr->GetDef(cid);
1644
1645 if (cinst->opcode() != SpvOpCompositeInsert) {
1646 return false;
1647 }
1648
1649 // Find the first position where the list of insert and extract indicies
1650 // differ, if at all.
1651 uint32_t i;
1652 for (i = 1; i < inst->NumInOperands(); ++i) {
1653 if (i + 1 >= cinst->NumInOperands()) {
1654 break;
1655 }
1656
1657 if (inst->GetSingleWordInOperand(i) !=
1658 cinst->GetSingleWordInOperand(i + 1)) {
1659 break;
1660 }
1661 }
1662
1663 // We are extracting the element that was inserted.
1664 if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1665 inst->SetOpcode(SpvOpCopyObject);
1666 inst->SetInOperands(
1667 {{SPV_OPERAND_TYPE_ID,
1668 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1669 return true;
1670 }
1671
1672 // Extracting the value that was inserted along with values for the base
1673 // composite. Cannot do anything.
1674 if (i == inst->NumInOperands()) {
1675 return false;
1676 }
1677
1678 // Extracting an element of the value that was inserted. Extract from
1679 // that value directly.
1680 if (i + 1 == cinst->NumInOperands()) {
1681 std::vector<Operand> operands;
1682 operands.push_back(
1683 {SPV_OPERAND_TYPE_ID,
1684 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1685 for (; i < inst->NumInOperands(); ++i) {
1686 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1687 {inst->GetSingleWordInOperand(i)}});
1688 }
1689 inst->SetInOperands(std::move(operands));
1690 return true;
1691 }
1692
1693 // Extracting a value that is disjoint from the element being inserted.
1694 // Rewrite the extract to use the composite input to the insert.
1695 std::vector<Operand> operands;
1696 operands.push_back(
1697 {SPV_OPERAND_TYPE_ID,
1698 {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1699 for (i = 1; i < inst->NumInOperands(); ++i) {
1700 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1701 {inst->GetSingleWordInOperand(i)}});
1702 }
1703 inst->SetInOperands(std::move(operands));
1704 return true;
1705 };
1706 }
1707
1708 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1709 // operands of the VectorShuffle. We just need to adjust the index in the
1710 // extract instruction.
VectorShuffleFeedingExtract()1711 FoldingRule VectorShuffleFeedingExtract() {
1712 return [](IRContext* context, Instruction* inst,
1713 const std::vector<const analysis::Constant*>&) {
1714 assert(inst->opcode() == SpvOpCompositeExtract &&
1715 "Wrong opcode. Should be OpCompositeExtract.");
1716 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1717 analysis::TypeManager* type_mgr = context->get_type_mgr();
1718 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1719 Instruction* cinst = def_use_mgr->GetDef(cid);
1720
1721 if (cinst->opcode() != SpvOpVectorShuffle) {
1722 return false;
1723 }
1724
1725 // Find the size of the first vector operand of the VectorShuffle
1726 Instruction* first_input =
1727 def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1728 analysis::Type* first_input_type =
1729 type_mgr->GetType(first_input->type_id());
1730 assert(first_input_type->AsVector() &&
1731 "Input to vector shuffle should be vectors.");
1732 uint32_t first_input_size = first_input_type->AsVector()->element_count();
1733
1734 // Get index of the element the vector shuffle is placing in the position
1735 // being extracted.
1736 uint32_t new_index =
1737 cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1738
1739 // Extracting an undefined value so fold this extract into an undef.
1740 const uint32_t undef_literal_value = 0xffffffff;
1741 if (new_index == undef_literal_value) {
1742 inst->SetOpcode(SpvOpUndef);
1743 inst->SetInOperands({});
1744 return true;
1745 }
1746
1747 // Get the id of the of the vector the elemtent comes from, and update the
1748 // index if needed.
1749 uint32_t new_vector = 0;
1750 if (new_index < first_input_size) {
1751 new_vector = cinst->GetSingleWordInOperand(0);
1752 } else {
1753 new_vector = cinst->GetSingleWordInOperand(1);
1754 new_index -= first_input_size;
1755 }
1756
1757 // Update the extract instruction.
1758 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1759 inst->SetInOperand(1, {new_index});
1760 return true;
1761 };
1762 }
1763
1764 // When an FMix with is feeding an Extract that extracts an element whose
1765 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1766 // operands of the FMix.
FMixFeedingExtract()1767 FoldingRule FMixFeedingExtract() {
1768 return [](IRContext* context, Instruction* inst,
1769 const std::vector<const analysis::Constant*>&) {
1770 assert(inst->opcode() == SpvOpCompositeExtract &&
1771 "Wrong opcode. Should be OpCompositeExtract.");
1772 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1773 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1774
1775 uint32_t composite_id =
1776 inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1777 Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1778
1779 if (composite_inst->opcode() != SpvOpExtInst) {
1780 return false;
1781 }
1782
1783 uint32_t inst_set_id =
1784 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1785
1786 if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1787 inst_set_id ||
1788 composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1789 GLSLstd450FMix) {
1790 return false;
1791 }
1792
1793 // Get the |a| for the FMix instruction.
1794 uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1795 std::unique_ptr<Instruction> a(inst->Clone(context));
1796 a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1797 context->get_instruction_folder().FoldInstruction(a.get());
1798
1799 if (a->opcode() != SpvOpCopyObject) {
1800 return false;
1801 }
1802
1803 const analysis::Constant* a_const =
1804 const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1805
1806 if (!a_const) {
1807 return false;
1808 }
1809
1810 bool use_x = false;
1811
1812 assert(a_const->type()->AsFloat());
1813 double element_value = a_const->GetValueAsDouble();
1814 if (element_value == 0.0) {
1815 use_x = true;
1816 } else if (element_value == 1.0) {
1817 use_x = false;
1818 } else {
1819 return false;
1820 }
1821
1822 // Get the id of the of the vector the element comes from.
1823 uint32_t new_vector = 0;
1824 if (use_x) {
1825 new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1826 } else {
1827 new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1828 }
1829
1830 // Update the extract instruction.
1831 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1832 return true;
1833 };
1834 }
1835
RedundantPhi()1836 FoldingRule RedundantPhi() {
1837 // An OpPhi instruction where all values are the same or the result of the phi
1838 // itself, can be replaced by the value itself.
1839 return [](IRContext*, Instruction* inst,
1840 const std::vector<const analysis::Constant*>&) {
1841 assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
1842
1843 uint32_t incoming_value = 0;
1844
1845 for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
1846 uint32_t op_id = inst->GetSingleWordInOperand(i);
1847 if (op_id == inst->result_id()) {
1848 continue;
1849 }
1850
1851 if (incoming_value == 0) {
1852 incoming_value = op_id;
1853 } else if (op_id != incoming_value) {
1854 // Found two possible value. Can't simplify.
1855 return false;
1856 }
1857 }
1858
1859 if (incoming_value == 0) {
1860 // Code looks invalid. Don't do anything.
1861 return false;
1862 }
1863
1864 // We have a single incoming value. Simplify using that value.
1865 inst->SetOpcode(SpvOpCopyObject);
1866 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
1867 return true;
1868 };
1869 }
1870
BitCastScalarOrVector()1871 FoldingRule BitCastScalarOrVector() {
1872 return [](IRContext* context, Instruction* inst,
1873 const std::vector<const analysis::Constant*>& constants) {
1874 assert(inst->opcode() == SpvOpBitcast && constants.size() == 1);
1875 if (constants[0] == nullptr) return false;
1876
1877 const analysis::Type* type =
1878 context->get_type_mgr()->GetType(inst->type_id());
1879 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
1880 return false;
1881
1882 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1883 std::vector<uint32_t> words =
1884 GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
1885 if (words.size() == 0) return false;
1886
1887 const analysis::Constant* bitcasted_constant =
1888 ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
1889 if (!bitcasted_constant) return false;
1890
1891 auto new_feeder_id =
1892 const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
1893 ->result_id();
1894 inst->SetOpcode(SpvOpCopyObject);
1895 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
1896 return true;
1897 };
1898 }
1899
RedundantSelect()1900 FoldingRule RedundantSelect() {
1901 // An OpSelect instruction where both values are the same or the condition is
1902 // constant can be replaced by one of the values
1903 return [](IRContext*, Instruction* inst,
1904 const std::vector<const analysis::Constant*>& constants) {
1905 assert(inst->opcode() == SpvOpSelect &&
1906 "Wrong opcode. Should be OpSelect.");
1907 assert(inst->NumInOperands() == 3);
1908 assert(constants.size() == 3);
1909
1910 uint32_t true_id = inst->GetSingleWordInOperand(1);
1911 uint32_t false_id = inst->GetSingleWordInOperand(2);
1912
1913 if (true_id == false_id) {
1914 // Both results are the same, condition doesn't matter
1915 inst->SetOpcode(SpvOpCopyObject);
1916 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1917 return true;
1918 } else if (constants[0]) {
1919 const analysis::Type* type = constants[0]->type();
1920 if (type->AsBool()) {
1921 // Scalar constant value, select the corresponding value.
1922 inst->SetOpcode(SpvOpCopyObject);
1923 if (constants[0]->AsNullConstant() ||
1924 !constants[0]->AsBoolConstant()->value()) {
1925 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1926 } else {
1927 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1928 }
1929 return true;
1930 } else {
1931 assert(type->AsVector());
1932 if (constants[0]->AsNullConstant()) {
1933 // All values come from false id.
1934 inst->SetOpcode(SpvOpCopyObject);
1935 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1936 return true;
1937 } else {
1938 // Convert to a vector shuffle.
1939 std::vector<Operand> ops;
1940 ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
1941 ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
1942 const analysis::VectorConstant* vector_const =
1943 constants[0]->AsVectorConstant();
1944 uint32_t size =
1945 static_cast<uint32_t>(vector_const->GetComponents().size());
1946 for (uint32_t i = 0; i != size; ++i) {
1947 const analysis::Constant* component =
1948 vector_const->GetComponents()[i];
1949 if (component->AsNullConstant() ||
1950 !component->AsBoolConstant()->value()) {
1951 // Selecting from the false vector which is the second input
1952 // vector to the shuffle. Offset the index by |size|.
1953 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
1954 } else {
1955 // Selecting from true vector which is the first input vector to
1956 // the shuffle.
1957 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
1958 }
1959 }
1960
1961 inst->SetOpcode(SpvOpVectorShuffle);
1962 inst->SetInOperands(std::move(ops));
1963 return true;
1964 }
1965 }
1966 }
1967
1968 return false;
1969 };
1970 }
1971
1972 enum class FloatConstantKind { Unknown, Zero, One };
1973
getFloatConstantKind(const analysis::Constant * constant)1974 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
1975 if (constant == nullptr) {
1976 return FloatConstantKind::Unknown;
1977 }
1978
1979 assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
1980
1981 if (constant->AsNullConstant()) {
1982 return FloatConstantKind::Zero;
1983 } else if (const analysis::VectorConstant* vc =
1984 constant->AsVectorConstant()) {
1985 const std::vector<const analysis::Constant*>& components =
1986 vc->GetComponents();
1987 assert(!components.empty());
1988
1989 FloatConstantKind kind = getFloatConstantKind(components[0]);
1990
1991 for (size_t i = 1; i < components.size(); ++i) {
1992 if (getFloatConstantKind(components[i]) != kind) {
1993 return FloatConstantKind::Unknown;
1994 }
1995 }
1996
1997 return kind;
1998 } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
1999 if (fc->IsZero()) return FloatConstantKind::Zero;
2000
2001 uint32_t width = fc->type()->AsFloat()->width();
2002 if (width != 32 && width != 64) return FloatConstantKind::Unknown;
2003
2004 double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
2005
2006 if (value == 0.0) {
2007 return FloatConstantKind::Zero;
2008 } else if (value == 1.0) {
2009 return FloatConstantKind::One;
2010 } else {
2011 return FloatConstantKind::Unknown;
2012 }
2013 } else {
2014 return FloatConstantKind::Unknown;
2015 }
2016 }
2017
RedundantFAdd()2018 FoldingRule RedundantFAdd() {
2019 return [](IRContext*, Instruction* inst,
2020 const std::vector<const analysis::Constant*>& constants) {
2021 assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
2022 assert(constants.size() == 2);
2023
2024 if (!inst->IsFloatingPointFoldingAllowed()) {
2025 return false;
2026 }
2027
2028 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2029 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2030
2031 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2032 inst->SetOpcode(SpvOpCopyObject);
2033 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2034 {inst->GetSingleWordInOperand(
2035 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
2036 return true;
2037 }
2038
2039 return false;
2040 };
2041 }
2042
RedundantFSub()2043 FoldingRule RedundantFSub() {
2044 return [](IRContext*, Instruction* inst,
2045 const std::vector<const analysis::Constant*>& constants) {
2046 assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
2047 assert(constants.size() == 2);
2048
2049 if (!inst->IsFloatingPointFoldingAllowed()) {
2050 return false;
2051 }
2052
2053 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2054 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2055
2056 if (kind0 == FloatConstantKind::Zero) {
2057 inst->SetOpcode(SpvOpFNegate);
2058 inst->SetInOperands(
2059 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
2060 return true;
2061 }
2062
2063 if (kind1 == FloatConstantKind::Zero) {
2064 inst->SetOpcode(SpvOpCopyObject);
2065 inst->SetInOperands(
2066 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2067 return true;
2068 }
2069
2070 return false;
2071 };
2072 }
2073
RedundantFMul()2074 FoldingRule RedundantFMul() {
2075 return [](IRContext*, Instruction* inst,
2076 const std::vector<const analysis::Constant*>& constants) {
2077 assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
2078 assert(constants.size() == 2);
2079
2080 if (!inst->IsFloatingPointFoldingAllowed()) {
2081 return false;
2082 }
2083
2084 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2085 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2086
2087 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2088 inst->SetOpcode(SpvOpCopyObject);
2089 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2090 {inst->GetSingleWordInOperand(
2091 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
2092 return true;
2093 }
2094
2095 if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
2096 inst->SetOpcode(SpvOpCopyObject);
2097 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2098 {inst->GetSingleWordInOperand(
2099 kind0 == FloatConstantKind::One ? 1 : 0)}}});
2100 return true;
2101 }
2102
2103 return false;
2104 };
2105 }
2106
RedundantFDiv()2107 FoldingRule RedundantFDiv() {
2108 return [](IRContext*, Instruction* inst,
2109 const std::vector<const analysis::Constant*>& constants) {
2110 assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
2111 assert(constants.size() == 2);
2112
2113 if (!inst->IsFloatingPointFoldingAllowed()) {
2114 return false;
2115 }
2116
2117 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2118 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2119
2120 if (kind0 == FloatConstantKind::Zero) {
2121 inst->SetOpcode(SpvOpCopyObject);
2122 inst->SetInOperands(
2123 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2124 return true;
2125 }
2126
2127 if (kind1 == FloatConstantKind::One) {
2128 inst->SetOpcode(SpvOpCopyObject);
2129 inst->SetInOperands(
2130 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2131 return true;
2132 }
2133
2134 return false;
2135 };
2136 }
2137
RedundantFMix()2138 FoldingRule RedundantFMix() {
2139 return [](IRContext* context, Instruction* inst,
2140 const std::vector<const analysis::Constant*>& constants) {
2141 assert(inst->opcode() == SpvOpExtInst &&
2142 "Wrong opcode. Should be OpExtInst.");
2143
2144 if (!inst->IsFloatingPointFoldingAllowed()) {
2145 return false;
2146 }
2147
2148 uint32_t instSetId =
2149 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2150
2151 if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2152 inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2153 GLSLstd450FMix) {
2154 assert(constants.size() == 5);
2155
2156 FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2157
2158 if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2159 inst->SetOpcode(SpvOpCopyObject);
2160 inst->SetInOperands(
2161 {{SPV_OPERAND_TYPE_ID,
2162 {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2163 ? kFMixXIdInIdx
2164 : kFMixYIdInIdx)}}});
2165 return true;
2166 }
2167 }
2168
2169 return false;
2170 };
2171 }
2172
2173 // This rule handles addition of zero for integers.
RedundantIAdd()2174 FoldingRule RedundantIAdd() {
2175 return [](IRContext* context, Instruction* inst,
2176 const std::vector<const analysis::Constant*>& constants) {
2177 assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
2178
2179 uint32_t operand = std::numeric_limits<uint32_t>::max();
2180 const analysis::Type* operand_type = nullptr;
2181 if (constants[0] && constants[0]->IsZero()) {
2182 operand = inst->GetSingleWordInOperand(1);
2183 operand_type = constants[0]->type();
2184 } else if (constants[1] && constants[1]->IsZero()) {
2185 operand = inst->GetSingleWordInOperand(0);
2186 operand_type = constants[1]->type();
2187 }
2188
2189 if (operand != std::numeric_limits<uint32_t>::max()) {
2190 const analysis::Type* inst_type =
2191 context->get_type_mgr()->GetType(inst->type_id());
2192 if (inst_type->IsSame(operand_type)) {
2193 inst->SetOpcode(SpvOpCopyObject);
2194 } else {
2195 inst->SetOpcode(SpvOpBitcast);
2196 }
2197 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2198 return true;
2199 }
2200 return false;
2201 };
2202 }
2203
2204 // This rule look for a dot with a constant vector containing a single 1 and
2205 // the rest 0s. This is the same as doing an extract.
DotProductDoingExtract()2206 FoldingRule DotProductDoingExtract() {
2207 return [](IRContext* context, Instruction* inst,
2208 const std::vector<const analysis::Constant*>& constants) {
2209 assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
2210
2211 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2212
2213 if (!inst->IsFloatingPointFoldingAllowed()) {
2214 return false;
2215 }
2216
2217 for (int i = 0; i < 2; ++i) {
2218 if (!constants[i]) {
2219 continue;
2220 }
2221
2222 const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2223 assert(vector_type && "Inputs to OpDot must be vectors.");
2224 const analysis::Float* element_type =
2225 vector_type->element_type()->AsFloat();
2226 assert(element_type && "Inputs to OpDot must be vectors of floats.");
2227 uint32_t element_width = element_type->width();
2228 if (element_width != 32 && element_width != 64) {
2229 return false;
2230 }
2231
2232 std::vector<const analysis::Constant*> components;
2233 components = constants[i]->GetVectorComponents(const_mgr);
2234
2235 const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2236
2237 uint32_t component_with_one = kNotFound;
2238 bool all_others_zero = true;
2239 for (uint32_t j = 0; j < components.size(); ++j) {
2240 const analysis::Constant* element = components[j];
2241 double value =
2242 (element_width == 32 ? element->GetFloat() : element->GetDouble());
2243 if (value == 0.0) {
2244 continue;
2245 } else if (value == 1.0) {
2246 if (component_with_one == kNotFound) {
2247 component_with_one = j;
2248 } else {
2249 component_with_one = kNotFound;
2250 break;
2251 }
2252 } else {
2253 all_others_zero = false;
2254 break;
2255 }
2256 }
2257
2258 if (!all_others_zero || component_with_one == kNotFound) {
2259 continue;
2260 }
2261
2262 std::vector<Operand> operands;
2263 operands.push_back(
2264 {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2265 operands.push_back(
2266 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2267
2268 inst->SetOpcode(SpvOpCompositeExtract);
2269 inst->SetInOperands(std::move(operands));
2270 return true;
2271 }
2272 return false;
2273 };
2274 }
2275
2276 // If we are storing an undef, then we can remove the store.
2277 //
2278 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2279 // is complicated. Waiting to see if it is needed.
StoringUndef()2280 FoldingRule StoringUndef() {
2281 return [](IRContext* context, Instruction* inst,
2282 const std::vector<const analysis::Constant*>&) {
2283 assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
2284
2285 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2286
2287 // If this is a volatile store, the store cannot be removed.
2288 if (inst->NumInOperands() == 3) {
2289 if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
2290 return false;
2291 }
2292 }
2293
2294 uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2295 Instruction* object_inst = def_use_mgr->GetDef(object_id);
2296 if (object_inst->opcode() == SpvOpUndef) {
2297 inst->ToNop();
2298 return true;
2299 }
2300 return false;
2301 };
2302 }
2303
VectorShuffleFeedingShuffle()2304 FoldingRule VectorShuffleFeedingShuffle() {
2305 return [](IRContext* context, Instruction* inst,
2306 const std::vector<const analysis::Constant*>&) {
2307 assert(inst->opcode() == SpvOpVectorShuffle &&
2308 "Wrong opcode. Should be OpVectorShuffle.");
2309
2310 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2311 analysis::TypeManager* type_mgr = context->get_type_mgr();
2312
2313 Instruction* feeding_shuffle_inst =
2314 def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2315 analysis::Vector* op0_type =
2316 type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2317 uint32_t op0_length = op0_type->element_count();
2318
2319 bool feeder_is_op0 = true;
2320 if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2321 feeding_shuffle_inst =
2322 def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2323 feeder_is_op0 = false;
2324 }
2325
2326 if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2327 return false;
2328 }
2329
2330 Instruction* feeder2 =
2331 def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2332 analysis::Vector* feeder_op0_type =
2333 type_mgr->GetType(feeder2->type_id())->AsVector();
2334 uint32_t feeder_op0_length = feeder_op0_type->element_count();
2335
2336 uint32_t new_feeder_id = 0;
2337 std::vector<Operand> new_operands;
2338 new_operands.resize(
2339 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
2340 const uint32_t undef_literal = 0xffffffff;
2341 for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2342 uint32_t component_index = inst->GetSingleWordInOperand(op);
2343
2344 // Do not interpret the undefined value literal as coming from operand 1.
2345 if (component_index != undef_literal &&
2346 feeder_is_op0 == (component_index < op0_length)) {
2347 // This component comes from the feeding_shuffle_inst. Update
2348 // |component_index| to be the index into the operand of the feeder.
2349
2350 // Adjust component_index to get the index into the operands of the
2351 // feeding_shuffle_inst.
2352 if (component_index >= op0_length) {
2353 component_index -= op0_length;
2354 }
2355 component_index =
2356 feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2357
2358 // Check if we are using a component from the first or second operand of
2359 // the feeding instruction.
2360 if (component_index < feeder_op0_length) {
2361 if (new_feeder_id == 0) {
2362 // First time through, save the id of the operand the element comes
2363 // from.
2364 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2365 } else if (new_feeder_id !=
2366 feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2367 // We need both elements of the feeding_shuffle_inst, so we cannot
2368 // fold.
2369 return false;
2370 }
2371 } else {
2372 if (new_feeder_id == 0) {
2373 // First time through, save the id of the operand the element comes
2374 // from.
2375 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2376 } else if (new_feeder_id !=
2377 feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2378 // We need both elements of the feeding_shuffle_inst, so we cannot
2379 // fold.
2380 return false;
2381 }
2382 component_index -= feeder_op0_length;
2383 }
2384
2385 if (!feeder_is_op0) {
2386 component_index += op0_length;
2387 }
2388 }
2389 new_operands.push_back(
2390 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2391 }
2392
2393 if (new_feeder_id == 0) {
2394 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2395 const analysis::Type* type =
2396 type_mgr->GetType(feeding_shuffle_inst->type_id());
2397 const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2398 new_feeder_id =
2399 const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2400 }
2401
2402 if (feeder_is_op0) {
2403 // If the size of the first vector operand changed then the indices
2404 // referring to the second operand need to be adjusted.
2405 Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2406 analysis::Type* new_feeder_type =
2407 type_mgr->GetType(new_feeder_inst->type_id());
2408 uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2409 int32_t adjustment = op0_length - new_op0_size;
2410
2411 if (adjustment != 0) {
2412 for (uint32_t i = 2; i < new_operands.size(); i++) {
2413 if (inst->GetSingleWordInOperand(i) >= op0_length) {
2414 new_operands[i].words[0] -= adjustment;
2415 }
2416 }
2417 }
2418
2419 new_operands[0].words[0] = new_feeder_id;
2420 new_operands[1] = inst->GetInOperand(1);
2421 } else {
2422 new_operands[1].words[0] = new_feeder_id;
2423 new_operands[0] = inst->GetInOperand(0);
2424 }
2425
2426 inst->SetInOperands(std::move(new_operands));
2427 return true;
2428 };
2429 }
2430
2431 // Removes duplicate ids from the interface list of an OpEntryPoint
2432 // instruction.
RemoveRedundantOperands()2433 FoldingRule RemoveRedundantOperands() {
2434 return [](IRContext*, Instruction* inst,
2435 const std::vector<const analysis::Constant*>&) {
2436 assert(inst->opcode() == SpvOpEntryPoint &&
2437 "Wrong opcode. Should be OpEntryPoint.");
2438 bool has_redundant_operand = false;
2439 std::unordered_set<uint32_t> seen_operands;
2440 std::vector<Operand> new_operands;
2441
2442 new_operands.emplace_back(inst->GetOperand(0));
2443 new_operands.emplace_back(inst->GetOperand(1));
2444 new_operands.emplace_back(inst->GetOperand(2));
2445 for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2446 if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2447 new_operands.emplace_back(inst->GetOperand(i));
2448 } else {
2449 has_redundant_operand = true;
2450 }
2451 }
2452
2453 if (!has_redundant_operand) {
2454 return false;
2455 }
2456
2457 inst->SetInOperands(std::move(new_operands));
2458 return true;
2459 };
2460 }
2461
2462 // If an image instruction's operand is a constant, updates the image operand
2463 // flag from Offset to ConstOffset.
UpdateImageOperands()2464 FoldingRule UpdateImageOperands() {
2465 return [](IRContext*, Instruction* inst,
2466 const std::vector<const analysis::Constant*>& constants) {
2467 const auto opcode = inst->opcode();
2468 (void)opcode;
2469 assert((opcode == SpvOpImageSampleImplicitLod ||
2470 opcode == SpvOpImageSampleExplicitLod ||
2471 opcode == SpvOpImageSampleDrefImplicitLod ||
2472 opcode == SpvOpImageSampleDrefExplicitLod ||
2473 opcode == SpvOpImageSampleProjImplicitLod ||
2474 opcode == SpvOpImageSampleProjExplicitLod ||
2475 opcode == SpvOpImageSampleProjDrefImplicitLod ||
2476 opcode == SpvOpImageSampleProjDrefExplicitLod ||
2477 opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
2478 opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
2479 opcode == SpvOpImageWrite ||
2480 opcode == SpvOpImageSparseSampleImplicitLod ||
2481 opcode == SpvOpImageSparseSampleExplicitLod ||
2482 opcode == SpvOpImageSparseSampleDrefImplicitLod ||
2483 opcode == SpvOpImageSparseSampleDrefExplicitLod ||
2484 opcode == SpvOpImageSparseSampleProjImplicitLod ||
2485 opcode == SpvOpImageSparseSampleProjExplicitLod ||
2486 opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
2487 opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
2488 opcode == SpvOpImageSparseFetch ||
2489 opcode == SpvOpImageSparseGather ||
2490 opcode == SpvOpImageSparseDrefGather ||
2491 opcode == SpvOpImageSparseRead) &&
2492 "Wrong opcode. Should be an image instruction.");
2493
2494 int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2495 if (operand_index >= 0) {
2496 auto image_operands = inst->GetSingleWordInOperand(operand_index);
2497 if (image_operands & SpvImageOperandsOffsetMask) {
2498 uint32_t offset_operand_index = operand_index + 1;
2499 if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
2500 if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
2501 if (image_operands & SpvImageOperandsGradMask)
2502 offset_operand_index += 2;
2503 assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
2504 "Offset and ConstOffset may not be used together");
2505 if (offset_operand_index < inst->NumOperands()) {
2506 if (constants[offset_operand_index]) {
2507 image_operands = image_operands | SpvImageOperandsConstOffsetMask;
2508 image_operands = image_operands & ~SpvImageOperandsOffsetMask;
2509 inst->SetInOperand(operand_index, {image_operands});
2510 return true;
2511 }
2512 }
2513 }
2514 }
2515
2516 return false;
2517 };
2518 }
2519
2520 } // namespace
2521
AddFoldingRules()2522 void FoldingRules::AddFoldingRules() {
2523 // Add all folding rules to the list for the opcodes to which they apply.
2524 // Note that the order in which rules are added to the list matters. If a rule
2525 // applies to the instruction, the rest of the rules will not be attempted.
2526 // Take that into consideration.
2527 rules_[SpvOpBitcast].push_back(BitCastScalarOrVector());
2528
2529 rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
2530
2531 rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
2532 rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract);
2533 rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2534 rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
2535
2536 rules_[SpvOpDot].push_back(DotProductDoingExtract());
2537
2538 rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
2539
2540 rules_[SpvOpFAdd].push_back(RedundantFAdd());
2541 rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
2542 rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
2543 rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
2544 rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
2545 rules_[SpvOpFAdd].push_back(FactorAddMuls());
2546
2547 rules_[SpvOpFDiv].push_back(RedundantFDiv());
2548 rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
2549 rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
2550 rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
2551 rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
2552
2553 rules_[SpvOpFMul].push_back(RedundantFMul());
2554 rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
2555 rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
2556 rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
2557
2558 rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
2559 rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
2560 rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
2561
2562 rules_[SpvOpFSub].push_back(RedundantFSub());
2563 rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
2564 rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
2565 rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
2566
2567 rules_[SpvOpIAdd].push_back(RedundantIAdd());
2568 rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
2569 rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
2570 rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
2571 rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
2572 rules_[SpvOpIAdd].push_back(FactorAddMuls());
2573
2574 rules_[SpvOpIMul].push_back(IntMultipleBy1());
2575 rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
2576 rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
2577
2578 rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
2579 rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
2580 rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
2581
2582 rules_[SpvOpPhi].push_back(RedundantPhi());
2583
2584 rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
2585 rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
2586 rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
2587
2588 rules_[SpvOpSelect].push_back(RedundantSelect());
2589
2590 rules_[SpvOpStore].push_back(StoringUndef());
2591
2592 rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2593
2594 rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
2595 rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
2596 rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
2597 rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
2598 rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
2599 rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
2600 rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
2601 rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
2602 rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
2603 rules_[SpvOpImageGather].push_back(UpdateImageOperands());
2604 rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
2605 rules_[SpvOpImageRead].push_back(UpdateImageOperands());
2606 rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
2607 rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
2608 rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
2609 rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
2610 UpdateImageOperands());
2611 rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
2612 UpdateImageOperands());
2613 rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
2614 UpdateImageOperands());
2615 rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
2616 UpdateImageOperands());
2617 rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
2618 UpdateImageOperands());
2619 rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
2620 UpdateImageOperands());
2621 rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
2622 rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
2623 rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
2624 rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
2625
2626 FeatureManager* feature_manager = context_->get_feature_mgr();
2627 // Add rules for GLSLstd450
2628 uint32_t ext_inst_glslstd450_id =
2629 feature_manager->GetExtInstImportId_GLSLstd450();
2630 if (ext_inst_glslstd450_id != 0) {
2631 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
2632 RedundantFMix());
2633 }
2634 }
2635 } // namespace opt
2636 } // namespace spvtools
2637