1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/folding_rules.h"
16
17 #include <limits>
18 #include <memory>
19 #include <utility>
20
21 #include "ir_builder.h"
22 #include "source/latest_version_glsl_std_450_header.h"
23 #include "source/opt/ir_context.h"
24
25 namespace spvtools {
26 namespace opt {
27 namespace {
28
29 constexpr uint32_t kExtractCompositeIdInIdx = 0;
30 constexpr uint32_t kInsertObjectIdInIdx = 0;
31 constexpr uint32_t kInsertCompositeIdInIdx = 1;
32 constexpr uint32_t kExtInstSetIdInIdx = 0;
33 constexpr uint32_t kExtInstInstructionInIdx = 1;
34 constexpr uint32_t kFMixXIdInIdx = 2;
35 constexpr uint32_t kFMixYIdInIdx = 3;
36 constexpr uint32_t kFMixAIdInIdx = 4;
37 constexpr uint32_t kStoreObjectInIdx = 1;
38
39 // Some image instructions may contain an "image operands" argument.
40 // Returns the operand index for the "image operands".
41 // Returns -1 if the instruction does not have image operands.
ImageOperandsMaskInOperandIndex(Instruction * inst)42 int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
43 const auto opcode = inst->opcode();
44 switch (opcode) {
45 case spv::Op::OpImageSampleImplicitLod:
46 case spv::Op::OpImageSampleExplicitLod:
47 case spv::Op::OpImageSampleProjImplicitLod:
48 case spv::Op::OpImageSampleProjExplicitLod:
49 case spv::Op::OpImageFetch:
50 case spv::Op::OpImageRead:
51 case spv::Op::OpImageSparseSampleImplicitLod:
52 case spv::Op::OpImageSparseSampleExplicitLod:
53 case spv::Op::OpImageSparseSampleProjImplicitLod:
54 case spv::Op::OpImageSparseSampleProjExplicitLod:
55 case spv::Op::OpImageSparseFetch:
56 case spv::Op::OpImageSparseRead:
57 return inst->NumOperands() > 4 ? 2 : -1;
58 case spv::Op::OpImageSampleDrefImplicitLod:
59 case spv::Op::OpImageSampleDrefExplicitLod:
60 case spv::Op::OpImageSampleProjDrefImplicitLod:
61 case spv::Op::OpImageSampleProjDrefExplicitLod:
62 case spv::Op::OpImageGather:
63 case spv::Op::OpImageDrefGather:
64 case spv::Op::OpImageSparseSampleDrefImplicitLod:
65 case spv::Op::OpImageSparseSampleDrefExplicitLod:
66 case spv::Op::OpImageSparseSampleProjDrefImplicitLod:
67 case spv::Op::OpImageSparseSampleProjDrefExplicitLod:
68 case spv::Op::OpImageSparseGather:
69 case spv::Op::OpImageSparseDrefGather:
70 return inst->NumOperands() > 5 ? 3 : -1;
71 case spv::Op::OpImageWrite:
72 return inst->NumOperands() > 3 ? 3 : -1;
73 default:
74 return -1;
75 }
76 }
77
78 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)79 uint32_t ElementWidth(const analysis::Type* type) {
80 if (const analysis::Vector* vec_type = type->AsVector()) {
81 return ElementWidth(vec_type->element_type());
82 } else if (const analysis::Float* float_type = type->AsFloat()) {
83 return float_type->width();
84 } else {
85 assert(type->AsInteger());
86 return type->AsInteger()->width();
87 }
88 }
89
90 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)91 bool HasFloatingPoint(const analysis::Type* type) {
92 if (type->AsFloat()) {
93 return true;
94 } else if (const analysis::Vector* vec_type = type->AsVector()) {
95 return vec_type->element_type()->AsFloat() != nullptr;
96 }
97
98 return false;
99 }
100
101 // Returns false if |val| is NaN, infinite or subnormal.
102 template <typename T>
IsValidResult(T val)103 bool IsValidResult(T val) {
104 int classified = std::fpclassify(val);
105 switch (classified) {
106 case FP_NAN:
107 case FP_INFINITE:
108 case FP_SUBNORMAL:
109 return false;
110 default:
111 return true;
112 }
113 }
114
ConstInput(const std::vector<const analysis::Constant * > & constants)115 const analysis::Constant* ConstInput(
116 const std::vector<const analysis::Constant*>& constants) {
117 return constants[0] ? constants[0] : constants[1];
118 }
119
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)120 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
121 Instruction* inst) {
122 uint32_t in_op = c ? 1u : 0u;
123 return context->get_def_use_mgr()->GetDef(
124 inst->GetSingleWordInOperand(in_op));
125 }
126
ExtractInts(uint64_t val)127 std::vector<uint32_t> ExtractInts(uint64_t val) {
128 std::vector<uint32_t> words;
129 words.push_back(static_cast<uint32_t>(val));
130 words.push_back(static_cast<uint32_t>(val >> 32));
131 return words;
132 }
133
GetWordsFromScalarIntConstant(const analysis::IntConstant * c)134 std::vector<uint32_t> GetWordsFromScalarIntConstant(
135 const analysis::IntConstant* c) {
136 assert(c != nullptr);
137 uint32_t width = c->type()->AsInteger()->width();
138 assert(width == 8 || width == 16 || width == 32 || width == 64);
139 if (width == 64) {
140 uint64_t uval = static_cast<uint64_t>(c->GetU64());
141 return ExtractInts(uval);
142 }
143 // Section 2.2.1 of the SPIR-V spec guarantees that all integer types
144 // smaller than 32-bits are automatically zero or sign extended to 32-bits.
145 return {c->GetU32BitValue()};
146 }
147
GetWordsFromScalarFloatConstant(const analysis::FloatConstant * c)148 std::vector<uint32_t> GetWordsFromScalarFloatConstant(
149 const analysis::FloatConstant* c) {
150 assert(c != nullptr);
151 uint32_t width = c->type()->AsFloat()->width();
152 assert(width == 16 || width == 32 || width == 64);
153 if (width == 64) {
154 utils::FloatProxy<double> result(c->GetDouble());
155 return result.GetWords();
156 }
157 // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types
158 // smaller than 32-bits are automatically zero extended to 32-bits.
159 return {c->GetU32BitValue()};
160 }
161
GetWordsFromNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)162 std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
163 analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
164 if (const auto* float_constant = c->AsFloatConstant()) {
165 return GetWordsFromScalarFloatConstant(float_constant);
166 } else if (const auto* int_constant = c->AsIntConstant()) {
167 return GetWordsFromScalarIntConstant(int_constant);
168 } else if (const auto* vec_constant = c->AsVectorConstant()) {
169 std::vector<uint32_t> words;
170 for (const auto* comp : vec_constant->GetComponents()) {
171 auto comp_in_words =
172 GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
173 words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
174 }
175 return words;
176 }
177 return {};
178 }
179
ConvertWordsToNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const std::vector<uint32_t> & words,const analysis::Type * type)180 const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
181 analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
182 const analysis::Type* type) {
183 if (type->AsInteger() || type->AsFloat())
184 return const_mgr->GetConstant(type, words);
185 if (const auto* vec_type = type->AsVector())
186 return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
187 return nullptr;
188 }
189
190 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
191 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)192 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
193 const analysis::Constant* c) {
194 assert(c);
195 assert(c->type()->AsFloat());
196 uint32_t width = c->type()->AsFloat()->width();
197 assert(width == 32 || width == 64);
198 std::vector<uint32_t> words;
199 if (width == 64) {
200 utils::FloatProxy<double> result(c->GetDouble() * -1.0);
201 words = result.GetWords();
202 } else {
203 utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
204 words = result.GetWords();
205 }
206
207 const analysis::Constant* negated_const =
208 const_mgr->GetConstant(c->type(), std::move(words));
209 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
210 }
211
212 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)213 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
214 const analysis::Constant* c) {
215 assert(c);
216 assert(c->type()->AsInteger());
217 uint32_t width = c->type()->AsInteger()->width();
218 assert(width == 32 || width == 64);
219 std::vector<uint32_t> words;
220 if (width == 64) {
221 uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
222 words = ExtractInts(uval);
223 } else {
224 words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
225 }
226
227 const analysis::Constant* negated_const =
228 const_mgr->GetConstant(c->type(), std::move(words));
229 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
230 }
231
232 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)233 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
234 const analysis::Constant* c) {
235 assert(const_mgr && c);
236 assert(c->type()->AsVector());
237 if (c->AsNullConstant()) {
238 // 0.0 vs -0.0 shouldn't matter.
239 return const_mgr->GetDefiningInstruction(c)->result_id();
240 } else {
241 const analysis::Type* component_type =
242 c->AsVectorConstant()->component_type();
243 std::vector<uint32_t> words;
244 for (auto& comp : c->AsVectorConstant()->GetComponents()) {
245 if (component_type->AsFloat()) {
246 words.push_back(NegateFloatingPointConstant(const_mgr, comp));
247 } else {
248 assert(component_type->AsInteger());
249 words.push_back(NegateIntegerConstant(const_mgr, comp));
250 }
251 }
252
253 const analysis::Constant* negated_const =
254 const_mgr->GetConstant(c->type(), std::move(words));
255 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
256 }
257 }
258
259 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)260 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
261 const analysis::Constant* c) {
262 if (c->type()->AsVector()) {
263 return NegateVectorConstant(const_mgr, c);
264 } else if (c->type()->AsFloat()) {
265 return NegateFloatingPointConstant(const_mgr, c);
266 } else {
267 assert(c->type()->AsInteger());
268 return NegateIntegerConstant(const_mgr, c);
269 }
270 }
271
272 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
273 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)274 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
275 const analysis::Constant* c) {
276 assert(const_mgr && c);
277 assert(c->type()->AsFloat());
278
279 uint32_t width = c->type()->AsFloat()->width();
280 assert(width == 32 || width == 64);
281 std::vector<uint32_t> words;
282
283 if (c->IsZero()) {
284 return 0;
285 }
286
287 if (width == 64) {
288 spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
289 if (!IsValidResult(result.getAsFloat())) return 0;
290 words = result.GetWords();
291 } else {
292 spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
293 if (!IsValidResult(result.getAsFloat())) return 0;
294 words = result.GetWords();
295 }
296
297 const analysis::Constant* negated_const =
298 const_mgr->GetConstant(c->type(), std::move(words));
299 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
300 }
301
302 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()303 FoldingRule ReciprocalFDiv() {
304 return [](IRContext* context, Instruction* inst,
305 const std::vector<const analysis::Constant*>& constants) {
306 assert(inst->opcode() == spv::Op::OpFDiv);
307 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
308 const analysis::Type* type =
309 context->get_type_mgr()->GetType(inst->type_id());
310 if (!inst->IsFloatingPointFoldingAllowed()) return false;
311
312 uint32_t width = ElementWidth(type);
313 if (width != 32 && width != 64) return false;
314
315 if (constants[1] != nullptr) {
316 uint32_t id = 0;
317 if (const analysis::VectorConstant* vector_const =
318 constants[1]->AsVectorConstant()) {
319 std::vector<uint32_t> neg_ids;
320 for (auto& comp : vector_const->GetComponents()) {
321 id = Reciprocal(const_mgr, comp);
322 if (id == 0) return false;
323 neg_ids.push_back(id);
324 }
325 const analysis::Constant* negated_const =
326 const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
327 id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
328 } else if (constants[1]->AsFloatConstant()) {
329 id = Reciprocal(const_mgr, constants[1]);
330 if (id == 0) return false;
331 } else {
332 // Don't fold a null constant.
333 return false;
334 }
335 inst->SetOpcode(spv::Op::OpFMul);
336 inst->SetInOperands(
337 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
338 {SPV_OPERAND_TYPE_ID, {id}}});
339 return true;
340 }
341
342 return false;
343 };
344 }
345
346 // Elides consecutive negate instructions.
MergeNegateArithmetic()347 FoldingRule MergeNegateArithmetic() {
348 return [](IRContext* context, Instruction* inst,
349 const std::vector<const analysis::Constant*>& constants) {
350 assert(inst->opcode() == spv::Op::OpFNegate ||
351 inst->opcode() == spv::Op::OpSNegate);
352 (void)constants;
353 const analysis::Type* type =
354 context->get_type_mgr()->GetType(inst->type_id());
355 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
356 return false;
357
358 Instruction* op_inst =
359 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
360 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
361 return false;
362
363 if (op_inst->opcode() == inst->opcode()) {
364 // Elide negates.
365 inst->SetOpcode(spv::Op::OpCopyObject);
366 inst->SetInOperands(
367 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
368 return true;
369 }
370
371 return false;
372 };
373 }
374
375 // Merges negate into a mul or div operation if that operation contains a
376 // constant operand.
377 // Cases:
378 // -(x * 2) = x * -2
379 // -(2 * x) = x * -2
380 // -(x / 2) = x / -2
381 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()382 FoldingRule MergeNegateMulDivArithmetic() {
383 return [](IRContext* context, Instruction* inst,
384 const std::vector<const analysis::Constant*>& constants) {
385 assert(inst->opcode() == spv::Op::OpFNegate ||
386 inst->opcode() == spv::Op::OpSNegate);
387 (void)constants;
388 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
389 const analysis::Type* type =
390 context->get_type_mgr()->GetType(inst->type_id());
391 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
392 return false;
393
394 Instruction* op_inst =
395 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
396 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
397 return false;
398
399 uint32_t width = ElementWidth(type);
400 if (width != 32 && width != 64) return false;
401
402 spv::Op opcode = op_inst->opcode();
403 if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv ||
404 opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv ||
405 opcode == spv::Op::OpUDiv) {
406 std::vector<const analysis::Constant*> op_constants =
407 const_mgr->GetOperandConstants(op_inst);
408 // Merge negate into mul or div if one operand is constant.
409 if (op_constants[0] || op_constants[1]) {
410 bool zero_is_variable = op_constants[0] == nullptr;
411 const analysis::Constant* c = ConstInput(op_constants);
412 uint32_t neg_id = NegateConstant(const_mgr, c);
413 uint32_t non_const_id = zero_is_variable
414 ? op_inst->GetSingleWordInOperand(0u)
415 : op_inst->GetSingleWordInOperand(1u);
416 // Change this instruction to a mul/div.
417 inst->SetOpcode(op_inst->opcode());
418 if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
419 opcode == spv::Op::OpSDiv) {
420 uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
421 uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
422 inst->SetInOperands(
423 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
424 } else {
425 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
426 {SPV_OPERAND_TYPE_ID, {neg_id}}});
427 }
428 return true;
429 }
430 }
431
432 return false;
433 };
434 }
435
436 // Merges negate into a add or sub operation if that operation contains a
437 // constant operand.
438 // Cases:
439 // -(x + 2) = -2 - x
440 // -(2 + x) = -2 - x
441 // -(x - 2) = 2 - x
442 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()443 FoldingRule MergeNegateAddSubArithmetic() {
444 return [](IRContext* context, Instruction* inst,
445 const std::vector<const analysis::Constant*>& constants) {
446 assert(inst->opcode() == spv::Op::OpFNegate ||
447 inst->opcode() == spv::Op::OpSNegate);
448 (void)constants;
449 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
450 const analysis::Type* type =
451 context->get_type_mgr()->GetType(inst->type_id());
452 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
453 return false;
454
455 Instruction* op_inst =
456 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
457 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
458 return false;
459
460 uint32_t width = ElementWidth(type);
461 if (width != 32 && width != 64) return false;
462
463 if (op_inst->opcode() == spv::Op::OpFAdd ||
464 op_inst->opcode() == spv::Op::OpFSub ||
465 op_inst->opcode() == spv::Op::OpIAdd ||
466 op_inst->opcode() == spv::Op::OpISub) {
467 std::vector<const analysis::Constant*> op_constants =
468 const_mgr->GetOperandConstants(op_inst);
469 if (op_constants[0] || op_constants[1]) {
470 bool zero_is_variable = op_constants[0] == nullptr;
471 bool is_add = (op_inst->opcode() == spv::Op::OpFAdd) ||
472 (op_inst->opcode() == spv::Op::OpIAdd);
473 bool swap_operands = !is_add || zero_is_variable;
474 bool negate_const = is_add;
475 const analysis::Constant* c = ConstInput(op_constants);
476 uint32_t const_id = 0;
477 if (negate_const) {
478 const_id = NegateConstant(const_mgr, c);
479 } else {
480 const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
481 : op_inst->GetSingleWordInOperand(0u);
482 }
483
484 // Swap operands if necessary and make the instruction a subtraction.
485 uint32_t op0 =
486 zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
487 uint32_t op1 =
488 zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
489 if (swap_operands) std::swap(op0, op1);
490 inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
491 : spv::Op::OpISub);
492 inst->SetInOperands(
493 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
494 return true;
495 }
496 }
497
498 return false;
499 };
500 }
501
502 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)503 bool HasZero(const analysis::Constant* c) {
504 if (c->AsNullConstant()) {
505 return true;
506 }
507 if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
508 for (auto& comp : vec_const->GetComponents())
509 if (HasZero(comp)) return true;
510 } else {
511 assert(c->AsScalarConstant());
512 return c->AsScalarConstant()->IsZero();
513 }
514
515 return false;
516 }
517
518 // Performs |input1| |opcode| |input2| and returns the merged constant result
519 // id. Returns 0 if the result is not a valid value. The input types must be
520 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)521 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
522 spv::Op opcode,
523 const analysis::Constant* input1,
524 const analysis::Constant* input2) {
525 const analysis::Type* type = input1->type();
526 assert(type->AsFloat());
527 uint32_t width = type->AsFloat()->width();
528 assert(width == 32 || width == 64);
529 std::vector<uint32_t> words;
530 #define FOLD_OP(op) \
531 if (width == 64) { \
532 utils::FloatProxy<double> val = \
533 input1->GetDouble() op input2->GetDouble(); \
534 double dval = val.getAsFloat(); \
535 if (!IsValidResult(dval)) return 0; \
536 words = val.GetWords(); \
537 } else { \
538 utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
539 float fval = val.getAsFloat(); \
540 if (!IsValidResult(fval)) return 0; \
541 words = val.GetWords(); \
542 } \
543 static_assert(true, "require extra semicolon")
544 switch (opcode) {
545 case spv::Op::OpFMul:
546 FOLD_OP(*);
547 break;
548 case spv::Op::OpFDiv:
549 if (HasZero(input2)) return 0;
550 FOLD_OP(/);
551 break;
552 case spv::Op::OpFAdd:
553 FOLD_OP(+);
554 break;
555 case spv::Op::OpFSub:
556 FOLD_OP(-);
557 break;
558 default:
559 assert(false && "Unexpected operation");
560 break;
561 }
562 #undef FOLD_OP
563 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
564 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
565 }
566
567 // Performs |input1| |opcode| |input2| and returns the merged constant result
568 // id. Returns 0 if the result is not a valid value. The input types must be
569 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)570 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
571 spv::Op opcode,
572 const analysis::Constant* input1,
573 const analysis::Constant* input2) {
574 assert(input1->type()->AsInteger());
575 const analysis::Integer* type = input1->type()->AsInteger();
576 uint32_t width = type->AsInteger()->width();
577 assert(width == 32 || width == 64);
578 std::vector<uint32_t> words;
579 // Regardless of the sign of the constant, folding is performed on an unsigned
580 // interpretation of the constant data. This avoids signed integer overflow
581 // while folding, and works because sign is irrelevant for the IAdd, ISub and
582 // IMul instructions.
583 #define FOLD_OP(op) \
584 if (width == 64) { \
585 uint64_t val = input1->GetU64() op input2->GetU64(); \
586 words = ExtractInts(val); \
587 } else { \
588 uint32_t val = input1->GetU32() op input2->GetU32(); \
589 words.push_back(val); \
590 } \
591 static_assert(true, "require extra semicolon")
592 switch (opcode) {
593 case spv::Op::OpIMul:
594 FOLD_OP(*);
595 break;
596 case spv::Op::OpSDiv:
597 case spv::Op::OpUDiv:
598 assert(false && "Should not merge integer division");
599 break;
600 case spv::Op::OpIAdd:
601 FOLD_OP(+);
602 break;
603 case spv::Op::OpISub:
604 FOLD_OP(-);
605 break;
606 default:
607 assert(false && "Unexpected operation");
608 break;
609 }
610 #undef FOLD_OP
611 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
612 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
613 }
614
615 // Performs |input1| |opcode| |input2| and returns the merged constant result
616 // id. Returns 0 if the result is not a valid value. The input types must be
617 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)618 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, spv::Op opcode,
619 const analysis::Constant* input1,
620 const analysis::Constant* input2) {
621 assert(input1 && input2);
622 const analysis::Type* type = input1->type();
623 std::vector<uint32_t> words;
624 if (const analysis::Vector* vector_type = type->AsVector()) {
625 const analysis::Type* ele_type = vector_type->element_type();
626 for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
627 uint32_t id = 0;
628
629 const analysis::Constant* input1_comp = nullptr;
630 if (const analysis::VectorConstant* input1_vector =
631 input1->AsVectorConstant()) {
632 input1_comp = input1_vector->GetComponents()[i];
633 } else {
634 assert(input1->AsNullConstant());
635 input1_comp = const_mgr->GetConstant(ele_type, {});
636 }
637
638 const analysis::Constant* input2_comp = nullptr;
639 if (const analysis::VectorConstant* input2_vector =
640 input2->AsVectorConstant()) {
641 input2_comp = input2_vector->GetComponents()[i];
642 } else {
643 assert(input2->AsNullConstant());
644 input2_comp = const_mgr->GetConstant(ele_type, {});
645 }
646
647 if (ele_type->AsFloat()) {
648 id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
649 input2_comp);
650 } else {
651 assert(ele_type->AsInteger());
652 id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
653 input2_comp);
654 }
655 if (id == 0) return 0;
656 words.push_back(id);
657 }
658 const analysis::Constant* merged_const =
659 const_mgr->GetConstant(type, words);
660 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
661 } else if (type->AsFloat()) {
662 return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
663 } else {
664 assert(type->AsInteger());
665 return PerformIntegerOperation(const_mgr, opcode, input1, input2);
666 }
667 }
668
669 // Merges consecutive multiplies where each contains one constant operand.
670 // Cases:
671 // 2 * (x * 2) = x * 4
672 // 2 * (2 * x) = x * 4
673 // (x * 2) * 2 = x * 4
674 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()675 FoldingRule MergeMulMulArithmetic() {
676 return [](IRContext* context, Instruction* inst,
677 const std::vector<const analysis::Constant*>& constants) {
678 assert(inst->opcode() == spv::Op::OpFMul ||
679 inst->opcode() == spv::Op::OpIMul);
680 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
681 const analysis::Type* type =
682 context->get_type_mgr()->GetType(inst->type_id());
683 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
684 return false;
685
686 uint32_t width = ElementWidth(type);
687 if (width != 32 && width != 64) return false;
688
689 // Determine the constant input and the variable input in |inst|.
690 const analysis::Constant* const_input1 = ConstInput(constants);
691 if (!const_input1) return false;
692 Instruction* other_inst = NonConstInput(context, constants[0], inst);
693 if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
694 return false;
695
696 if (other_inst->opcode() == inst->opcode()) {
697 std::vector<const analysis::Constant*> other_constants =
698 const_mgr->GetOperandConstants(other_inst);
699 const analysis::Constant* const_input2 = ConstInput(other_constants);
700 if (!const_input2) return false;
701
702 bool other_first_is_variable = other_constants[0] == nullptr;
703 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
704 const_input1, const_input2);
705 if (merged_id == 0) return false;
706
707 uint32_t non_const_id = other_first_is_variable
708 ? other_inst->GetSingleWordInOperand(0u)
709 : other_inst->GetSingleWordInOperand(1u);
710 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
711 {SPV_OPERAND_TYPE_ID, {merged_id}}});
712 return true;
713 }
714
715 return false;
716 };
717 }
718
719 // Merges divides into subsequent multiplies if each instruction contains one
720 // constant operand. Does not support integer operations.
721 // Cases:
722 // 2 * (x / 2) = x * 1
723 // 2 * (2 / x) = 4 / x
724 // (x / 2) * 2 = x * 1
725 // (2 / x) * 2 = 4 / x
726 // (y / x) * x = y
727 // x * (y / x) = y
MergeMulDivArithmetic()728 FoldingRule MergeMulDivArithmetic() {
729 return [](IRContext* context, Instruction* inst,
730 const std::vector<const analysis::Constant*>& constants) {
731 assert(inst->opcode() == spv::Op::OpFMul);
732 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
733 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
734
735 const analysis::Type* type =
736 context->get_type_mgr()->GetType(inst->type_id());
737 if (!inst->IsFloatingPointFoldingAllowed()) return false;
738
739 uint32_t width = ElementWidth(type);
740 if (width != 32 && width != 64) return false;
741
742 for (uint32_t i = 0; i < 2; i++) {
743 uint32_t op_id = inst->GetSingleWordInOperand(i);
744 Instruction* op_inst = def_use_mgr->GetDef(op_id);
745 if (op_inst->opcode() == spv::Op::OpFDiv) {
746 if (op_inst->GetSingleWordInOperand(1) ==
747 inst->GetSingleWordInOperand(1 - i)) {
748 inst->SetOpcode(spv::Op::OpCopyObject);
749 inst->SetInOperands(
750 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
751 return true;
752 }
753 }
754 }
755
756 const analysis::Constant* const_input1 = ConstInput(constants);
757 if (!const_input1) return false;
758 Instruction* other_inst = NonConstInput(context, constants[0], inst);
759 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
760
761 if (other_inst->opcode() == spv::Op::OpFDiv) {
762 std::vector<const analysis::Constant*> other_constants =
763 const_mgr->GetOperandConstants(other_inst);
764 const analysis::Constant* const_input2 = ConstInput(other_constants);
765 if (!const_input2 || HasZero(const_input2)) return false;
766
767 bool other_first_is_variable = other_constants[0] == nullptr;
768 // If the variable value is the second operand of the divide, multiply
769 // the constants together. Otherwise divide the constants.
770 uint32_t merged_id = PerformOperation(
771 const_mgr,
772 other_first_is_variable ? other_inst->opcode() : inst->opcode(),
773 const_input1, const_input2);
774 if (merged_id == 0) return false;
775
776 uint32_t non_const_id = other_first_is_variable
777 ? other_inst->GetSingleWordInOperand(0u)
778 : other_inst->GetSingleWordInOperand(1u);
779
780 // If the variable value is on the second operand of the div, then this
781 // operation is a div. Otherwise it should be a multiply.
782 inst->SetOpcode(other_first_is_variable ? inst->opcode()
783 : other_inst->opcode());
784 if (other_first_is_variable) {
785 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
786 {SPV_OPERAND_TYPE_ID, {merged_id}}});
787 } else {
788 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
789 {SPV_OPERAND_TYPE_ID, {non_const_id}}});
790 }
791 return true;
792 }
793
794 return false;
795 };
796 }
797
798 // Merges multiply of constant and negation.
799 // Cases:
800 // (-x) * 2 = x * -2
801 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()802 FoldingRule MergeMulNegateArithmetic() {
803 return [](IRContext* context, Instruction* inst,
804 const std::vector<const analysis::Constant*>& constants) {
805 assert(inst->opcode() == spv::Op::OpFMul ||
806 inst->opcode() == spv::Op::OpIMul);
807 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
808 const analysis::Type* type =
809 context->get_type_mgr()->GetType(inst->type_id());
810 bool uses_float = HasFloatingPoint(type);
811 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
812
813 uint32_t width = ElementWidth(type);
814 if (width != 32 && width != 64) return false;
815
816 const analysis::Constant* const_input1 = ConstInput(constants);
817 if (!const_input1) return false;
818 Instruction* other_inst = NonConstInput(context, constants[0], inst);
819 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
820 return false;
821
822 if (other_inst->opcode() == spv::Op::OpFNegate ||
823 other_inst->opcode() == spv::Op::OpSNegate) {
824 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
825
826 inst->SetInOperands(
827 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
828 {SPV_OPERAND_TYPE_ID, {neg_id}}});
829 return true;
830 }
831
832 return false;
833 };
834 }
835
836 // Merges consecutive divides if each instruction contains one constant operand.
837 // Does not support integer division.
838 // Cases:
839 // 2 / (x / 2) = 4 / x
840 // 4 / (2 / x) = 2 * x
841 // (4 / x) / 2 = 2 / x
842 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()843 FoldingRule MergeDivDivArithmetic() {
844 return [](IRContext* context, Instruction* inst,
845 const std::vector<const analysis::Constant*>& constants) {
846 assert(inst->opcode() == spv::Op::OpFDiv);
847 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
848 const analysis::Type* type =
849 context->get_type_mgr()->GetType(inst->type_id());
850 if (!inst->IsFloatingPointFoldingAllowed()) return false;
851
852 uint32_t width = ElementWidth(type);
853 if (width != 32 && width != 64) return false;
854
855 const analysis::Constant* const_input1 = ConstInput(constants);
856 if (!const_input1 || HasZero(const_input1)) return false;
857 Instruction* other_inst = NonConstInput(context, constants[0], inst);
858 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
859
860 bool first_is_variable = constants[0] == nullptr;
861 if (other_inst->opcode() == inst->opcode()) {
862 std::vector<const analysis::Constant*> other_constants =
863 const_mgr->GetOperandConstants(other_inst);
864 const analysis::Constant* const_input2 = ConstInput(other_constants);
865 if (!const_input2 || HasZero(const_input2)) return false;
866
867 bool other_first_is_variable = other_constants[0] == nullptr;
868
869 spv::Op merge_op = inst->opcode();
870 if (other_first_is_variable) {
871 // Constants magnify.
872 merge_op = spv::Op::OpFMul;
873 }
874
875 // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
876 // because it is commutative.
877 if (first_is_variable) std::swap(const_input1, const_input2);
878 uint32_t merged_id =
879 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
880 if (merged_id == 0) return false;
881
882 uint32_t non_const_id = other_first_is_variable
883 ? other_inst->GetSingleWordInOperand(0u)
884 : other_inst->GetSingleWordInOperand(1u);
885
886 spv::Op op = inst->opcode();
887 if (!first_is_variable && !other_first_is_variable) {
888 // Effectively div of 1/x, so change to multiply.
889 op = spv::Op::OpFMul;
890 }
891
892 uint32_t op1 = merged_id;
893 uint32_t op2 = non_const_id;
894 if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
895 inst->SetOpcode(op);
896 inst->SetInOperands(
897 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
898 return true;
899 }
900
901 return false;
902 };
903 }
904
905 // Fold multiplies succeeded by divides where each instruction contains a
906 // constant operand. Does not support integer divide.
907 // Cases:
908 // 4 / (x * 2) = 2 / x
909 // 4 / (2 * x) = 2 / x
910 // (x * 4) / 2 = x * 2
911 // (4 * x) / 2 = x * 2
912 // (x * y) / x = y
913 // (y * x) / x = y
MergeDivMulArithmetic()914 FoldingRule MergeDivMulArithmetic() {
915 return [](IRContext* context, Instruction* inst,
916 const std::vector<const analysis::Constant*>& constants) {
917 assert(inst->opcode() == spv::Op::OpFDiv);
918 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
919 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
920
921 const analysis::Type* type =
922 context->get_type_mgr()->GetType(inst->type_id());
923 if (!inst->IsFloatingPointFoldingAllowed()) return false;
924
925 uint32_t width = ElementWidth(type);
926 if (width != 32 && width != 64) return false;
927
928 uint32_t op_id = inst->GetSingleWordInOperand(0);
929 Instruction* op_inst = def_use_mgr->GetDef(op_id);
930
931 if (op_inst->opcode() == spv::Op::OpFMul) {
932 for (uint32_t i = 0; i < 2; i++) {
933 if (op_inst->GetSingleWordInOperand(i) ==
934 inst->GetSingleWordInOperand(1)) {
935 inst->SetOpcode(spv::Op::OpCopyObject);
936 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
937 {op_inst->GetSingleWordInOperand(1 - i)}}});
938 return true;
939 }
940 }
941 }
942
943 const analysis::Constant* const_input1 = ConstInput(constants);
944 if (!const_input1 || HasZero(const_input1)) return false;
945 Instruction* other_inst = NonConstInput(context, constants[0], inst);
946 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
947
948 bool first_is_variable = constants[0] == nullptr;
949 if (other_inst->opcode() == spv::Op::OpFMul) {
950 std::vector<const analysis::Constant*> other_constants =
951 const_mgr->GetOperandConstants(other_inst);
952 const analysis::Constant* const_input2 = ConstInput(other_constants);
953 if (!const_input2) return false;
954
955 bool other_first_is_variable = other_constants[0] == nullptr;
956
957 // This is an x / (*) case. Swap the inputs.
958 if (first_is_variable) std::swap(const_input1, const_input2);
959 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
960 const_input1, const_input2);
961 if (merged_id == 0) return false;
962
963 uint32_t non_const_id = other_first_is_variable
964 ? other_inst->GetSingleWordInOperand(0u)
965 : other_inst->GetSingleWordInOperand(1u);
966
967 uint32_t op1 = merged_id;
968 uint32_t op2 = non_const_id;
969 if (first_is_variable) std::swap(op1, op2);
970
971 // Convert to multiply
972 if (first_is_variable) inst->SetOpcode(other_inst->opcode());
973 inst->SetInOperands(
974 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
975 return true;
976 }
977
978 return false;
979 };
980 }
981
982 // Fold divides of a constant and a negation.
983 // Cases:
984 // (-x) / 2 = x / -2
985 // 2 / (-x) = -2 / x
MergeDivNegateArithmetic()986 FoldingRule MergeDivNegateArithmetic() {
987 return [](IRContext* context, Instruction* inst,
988 const std::vector<const analysis::Constant*>& constants) {
989 assert(inst->opcode() == spv::Op::OpFDiv);
990 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
991 if (!inst->IsFloatingPointFoldingAllowed()) return false;
992
993 const analysis::Constant* const_input1 = ConstInput(constants);
994 if (!const_input1) return false;
995 Instruction* other_inst = NonConstInput(context, constants[0], inst);
996 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
997
998 bool first_is_variable = constants[0] == nullptr;
999 if (other_inst->opcode() == spv::Op::OpFNegate) {
1000 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
1001
1002 if (first_is_variable) {
1003 inst->SetInOperands(
1004 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
1005 {SPV_OPERAND_TYPE_ID, {neg_id}}});
1006 } else {
1007 inst->SetInOperands(
1008 {{SPV_OPERAND_TYPE_ID, {neg_id}},
1009 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1010 }
1011 return true;
1012 }
1013
1014 return false;
1015 };
1016 }
1017
1018 // Folds addition of a constant and a negation.
1019 // Cases:
1020 // (-x) + 2 = 2 - x
1021 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()1022 FoldingRule MergeAddNegateArithmetic() {
1023 return [](IRContext* context, Instruction* inst,
1024 const std::vector<const analysis::Constant*>& constants) {
1025 assert(inst->opcode() == spv::Op::OpFAdd ||
1026 inst->opcode() == spv::Op::OpIAdd);
1027 const analysis::Type* type =
1028 context->get_type_mgr()->GetType(inst->type_id());
1029 bool uses_float = HasFloatingPoint(type);
1030 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1031
1032 const analysis::Constant* const_input1 = ConstInput(constants);
1033 if (!const_input1) return false;
1034 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1035 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1036 return false;
1037
1038 if (other_inst->opcode() == spv::Op::OpSNegate ||
1039 other_inst->opcode() == spv::Op::OpFNegate) {
1040 inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
1041 : spv::Op::OpISub);
1042 uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
1043 : inst->GetSingleWordInOperand(1u);
1044 inst->SetInOperands(
1045 {{SPV_OPERAND_TYPE_ID, {const_id}},
1046 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1047 return true;
1048 }
1049 return false;
1050 };
1051 }
1052
1053 // Folds subtraction of a constant and a negation.
1054 // Cases:
1055 // (-x) - 2 = -2 - x
1056 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()1057 FoldingRule MergeSubNegateArithmetic() {
1058 return [](IRContext* context, Instruction* inst,
1059 const std::vector<const analysis::Constant*>& constants) {
1060 assert(inst->opcode() == spv::Op::OpFSub ||
1061 inst->opcode() == spv::Op::OpISub);
1062 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1063 const analysis::Type* type =
1064 context->get_type_mgr()->GetType(inst->type_id());
1065 bool uses_float = HasFloatingPoint(type);
1066 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1067
1068 uint32_t width = ElementWidth(type);
1069 if (width != 32 && width != 64) return false;
1070
1071 const analysis::Constant* const_input1 = ConstInput(constants);
1072 if (!const_input1) return false;
1073 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1074 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1075 return false;
1076
1077 if (other_inst->opcode() == spv::Op::OpSNegate ||
1078 other_inst->opcode() == spv::Op::OpFNegate) {
1079 uint32_t op1 = 0;
1080 uint32_t op2 = 0;
1081 spv::Op opcode = inst->opcode();
1082 if (constants[0] != nullptr) {
1083 op1 = other_inst->GetSingleWordInOperand(0u);
1084 op2 = inst->GetSingleWordInOperand(0u);
1085 opcode = HasFloatingPoint(type) ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1086 } else {
1087 op1 = NegateConstant(const_mgr, const_input1);
1088 op2 = other_inst->GetSingleWordInOperand(0u);
1089 }
1090
1091 inst->SetOpcode(opcode);
1092 inst->SetInOperands(
1093 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1094 return true;
1095 }
1096 return false;
1097 };
1098 }
1099
1100 // Folds addition of an addition where each operation has a constant operand.
1101 // Cases:
1102 // (x + 2) + 2 = x + 4
1103 // (2 + x) + 2 = x + 4
1104 // 2 + (x + 2) = x + 4
1105 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1106 FoldingRule MergeAddAddArithmetic() {
1107 return [](IRContext* context, Instruction* inst,
1108 const std::vector<const analysis::Constant*>& constants) {
1109 assert(inst->opcode() == spv::Op::OpFAdd ||
1110 inst->opcode() == spv::Op::OpIAdd);
1111 const analysis::Type* type =
1112 context->get_type_mgr()->GetType(inst->type_id());
1113 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1114 bool uses_float = HasFloatingPoint(type);
1115 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1116
1117 uint32_t width = ElementWidth(type);
1118 if (width != 32 && width != 64) return false;
1119
1120 const analysis::Constant* const_input1 = ConstInput(constants);
1121 if (!const_input1) return false;
1122 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1123 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1124 return false;
1125
1126 if (other_inst->opcode() == spv::Op::OpFAdd ||
1127 other_inst->opcode() == spv::Op::OpIAdd) {
1128 std::vector<const analysis::Constant*> other_constants =
1129 const_mgr->GetOperandConstants(other_inst);
1130 const analysis::Constant* const_input2 = ConstInput(other_constants);
1131 if (!const_input2) return false;
1132
1133 Instruction* non_const_input =
1134 NonConstInput(context, other_constants[0], other_inst);
1135 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1136 const_input1, const_input2);
1137 if (merged_id == 0) return false;
1138
1139 inst->SetInOperands(
1140 {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1141 {SPV_OPERAND_TYPE_ID, {merged_id}}});
1142 return true;
1143 }
1144 return false;
1145 };
1146 }
1147
1148 // Folds addition of a subtraction where each operation has a constant operand.
1149 // Cases:
1150 // (x - 2) + 2 = x + 0
1151 // (2 - x) + 2 = 4 - x
1152 // 2 + (x - 2) = x + 0
1153 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1154 FoldingRule MergeAddSubArithmetic() {
1155 return [](IRContext* context, Instruction* inst,
1156 const std::vector<const analysis::Constant*>& constants) {
1157 assert(inst->opcode() == spv::Op::OpFAdd ||
1158 inst->opcode() == spv::Op::OpIAdd);
1159 const analysis::Type* type =
1160 context->get_type_mgr()->GetType(inst->type_id());
1161 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1162 bool uses_float = HasFloatingPoint(type);
1163 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1164
1165 uint32_t width = ElementWidth(type);
1166 if (width != 32 && width != 64) return false;
1167
1168 const analysis::Constant* const_input1 = ConstInput(constants);
1169 if (!const_input1) return false;
1170 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1171 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1172 return false;
1173
1174 if (other_inst->opcode() == spv::Op::OpFSub ||
1175 other_inst->opcode() == spv::Op::OpISub) {
1176 std::vector<const analysis::Constant*> other_constants =
1177 const_mgr->GetOperandConstants(other_inst);
1178 const analysis::Constant* const_input2 = ConstInput(other_constants);
1179 if (!const_input2) return false;
1180
1181 bool first_is_variable = other_constants[0] == nullptr;
1182 spv::Op op = inst->opcode();
1183 uint32_t op1 = 0;
1184 uint32_t op2 = 0;
1185 if (first_is_variable) {
1186 // Subtract constants. Non-constant operand is first.
1187 op1 = other_inst->GetSingleWordInOperand(0u);
1188 op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1189 const_input2);
1190 } else {
1191 // Add constants. Constant operand is first. Change the opcode.
1192 op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1193 const_input2);
1194 op2 = other_inst->GetSingleWordInOperand(1u);
1195 op = other_inst->opcode();
1196 }
1197 if (op1 == 0 || op2 == 0) return false;
1198
1199 inst->SetOpcode(op);
1200 inst->SetInOperands(
1201 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1202 return true;
1203 }
1204 return false;
1205 };
1206 }
1207
1208 // Folds subtraction of an addition where each operand has a constant operand.
1209 // Cases:
1210 // (x + 2) - 2 = x + 0
1211 // (2 + x) - 2 = x + 0
1212 // 2 - (x + 2) = 0 - x
1213 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1214 FoldingRule MergeSubAddArithmetic() {
1215 return [](IRContext* context, Instruction* inst,
1216 const std::vector<const analysis::Constant*>& constants) {
1217 assert(inst->opcode() == spv::Op::OpFSub ||
1218 inst->opcode() == spv::Op::OpISub);
1219 const analysis::Type* type =
1220 context->get_type_mgr()->GetType(inst->type_id());
1221 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1222 bool uses_float = HasFloatingPoint(type);
1223 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1224
1225 uint32_t width = ElementWidth(type);
1226 if (width != 32 && width != 64) return false;
1227
1228 const analysis::Constant* const_input1 = ConstInput(constants);
1229 if (!const_input1) return false;
1230 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1231 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1232 return false;
1233
1234 if (other_inst->opcode() == spv::Op::OpFAdd ||
1235 other_inst->opcode() == spv::Op::OpIAdd) {
1236 std::vector<const analysis::Constant*> other_constants =
1237 const_mgr->GetOperandConstants(other_inst);
1238 const analysis::Constant* const_input2 = ConstInput(other_constants);
1239 if (!const_input2) return false;
1240
1241 Instruction* non_const_input =
1242 NonConstInput(context, other_constants[0], other_inst);
1243
1244 // If the first operand of the sub is not a constant, swap the constants
1245 // so the subtraction has the correct operands.
1246 if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1247 // Subtract the constants.
1248 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1249 const_input1, const_input2);
1250 spv::Op op = inst->opcode();
1251 uint32_t op1 = 0;
1252 uint32_t op2 = 0;
1253 if (constants[0] == nullptr) {
1254 // Non-constant operand is first. Change the opcode.
1255 op1 = non_const_input->result_id();
1256 op2 = merged_id;
1257 op = other_inst->opcode();
1258 } else {
1259 // Constant operand is first.
1260 op1 = merged_id;
1261 op2 = non_const_input->result_id();
1262 }
1263 if (op1 == 0 || op2 == 0) return false;
1264
1265 inst->SetOpcode(op);
1266 inst->SetInOperands(
1267 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1268 return true;
1269 }
1270 return false;
1271 };
1272 }
1273
1274 // Folds subtraction of a subtraction where each operand has a constant operand.
1275 // Cases:
1276 // (x - 2) - 2 = x - 4
1277 // (2 - x) - 2 = 0 - x
1278 // 2 - (x - 2) = 4 - x
1279 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1280 FoldingRule MergeSubSubArithmetic() {
1281 return [](IRContext* context, Instruction* inst,
1282 const std::vector<const analysis::Constant*>& constants) {
1283 assert(inst->opcode() == spv::Op::OpFSub ||
1284 inst->opcode() == spv::Op::OpISub);
1285 const analysis::Type* type =
1286 context->get_type_mgr()->GetType(inst->type_id());
1287 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1288 bool uses_float = HasFloatingPoint(type);
1289 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1290
1291 uint32_t width = ElementWidth(type);
1292 if (width != 32 && width != 64) return false;
1293
1294 const analysis::Constant* const_input1 = ConstInput(constants);
1295 if (!const_input1) return false;
1296 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1297 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1298 return false;
1299
1300 if (other_inst->opcode() == spv::Op::OpFSub ||
1301 other_inst->opcode() == spv::Op::OpISub) {
1302 std::vector<const analysis::Constant*> other_constants =
1303 const_mgr->GetOperandConstants(other_inst);
1304 const analysis::Constant* const_input2 = ConstInput(other_constants);
1305 if (!const_input2) return false;
1306
1307 Instruction* non_const_input =
1308 NonConstInput(context, other_constants[0], other_inst);
1309
1310 // Merge the constants.
1311 uint32_t merged_id = 0;
1312 spv::Op merge_op = inst->opcode();
1313 if (other_constants[0] == nullptr) {
1314 merge_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1315 } else if (constants[0] == nullptr) {
1316 std::swap(const_input1, const_input2);
1317 }
1318 merged_id =
1319 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1320 if (merged_id == 0) return false;
1321
1322 spv::Op op = inst->opcode();
1323 if (constants[0] != nullptr && other_constants[0] != nullptr) {
1324 // Change the operation.
1325 op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1326 }
1327
1328 uint32_t op1 = 0;
1329 uint32_t op2 = 0;
1330 if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1331 op1 = merged_id;
1332 op2 = non_const_input->result_id();
1333 } else {
1334 op1 = non_const_input->result_id();
1335 op2 = merged_id;
1336 }
1337
1338 inst->SetOpcode(op);
1339 inst->SetInOperands(
1340 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1341 return true;
1342 }
1343 return false;
1344 };
1345 }
1346
1347 // Helper function for MergeGenericAddSubArithmetic. If |addend| and
1348 // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
MergeGenericAddendSub(uint32_t addend,uint32_t sub,Instruction * inst)1349 bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1350 IRContext* context = inst->context();
1351 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1352 Instruction* sub_inst = def_use_mgr->GetDef(sub);
1353 if (sub_inst->opcode() != spv::Op::OpFSub &&
1354 sub_inst->opcode() != spv::Op::OpISub)
1355 return false;
1356 if (sub_inst->opcode() == spv::Op::OpFSub &&
1357 !sub_inst->IsFloatingPointFoldingAllowed())
1358 return false;
1359 if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1360 inst->SetOpcode(spv::Op::OpCopyObject);
1361 inst->SetInOperands(
1362 {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1363 context->UpdateDefUse(inst);
1364 return true;
1365 }
1366
1367 // Folds addition of a subtraction where the subtrahend is equal to the
1368 // other addend. Return a copy of the minuend. Accepts generic (const and
1369 // non-const) operands.
1370 // Cases:
1371 // (a - b) + b = a
1372 // b + (a - b) = a
MergeGenericAddSubArithmetic()1373 FoldingRule MergeGenericAddSubArithmetic() {
1374 return [](IRContext* context, Instruction* inst,
1375 const std::vector<const analysis::Constant*>&) {
1376 assert(inst->opcode() == spv::Op::OpFAdd ||
1377 inst->opcode() == spv::Op::OpIAdd);
1378 const analysis::Type* type =
1379 context->get_type_mgr()->GetType(inst->type_id());
1380 bool uses_float = HasFloatingPoint(type);
1381 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1382
1383 uint32_t width = ElementWidth(type);
1384 if (width != 32 && width != 64) return false;
1385
1386 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1387 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1388 if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1389 return MergeGenericAddendSub(add_op1, add_op0, inst);
1390 };
1391 }
1392
1393 // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1394 // generate |factor0_0| * (|factor0_1| + |factor1_1|).
FactorAddMulsOpnds(uint32_t factor0_0,uint32_t factor0_1,uint32_t factor1_0,uint32_t factor1_1,Instruction * inst)1395 bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1396 uint32_t factor1_0, uint32_t factor1_1,
1397 Instruction* inst) {
1398 IRContext* context = inst->context();
1399 if (factor0_0 != factor1_0) return false;
1400 InstructionBuilder ir_builder(
1401 context, inst,
1402 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1403 Instruction* new_add_inst = ir_builder.AddBinaryOp(
1404 inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1405 inst->SetOpcode(inst->opcode() == spv::Op::OpFAdd ? spv::Op::OpFMul
1406 : spv::Op::OpIMul);
1407 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1408 {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1409 context->UpdateDefUse(inst);
1410 return true;
1411 }
1412
1413 // Perform the following factoring identity, handling all operand order
1414 // combinations: (a * b) + (a * c) = a * (b + c)
FactorAddMuls()1415 FoldingRule FactorAddMuls() {
1416 return [](IRContext* context, Instruction* inst,
1417 const std::vector<const analysis::Constant*>&) {
1418 assert(inst->opcode() == spv::Op::OpFAdd ||
1419 inst->opcode() == spv::Op::OpIAdd);
1420 const analysis::Type* type =
1421 context->get_type_mgr()->GetType(inst->type_id());
1422 bool uses_float = HasFloatingPoint(type);
1423 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1424
1425 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1426 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1427 Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1428 if (add_op0_inst->opcode() != spv::Op::OpFMul &&
1429 add_op0_inst->opcode() != spv::Op::OpIMul)
1430 return false;
1431 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1432 Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1433 if (add_op1_inst->opcode() != spv::Op::OpFMul &&
1434 add_op1_inst->opcode() != spv::Op::OpIMul)
1435 return false;
1436
1437 // Only perform this optimization if both of the muls only have one use.
1438 // Otherwise this is a deoptimization in size and performance.
1439 if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1440 if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1441
1442 if (add_op0_inst->opcode() == spv::Op::OpFMul &&
1443 (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1444 !add_op1_inst->IsFloatingPointFoldingAllowed()))
1445 return false;
1446
1447 for (int i = 0; i < 2; i++) {
1448 for (int j = 0; j < 2; j++) {
1449 // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1450 if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1451 add_op0_inst->GetSingleWordInOperand(1 - i),
1452 add_op1_inst->GetSingleWordInOperand(j),
1453 add_op1_inst->GetSingleWordInOperand(1 - j),
1454 inst))
1455 return true;
1456 }
1457 }
1458 return false;
1459 };
1460 }
1461
IntMultipleBy1()1462 FoldingRule IntMultipleBy1() {
1463 return [](IRContext*, Instruction* inst,
1464 const std::vector<const analysis::Constant*>& constants) {
1465 assert(inst->opcode() == spv::Op::OpIMul &&
1466 "Wrong opcode. Should be OpIMul.");
1467 for (uint32_t i = 0; i < 2; i++) {
1468 if (constants[i] == nullptr) {
1469 continue;
1470 }
1471 const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1472 if (int_constant) {
1473 uint32_t width = ElementWidth(int_constant->type());
1474 if (width != 32 && width != 64) return false;
1475 bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1476 : int_constant->GetU64BitValue() == 1ull;
1477 if (is_one) {
1478 inst->SetOpcode(spv::Op::OpCopyObject);
1479 inst->SetInOperands(
1480 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1481 return true;
1482 }
1483 }
1484 }
1485 return false;
1486 };
1487 }
1488
1489 // Returns the number of elements that the |index|th in operand in |inst|
1490 // contributes to the result of |inst|. |inst| must be an
1491 // OpCompositeConstructInstruction.
GetNumOfElementsContributedByOperand(IRContext * context,const Instruction * inst,uint32_t index)1492 uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
1493 const Instruction* inst,
1494 uint32_t index) {
1495 assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1496 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1497 analysis::TypeManager* type_mgr = context->get_type_mgr();
1498
1499 analysis::Vector* result_type =
1500 type_mgr->GetType(inst->type_id())->AsVector();
1501 if (result_type == nullptr) {
1502 // If the result of the OpCompositeConstruct is not a vector then every
1503 // operands corresponds to a single element in the result.
1504 return 1;
1505 }
1506
1507 // If the result type is a vector then the operands are either scalars or
1508 // vectors. If it is a scalar, then it corresponds to a single element. If it
1509 // is a vector, then each element in the vector will be an element in the
1510 // result.
1511 uint32_t id = inst->GetSingleWordInOperand(index);
1512 Instruction* def = def_use_mgr->GetDef(id);
1513 analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
1514 if (type == nullptr) {
1515 return 1;
1516 }
1517 return type->element_count();
1518 }
1519
1520 // Returns the in-operands for an OpCompositeExtract instruction that are needed
1521 // to extract the |result_index|th element in the result of |inst| without using
1522 // the result of |inst|. Returns the empty vector if |result_index| is
1523 // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
GetExtractOperandsForElementOfCompositeConstruct(IRContext * context,const Instruction * inst,uint32_t result_index)1524 std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
1525 IRContext* context, const Instruction* inst, uint32_t result_index) {
1526 assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1527 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1528 analysis::TypeManager* type_mgr = context->get_type_mgr();
1529
1530 analysis::Type* result_type = type_mgr->GetType(inst->type_id());
1531 if (result_type->AsVector() == nullptr) {
1532 if (result_index < inst->NumInOperands()) {
1533 uint32_t id = inst->GetSingleWordInOperand(result_index);
1534 return {Operand(SPV_OPERAND_TYPE_ID, {id})};
1535 }
1536 return {};
1537 }
1538
1539 // If the result type is a vector, then vector operands are concatenated.
1540 uint32_t total_element_count = 0;
1541 for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
1542 uint32_t element_count =
1543 GetNumOfElementsContributedByOperand(context, inst, idx);
1544 total_element_count += element_count;
1545 if (result_index < total_element_count) {
1546 std::vector<Operand> operands;
1547 uint32_t id = inst->GetSingleWordInOperand(idx);
1548 Instruction* operand_def = def_use_mgr->GetDef(id);
1549 analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
1550
1551 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1552 if (operand_type->AsVector()) {
1553 uint32_t start_index_of_id = total_element_count - element_count;
1554 uint32_t index_into_id = result_index - start_index_of_id;
1555 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
1556 }
1557 return operands;
1558 }
1559 }
1560 return {};
1561 }
1562
CompositeConstructFeedingExtract(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1563 bool CompositeConstructFeedingExtract(
1564 IRContext* context, Instruction* inst,
1565 const std::vector<const analysis::Constant*>&) {
1566 // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1567 // then we can simply use the appropriate element in the construction.
1568 assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1569 "Wrong opcode. Should be OpCompositeExtract.");
1570 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1571
1572 // If there are no index operands, then this rule cannot do anything.
1573 if (inst->NumInOperands() <= 1) {
1574 return false;
1575 }
1576
1577 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1578 Instruction* cinst = def_use_mgr->GetDef(cid);
1579
1580 if (cinst->opcode() != spv::Op::OpCompositeConstruct) {
1581 return false;
1582 }
1583
1584 uint32_t index_into_result = inst->GetSingleWordInOperand(1);
1585 std::vector<Operand> operands =
1586 GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
1587 index_into_result);
1588
1589 if (operands.empty()) {
1590 return false;
1591 }
1592
1593 // Add the remaining indices for extraction.
1594 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1595 operands.push_back(
1596 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
1597 }
1598
1599 if (operands.size() == 1) {
1600 // If there were no extra indices, then we have the final object. No need
1601 // to extract any more.
1602 inst->SetOpcode(spv::Op::OpCopyObject);
1603 }
1604
1605 inst->SetInOperands(std::move(operands));
1606 return true;
1607 }
1608
1609 // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
1610 // OpCompositeExtract instruction, and returns the type id of the final element
1611 // being accessed. Returns 0 if a valid type could not be found.
GetElementType(uint32_t type_id,Instruction::iterator start,Instruction::iterator end,const analysis::DefUseManager * def_use_manager)1612 uint32_t GetElementType(uint32_t type_id, Instruction::iterator start,
1613 Instruction::iterator end,
1614 const analysis::DefUseManager* def_use_manager) {
1615 for (auto index : make_range(std::move(start), std::move(end))) {
1616 const Instruction* type_inst = def_use_manager->GetDef(type_id);
1617 assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
1618 index.words.size() == 1);
1619 if (type_inst->opcode() == spv::Op::OpTypeArray) {
1620 type_id = type_inst->GetSingleWordInOperand(0);
1621 } else if (type_inst->opcode() == spv::Op::OpTypeMatrix) {
1622 type_id = type_inst->GetSingleWordInOperand(0);
1623 } else if (type_inst->opcode() == spv::Op::OpTypeStruct) {
1624 type_id = type_inst->GetSingleWordInOperand(index.words[0]);
1625 } else {
1626 return 0;
1627 }
1628 }
1629 return type_id;
1630 }
1631
1632 // Returns true of |inst_1| and |inst_2| have the same indexes that will be used
1633 // to index into a composite object, excluding the last index. The two
1634 // instructions must have the same opcode, and be either OpCompositeExtract or
1635 // OpCompositeInsert instructions.
HaveSameIndexesExceptForLast(Instruction * inst_1,Instruction * inst_2)1636 bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
1637 assert(inst_1->opcode() == inst_2->opcode() &&
1638 "Expecting the opcodes to be the same.");
1639 assert((inst_1->opcode() == spv::Op::OpCompositeInsert ||
1640 inst_1->opcode() == spv::Op::OpCompositeExtract) &&
1641 "Instructions must be OpCompositeInsert or OpCompositeExtract.");
1642
1643 if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
1644 return false;
1645 }
1646
1647 uint32_t first_index_position =
1648 (inst_1->opcode() == spv::Op::OpCompositeInsert ? 2 : 1);
1649 for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
1650 i++) {
1651 if (inst_1->GetSingleWordInOperand(i) !=
1652 inst_2->GetSingleWordInOperand(i)) {
1653 return false;
1654 }
1655 }
1656 return true;
1657 }
1658
1659 // If the OpCompositeConstruct is simply putting back together elements that
1660 // where extracted from the same source, we can simply reuse the source.
1661 //
1662 // This is a common code pattern because of the way that scalar replacement
1663 // works.
CompositeExtractFeedingConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1664 bool CompositeExtractFeedingConstruct(
1665 IRContext* context, Instruction* inst,
1666 const std::vector<const analysis::Constant*>&) {
1667 assert(inst->opcode() == spv::Op::OpCompositeConstruct &&
1668 "Wrong opcode. Should be OpCompositeConstruct.");
1669 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1670 uint32_t original_id = 0;
1671
1672 if (inst->NumInOperands() == 0) {
1673 // The struct being constructed has no members.
1674 return false;
1675 }
1676
1677 // Check each element to make sure they are:
1678 // - extractions
1679 // - extracting the same position they are inserting
1680 // - all extract from the same id.
1681 Instruction* first_element_inst = nullptr;
1682 for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1683 const uint32_t element_id = inst->GetSingleWordInOperand(i);
1684 Instruction* element_inst = def_use_mgr->GetDef(element_id);
1685 if (first_element_inst == nullptr) {
1686 first_element_inst = element_inst;
1687 }
1688
1689 if (element_inst->opcode() != spv::Op::OpCompositeExtract) {
1690 return false;
1691 }
1692
1693 if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
1694 return false;
1695 }
1696
1697 if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
1698 1) != i) {
1699 return false;
1700 }
1701
1702 if (i == 0) {
1703 original_id =
1704 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1705 } else if (original_id !=
1706 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1707 return false;
1708 }
1709 }
1710
1711 // The last check it to see that the object being extracted from is the
1712 // correct type.
1713 Instruction* original_inst = def_use_mgr->GetDef(original_id);
1714 uint32_t original_type_id =
1715 GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
1716 first_element_inst->end() - 1, def_use_mgr);
1717
1718 if (inst->type_id() != original_type_id) {
1719 return false;
1720 }
1721
1722 if (first_element_inst->NumInOperands() == 2) {
1723 // Simplify by using the original object.
1724 inst->SetOpcode(spv::Op::OpCopyObject);
1725 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1726 return true;
1727 }
1728
1729 // Copies the original id and all indexes except for the last to the new
1730 // extract instruction.
1731 inst->SetOpcode(spv::Op::OpCompositeExtract);
1732 inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
1733 first_element_inst->end() - 1));
1734 return true;
1735 }
1736
InsertFeedingExtract()1737 FoldingRule InsertFeedingExtract() {
1738 return [](IRContext* context, Instruction* inst,
1739 const std::vector<const analysis::Constant*>&) {
1740 assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1741 "Wrong opcode. Should be OpCompositeExtract.");
1742 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1743 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1744 Instruction* cinst = def_use_mgr->GetDef(cid);
1745
1746 if (cinst->opcode() != spv::Op::OpCompositeInsert) {
1747 return false;
1748 }
1749
1750 // Find the first position where the list of insert and extract indicies
1751 // differ, if at all.
1752 uint32_t i;
1753 for (i = 1; i < inst->NumInOperands(); ++i) {
1754 if (i + 1 >= cinst->NumInOperands()) {
1755 break;
1756 }
1757
1758 if (inst->GetSingleWordInOperand(i) !=
1759 cinst->GetSingleWordInOperand(i + 1)) {
1760 break;
1761 }
1762 }
1763
1764 // We are extracting the element that was inserted.
1765 if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1766 inst->SetOpcode(spv::Op::OpCopyObject);
1767 inst->SetInOperands(
1768 {{SPV_OPERAND_TYPE_ID,
1769 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1770 return true;
1771 }
1772
1773 // Extracting the value that was inserted along with values for the base
1774 // composite. Cannot do anything.
1775 if (i == inst->NumInOperands()) {
1776 return false;
1777 }
1778
1779 // Extracting an element of the value that was inserted. Extract from
1780 // that value directly.
1781 if (i + 1 == cinst->NumInOperands()) {
1782 std::vector<Operand> operands;
1783 operands.push_back(
1784 {SPV_OPERAND_TYPE_ID,
1785 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1786 for (; i < inst->NumInOperands(); ++i) {
1787 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1788 {inst->GetSingleWordInOperand(i)}});
1789 }
1790 inst->SetInOperands(std::move(operands));
1791 return true;
1792 }
1793
1794 // Extracting a value that is disjoint from the element being inserted.
1795 // Rewrite the extract to use the composite input to the insert.
1796 std::vector<Operand> operands;
1797 operands.push_back(
1798 {SPV_OPERAND_TYPE_ID,
1799 {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1800 for (i = 1; i < inst->NumInOperands(); ++i) {
1801 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1802 {inst->GetSingleWordInOperand(i)}});
1803 }
1804 inst->SetInOperands(std::move(operands));
1805 return true;
1806 };
1807 }
1808
1809 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1810 // operands of the VectorShuffle. We just need to adjust the index in the
1811 // extract instruction.
VectorShuffleFeedingExtract()1812 FoldingRule VectorShuffleFeedingExtract() {
1813 return [](IRContext* context, Instruction* inst,
1814 const std::vector<const analysis::Constant*>&) {
1815 assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1816 "Wrong opcode. Should be OpCompositeExtract.");
1817 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1818 analysis::TypeManager* type_mgr = context->get_type_mgr();
1819 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1820 Instruction* cinst = def_use_mgr->GetDef(cid);
1821
1822 if (cinst->opcode() != spv::Op::OpVectorShuffle) {
1823 return false;
1824 }
1825
1826 // Find the size of the first vector operand of the VectorShuffle
1827 Instruction* first_input =
1828 def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1829 analysis::Type* first_input_type =
1830 type_mgr->GetType(first_input->type_id());
1831 assert(first_input_type->AsVector() &&
1832 "Input to vector shuffle should be vectors.");
1833 uint32_t first_input_size = first_input_type->AsVector()->element_count();
1834
1835 // Get index of the element the vector shuffle is placing in the position
1836 // being extracted.
1837 uint32_t new_index =
1838 cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1839
1840 // Extracting an undefined value so fold this extract into an undef.
1841 const uint32_t undef_literal_value = 0xffffffff;
1842 if (new_index == undef_literal_value) {
1843 inst->SetOpcode(spv::Op::OpUndef);
1844 inst->SetInOperands({});
1845 return true;
1846 }
1847
1848 // Get the id of the of the vector the elemtent comes from, and update the
1849 // index if needed.
1850 uint32_t new_vector = 0;
1851 if (new_index < first_input_size) {
1852 new_vector = cinst->GetSingleWordInOperand(0);
1853 } else {
1854 new_vector = cinst->GetSingleWordInOperand(1);
1855 new_index -= first_input_size;
1856 }
1857
1858 // Update the extract instruction.
1859 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1860 inst->SetInOperand(1, {new_index});
1861 return true;
1862 };
1863 }
1864
1865 // When an FMix with is feeding an Extract that extracts an element whose
1866 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1867 // operands of the FMix.
FMixFeedingExtract()1868 FoldingRule FMixFeedingExtract() {
1869 return [](IRContext* context, Instruction* inst,
1870 const std::vector<const analysis::Constant*>&) {
1871 assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1872 "Wrong opcode. Should be OpCompositeExtract.");
1873 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1874 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1875
1876 uint32_t composite_id =
1877 inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1878 Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1879
1880 if (composite_inst->opcode() != spv::Op::OpExtInst) {
1881 return false;
1882 }
1883
1884 uint32_t inst_set_id =
1885 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1886
1887 if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1888 inst_set_id ||
1889 composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1890 GLSLstd450FMix) {
1891 return false;
1892 }
1893
1894 // Get the |a| for the FMix instruction.
1895 uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1896 std::unique_ptr<Instruction> a(inst->Clone(context));
1897 a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1898 context->get_instruction_folder().FoldInstruction(a.get());
1899
1900 if (a->opcode() != spv::Op::OpCopyObject) {
1901 return false;
1902 }
1903
1904 const analysis::Constant* a_const =
1905 const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1906
1907 if (!a_const) {
1908 return false;
1909 }
1910
1911 bool use_x = false;
1912
1913 assert(a_const->type()->AsFloat());
1914 double element_value = a_const->GetValueAsDouble();
1915 if (element_value == 0.0) {
1916 use_x = true;
1917 } else if (element_value == 1.0) {
1918 use_x = false;
1919 } else {
1920 return false;
1921 }
1922
1923 // Get the id of the of the vector the element comes from.
1924 uint32_t new_vector = 0;
1925 if (use_x) {
1926 new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1927 } else {
1928 new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1929 }
1930
1931 // Update the extract instruction.
1932 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1933 return true;
1934 };
1935 }
1936
1937 // Returns the number of elements in the composite type |type|. Returns 0 if
1938 // |type| is a scalar value. Return UINT32_MAX when the size is unknown at
1939 // compile time.
GetNumberOfElements(const analysis::Type * type)1940 uint32_t GetNumberOfElements(const analysis::Type* type) {
1941 if (auto* vector_type = type->AsVector()) {
1942 return vector_type->element_count();
1943 }
1944 if (auto* matrix_type = type->AsMatrix()) {
1945 return matrix_type->element_count();
1946 }
1947 if (auto* struct_type = type->AsStruct()) {
1948 return static_cast<uint32_t>(struct_type->element_types().size());
1949 }
1950 if (auto* array_type = type->AsArray()) {
1951 if (array_type->length_info().words[0] ==
1952 analysis::Array::LengthInfo::kConstant &&
1953 array_type->length_info().words.size() == 2) {
1954 return array_type->length_info().words[1];
1955 }
1956 return UINT32_MAX;
1957 }
1958 return 0;
1959 }
1960
1961 // Returns a map with the set of values that were inserted into an object by
1962 // the chain of OpCompositeInsertInstruction starting with |inst|.
1963 // The map will map the index to the value inserted at that index. An empty map
1964 // will be returned if the map could not be properly generated.
GetInsertedValues(Instruction * inst)1965 std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
1966 analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
1967 std::map<uint32_t, uint32_t> values_inserted;
1968 Instruction* current_inst = inst;
1969 while (current_inst->opcode() == spv::Op::OpCompositeInsert) {
1970 if (current_inst->NumInOperands() > inst->NumInOperands()) {
1971 // This is to catch the case
1972 // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
1973 // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
1974 // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
1975 // In this case we cannot do a single construct to get the matrix.
1976 uint32_t partially_inserted_element_index =
1977 current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
1978 if (values_inserted.count(partially_inserted_element_index) == 0)
1979 return {};
1980 }
1981 if (HaveSameIndexesExceptForLast(inst, current_inst)) {
1982 values_inserted.insert(
1983 {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
1984 1),
1985 current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
1986 }
1987 current_inst = def_use_mgr->GetDef(
1988 current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
1989 }
1990 return values_inserted;
1991 }
1992
1993 // Returns true of there is an entry in |values_inserted| for every element of
1994 // |Type|.
DoInsertedValuesCoverEntireObject(const analysis::Type * type,std::map<uint32_t,uint32_t> & values_inserted)1995 bool DoInsertedValuesCoverEntireObject(
1996 const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
1997 uint32_t container_size = GetNumberOfElements(type);
1998 if (container_size != values_inserted.size()) {
1999 return false;
2000 }
2001
2002 if (values_inserted.rbegin()->first >= container_size) {
2003 return false;
2004 }
2005 return true;
2006 }
2007
2008 // Returns id of the type of the element that immediately contains the element
2009 // being inserted by the OpCompositeInsert instruction |inst|. Returns 0 if it
2010 // could not be found.
GetContainerTypeId(Instruction * inst)2011 uint32_t GetContainerTypeId(Instruction* inst) {
2012 assert(inst->opcode() == spv::Op::OpCompositeInsert);
2013 analysis::DefUseManager* def_use_manager = inst->context()->get_def_use_mgr();
2014 uint32_t container_type_id = GetElementType(
2015 inst->type_id(), inst->begin() + 4, inst->end() - 1, def_use_manager);
2016 return container_type_id;
2017 }
2018
2019 // Returns an OpCompositeConstruct instruction that build an object with
2020 // |type_id| out of the values in |values_inserted|. Each value will be
2021 // placed at the index corresponding to the value. The new instruction will
2022 // be placed before |insert_before|.
BuildCompositeConstruct(uint32_t type_id,const std::map<uint32_t,uint32_t> & values_inserted,Instruction * insert_before)2023 Instruction* BuildCompositeConstruct(
2024 uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
2025 Instruction* insert_before) {
2026 InstructionBuilder ir_builder(
2027 insert_before->context(), insert_before,
2028 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2029
2030 std::vector<uint32_t> ids_in_order;
2031 for (auto it : values_inserted) {
2032 ids_in_order.push_back(it.second);
2033 }
2034 Instruction* construct =
2035 ir_builder.AddCompositeConstruct(type_id, ids_in_order);
2036 return construct;
2037 }
2038
2039 // Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
2040 // object as |inst| with final index removed. If the resulting
2041 // OpCompositeInsert instruction would have no remaining indexes, the
2042 // instruction is replaced with an OpCopyObject instead.
InsertConstructedObject(Instruction * inst,const Instruction * construct)2043 void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
2044 if (inst->NumInOperands() == 3) {
2045 inst->SetOpcode(spv::Op::OpCopyObject);
2046 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
2047 } else {
2048 inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
2049 inst->RemoveOperand(inst->NumOperands() - 1);
2050 }
2051 }
2052
2053 // Replaces a series of |OpCompositeInsert| instruction that cover the entire
2054 // object with an |OpCompositeConstruct|.
CompositeInsertToCompositeConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)2055 bool CompositeInsertToCompositeConstruct(
2056 IRContext* context, Instruction* inst,
2057 const std::vector<const analysis::Constant*>&) {
2058 assert(inst->opcode() == spv::Op::OpCompositeInsert &&
2059 "Wrong opcode. Should be OpCompositeInsert.");
2060 if (inst->NumInOperands() < 3) return false;
2061
2062 std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
2063 uint32_t container_type_id = GetContainerTypeId(inst);
2064 if (container_type_id == 0) {
2065 return false;
2066 }
2067
2068 analysis::TypeManager* type_mgr = context->get_type_mgr();
2069 const analysis::Type* container_type = type_mgr->GetType(container_type_id);
2070 assert(container_type && "GetContainerTypeId returned a bad id.");
2071 if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
2072 return false;
2073 }
2074
2075 Instruction* construct =
2076 BuildCompositeConstruct(container_type_id, values_inserted, inst);
2077 InsertConstructedObject(inst, construct);
2078 return true;
2079 }
2080
RedundantPhi()2081 FoldingRule RedundantPhi() {
2082 // An OpPhi instruction where all values are the same or the result of the phi
2083 // itself, can be replaced by the value itself.
2084 return [](IRContext*, Instruction* inst,
2085 const std::vector<const analysis::Constant*>&) {
2086 assert(inst->opcode() == spv::Op::OpPhi &&
2087 "Wrong opcode. Should be OpPhi.");
2088
2089 uint32_t incoming_value = 0;
2090
2091 for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
2092 uint32_t op_id = inst->GetSingleWordInOperand(i);
2093 if (op_id == inst->result_id()) {
2094 continue;
2095 }
2096
2097 if (incoming_value == 0) {
2098 incoming_value = op_id;
2099 } else if (op_id != incoming_value) {
2100 // Found two possible value. Can't simplify.
2101 return false;
2102 }
2103 }
2104
2105 if (incoming_value == 0) {
2106 // Code looks invalid. Don't do anything.
2107 return false;
2108 }
2109
2110 // We have a single incoming value. Simplify using that value.
2111 inst->SetOpcode(spv::Op::OpCopyObject);
2112 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
2113 return true;
2114 };
2115 }
2116
BitCastScalarOrVector()2117 FoldingRule BitCastScalarOrVector() {
2118 return [](IRContext* context, Instruction* inst,
2119 const std::vector<const analysis::Constant*>& constants) {
2120 assert(inst->opcode() == spv::Op::OpBitcast && constants.size() == 1);
2121 if (constants[0] == nullptr) return false;
2122
2123 const analysis::Type* type =
2124 context->get_type_mgr()->GetType(inst->type_id());
2125 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
2126 return false;
2127
2128 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2129 std::vector<uint32_t> words =
2130 GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
2131 if (words.size() == 0) return false;
2132
2133 const analysis::Constant* bitcasted_constant =
2134 ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
2135 if (!bitcasted_constant) return false;
2136
2137 auto new_feeder_id =
2138 const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
2139 ->result_id();
2140 inst->SetOpcode(spv::Op::OpCopyObject);
2141 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2142 return true;
2143 };
2144 }
2145
RedundantSelect()2146 FoldingRule RedundantSelect() {
2147 // An OpSelect instruction where both values are the same or the condition is
2148 // constant can be replaced by one of the values
2149 return [](IRContext*, Instruction* inst,
2150 const std::vector<const analysis::Constant*>& constants) {
2151 assert(inst->opcode() == spv::Op::OpSelect &&
2152 "Wrong opcode. Should be OpSelect.");
2153 assert(inst->NumInOperands() == 3);
2154 assert(constants.size() == 3);
2155
2156 uint32_t true_id = inst->GetSingleWordInOperand(1);
2157 uint32_t false_id = inst->GetSingleWordInOperand(2);
2158
2159 if (true_id == false_id) {
2160 // Both results are the same, condition doesn't matter
2161 inst->SetOpcode(spv::Op::OpCopyObject);
2162 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2163 return true;
2164 } else if (constants[0]) {
2165 const analysis::Type* type = constants[0]->type();
2166 if (type->AsBool()) {
2167 // Scalar constant value, select the corresponding value.
2168 inst->SetOpcode(spv::Op::OpCopyObject);
2169 if (constants[0]->AsNullConstant() ||
2170 !constants[0]->AsBoolConstant()->value()) {
2171 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2172 } else {
2173 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2174 }
2175 return true;
2176 } else {
2177 assert(type->AsVector());
2178 if (constants[0]->AsNullConstant()) {
2179 // All values come from false id.
2180 inst->SetOpcode(spv::Op::OpCopyObject);
2181 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2182 return true;
2183 } else {
2184 // Convert to a vector shuffle.
2185 std::vector<Operand> ops;
2186 ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
2187 ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
2188 const analysis::VectorConstant* vector_const =
2189 constants[0]->AsVectorConstant();
2190 uint32_t size =
2191 static_cast<uint32_t>(vector_const->GetComponents().size());
2192 for (uint32_t i = 0; i != size; ++i) {
2193 const analysis::Constant* component =
2194 vector_const->GetComponents()[i];
2195 if (component->AsNullConstant() ||
2196 !component->AsBoolConstant()->value()) {
2197 // Selecting from the false vector which is the second input
2198 // vector to the shuffle. Offset the index by |size|.
2199 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
2200 } else {
2201 // Selecting from true vector which is the first input vector to
2202 // the shuffle.
2203 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
2204 }
2205 }
2206
2207 inst->SetOpcode(spv::Op::OpVectorShuffle);
2208 inst->SetInOperands(std::move(ops));
2209 return true;
2210 }
2211 }
2212 }
2213
2214 return false;
2215 };
2216 }
2217
2218 enum class FloatConstantKind { Unknown, Zero, One };
2219
getFloatConstantKind(const analysis::Constant * constant)2220 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
2221 if (constant == nullptr) {
2222 return FloatConstantKind::Unknown;
2223 }
2224
2225 assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
2226
2227 if (constant->AsNullConstant()) {
2228 return FloatConstantKind::Zero;
2229 } else if (const analysis::VectorConstant* vc =
2230 constant->AsVectorConstant()) {
2231 const std::vector<const analysis::Constant*>& components =
2232 vc->GetComponents();
2233 assert(!components.empty());
2234
2235 FloatConstantKind kind = getFloatConstantKind(components[0]);
2236
2237 for (size_t i = 1; i < components.size(); ++i) {
2238 if (getFloatConstantKind(components[i]) != kind) {
2239 return FloatConstantKind::Unknown;
2240 }
2241 }
2242
2243 return kind;
2244 } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
2245 if (fc->IsZero()) return FloatConstantKind::Zero;
2246
2247 uint32_t width = fc->type()->AsFloat()->width();
2248 if (width != 32 && width != 64) return FloatConstantKind::Unknown;
2249
2250 double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
2251
2252 if (value == 0.0) {
2253 return FloatConstantKind::Zero;
2254 } else if (value == 1.0) {
2255 return FloatConstantKind::One;
2256 } else {
2257 return FloatConstantKind::Unknown;
2258 }
2259 } else {
2260 return FloatConstantKind::Unknown;
2261 }
2262 }
2263
RedundantFAdd()2264 FoldingRule RedundantFAdd() {
2265 return [](IRContext*, Instruction* inst,
2266 const std::vector<const analysis::Constant*>& constants) {
2267 assert(inst->opcode() == spv::Op::OpFAdd &&
2268 "Wrong opcode. Should be OpFAdd.");
2269 assert(constants.size() == 2);
2270
2271 if (!inst->IsFloatingPointFoldingAllowed()) {
2272 return false;
2273 }
2274
2275 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2276 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2277
2278 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2279 inst->SetOpcode(spv::Op::OpCopyObject);
2280 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2281 {inst->GetSingleWordInOperand(
2282 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
2283 return true;
2284 }
2285
2286 return false;
2287 };
2288 }
2289
RedundantFSub()2290 FoldingRule RedundantFSub() {
2291 return [](IRContext*, Instruction* inst,
2292 const std::vector<const analysis::Constant*>& constants) {
2293 assert(inst->opcode() == spv::Op::OpFSub &&
2294 "Wrong opcode. Should be OpFSub.");
2295 assert(constants.size() == 2);
2296
2297 if (!inst->IsFloatingPointFoldingAllowed()) {
2298 return false;
2299 }
2300
2301 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2302 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2303
2304 if (kind0 == FloatConstantKind::Zero) {
2305 inst->SetOpcode(spv::Op::OpFNegate);
2306 inst->SetInOperands(
2307 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
2308 return true;
2309 }
2310
2311 if (kind1 == FloatConstantKind::Zero) {
2312 inst->SetOpcode(spv::Op::OpCopyObject);
2313 inst->SetInOperands(
2314 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2315 return true;
2316 }
2317
2318 return false;
2319 };
2320 }
2321
RedundantFMul()2322 FoldingRule RedundantFMul() {
2323 return [](IRContext*, Instruction* inst,
2324 const std::vector<const analysis::Constant*>& constants) {
2325 assert(inst->opcode() == spv::Op::OpFMul &&
2326 "Wrong opcode. Should be OpFMul.");
2327 assert(constants.size() == 2);
2328
2329 if (!inst->IsFloatingPointFoldingAllowed()) {
2330 return false;
2331 }
2332
2333 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2334 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2335
2336 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2337 inst->SetOpcode(spv::Op::OpCopyObject);
2338 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2339 {inst->GetSingleWordInOperand(
2340 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
2341 return true;
2342 }
2343
2344 if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
2345 inst->SetOpcode(spv::Op::OpCopyObject);
2346 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2347 {inst->GetSingleWordInOperand(
2348 kind0 == FloatConstantKind::One ? 1 : 0)}}});
2349 return true;
2350 }
2351
2352 return false;
2353 };
2354 }
2355
RedundantFDiv()2356 FoldingRule RedundantFDiv() {
2357 return [](IRContext*, Instruction* inst,
2358 const std::vector<const analysis::Constant*>& constants) {
2359 assert(inst->opcode() == spv::Op::OpFDiv &&
2360 "Wrong opcode. Should be OpFDiv.");
2361 assert(constants.size() == 2);
2362
2363 if (!inst->IsFloatingPointFoldingAllowed()) {
2364 return false;
2365 }
2366
2367 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2368 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2369
2370 if (kind0 == FloatConstantKind::Zero) {
2371 inst->SetOpcode(spv::Op::OpCopyObject);
2372 inst->SetInOperands(
2373 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2374 return true;
2375 }
2376
2377 if (kind1 == FloatConstantKind::One) {
2378 inst->SetOpcode(spv::Op::OpCopyObject);
2379 inst->SetInOperands(
2380 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2381 return true;
2382 }
2383
2384 return false;
2385 };
2386 }
2387
RedundantFMix()2388 FoldingRule RedundantFMix() {
2389 return [](IRContext* context, Instruction* inst,
2390 const std::vector<const analysis::Constant*>& constants) {
2391 assert(inst->opcode() == spv::Op::OpExtInst &&
2392 "Wrong opcode. Should be OpExtInst.");
2393
2394 if (!inst->IsFloatingPointFoldingAllowed()) {
2395 return false;
2396 }
2397
2398 uint32_t instSetId =
2399 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2400
2401 if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2402 inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2403 GLSLstd450FMix) {
2404 assert(constants.size() == 5);
2405
2406 FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2407
2408 if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2409 inst->SetOpcode(spv::Op::OpCopyObject);
2410 inst->SetInOperands(
2411 {{SPV_OPERAND_TYPE_ID,
2412 {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2413 ? kFMixXIdInIdx
2414 : kFMixYIdInIdx)}}});
2415 return true;
2416 }
2417 }
2418
2419 return false;
2420 };
2421 }
2422
2423 // This rule handles addition of zero for integers.
RedundantIAdd()2424 FoldingRule RedundantIAdd() {
2425 return [](IRContext* context, Instruction* inst,
2426 const std::vector<const analysis::Constant*>& constants) {
2427 assert(inst->opcode() == spv::Op::OpIAdd &&
2428 "Wrong opcode. Should be OpIAdd.");
2429
2430 uint32_t operand = std::numeric_limits<uint32_t>::max();
2431 const analysis::Type* operand_type = nullptr;
2432 if (constants[0] && constants[0]->IsZero()) {
2433 operand = inst->GetSingleWordInOperand(1);
2434 operand_type = constants[0]->type();
2435 } else if (constants[1] && constants[1]->IsZero()) {
2436 operand = inst->GetSingleWordInOperand(0);
2437 operand_type = constants[1]->type();
2438 }
2439
2440 if (operand != std::numeric_limits<uint32_t>::max()) {
2441 const analysis::Type* inst_type =
2442 context->get_type_mgr()->GetType(inst->type_id());
2443 if (inst_type->IsSame(operand_type)) {
2444 inst->SetOpcode(spv::Op::OpCopyObject);
2445 } else {
2446 inst->SetOpcode(spv::Op::OpBitcast);
2447 }
2448 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2449 return true;
2450 }
2451 return false;
2452 };
2453 }
2454
2455 // This rule look for a dot with a constant vector containing a single 1 and
2456 // the rest 0s. This is the same as doing an extract.
DotProductDoingExtract()2457 FoldingRule DotProductDoingExtract() {
2458 return [](IRContext* context, Instruction* inst,
2459 const std::vector<const analysis::Constant*>& constants) {
2460 assert(inst->opcode() == spv::Op::OpDot &&
2461 "Wrong opcode. Should be OpDot.");
2462
2463 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2464
2465 if (!inst->IsFloatingPointFoldingAllowed()) {
2466 return false;
2467 }
2468
2469 for (int i = 0; i < 2; ++i) {
2470 if (!constants[i]) {
2471 continue;
2472 }
2473
2474 const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2475 assert(vector_type && "Inputs to OpDot must be vectors.");
2476 const analysis::Float* element_type =
2477 vector_type->element_type()->AsFloat();
2478 assert(element_type && "Inputs to OpDot must be vectors of floats.");
2479 uint32_t element_width = element_type->width();
2480 if (element_width != 32 && element_width != 64) {
2481 return false;
2482 }
2483
2484 std::vector<const analysis::Constant*> components;
2485 components = constants[i]->GetVectorComponents(const_mgr);
2486
2487 constexpr uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2488
2489 uint32_t component_with_one = kNotFound;
2490 bool all_others_zero = true;
2491 for (uint32_t j = 0; j < components.size(); ++j) {
2492 const analysis::Constant* element = components[j];
2493 double value =
2494 (element_width == 32 ? element->GetFloat() : element->GetDouble());
2495 if (value == 0.0) {
2496 continue;
2497 } else if (value == 1.0) {
2498 if (component_with_one == kNotFound) {
2499 component_with_one = j;
2500 } else {
2501 component_with_one = kNotFound;
2502 break;
2503 }
2504 } else {
2505 all_others_zero = false;
2506 break;
2507 }
2508 }
2509
2510 if (!all_others_zero || component_with_one == kNotFound) {
2511 continue;
2512 }
2513
2514 std::vector<Operand> operands;
2515 operands.push_back(
2516 {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2517 operands.push_back(
2518 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2519
2520 inst->SetOpcode(spv::Op::OpCompositeExtract);
2521 inst->SetInOperands(std::move(operands));
2522 return true;
2523 }
2524 return false;
2525 };
2526 }
2527
2528 // If we are storing an undef, then we can remove the store.
2529 //
2530 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2531 // is complicated. Waiting to see if it is needed.
StoringUndef()2532 FoldingRule StoringUndef() {
2533 return [](IRContext* context, Instruction* inst,
2534 const std::vector<const analysis::Constant*>&) {
2535 assert(inst->opcode() == spv::Op::OpStore &&
2536 "Wrong opcode. Should be OpStore.");
2537
2538 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2539
2540 // If this is a volatile store, the store cannot be removed.
2541 if (inst->NumInOperands() == 3) {
2542 if (inst->GetSingleWordInOperand(2) &
2543 uint32_t(spv::MemoryAccessMask::Volatile)) {
2544 return false;
2545 }
2546 }
2547
2548 uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2549 Instruction* object_inst = def_use_mgr->GetDef(object_id);
2550 if (object_inst->opcode() == spv::Op::OpUndef) {
2551 inst->ToNop();
2552 return true;
2553 }
2554 return false;
2555 };
2556 }
2557
VectorShuffleFeedingShuffle()2558 FoldingRule VectorShuffleFeedingShuffle() {
2559 return [](IRContext* context, Instruction* inst,
2560 const std::vector<const analysis::Constant*>&) {
2561 assert(inst->opcode() == spv::Op::OpVectorShuffle &&
2562 "Wrong opcode. Should be OpVectorShuffle.");
2563
2564 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2565 analysis::TypeManager* type_mgr = context->get_type_mgr();
2566
2567 Instruction* feeding_shuffle_inst =
2568 def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2569 analysis::Vector* op0_type =
2570 type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2571 uint32_t op0_length = op0_type->element_count();
2572
2573 bool feeder_is_op0 = true;
2574 if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2575 feeding_shuffle_inst =
2576 def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2577 feeder_is_op0 = false;
2578 }
2579
2580 if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2581 return false;
2582 }
2583
2584 Instruction* feeder2 =
2585 def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2586 analysis::Vector* feeder_op0_type =
2587 type_mgr->GetType(feeder2->type_id())->AsVector();
2588 uint32_t feeder_op0_length = feeder_op0_type->element_count();
2589
2590 uint32_t new_feeder_id = 0;
2591 std::vector<Operand> new_operands;
2592 new_operands.resize(
2593 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
2594 const uint32_t undef_literal = 0xffffffff;
2595 for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2596 uint32_t component_index = inst->GetSingleWordInOperand(op);
2597
2598 // Do not interpret the undefined value literal as coming from operand 1.
2599 if (component_index != undef_literal &&
2600 feeder_is_op0 == (component_index < op0_length)) {
2601 // This component comes from the feeding_shuffle_inst. Update
2602 // |component_index| to be the index into the operand of the feeder.
2603
2604 // Adjust component_index to get the index into the operands of the
2605 // feeding_shuffle_inst.
2606 if (component_index >= op0_length) {
2607 component_index -= op0_length;
2608 }
2609 component_index =
2610 feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2611
2612 // Check if we are using a component from the first or second operand of
2613 // the feeding instruction.
2614 if (component_index < feeder_op0_length) {
2615 if (new_feeder_id == 0) {
2616 // First time through, save the id of the operand the element comes
2617 // from.
2618 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2619 } else if (new_feeder_id !=
2620 feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2621 // We need both elements of the feeding_shuffle_inst, so we cannot
2622 // fold.
2623 return false;
2624 }
2625 } else if (component_index != undef_literal) {
2626 if (new_feeder_id == 0) {
2627 // First time through, save the id of the operand the element comes
2628 // from.
2629 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2630 } else if (new_feeder_id !=
2631 feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2632 // We need both elements of the feeding_shuffle_inst, so we cannot
2633 // fold.
2634 return false;
2635 }
2636 component_index -= feeder_op0_length;
2637 }
2638
2639 if (!feeder_is_op0 && component_index != undef_literal) {
2640 component_index += op0_length;
2641 }
2642 }
2643 new_operands.push_back(
2644 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2645 }
2646
2647 if (new_feeder_id == 0) {
2648 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2649 const analysis::Type* type =
2650 type_mgr->GetType(feeding_shuffle_inst->type_id());
2651 const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2652 new_feeder_id =
2653 const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2654 }
2655
2656 if (feeder_is_op0) {
2657 // If the size of the first vector operand changed then the indices
2658 // referring to the second operand need to be adjusted.
2659 Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2660 analysis::Type* new_feeder_type =
2661 type_mgr->GetType(new_feeder_inst->type_id());
2662 uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2663 int32_t adjustment = op0_length - new_op0_size;
2664
2665 if (adjustment != 0) {
2666 for (uint32_t i = 2; i < new_operands.size(); i++) {
2667 uint32_t operand = inst->GetSingleWordInOperand(i);
2668 if (operand >= op0_length && operand != undef_literal) {
2669 new_operands[i].words[0] -= adjustment;
2670 }
2671 }
2672 }
2673
2674 new_operands[0].words[0] = new_feeder_id;
2675 new_operands[1] = inst->GetInOperand(1);
2676 } else {
2677 new_operands[1].words[0] = new_feeder_id;
2678 new_operands[0] = inst->GetInOperand(0);
2679 }
2680
2681 inst->SetInOperands(std::move(new_operands));
2682 return true;
2683 };
2684 }
2685
2686 // Removes duplicate ids from the interface list of an OpEntryPoint
2687 // instruction.
RemoveRedundantOperands()2688 FoldingRule RemoveRedundantOperands() {
2689 return [](IRContext*, Instruction* inst,
2690 const std::vector<const analysis::Constant*>&) {
2691 assert(inst->opcode() == spv::Op::OpEntryPoint &&
2692 "Wrong opcode. Should be OpEntryPoint.");
2693 bool has_redundant_operand = false;
2694 std::unordered_set<uint32_t> seen_operands;
2695 std::vector<Operand> new_operands;
2696
2697 new_operands.emplace_back(inst->GetOperand(0));
2698 new_operands.emplace_back(inst->GetOperand(1));
2699 new_operands.emplace_back(inst->GetOperand(2));
2700 for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2701 if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2702 new_operands.emplace_back(inst->GetOperand(i));
2703 } else {
2704 has_redundant_operand = true;
2705 }
2706 }
2707
2708 if (!has_redundant_operand) {
2709 return false;
2710 }
2711
2712 inst->SetInOperands(std::move(new_operands));
2713 return true;
2714 };
2715 }
2716
2717 // If an image instruction's operand is a constant, updates the image operand
2718 // flag from Offset to ConstOffset.
UpdateImageOperands()2719 FoldingRule UpdateImageOperands() {
2720 return [](IRContext*, Instruction* inst,
2721 const std::vector<const analysis::Constant*>& constants) {
2722 const auto opcode = inst->opcode();
2723 (void)opcode;
2724 assert((opcode == spv::Op::OpImageSampleImplicitLod ||
2725 opcode == spv::Op::OpImageSampleExplicitLod ||
2726 opcode == spv::Op::OpImageSampleDrefImplicitLod ||
2727 opcode == spv::Op::OpImageSampleDrefExplicitLod ||
2728 opcode == spv::Op::OpImageSampleProjImplicitLod ||
2729 opcode == spv::Op::OpImageSampleProjExplicitLod ||
2730 opcode == spv::Op::OpImageSampleProjDrefImplicitLod ||
2731 opcode == spv::Op::OpImageSampleProjDrefExplicitLod ||
2732 opcode == spv::Op::OpImageFetch ||
2733 opcode == spv::Op::OpImageGather ||
2734 opcode == spv::Op::OpImageDrefGather ||
2735 opcode == spv::Op::OpImageRead || opcode == spv::Op::OpImageWrite ||
2736 opcode == spv::Op::OpImageSparseSampleImplicitLod ||
2737 opcode == spv::Op::OpImageSparseSampleExplicitLod ||
2738 opcode == spv::Op::OpImageSparseSampleDrefImplicitLod ||
2739 opcode == spv::Op::OpImageSparseSampleDrefExplicitLod ||
2740 opcode == spv::Op::OpImageSparseSampleProjImplicitLod ||
2741 opcode == spv::Op::OpImageSparseSampleProjExplicitLod ||
2742 opcode == spv::Op::OpImageSparseSampleProjDrefImplicitLod ||
2743 opcode == spv::Op::OpImageSparseSampleProjDrefExplicitLod ||
2744 opcode == spv::Op::OpImageSparseFetch ||
2745 opcode == spv::Op::OpImageSparseGather ||
2746 opcode == spv::Op::OpImageSparseDrefGather ||
2747 opcode == spv::Op::OpImageSparseRead) &&
2748 "Wrong opcode. Should be an image instruction.");
2749
2750 int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2751 if (operand_index >= 0) {
2752 auto image_operands = inst->GetSingleWordInOperand(operand_index);
2753 if (image_operands & uint32_t(spv::ImageOperandsMask::Offset)) {
2754 uint32_t offset_operand_index = operand_index + 1;
2755 if (image_operands & uint32_t(spv::ImageOperandsMask::Bias))
2756 offset_operand_index++;
2757 if (image_operands & uint32_t(spv::ImageOperandsMask::Lod))
2758 offset_operand_index++;
2759 if (image_operands & uint32_t(spv::ImageOperandsMask::Grad))
2760 offset_operand_index += 2;
2761 assert(((image_operands &
2762 uint32_t(spv::ImageOperandsMask::ConstOffset)) == 0) &&
2763 "Offset and ConstOffset may not be used together");
2764 if (offset_operand_index < inst->NumOperands()) {
2765 if (constants[offset_operand_index]) {
2766 if (constants[offset_operand_index]->IsZero()) {
2767 inst->RemoveInOperand(offset_operand_index);
2768 } else {
2769 image_operands = image_operands |
2770 uint32_t(spv::ImageOperandsMask::ConstOffset);
2771 }
2772 image_operands =
2773 image_operands & ~uint32_t(spv::ImageOperandsMask::Offset);
2774 inst->SetInOperand(operand_index, {image_operands});
2775 return true;
2776 }
2777 }
2778 }
2779 }
2780
2781 return false;
2782 };
2783 }
2784
2785 } // namespace
2786
AddFoldingRules()2787 void FoldingRules::AddFoldingRules() {
2788 // Add all folding rules to the list for the opcodes to which they apply.
2789 // Note that the order in which rules are added to the list matters. If a rule
2790 // applies to the instruction, the rest of the rules will not be attempted.
2791 // Take that into consideration.
2792 rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());
2793
2794 rules_[spv::Op::OpCompositeConstruct].push_back(
2795 CompositeExtractFeedingConstruct);
2796
2797 rules_[spv::Op::OpCompositeExtract].push_back(InsertFeedingExtract());
2798 rules_[spv::Op::OpCompositeExtract].push_back(
2799 CompositeConstructFeedingExtract);
2800 rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2801 rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract());
2802
2803 rules_[spv::Op::OpCompositeInsert].push_back(
2804 CompositeInsertToCompositeConstruct);
2805
2806 rules_[spv::Op::OpDot].push_back(DotProductDoingExtract());
2807
2808 rules_[spv::Op::OpEntryPoint].push_back(RemoveRedundantOperands());
2809
2810 rules_[spv::Op::OpFAdd].push_back(RedundantFAdd());
2811 rules_[spv::Op::OpFAdd].push_back(MergeAddNegateArithmetic());
2812 rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic());
2813 rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic());
2814 rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic());
2815 rules_[spv::Op::OpFAdd].push_back(FactorAddMuls());
2816
2817 rules_[spv::Op::OpFDiv].push_back(RedundantFDiv());
2818 rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv());
2819 rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic());
2820 rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
2821 rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());
2822
2823 rules_[spv::Op::OpFMul].push_back(RedundantFMul());
2824 rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
2825 rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());
2826 rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic());
2827
2828 rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic());
2829 rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic());
2830 rules_[spv::Op::OpFNegate].push_back(MergeNegateMulDivArithmetic());
2831
2832 rules_[spv::Op::OpFSub].push_back(RedundantFSub());
2833 rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic());
2834 rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic());
2835 rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic());
2836
2837 rules_[spv::Op::OpIAdd].push_back(RedundantIAdd());
2838 rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic());
2839 rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic());
2840 rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic());
2841 rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic());
2842 rules_[spv::Op::OpIAdd].push_back(FactorAddMuls());
2843
2844 rules_[spv::Op::OpIMul].push_back(IntMultipleBy1());
2845 rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic());
2846 rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic());
2847
2848 rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic());
2849 rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic());
2850 rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic());
2851
2852 rules_[spv::Op::OpPhi].push_back(RedundantPhi());
2853
2854 rules_[spv::Op::OpSNegate].push_back(MergeNegateArithmetic());
2855 rules_[spv::Op::OpSNegate].push_back(MergeNegateMulDivArithmetic());
2856 rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic());
2857
2858 rules_[spv::Op::OpSelect].push_back(RedundantSelect());
2859
2860 rules_[spv::Op::OpStore].push_back(StoringUndef());
2861
2862 rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2863
2864 rules_[spv::Op::OpImageSampleImplicitLod].push_back(UpdateImageOperands());
2865 rules_[spv::Op::OpImageSampleExplicitLod].push_back(UpdateImageOperands());
2866 rules_[spv::Op::OpImageSampleDrefImplicitLod].push_back(
2867 UpdateImageOperands());
2868 rules_[spv::Op::OpImageSampleDrefExplicitLod].push_back(
2869 UpdateImageOperands());
2870 rules_[spv::Op::OpImageSampleProjImplicitLod].push_back(
2871 UpdateImageOperands());
2872 rules_[spv::Op::OpImageSampleProjExplicitLod].push_back(
2873 UpdateImageOperands());
2874 rules_[spv::Op::OpImageSampleProjDrefImplicitLod].push_back(
2875 UpdateImageOperands());
2876 rules_[spv::Op::OpImageSampleProjDrefExplicitLod].push_back(
2877 UpdateImageOperands());
2878 rules_[spv::Op::OpImageFetch].push_back(UpdateImageOperands());
2879 rules_[spv::Op::OpImageGather].push_back(UpdateImageOperands());
2880 rules_[spv::Op::OpImageDrefGather].push_back(UpdateImageOperands());
2881 rules_[spv::Op::OpImageRead].push_back(UpdateImageOperands());
2882 rules_[spv::Op::OpImageWrite].push_back(UpdateImageOperands());
2883 rules_[spv::Op::OpImageSparseSampleImplicitLod].push_back(
2884 UpdateImageOperands());
2885 rules_[spv::Op::OpImageSparseSampleExplicitLod].push_back(
2886 UpdateImageOperands());
2887 rules_[spv::Op::OpImageSparseSampleDrefImplicitLod].push_back(
2888 UpdateImageOperands());
2889 rules_[spv::Op::OpImageSparseSampleDrefExplicitLod].push_back(
2890 UpdateImageOperands());
2891 rules_[spv::Op::OpImageSparseSampleProjImplicitLod].push_back(
2892 UpdateImageOperands());
2893 rules_[spv::Op::OpImageSparseSampleProjExplicitLod].push_back(
2894 UpdateImageOperands());
2895 rules_[spv::Op::OpImageSparseSampleProjDrefImplicitLod].push_back(
2896 UpdateImageOperands());
2897 rules_[spv::Op::OpImageSparseSampleProjDrefExplicitLod].push_back(
2898 UpdateImageOperands());
2899 rules_[spv::Op::OpImageSparseFetch].push_back(UpdateImageOperands());
2900 rules_[spv::Op::OpImageSparseGather].push_back(UpdateImageOperands());
2901 rules_[spv::Op::OpImageSparseDrefGather].push_back(UpdateImageOperands());
2902 rules_[spv::Op::OpImageSparseRead].push_back(UpdateImageOperands());
2903
2904 FeatureManager* feature_manager = context_->get_feature_mgr();
2905 // Add rules for GLSLstd450
2906 uint32_t ext_inst_glslstd450_id =
2907 feature_manager->GetExtInstImportId_GLSLstd450();
2908 if (ext_inst_glslstd450_id != 0) {
2909 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
2910 RedundantFMix());
2911 }
2912 }
2913 } // namespace opt
2914 } // namespace spvtools
2915