1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/folding_rules.h"
16
17 #include <limits>
18 #include <memory>
19 #include <utility>
20
21 #include "ir_builder.h"
22 #include "source/latest_version_glsl_std_450_header.h"
23 #include "source/opt/ir_context.h"
24
25 namespace spvtools {
26 namespace opt {
27 namespace {
28
29 constexpr uint32_t kExtractCompositeIdInIdx = 0;
30 constexpr uint32_t kInsertObjectIdInIdx = 0;
31 constexpr uint32_t kInsertCompositeIdInIdx = 1;
32 constexpr uint32_t kExtInstSetIdInIdx = 0;
33 constexpr uint32_t kExtInstInstructionInIdx = 1;
34 constexpr uint32_t kFMixXIdInIdx = 2;
35 constexpr uint32_t kFMixYIdInIdx = 3;
36 constexpr uint32_t kFMixAIdInIdx = 4;
37 constexpr uint32_t kStoreObjectInIdx = 1;
38
39 // Some image instructions may contain an "image operands" argument.
40 // Returns the operand index for the "image operands".
41 // Returns -1 if the instruction does not have image operands.
ImageOperandsMaskInOperandIndex(Instruction * inst)42 int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
43 const auto opcode = inst->opcode();
44 switch (opcode) {
45 case spv::Op::OpImageSampleImplicitLod:
46 case spv::Op::OpImageSampleExplicitLod:
47 case spv::Op::OpImageSampleProjImplicitLod:
48 case spv::Op::OpImageSampleProjExplicitLod:
49 case spv::Op::OpImageFetch:
50 case spv::Op::OpImageRead:
51 case spv::Op::OpImageSparseSampleImplicitLod:
52 case spv::Op::OpImageSparseSampleExplicitLod:
53 case spv::Op::OpImageSparseSampleProjImplicitLod:
54 case spv::Op::OpImageSparseSampleProjExplicitLod:
55 case spv::Op::OpImageSparseFetch:
56 case spv::Op::OpImageSparseRead:
57 return inst->NumOperands() > 4 ? 2 : -1;
58 case spv::Op::OpImageSampleDrefImplicitLod:
59 case spv::Op::OpImageSampleDrefExplicitLod:
60 case spv::Op::OpImageSampleProjDrefImplicitLod:
61 case spv::Op::OpImageSampleProjDrefExplicitLod:
62 case spv::Op::OpImageGather:
63 case spv::Op::OpImageDrefGather:
64 case spv::Op::OpImageSparseSampleDrefImplicitLod:
65 case spv::Op::OpImageSparseSampleDrefExplicitLod:
66 case spv::Op::OpImageSparseSampleProjDrefImplicitLod:
67 case spv::Op::OpImageSparseSampleProjDrefExplicitLod:
68 case spv::Op::OpImageSparseGather:
69 case spv::Op::OpImageSparseDrefGather:
70 return inst->NumOperands() > 5 ? 3 : -1;
71 case spv::Op::OpImageWrite:
72 return inst->NumOperands() > 3 ? 3 : -1;
73 default:
74 return -1;
75 }
76 }
77
78 // Returns the element width of |type|.
ElementWidth(const analysis::Type * type)79 uint32_t ElementWidth(const analysis::Type* type) {
80 if (const analysis::Vector* vec_type = type->AsVector()) {
81 return ElementWidth(vec_type->element_type());
82 } else if (const analysis::Float* float_type = type->AsFloat()) {
83 return float_type->width();
84 } else {
85 assert(type->AsInteger());
86 return type->AsInteger()->width();
87 }
88 }
89
90 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)91 bool HasFloatingPoint(const analysis::Type* type) {
92 if (type->AsFloat()) {
93 return true;
94 } else if (const analysis::Vector* vec_type = type->AsVector()) {
95 return vec_type->element_type()->AsFloat() != nullptr;
96 }
97
98 return false;
99 }
100
101 // Returns false if |val| is NaN, infinite or subnormal.
102 template <typename T>
IsValidResult(T val)103 bool IsValidResult(T val) {
104 int classified = std::fpclassify(val);
105 switch (classified) {
106 case FP_NAN:
107 case FP_INFINITE:
108 case FP_SUBNORMAL:
109 return false;
110 default:
111 return true;
112 }
113 }
114
ConstInput(const std::vector<const analysis::Constant * > & constants)115 const analysis::Constant* ConstInput(
116 const std::vector<const analysis::Constant*>& constants) {
117 return constants[0] ? constants[0] : constants[1];
118 }
119
NonConstInput(IRContext * context,const analysis::Constant * c,Instruction * inst)120 Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
121 Instruction* inst) {
122 uint32_t in_op = c ? 1u : 0u;
123 return context->get_def_use_mgr()->GetDef(
124 inst->GetSingleWordInOperand(in_op));
125 }
126
ExtractInts(uint64_t val)127 std::vector<uint32_t> ExtractInts(uint64_t val) {
128 std::vector<uint32_t> words;
129 words.push_back(static_cast<uint32_t>(val));
130 words.push_back(static_cast<uint32_t>(val >> 32));
131 return words;
132 }
133
GetWordsFromScalarIntConstant(const analysis::IntConstant * c)134 std::vector<uint32_t> GetWordsFromScalarIntConstant(
135 const analysis::IntConstant* c) {
136 assert(c != nullptr);
137 uint32_t width = c->type()->AsInteger()->width();
138 assert(width == 8 || width == 16 || width == 32 || width == 64);
139 if (width == 64) {
140 uint64_t uval = static_cast<uint64_t>(c->GetU64());
141 return ExtractInts(uval);
142 }
143 // Section 2.2.1 of the SPIR-V spec guarantees that all integer types
144 // smaller than 32-bits are automatically zero or sign extended to 32-bits.
145 return {c->GetU32BitValue()};
146 }
147
GetWordsFromScalarFloatConstant(const analysis::FloatConstant * c)148 std::vector<uint32_t> GetWordsFromScalarFloatConstant(
149 const analysis::FloatConstant* c) {
150 assert(c != nullptr);
151 uint32_t width = c->type()->AsFloat()->width();
152 assert(width == 16 || width == 32 || width == 64);
153 if (width == 64) {
154 utils::FloatProxy<double> result(c->GetDouble());
155 return result.GetWords();
156 }
157 // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types
158 // smaller than 32-bits are automatically zero extended to 32-bits.
159 return {c->GetU32BitValue()};
160 }
161
GetWordsFromNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)162 std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
163 analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
164 if (const auto* float_constant = c->AsFloatConstant()) {
165 return GetWordsFromScalarFloatConstant(float_constant);
166 } else if (const auto* int_constant = c->AsIntConstant()) {
167 return GetWordsFromScalarIntConstant(int_constant);
168 } else if (const auto* vec_constant = c->AsVectorConstant()) {
169 std::vector<uint32_t> words;
170 for (const auto* comp : vec_constant->GetComponents()) {
171 auto comp_in_words =
172 GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
173 words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
174 }
175 return words;
176 }
177 return {};
178 }
179
ConvertWordsToNumericScalarOrVectorConstant(analysis::ConstantManager * const_mgr,const std::vector<uint32_t> & words,const analysis::Type * type)180 const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
181 analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
182 const analysis::Type* type) {
183 if (type->AsInteger() || type->AsFloat())
184 return const_mgr->GetConstant(type, words);
185 if (const auto* vec_type = type->AsVector())
186 return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
187 return nullptr;
188 }
189
190 // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
191 // constant.
NegateFloatingPointConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)192 uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
193 const analysis::Constant* c) {
194 assert(c);
195 assert(c->type()->AsFloat());
196 uint32_t width = c->type()->AsFloat()->width();
197 assert(width == 32 || width == 64);
198 std::vector<uint32_t> words;
199 if (width == 64) {
200 utils::FloatProxy<double> result(c->GetDouble() * -1.0);
201 words = result.GetWords();
202 } else {
203 utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
204 words = result.GetWords();
205 }
206
207 const analysis::Constant* negated_const =
208 const_mgr->GetConstant(c->type(), std::move(words));
209 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
210 }
211
212 // Negates the integer constant |c|. Returns the id of the defining instruction.
NegateIntegerConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)213 uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
214 const analysis::Constant* c) {
215 assert(c);
216 assert(c->type()->AsInteger());
217 uint32_t width = c->type()->AsInteger()->width();
218 assert(width == 32 || width == 64);
219 std::vector<uint32_t> words;
220 if (width == 64) {
221 uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
222 words = ExtractInts(uval);
223 } else {
224 words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
225 }
226
227 const analysis::Constant* negated_const =
228 const_mgr->GetConstant(c->type(), std::move(words));
229 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
230 }
231
232 // Negates the vector constant |c|. Returns the id of the defining instruction.
NegateVectorConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)233 uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
234 const analysis::Constant* c) {
235 assert(const_mgr && c);
236 assert(c->type()->AsVector());
237 if (c->AsNullConstant()) {
238 // 0.0 vs -0.0 shouldn't matter.
239 return const_mgr->GetDefiningInstruction(c)->result_id();
240 } else {
241 const analysis::Type* component_type =
242 c->AsVectorConstant()->component_type();
243 std::vector<uint32_t> words;
244 for (auto& comp : c->AsVectorConstant()->GetComponents()) {
245 if (component_type->AsFloat()) {
246 words.push_back(NegateFloatingPointConstant(const_mgr, comp));
247 } else {
248 assert(component_type->AsInteger());
249 words.push_back(NegateIntegerConstant(const_mgr, comp));
250 }
251 }
252
253 const analysis::Constant* negated_const =
254 const_mgr->GetConstant(c->type(), std::move(words));
255 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
256 }
257 }
258
259 // Negates |c|. Returns the id of the defining instruction.
NegateConstant(analysis::ConstantManager * const_mgr,const analysis::Constant * c)260 uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
261 const analysis::Constant* c) {
262 if (c->type()->AsVector()) {
263 return NegateVectorConstant(const_mgr, c);
264 } else if (c->type()->AsFloat()) {
265 return NegateFloatingPointConstant(const_mgr, c);
266 } else {
267 assert(c->type()->AsInteger());
268 return NegateIntegerConstant(const_mgr, c);
269 }
270 }
271
272 // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
273 // Returns 0 if the reciprocal is NaN, infinite or subnormal.
Reciprocal(analysis::ConstantManager * const_mgr,const analysis::Constant * c)274 uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
275 const analysis::Constant* c) {
276 assert(const_mgr && c);
277 assert(c->type()->AsFloat());
278
279 uint32_t width = c->type()->AsFloat()->width();
280 assert(width == 32 || width == 64);
281 std::vector<uint32_t> words;
282
283 if (c->IsZero()) {
284 return 0;
285 }
286
287 if (width == 64) {
288 spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
289 if (!IsValidResult(result.getAsFloat())) return 0;
290 words = result.GetWords();
291 } else {
292 spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
293 if (!IsValidResult(result.getAsFloat())) return 0;
294 words = result.GetWords();
295 }
296
297 const analysis::Constant* negated_const =
298 const_mgr->GetConstant(c->type(), std::move(words));
299 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
300 }
301
302 // Replaces fdiv where second operand is constant with fmul.
ReciprocalFDiv()303 FoldingRule ReciprocalFDiv() {
304 return [](IRContext* context, Instruction* inst,
305 const std::vector<const analysis::Constant*>& constants) {
306 assert(inst->opcode() == spv::Op::OpFDiv);
307 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
308 const analysis::Type* type =
309 context->get_type_mgr()->GetType(inst->type_id());
310 if (!inst->IsFloatingPointFoldingAllowed()) return false;
311
312 uint32_t width = ElementWidth(type);
313 if (width != 32 && width != 64) return false;
314
315 if (constants[1] != nullptr) {
316 uint32_t id = 0;
317 if (const analysis::VectorConstant* vector_const =
318 constants[1]->AsVectorConstant()) {
319 std::vector<uint32_t> neg_ids;
320 for (auto& comp : vector_const->GetComponents()) {
321 id = Reciprocal(const_mgr, comp);
322 if (id == 0) return false;
323 neg_ids.push_back(id);
324 }
325 const analysis::Constant* negated_const =
326 const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
327 id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
328 } else if (constants[1]->AsFloatConstant()) {
329 id = Reciprocal(const_mgr, constants[1]);
330 if (id == 0) return false;
331 } else {
332 // Don't fold a null constant.
333 return false;
334 }
335 inst->SetOpcode(spv::Op::OpFMul);
336 inst->SetInOperands(
337 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
338 {SPV_OPERAND_TYPE_ID, {id}}});
339 return true;
340 }
341
342 return false;
343 };
344 }
345
346 // Elides consecutive negate instructions.
MergeNegateArithmetic()347 FoldingRule MergeNegateArithmetic() {
348 return [](IRContext* context, Instruction* inst,
349 const std::vector<const analysis::Constant*>& constants) {
350 assert(inst->opcode() == spv::Op::OpFNegate ||
351 inst->opcode() == spv::Op::OpSNegate);
352 (void)constants;
353 const analysis::Type* type =
354 context->get_type_mgr()->GetType(inst->type_id());
355 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
356 return false;
357
358 Instruction* op_inst =
359 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
360 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
361 return false;
362
363 if (op_inst->opcode() == inst->opcode()) {
364 // Elide negates.
365 inst->SetOpcode(spv::Op::OpCopyObject);
366 inst->SetInOperands(
367 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
368 return true;
369 }
370
371 return false;
372 };
373 }
374
375 // Merges negate into a mul or div operation if that operation contains a
376 // constant operand.
377 // Cases:
378 // -(x * 2) = x * -2
379 // -(2 * x) = x * -2
380 // -(x / 2) = x / -2
381 // -(2 / x) = -2 / x
MergeNegateMulDivArithmetic()382 FoldingRule MergeNegateMulDivArithmetic() {
383 return [](IRContext* context, Instruction* inst,
384 const std::vector<const analysis::Constant*>& constants) {
385 assert(inst->opcode() == spv::Op::OpFNegate ||
386 inst->opcode() == spv::Op::OpSNegate);
387 (void)constants;
388 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
389 const analysis::Type* type =
390 context->get_type_mgr()->GetType(inst->type_id());
391 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
392 return false;
393
394 Instruction* op_inst =
395 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
396 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
397 return false;
398
399 uint32_t width = ElementWidth(type);
400 if (width != 32 && width != 64) return false;
401
402 spv::Op opcode = op_inst->opcode();
403 if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv ||
404 opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv ||
405 opcode == spv::Op::OpUDiv) {
406 std::vector<const analysis::Constant*> op_constants =
407 const_mgr->GetOperandConstants(op_inst);
408 // Merge negate into mul or div if one operand is constant.
409 if (op_constants[0] || op_constants[1]) {
410 bool zero_is_variable = op_constants[0] == nullptr;
411 const analysis::Constant* c = ConstInput(op_constants);
412 uint32_t neg_id = NegateConstant(const_mgr, c);
413 uint32_t non_const_id = zero_is_variable
414 ? op_inst->GetSingleWordInOperand(0u)
415 : op_inst->GetSingleWordInOperand(1u);
416 // Change this instruction to a mul/div.
417 inst->SetOpcode(op_inst->opcode());
418 if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
419 opcode == spv::Op::OpSDiv) {
420 uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
421 uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
422 inst->SetInOperands(
423 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
424 } else {
425 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
426 {SPV_OPERAND_TYPE_ID, {neg_id}}});
427 }
428 return true;
429 }
430 }
431
432 return false;
433 };
434 }
435
436 // Merges negate into a add or sub operation if that operation contains a
437 // constant operand.
438 // Cases:
439 // -(x + 2) = -2 - x
440 // -(2 + x) = -2 - x
441 // -(x - 2) = 2 - x
442 // -(2 - x) = x - 2
MergeNegateAddSubArithmetic()443 FoldingRule MergeNegateAddSubArithmetic() {
444 return [](IRContext* context, Instruction* inst,
445 const std::vector<const analysis::Constant*>& constants) {
446 assert(inst->opcode() == spv::Op::OpFNegate ||
447 inst->opcode() == spv::Op::OpSNegate);
448 (void)constants;
449 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
450 const analysis::Type* type =
451 context->get_type_mgr()->GetType(inst->type_id());
452 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
453 return false;
454
455 Instruction* op_inst =
456 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
457 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
458 return false;
459
460 uint32_t width = ElementWidth(type);
461 if (width != 32 && width != 64) return false;
462
463 if (op_inst->opcode() == spv::Op::OpFAdd ||
464 op_inst->opcode() == spv::Op::OpFSub ||
465 op_inst->opcode() == spv::Op::OpIAdd ||
466 op_inst->opcode() == spv::Op::OpISub) {
467 std::vector<const analysis::Constant*> op_constants =
468 const_mgr->GetOperandConstants(op_inst);
469 if (op_constants[0] || op_constants[1]) {
470 bool zero_is_variable = op_constants[0] == nullptr;
471 bool is_add = (op_inst->opcode() == spv::Op::OpFAdd) ||
472 (op_inst->opcode() == spv::Op::OpIAdd);
473 bool swap_operands = !is_add || zero_is_variable;
474 bool negate_const = is_add;
475 const analysis::Constant* c = ConstInput(op_constants);
476 uint32_t const_id = 0;
477 if (negate_const) {
478 const_id = NegateConstant(const_mgr, c);
479 } else {
480 const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
481 : op_inst->GetSingleWordInOperand(0u);
482 }
483
484 // Swap operands if necessary and make the instruction a subtraction.
485 uint32_t op0 =
486 zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
487 uint32_t op1 =
488 zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
489 if (swap_operands) std::swap(op0, op1);
490 inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
491 : spv::Op::OpISub);
492 inst->SetInOperands(
493 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
494 return true;
495 }
496 }
497
498 return false;
499 };
500 }
501
502 // Returns true if |c| has a zero element.
HasZero(const analysis::Constant * c)503 bool HasZero(const analysis::Constant* c) {
504 if (c->AsNullConstant()) {
505 return true;
506 }
507 if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
508 for (auto& comp : vec_const->GetComponents())
509 if (HasZero(comp)) return true;
510 } else {
511 assert(c->AsScalarConstant());
512 return c->AsScalarConstant()->IsZero();
513 }
514
515 return false;
516 }
517
518 // Performs |input1| |opcode| |input2| and returns the merged constant result
519 // id. Returns 0 if the result is not a valid value. The input types must be
520 // Float.
PerformFloatingPointOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)521 uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
522 spv::Op opcode,
523 const analysis::Constant* input1,
524 const analysis::Constant* input2) {
525 const analysis::Type* type = input1->type();
526 assert(type->AsFloat());
527 uint32_t width = type->AsFloat()->width();
528 assert(width == 32 || width == 64);
529 std::vector<uint32_t> words;
530 #define FOLD_OP(op) \
531 if (width == 64) { \
532 utils::FloatProxy<double> val = \
533 input1->GetDouble() op input2->GetDouble(); \
534 double dval = val.getAsFloat(); \
535 if (!IsValidResult(dval)) return 0; \
536 words = val.GetWords(); \
537 } else { \
538 utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
539 float fval = val.getAsFloat(); \
540 if (!IsValidResult(fval)) return 0; \
541 words = val.GetWords(); \
542 } \
543 static_assert(true, "require extra semicolon")
544 switch (opcode) {
545 case spv::Op::OpFMul:
546 FOLD_OP(*);
547 break;
548 case spv::Op::OpFDiv:
549 if (HasZero(input2)) return 0;
550 FOLD_OP(/);
551 break;
552 case spv::Op::OpFAdd:
553 FOLD_OP(+);
554 break;
555 case spv::Op::OpFSub:
556 FOLD_OP(-);
557 break;
558 default:
559 assert(false && "Unexpected operation");
560 break;
561 }
562 #undef FOLD_OP
563 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
564 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
565 }
566
567 // Performs |input1| |opcode| |input2| and returns the merged constant result
568 // id. Returns 0 if the result is not a valid value. The input types must be
569 // Integers.
PerformIntegerOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)570 uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
571 spv::Op opcode,
572 const analysis::Constant* input1,
573 const analysis::Constant* input2) {
574 assert(input1->type()->AsInteger());
575 const analysis::Integer* type = input1->type()->AsInteger();
576 uint32_t width = type->AsInteger()->width();
577 assert(width == 32 || width == 64);
578 std::vector<uint32_t> words;
579 // Regardless of the sign of the constant, folding is performed on an unsigned
580 // interpretation of the constant data. This avoids signed integer overflow
581 // while folding, and works because sign is irrelevant for the IAdd, ISub and
582 // IMul instructions.
583 #define FOLD_OP(op) \
584 if (width == 64) { \
585 uint64_t val = input1->GetU64() op input2->GetU64(); \
586 words = ExtractInts(val); \
587 } else { \
588 uint32_t val = input1->GetU32() op input2->GetU32(); \
589 words.push_back(val); \
590 } \
591 static_assert(true, "require extra semicolon")
592 switch (opcode) {
593 case spv::Op::OpIMul:
594 FOLD_OP(*);
595 break;
596 case spv::Op::OpSDiv:
597 case spv::Op::OpUDiv:
598 assert(false && "Should not merge integer division");
599 break;
600 case spv::Op::OpIAdd:
601 FOLD_OP(+);
602 break;
603 case spv::Op::OpISub:
604 FOLD_OP(-);
605 break;
606 default:
607 assert(false && "Unexpected operation");
608 break;
609 }
610 #undef FOLD_OP
611 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
612 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
613 }
614
615 // Performs |input1| |opcode| |input2| and returns the merged constant result
616 // id. Returns 0 if the result is not a valid value. The input types must be
617 // Integers, Floats or Vectors of such.
PerformOperation(analysis::ConstantManager * const_mgr,spv::Op opcode,const analysis::Constant * input1,const analysis::Constant * input2)618 uint32_t PerformOperation(analysis::ConstantManager* const_mgr, spv::Op opcode,
619 const analysis::Constant* input1,
620 const analysis::Constant* input2) {
621 assert(input1 && input2);
622 const analysis::Type* type = input1->type();
623 std::vector<uint32_t> words;
624 if (const analysis::Vector* vector_type = type->AsVector()) {
625 const analysis::Type* ele_type = vector_type->element_type();
626 for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
627 uint32_t id = 0;
628
629 const analysis::Constant* input1_comp = nullptr;
630 if (const analysis::VectorConstant* input1_vector =
631 input1->AsVectorConstant()) {
632 input1_comp = input1_vector->GetComponents()[i];
633 } else {
634 assert(input1->AsNullConstant());
635 input1_comp = const_mgr->GetConstant(ele_type, {});
636 }
637
638 const analysis::Constant* input2_comp = nullptr;
639 if (const analysis::VectorConstant* input2_vector =
640 input2->AsVectorConstant()) {
641 input2_comp = input2_vector->GetComponents()[i];
642 } else {
643 assert(input2->AsNullConstant());
644 input2_comp = const_mgr->GetConstant(ele_type, {});
645 }
646
647 if (ele_type->AsFloat()) {
648 id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
649 input2_comp);
650 } else {
651 assert(ele_type->AsInteger());
652 id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
653 input2_comp);
654 }
655 if (id == 0) return 0;
656 words.push_back(id);
657 }
658 const analysis::Constant* merged_const =
659 const_mgr->GetConstant(type, words);
660 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
661 } else if (type->AsFloat()) {
662 return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
663 } else {
664 assert(type->AsInteger());
665 return PerformIntegerOperation(const_mgr, opcode, input1, input2);
666 }
667 }
668
669 // Merges consecutive multiplies where each contains one constant operand.
670 // Cases:
671 // 2 * (x * 2) = x * 4
672 // 2 * (2 * x) = x * 4
673 // (x * 2) * 2 = x * 4
674 // (2 * x) * 2 = x * 4
MergeMulMulArithmetic()675 FoldingRule MergeMulMulArithmetic() {
676 return [](IRContext* context, Instruction* inst,
677 const std::vector<const analysis::Constant*>& constants) {
678 assert(inst->opcode() == spv::Op::OpFMul ||
679 inst->opcode() == spv::Op::OpIMul);
680 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
681 const analysis::Type* type =
682 context->get_type_mgr()->GetType(inst->type_id());
683 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
684 return false;
685
686 uint32_t width = ElementWidth(type);
687 if (width != 32 && width != 64) return false;
688
689 // Determine the constant input and the variable input in |inst|.
690 const analysis::Constant* const_input1 = ConstInput(constants);
691 if (!const_input1) return false;
692 Instruction* other_inst = NonConstInput(context, constants[0], inst);
693 if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
694 return false;
695
696 if (other_inst->opcode() == inst->opcode()) {
697 std::vector<const analysis::Constant*> other_constants =
698 const_mgr->GetOperandConstants(other_inst);
699 const analysis::Constant* const_input2 = ConstInput(other_constants);
700 if (!const_input2) return false;
701
702 bool other_first_is_variable = other_constants[0] == nullptr;
703 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
704 const_input1, const_input2);
705 if (merged_id == 0) return false;
706
707 uint32_t non_const_id = other_first_is_variable
708 ? other_inst->GetSingleWordInOperand(0u)
709 : other_inst->GetSingleWordInOperand(1u);
710 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
711 {SPV_OPERAND_TYPE_ID, {merged_id}}});
712 return true;
713 }
714
715 return false;
716 };
717 }
718
719 // Merges divides into subsequent multiplies if each instruction contains one
720 // constant operand. Does not support integer operations.
721 // Cases:
722 // 2 * (x / 2) = x * 1
723 // 2 * (2 / x) = 4 / x
724 // (x / 2) * 2 = x * 1
725 // (2 / x) * 2 = 4 / x
726 // (y / x) * x = y
727 // x * (y / x) = y
MergeMulDivArithmetic()728 FoldingRule MergeMulDivArithmetic() {
729 return [](IRContext* context, Instruction* inst,
730 const std::vector<const analysis::Constant*>& constants) {
731 assert(inst->opcode() == spv::Op::OpFMul);
732 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
733 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
734
735 const analysis::Type* type =
736 context->get_type_mgr()->GetType(inst->type_id());
737 if (!inst->IsFloatingPointFoldingAllowed()) return false;
738
739 uint32_t width = ElementWidth(type);
740 if (width != 32 && width != 64) return false;
741
742 for (uint32_t i = 0; i < 2; i++) {
743 uint32_t op_id = inst->GetSingleWordInOperand(i);
744 Instruction* op_inst = def_use_mgr->GetDef(op_id);
745 if (op_inst->opcode() == spv::Op::OpFDiv) {
746 if (op_inst->GetSingleWordInOperand(1) ==
747 inst->GetSingleWordInOperand(1 - i)) {
748 inst->SetOpcode(spv::Op::OpCopyObject);
749 inst->SetInOperands(
750 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
751 return true;
752 }
753 }
754 }
755
756 const analysis::Constant* const_input1 = ConstInput(constants);
757 if (!const_input1) return false;
758 Instruction* other_inst = NonConstInput(context, constants[0], inst);
759 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
760
761 if (other_inst->opcode() == spv::Op::OpFDiv) {
762 std::vector<const analysis::Constant*> other_constants =
763 const_mgr->GetOperandConstants(other_inst);
764 const analysis::Constant* const_input2 = ConstInput(other_constants);
765 if (!const_input2 || HasZero(const_input2)) return false;
766
767 bool other_first_is_variable = other_constants[0] == nullptr;
768 // If the variable value is the second operand of the divide, multiply
769 // the constants together. Otherwise divide the constants.
770 uint32_t merged_id = PerformOperation(
771 const_mgr,
772 other_first_is_variable ? other_inst->opcode() : inst->opcode(),
773 const_input1, const_input2);
774 if (merged_id == 0) return false;
775
776 uint32_t non_const_id = other_first_is_variable
777 ? other_inst->GetSingleWordInOperand(0u)
778 : other_inst->GetSingleWordInOperand(1u);
779
780 // If the variable value is on the second operand of the div, then this
781 // operation is a div. Otherwise it should be a multiply.
782 inst->SetOpcode(other_first_is_variable ? inst->opcode()
783 : other_inst->opcode());
784 if (other_first_is_variable) {
785 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
786 {SPV_OPERAND_TYPE_ID, {merged_id}}});
787 } else {
788 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
789 {SPV_OPERAND_TYPE_ID, {non_const_id}}});
790 }
791 return true;
792 }
793
794 return false;
795 };
796 }
797
798 // Merges multiply of constant and negation.
799 // Cases:
800 // (-x) * 2 = x * -2
801 // 2 * (-x) = x * -2
MergeMulNegateArithmetic()802 FoldingRule MergeMulNegateArithmetic() {
803 return [](IRContext* context, Instruction* inst,
804 const std::vector<const analysis::Constant*>& constants) {
805 assert(inst->opcode() == spv::Op::OpFMul ||
806 inst->opcode() == spv::Op::OpIMul);
807 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
808 const analysis::Type* type =
809 context->get_type_mgr()->GetType(inst->type_id());
810 bool uses_float = HasFloatingPoint(type);
811 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
812
813 uint32_t width = ElementWidth(type);
814 if (width != 32 && width != 64) return false;
815
816 const analysis::Constant* const_input1 = ConstInput(constants);
817 if (!const_input1) return false;
818 Instruction* other_inst = NonConstInput(context, constants[0], inst);
819 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
820 return false;
821
822 if (other_inst->opcode() == spv::Op::OpFNegate ||
823 other_inst->opcode() == spv::Op::OpSNegate) {
824 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
825
826 inst->SetInOperands(
827 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
828 {SPV_OPERAND_TYPE_ID, {neg_id}}});
829 return true;
830 }
831
832 return false;
833 };
834 }
835
836 // Merges consecutive divides if each instruction contains one constant operand.
837 // Does not support integer division.
838 // Cases:
839 // 2 / (x / 2) = 4 / x
840 // 4 / (2 / x) = 2 * x
841 // (4 / x) / 2 = 2 / x
842 // (x / 2) / 2 = x / 4
MergeDivDivArithmetic()843 FoldingRule MergeDivDivArithmetic() {
844 return [](IRContext* context, Instruction* inst,
845 const std::vector<const analysis::Constant*>& constants) {
846 assert(inst->opcode() == spv::Op::OpFDiv);
847 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
848 const analysis::Type* type =
849 context->get_type_mgr()->GetType(inst->type_id());
850 if (!inst->IsFloatingPointFoldingAllowed()) return false;
851
852 uint32_t width = ElementWidth(type);
853 if (width != 32 && width != 64) return false;
854
855 const analysis::Constant* const_input1 = ConstInput(constants);
856 if (!const_input1 || HasZero(const_input1)) return false;
857 Instruction* other_inst = NonConstInput(context, constants[0], inst);
858 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
859
860 bool first_is_variable = constants[0] == nullptr;
861 if (other_inst->opcode() == inst->opcode()) {
862 std::vector<const analysis::Constant*> other_constants =
863 const_mgr->GetOperandConstants(other_inst);
864 const analysis::Constant* const_input2 = ConstInput(other_constants);
865 if (!const_input2 || HasZero(const_input2)) return false;
866
867 bool other_first_is_variable = other_constants[0] == nullptr;
868
869 spv::Op merge_op = inst->opcode();
870 if (other_first_is_variable) {
871 // Constants magnify.
872 merge_op = spv::Op::OpFMul;
873 }
874
875 // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
876 // because it is commutative.
877 if (first_is_variable) std::swap(const_input1, const_input2);
878 uint32_t merged_id =
879 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
880 if (merged_id == 0) return false;
881
882 uint32_t non_const_id = other_first_is_variable
883 ? other_inst->GetSingleWordInOperand(0u)
884 : other_inst->GetSingleWordInOperand(1u);
885
886 spv::Op op = inst->opcode();
887 if (!first_is_variable && !other_first_is_variable) {
888 // Effectively div of 1/x, so change to multiply.
889 op = spv::Op::OpFMul;
890 }
891
892 uint32_t op1 = merged_id;
893 uint32_t op2 = non_const_id;
894 if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
895 inst->SetOpcode(op);
896 inst->SetInOperands(
897 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
898 return true;
899 }
900
901 return false;
902 };
903 }
904
905 // Fold multiplies succeeded by divides where each instruction contains a
906 // constant operand. Does not support integer divide.
907 // Cases:
908 // 4 / (x * 2) = 2 / x
909 // 4 / (2 * x) = 2 / x
910 // (x * 4) / 2 = x * 2
911 // (4 * x) / 2 = x * 2
912 // (x * y) / x = y
913 // (y * x) / x = y
MergeDivMulArithmetic()914 FoldingRule MergeDivMulArithmetic() {
915 return [](IRContext* context, Instruction* inst,
916 const std::vector<const analysis::Constant*>& constants) {
917 assert(inst->opcode() == spv::Op::OpFDiv);
918 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
919 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
920
921 const analysis::Type* type =
922 context->get_type_mgr()->GetType(inst->type_id());
923 if (!inst->IsFloatingPointFoldingAllowed()) return false;
924
925 uint32_t width = ElementWidth(type);
926 if (width != 32 && width != 64) return false;
927
928 uint32_t op_id = inst->GetSingleWordInOperand(0);
929 Instruction* op_inst = def_use_mgr->GetDef(op_id);
930
931 if (op_inst->opcode() == spv::Op::OpFMul) {
932 for (uint32_t i = 0; i < 2; i++) {
933 if (op_inst->GetSingleWordInOperand(i) ==
934 inst->GetSingleWordInOperand(1)) {
935 inst->SetOpcode(spv::Op::OpCopyObject);
936 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
937 {op_inst->GetSingleWordInOperand(1 - i)}}});
938 return true;
939 }
940 }
941 }
942
943 const analysis::Constant* const_input1 = ConstInput(constants);
944 if (!const_input1 || HasZero(const_input1)) return false;
945 Instruction* other_inst = NonConstInput(context, constants[0], inst);
946 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
947
948 bool first_is_variable = constants[0] == nullptr;
949 if (other_inst->opcode() == spv::Op::OpFMul) {
950 std::vector<const analysis::Constant*> other_constants =
951 const_mgr->GetOperandConstants(other_inst);
952 const analysis::Constant* const_input2 = ConstInput(other_constants);
953 if (!const_input2) return false;
954
955 bool other_first_is_variable = other_constants[0] == nullptr;
956
957 // This is an x / (*) case. Swap the inputs.
958 if (first_is_variable) std::swap(const_input1, const_input2);
959 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
960 const_input1, const_input2);
961 if (merged_id == 0) return false;
962
963 uint32_t non_const_id = other_first_is_variable
964 ? other_inst->GetSingleWordInOperand(0u)
965 : other_inst->GetSingleWordInOperand(1u);
966
967 uint32_t op1 = merged_id;
968 uint32_t op2 = non_const_id;
969 if (first_is_variable) std::swap(op1, op2);
970
971 // Convert to multiply
972 if (first_is_variable) inst->SetOpcode(other_inst->opcode());
973 inst->SetInOperands(
974 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
975 return true;
976 }
977
978 return false;
979 };
980 }
981
982 // Fold divides of a constant and a negation.
983 // Cases:
984 // (-x) / 2 = x / -2
985 // 2 / (-x) = -2 / x
MergeDivNegateArithmetic()986 FoldingRule MergeDivNegateArithmetic() {
987 return [](IRContext* context, Instruction* inst,
988 const std::vector<const analysis::Constant*>& constants) {
989 assert(inst->opcode() == spv::Op::OpFDiv);
990 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
991 if (!inst->IsFloatingPointFoldingAllowed()) return false;
992
993 const analysis::Constant* const_input1 = ConstInput(constants);
994 if (!const_input1) return false;
995 Instruction* other_inst = NonConstInput(context, constants[0], inst);
996 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
997
998 bool first_is_variable = constants[0] == nullptr;
999 if (other_inst->opcode() == spv::Op::OpFNegate) {
1000 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
1001
1002 if (first_is_variable) {
1003 inst->SetInOperands(
1004 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
1005 {SPV_OPERAND_TYPE_ID, {neg_id}}});
1006 } else {
1007 inst->SetInOperands(
1008 {{SPV_OPERAND_TYPE_ID, {neg_id}},
1009 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1010 }
1011 return true;
1012 }
1013
1014 return false;
1015 };
1016 }
1017
1018 // Folds addition of a constant and a negation.
1019 // Cases:
1020 // (-x) + 2 = 2 - x
1021 // 2 + (-x) = 2 - x
MergeAddNegateArithmetic()1022 FoldingRule MergeAddNegateArithmetic() {
1023 return [](IRContext* context, Instruction* inst,
1024 const std::vector<const analysis::Constant*>& constants) {
1025 assert(inst->opcode() == spv::Op::OpFAdd ||
1026 inst->opcode() == spv::Op::OpIAdd);
1027 const analysis::Type* type =
1028 context->get_type_mgr()->GetType(inst->type_id());
1029 bool uses_float = HasFloatingPoint(type);
1030 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1031
1032 const analysis::Constant* const_input1 = ConstInput(constants);
1033 if (!const_input1) return false;
1034 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1035 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1036 return false;
1037
1038 if (other_inst->opcode() == spv::Op::OpSNegate ||
1039 other_inst->opcode() == spv::Op::OpFNegate) {
1040 inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
1041 : spv::Op::OpISub);
1042 uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
1043 : inst->GetSingleWordInOperand(1u);
1044 inst->SetInOperands(
1045 {{SPV_OPERAND_TYPE_ID, {const_id}},
1046 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1047 return true;
1048 }
1049 return false;
1050 };
1051 }
1052
1053 // Folds subtraction of a constant and a negation.
1054 // Cases:
1055 // (-x) - 2 = -2 - x
1056 // 2 - (-x) = x + 2
MergeSubNegateArithmetic()1057 FoldingRule MergeSubNegateArithmetic() {
1058 return [](IRContext* context, Instruction* inst,
1059 const std::vector<const analysis::Constant*>& constants) {
1060 assert(inst->opcode() == spv::Op::OpFSub ||
1061 inst->opcode() == spv::Op::OpISub);
1062 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1063 const analysis::Type* type =
1064 context->get_type_mgr()->GetType(inst->type_id());
1065 bool uses_float = HasFloatingPoint(type);
1066 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1067
1068 uint32_t width = ElementWidth(type);
1069 if (width != 32 && width != 64) return false;
1070
1071 const analysis::Constant* const_input1 = ConstInput(constants);
1072 if (!const_input1) return false;
1073 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1074 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1075 return false;
1076
1077 if (other_inst->opcode() == spv::Op::OpSNegate ||
1078 other_inst->opcode() == spv::Op::OpFNegate) {
1079 uint32_t op1 = 0;
1080 uint32_t op2 = 0;
1081 spv::Op opcode = inst->opcode();
1082 if (constants[0] != nullptr) {
1083 op1 = other_inst->GetSingleWordInOperand(0u);
1084 op2 = inst->GetSingleWordInOperand(0u);
1085 opcode = HasFloatingPoint(type) ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1086 } else {
1087 op1 = NegateConstant(const_mgr, const_input1);
1088 op2 = other_inst->GetSingleWordInOperand(0u);
1089 }
1090
1091 inst->SetOpcode(opcode);
1092 inst->SetInOperands(
1093 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1094 return true;
1095 }
1096 return false;
1097 };
1098 }
1099
1100 // Folds addition of an addition where each operation has a constant operand.
1101 // Cases:
1102 // (x + 2) + 2 = x + 4
1103 // (2 + x) + 2 = x + 4
1104 // 2 + (x + 2) = x + 4
1105 // 2 + (2 + x) = x + 4
MergeAddAddArithmetic()1106 FoldingRule MergeAddAddArithmetic() {
1107 return [](IRContext* context, Instruction* inst,
1108 const std::vector<const analysis::Constant*>& constants) {
1109 assert(inst->opcode() == spv::Op::OpFAdd ||
1110 inst->opcode() == spv::Op::OpIAdd);
1111 const analysis::Type* type =
1112 context->get_type_mgr()->GetType(inst->type_id());
1113 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1114 bool uses_float = HasFloatingPoint(type);
1115 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1116
1117 uint32_t width = ElementWidth(type);
1118 if (width != 32 && width != 64) return false;
1119
1120 const analysis::Constant* const_input1 = ConstInput(constants);
1121 if (!const_input1) return false;
1122 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1123 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1124 return false;
1125
1126 if (other_inst->opcode() == spv::Op::OpFAdd ||
1127 other_inst->opcode() == spv::Op::OpIAdd) {
1128 std::vector<const analysis::Constant*> other_constants =
1129 const_mgr->GetOperandConstants(other_inst);
1130 const analysis::Constant* const_input2 = ConstInput(other_constants);
1131 if (!const_input2) return false;
1132
1133 Instruction* non_const_input =
1134 NonConstInput(context, other_constants[0], other_inst);
1135 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1136 const_input1, const_input2);
1137 if (merged_id == 0) return false;
1138
1139 inst->SetInOperands(
1140 {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1141 {SPV_OPERAND_TYPE_ID, {merged_id}}});
1142 return true;
1143 }
1144 return false;
1145 };
1146 }
1147
1148 // Folds addition of a subtraction where each operation has a constant operand.
1149 // Cases:
1150 // (x - 2) + 2 = x + 0
1151 // (2 - x) + 2 = 4 - x
1152 // 2 + (x - 2) = x + 0
1153 // 2 + (2 - x) = 4 - x
MergeAddSubArithmetic()1154 FoldingRule MergeAddSubArithmetic() {
1155 return [](IRContext* context, Instruction* inst,
1156 const std::vector<const analysis::Constant*>& constants) {
1157 assert(inst->opcode() == spv::Op::OpFAdd ||
1158 inst->opcode() == spv::Op::OpIAdd);
1159 const analysis::Type* type =
1160 context->get_type_mgr()->GetType(inst->type_id());
1161 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1162 bool uses_float = HasFloatingPoint(type);
1163 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1164
1165 uint32_t width = ElementWidth(type);
1166 if (width != 32 && width != 64) return false;
1167
1168 const analysis::Constant* const_input1 = ConstInput(constants);
1169 if (!const_input1) return false;
1170 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1171 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1172 return false;
1173
1174 if (other_inst->opcode() == spv::Op::OpFSub ||
1175 other_inst->opcode() == spv::Op::OpISub) {
1176 std::vector<const analysis::Constant*> other_constants =
1177 const_mgr->GetOperandConstants(other_inst);
1178 const analysis::Constant* const_input2 = ConstInput(other_constants);
1179 if (!const_input2) return false;
1180
1181 bool first_is_variable = other_constants[0] == nullptr;
1182 spv::Op op = inst->opcode();
1183 uint32_t op1 = 0;
1184 uint32_t op2 = 0;
1185 if (first_is_variable) {
1186 // Subtract constants. Non-constant operand is first.
1187 op1 = other_inst->GetSingleWordInOperand(0u);
1188 op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1189 const_input2);
1190 } else {
1191 // Add constants. Constant operand is first. Change the opcode.
1192 op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1193 const_input2);
1194 op2 = other_inst->GetSingleWordInOperand(1u);
1195 op = other_inst->opcode();
1196 }
1197 if (op1 == 0 || op2 == 0) return false;
1198
1199 inst->SetOpcode(op);
1200 inst->SetInOperands(
1201 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1202 return true;
1203 }
1204 return false;
1205 };
1206 }
1207
1208 // Folds subtraction of an addition where each operand has a constant operand.
1209 // Cases:
1210 // (x + 2) - 2 = x + 0
1211 // (2 + x) - 2 = x + 0
1212 // 2 - (x + 2) = 0 - x
1213 // 2 - (2 + x) = 0 - x
MergeSubAddArithmetic()1214 FoldingRule MergeSubAddArithmetic() {
1215 return [](IRContext* context, Instruction* inst,
1216 const std::vector<const analysis::Constant*>& constants) {
1217 assert(inst->opcode() == spv::Op::OpFSub ||
1218 inst->opcode() == spv::Op::OpISub);
1219 const analysis::Type* type =
1220 context->get_type_mgr()->GetType(inst->type_id());
1221 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1222 bool uses_float = HasFloatingPoint(type);
1223 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1224
1225 uint32_t width = ElementWidth(type);
1226 if (width != 32 && width != 64) return false;
1227
1228 const analysis::Constant* const_input1 = ConstInput(constants);
1229 if (!const_input1) return false;
1230 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1231 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1232 return false;
1233
1234 if (other_inst->opcode() == spv::Op::OpFAdd ||
1235 other_inst->opcode() == spv::Op::OpIAdd) {
1236 std::vector<const analysis::Constant*> other_constants =
1237 const_mgr->GetOperandConstants(other_inst);
1238 const analysis::Constant* const_input2 = ConstInput(other_constants);
1239 if (!const_input2) return false;
1240
1241 Instruction* non_const_input =
1242 NonConstInput(context, other_constants[0], other_inst);
1243
1244 // If the first operand of the sub is not a constant, swap the constants
1245 // so the subtraction has the correct operands.
1246 if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1247 // Subtract the constants.
1248 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1249 const_input1, const_input2);
1250 spv::Op op = inst->opcode();
1251 uint32_t op1 = 0;
1252 uint32_t op2 = 0;
1253 if (constants[0] == nullptr) {
1254 // Non-constant operand is first. Change the opcode.
1255 op1 = non_const_input->result_id();
1256 op2 = merged_id;
1257 op = other_inst->opcode();
1258 } else {
1259 // Constant operand is first.
1260 op1 = merged_id;
1261 op2 = non_const_input->result_id();
1262 }
1263 if (op1 == 0 || op2 == 0) return false;
1264
1265 inst->SetOpcode(op);
1266 inst->SetInOperands(
1267 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1268 return true;
1269 }
1270 return false;
1271 };
1272 }
1273
1274 // Folds subtraction of a subtraction where each operand has a constant operand.
1275 // Cases:
1276 // (x - 2) - 2 = x - 4
1277 // (2 - x) - 2 = 0 - x
1278 // 2 - (x - 2) = 4 - x
1279 // 2 - (2 - x) = x + 0
MergeSubSubArithmetic()1280 FoldingRule MergeSubSubArithmetic() {
1281 return [](IRContext* context, Instruction* inst,
1282 const std::vector<const analysis::Constant*>& constants) {
1283 assert(inst->opcode() == spv::Op::OpFSub ||
1284 inst->opcode() == spv::Op::OpISub);
1285 const analysis::Type* type =
1286 context->get_type_mgr()->GetType(inst->type_id());
1287 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1288 bool uses_float = HasFloatingPoint(type);
1289 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1290
1291 uint32_t width = ElementWidth(type);
1292 if (width != 32 && width != 64) return false;
1293
1294 const analysis::Constant* const_input1 = ConstInput(constants);
1295 if (!const_input1) return false;
1296 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1297 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1298 return false;
1299
1300 if (other_inst->opcode() == spv::Op::OpFSub ||
1301 other_inst->opcode() == spv::Op::OpISub) {
1302 std::vector<const analysis::Constant*> other_constants =
1303 const_mgr->GetOperandConstants(other_inst);
1304 const analysis::Constant* const_input2 = ConstInput(other_constants);
1305 if (!const_input2) return false;
1306
1307 Instruction* non_const_input =
1308 NonConstInput(context, other_constants[0], other_inst);
1309
1310 // Merge the constants.
1311 uint32_t merged_id = 0;
1312 spv::Op merge_op = inst->opcode();
1313 if (other_constants[0] == nullptr) {
1314 merge_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1315 } else if (constants[0] == nullptr) {
1316 std::swap(const_input1, const_input2);
1317 }
1318 merged_id =
1319 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1320 if (merged_id == 0) return false;
1321
1322 spv::Op op = inst->opcode();
1323 if (constants[0] != nullptr && other_constants[0] != nullptr) {
1324 // Change the operation.
1325 op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1326 }
1327
1328 uint32_t op1 = 0;
1329 uint32_t op2 = 0;
1330 if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1331 op1 = merged_id;
1332 op2 = non_const_input->result_id();
1333 } else {
1334 op1 = non_const_input->result_id();
1335 op2 = merged_id;
1336 }
1337
1338 inst->SetOpcode(op);
1339 inst->SetInOperands(
1340 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1341 return true;
1342 }
1343 return false;
1344 };
1345 }
1346
1347 // Helper function for MergeGenericAddSubArithmetic. If |addend| and
1348 // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
MergeGenericAddendSub(uint32_t addend,uint32_t sub,Instruction * inst)1349 bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1350 IRContext* context = inst->context();
1351 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1352 Instruction* sub_inst = def_use_mgr->GetDef(sub);
1353 if (sub_inst->opcode() != spv::Op::OpFSub &&
1354 sub_inst->opcode() != spv::Op::OpISub)
1355 return false;
1356 if (sub_inst->opcode() == spv::Op::OpFSub &&
1357 !sub_inst->IsFloatingPointFoldingAllowed())
1358 return false;
1359 if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1360 inst->SetOpcode(spv::Op::OpCopyObject);
1361 inst->SetInOperands(
1362 {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1363 context->UpdateDefUse(inst);
1364 return true;
1365 }
1366
1367 // Folds addition of a subtraction where the subtrahend is equal to the
1368 // other addend. Return a copy of the minuend. Accepts generic (const and
1369 // non-const) operands.
1370 // Cases:
1371 // (a - b) + b = a
1372 // b + (a - b) = a
MergeGenericAddSubArithmetic()1373 FoldingRule MergeGenericAddSubArithmetic() {
1374 return [](IRContext* context, Instruction* inst,
1375 const std::vector<const analysis::Constant*>&) {
1376 assert(inst->opcode() == spv::Op::OpFAdd ||
1377 inst->opcode() == spv::Op::OpIAdd);
1378 const analysis::Type* type =
1379 context->get_type_mgr()->GetType(inst->type_id());
1380 bool uses_float = HasFloatingPoint(type);
1381 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1382
1383 uint32_t width = ElementWidth(type);
1384 if (width != 32 && width != 64) return false;
1385
1386 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1387 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1388 if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1389 return MergeGenericAddendSub(add_op1, add_op0, inst);
1390 };
1391 }
1392
1393 // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1394 // generate |factor0_0| * (|factor0_1| + |factor1_1|).
FactorAddMulsOpnds(uint32_t factor0_0,uint32_t factor0_1,uint32_t factor1_0,uint32_t factor1_1,Instruction * inst)1395 bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1396 uint32_t factor1_0, uint32_t factor1_1,
1397 Instruction* inst) {
1398 IRContext* context = inst->context();
1399 if (factor0_0 != factor1_0) return false;
1400 InstructionBuilder ir_builder(
1401 context, inst,
1402 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1403 Instruction* new_add_inst = ir_builder.AddBinaryOp(
1404 inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1405 inst->SetOpcode(inst->opcode() == spv::Op::OpFAdd ? spv::Op::OpFMul
1406 : spv::Op::OpIMul);
1407 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1408 {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1409 context->UpdateDefUse(inst);
1410 return true;
1411 }
1412
1413 // Perform the following factoring identity, handling all operand order
1414 // combinations: (a * b) + (a * c) = a * (b + c)
FactorAddMuls()1415 FoldingRule FactorAddMuls() {
1416 return [](IRContext* context, Instruction* inst,
1417 const std::vector<const analysis::Constant*>&) {
1418 assert(inst->opcode() == spv::Op::OpFAdd ||
1419 inst->opcode() == spv::Op::OpIAdd);
1420 const analysis::Type* type =
1421 context->get_type_mgr()->GetType(inst->type_id());
1422 bool uses_float = HasFloatingPoint(type);
1423 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1424
1425 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1426 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1427 Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1428 if (add_op0_inst->opcode() != spv::Op::OpFMul &&
1429 add_op0_inst->opcode() != spv::Op::OpIMul)
1430 return false;
1431 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1432 Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1433 if (add_op1_inst->opcode() != spv::Op::OpFMul &&
1434 add_op1_inst->opcode() != spv::Op::OpIMul)
1435 return false;
1436
1437 // Only perform this optimization if both of the muls only have one use.
1438 // Otherwise this is a deoptimization in size and performance.
1439 if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1440 if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1441
1442 if (add_op0_inst->opcode() == spv::Op::OpFMul &&
1443 (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1444 !add_op1_inst->IsFloatingPointFoldingAllowed()))
1445 return false;
1446
1447 for (int i = 0; i < 2; i++) {
1448 for (int j = 0; j < 2; j++) {
1449 // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1450 if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1451 add_op0_inst->GetSingleWordInOperand(1 - i),
1452 add_op1_inst->GetSingleWordInOperand(j),
1453 add_op1_inst->GetSingleWordInOperand(1 - j),
1454 inst))
1455 return true;
1456 }
1457 }
1458 return false;
1459 };
1460 }
1461
1462 // Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
ReplaceWithFma(Instruction * inst,uint32_t x,uint32_t y,uint32_t a)1463 void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
1464 uint32_t ext =
1465 inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1466
1467 if (ext == 0) {
1468 inst->context()->AddExtInstImport("GLSL.std.450");
1469 ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1470 assert(ext != 0 &&
1471 "Could not add the GLSL.std.450 extended instruction set");
1472 }
1473
1474 std::vector<Operand> operands;
1475 operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
1476 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
1477 operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
1478 operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
1479 operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
1480
1481 inst->SetOpcode(spv::Op::OpExtInst);
1482 inst->SetInOperands(std::move(operands));
1483 }
1484
1485 // Folds a multiple and add into an Fma.
1486 //
1487 // Cases:
1488 // (x * y) + a = Fma x y a
1489 // a + (x * y) = Fma x y a
MergeMulAddArithmetic(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1490 bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
1491 const std::vector<const analysis::Constant*>&) {
1492 assert(inst->opcode() == spv::Op::OpFAdd);
1493
1494 if (!inst->IsFloatingPointFoldingAllowed()) {
1495 return false;
1496 }
1497
1498 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1499 for (int i = 0; i < 2; i++) {
1500 uint32_t op_id = inst->GetSingleWordInOperand(i);
1501 Instruction* op_inst = def_use_mgr->GetDef(op_id);
1502
1503 if (op_inst->opcode() != spv::Op::OpFMul) {
1504 continue;
1505 }
1506
1507 if (!op_inst->IsFloatingPointFoldingAllowed()) {
1508 continue;
1509 }
1510
1511 uint32_t x = op_inst->GetSingleWordInOperand(0);
1512 uint32_t y = op_inst->GetSingleWordInOperand(1);
1513 uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
1514 ReplaceWithFma(inst, x, y, a);
1515 return true;
1516 }
1517 return false;
1518 }
1519
1520 // Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
1521 // negated if |negate_addition| is true, otherwise |x| gets negated.
ReplaceWithFmaAndNegate(Instruction * sub,uint32_t x,uint32_t y,uint32_t a,bool negate_addition)1522 void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
1523 uint32_t a, bool negate_addition) {
1524 uint32_t ext =
1525 sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1526
1527 if (ext == 0) {
1528 sub->context()->AddExtInstImport("GLSL.std.450");
1529 ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1530 assert(ext != 0 &&
1531 "Could not add the GLSL.std.450 extended instruction set");
1532 }
1533
1534 InstructionBuilder ir_builder(
1535 sub->context(), sub,
1536 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1537
1538 Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), spv::Op::OpFNegate,
1539 negate_addition ? a : x);
1540 uint32_t neg_op = neg->result_id(); // -a : -x
1541
1542 std::vector<Operand> operands;
1543 operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
1544 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
1545 operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
1546 operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
1547 operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
1548
1549 sub->SetOpcode(spv::Op::OpExtInst);
1550 sub->SetInOperands(std::move(operands));
1551 }
1552
1553 // Folds a multiply and subtract into an Fma and negation.
1554 //
1555 // Cases:
1556 // (x * y) - a = Fma x y -a
1557 // a - (x * y) = Fma -x y a
MergeMulSubArithmetic(IRContext * context,Instruction * sub,const std::vector<const analysis::Constant * > &)1558 bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
1559 const std::vector<const analysis::Constant*>&) {
1560 assert(sub->opcode() == spv::Op::OpFSub);
1561
1562 if (!sub->IsFloatingPointFoldingAllowed()) {
1563 return false;
1564 }
1565
1566 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1567 for (int i = 0; i < 2; i++) {
1568 uint32_t op_id = sub->GetSingleWordInOperand(i);
1569 Instruction* mul = def_use_mgr->GetDef(op_id);
1570
1571 if (mul->opcode() != spv::Op::OpFMul) {
1572 continue;
1573 }
1574
1575 if (!mul->IsFloatingPointFoldingAllowed()) {
1576 continue;
1577 }
1578
1579 uint32_t x = mul->GetSingleWordInOperand(0);
1580 uint32_t y = mul->GetSingleWordInOperand(1);
1581 uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
1582 ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
1583 return true;
1584 }
1585 return false;
1586 }
1587
IntMultipleBy1()1588 FoldingRule IntMultipleBy1() {
1589 return [](IRContext*, Instruction* inst,
1590 const std::vector<const analysis::Constant*>& constants) {
1591 assert(inst->opcode() == spv::Op::OpIMul &&
1592 "Wrong opcode. Should be OpIMul.");
1593 for (uint32_t i = 0; i < 2; i++) {
1594 if (constants[i] == nullptr) {
1595 continue;
1596 }
1597 const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1598 if (int_constant) {
1599 uint32_t width = ElementWidth(int_constant->type());
1600 if (width != 32 && width != 64) return false;
1601 bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1602 : int_constant->GetU64BitValue() == 1ull;
1603 if (is_one) {
1604 inst->SetOpcode(spv::Op::OpCopyObject);
1605 inst->SetInOperands(
1606 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1607 return true;
1608 }
1609 }
1610 }
1611 return false;
1612 };
1613 }
1614
1615 // Returns the number of elements that the |index|th in operand in |inst|
1616 // contributes to the result of |inst|. |inst| must be an
1617 // OpCompositeConstructInstruction.
GetNumOfElementsContributedByOperand(IRContext * context,const Instruction * inst,uint32_t index)1618 uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
1619 const Instruction* inst,
1620 uint32_t index) {
1621 assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1622 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1623 analysis::TypeManager* type_mgr = context->get_type_mgr();
1624
1625 analysis::Vector* result_type =
1626 type_mgr->GetType(inst->type_id())->AsVector();
1627 if (result_type == nullptr) {
1628 // If the result of the OpCompositeConstruct is not a vector then every
1629 // operands corresponds to a single element in the result.
1630 return 1;
1631 }
1632
1633 // If the result type is a vector then the operands are either scalars or
1634 // vectors. If it is a scalar, then it corresponds to a single element. If it
1635 // is a vector, then each element in the vector will be an element in the
1636 // result.
1637 uint32_t id = inst->GetSingleWordInOperand(index);
1638 Instruction* def = def_use_mgr->GetDef(id);
1639 analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
1640 if (type == nullptr) {
1641 return 1;
1642 }
1643 return type->element_count();
1644 }
1645
1646 // Returns the in-operands for an OpCompositeExtract instruction that are needed
1647 // to extract the |result_index|th element in the result of |inst| without using
1648 // the result of |inst|. Returns the empty vector if |result_index| is
1649 // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
GetExtractOperandsForElementOfCompositeConstruct(IRContext * context,const Instruction * inst,uint32_t result_index)1650 std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
1651 IRContext* context, const Instruction* inst, uint32_t result_index) {
1652 assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1653 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1654 analysis::TypeManager* type_mgr = context->get_type_mgr();
1655
1656 analysis::Type* result_type = type_mgr->GetType(inst->type_id());
1657 if (result_type->AsVector() == nullptr) {
1658 if (result_index < inst->NumInOperands()) {
1659 uint32_t id = inst->GetSingleWordInOperand(result_index);
1660 return {Operand(SPV_OPERAND_TYPE_ID, {id})};
1661 }
1662 return {};
1663 }
1664
1665 // If the result type is a vector, then vector operands are concatenated.
1666 uint32_t total_element_count = 0;
1667 for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
1668 uint32_t element_count =
1669 GetNumOfElementsContributedByOperand(context, inst, idx);
1670 total_element_count += element_count;
1671 if (result_index < total_element_count) {
1672 std::vector<Operand> operands;
1673 uint32_t id = inst->GetSingleWordInOperand(idx);
1674 Instruction* operand_def = def_use_mgr->GetDef(id);
1675 analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
1676
1677 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1678 if (operand_type->AsVector()) {
1679 uint32_t start_index_of_id = total_element_count - element_count;
1680 uint32_t index_into_id = result_index - start_index_of_id;
1681 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
1682 }
1683 return operands;
1684 }
1685 }
1686 return {};
1687 }
1688
CompositeConstructFeedingExtract(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1689 bool CompositeConstructFeedingExtract(
1690 IRContext* context, Instruction* inst,
1691 const std::vector<const analysis::Constant*>&) {
1692 // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1693 // then we can simply use the appropriate element in the construction.
1694 assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1695 "Wrong opcode. Should be OpCompositeExtract.");
1696 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1697
1698 // If there are no index operands, then this rule cannot do anything.
1699 if (inst->NumInOperands() <= 1) {
1700 return false;
1701 }
1702
1703 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1704 Instruction* cinst = def_use_mgr->GetDef(cid);
1705
1706 if (cinst->opcode() != spv::Op::OpCompositeConstruct) {
1707 return false;
1708 }
1709
1710 uint32_t index_into_result = inst->GetSingleWordInOperand(1);
1711 std::vector<Operand> operands =
1712 GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
1713 index_into_result);
1714
1715 if (operands.empty()) {
1716 return false;
1717 }
1718
1719 // Add the remaining indices for extraction.
1720 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1721 operands.push_back(
1722 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
1723 }
1724
1725 if (operands.size() == 1) {
1726 // If there were no extra indices, then we have the final object. No need
1727 // to extract any more.
1728 inst->SetOpcode(spv::Op::OpCopyObject);
1729 }
1730
1731 inst->SetInOperands(std::move(operands));
1732 return true;
1733 }
1734
1735 // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
1736 // OpCompositeExtract instruction, and returns the type of the final element
1737 // being accessed.
GetElementType(uint32_t type_id,Instruction::iterator start,Instruction::iterator end,const analysis::TypeManager * type_mgr)1738 const analysis::Type* GetElementType(uint32_t type_id,
1739 Instruction::iterator start,
1740 Instruction::iterator end,
1741 const analysis::TypeManager* type_mgr) {
1742 const analysis::Type* type = type_mgr->GetType(type_id);
1743 for (auto index : make_range(std::move(start), std::move(end))) {
1744 assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
1745 index.words.size() == 1);
1746 if (auto* array_type = type->AsArray()) {
1747 type = array_type->element_type();
1748 } else if (auto* matrix_type = type->AsMatrix()) {
1749 type = matrix_type->element_type();
1750 } else if (auto* struct_type = type->AsStruct()) {
1751 type = struct_type->element_types()[index.words[0]];
1752 } else {
1753 type = nullptr;
1754 }
1755 }
1756 return type;
1757 }
1758
1759 // Returns true of |inst_1| and |inst_2| have the same indexes that will be used
1760 // to index into a composite object, excluding the last index. The two
1761 // instructions must have the same opcode, and be either OpCompositeExtract or
1762 // OpCompositeInsert instructions.
HaveSameIndexesExceptForLast(Instruction * inst_1,Instruction * inst_2)1763 bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
1764 assert(inst_1->opcode() == inst_2->opcode() &&
1765 "Expecting the opcodes to be the same.");
1766 assert((inst_1->opcode() == spv::Op::OpCompositeInsert ||
1767 inst_1->opcode() == spv::Op::OpCompositeExtract) &&
1768 "Instructions must be OpCompositeInsert or OpCompositeExtract.");
1769
1770 if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
1771 return false;
1772 }
1773
1774 uint32_t first_index_position =
1775 (inst_1->opcode() == spv::Op::OpCompositeInsert ? 2 : 1);
1776 for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
1777 i++) {
1778 if (inst_1->GetSingleWordInOperand(i) !=
1779 inst_2->GetSingleWordInOperand(i)) {
1780 return false;
1781 }
1782 }
1783 return true;
1784 }
1785
1786 // If the OpCompositeConstruct is simply putting back together elements that
1787 // where extracted from the same source, we can simply reuse the source.
1788 //
1789 // This is a common code pattern because of the way that scalar replacement
1790 // works.
CompositeExtractFeedingConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)1791 bool CompositeExtractFeedingConstruct(
1792 IRContext* context, Instruction* inst,
1793 const std::vector<const analysis::Constant*>&) {
1794 assert(inst->opcode() == spv::Op::OpCompositeConstruct &&
1795 "Wrong opcode. Should be OpCompositeConstruct.");
1796 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1797 uint32_t original_id = 0;
1798
1799 if (inst->NumInOperands() == 0) {
1800 // The struct being constructed has no members.
1801 return false;
1802 }
1803
1804 // Check each element to make sure they are:
1805 // - extractions
1806 // - extracting the same position they are inserting
1807 // - all extract from the same id.
1808 Instruction* first_element_inst = nullptr;
1809 for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1810 const uint32_t element_id = inst->GetSingleWordInOperand(i);
1811 Instruction* element_inst = def_use_mgr->GetDef(element_id);
1812 if (first_element_inst == nullptr) {
1813 first_element_inst = element_inst;
1814 }
1815
1816 if (element_inst->opcode() != spv::Op::OpCompositeExtract) {
1817 return false;
1818 }
1819
1820 if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
1821 return false;
1822 }
1823
1824 if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
1825 1) != i) {
1826 return false;
1827 }
1828
1829 if (i == 0) {
1830 original_id =
1831 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1832 } else if (original_id !=
1833 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1834 return false;
1835 }
1836 }
1837
1838 // The last check it to see that the object being extracted from is the
1839 // correct type.
1840 Instruction* original_inst = def_use_mgr->GetDef(original_id);
1841 analysis::TypeManager* type_mgr = context->get_type_mgr();
1842 const analysis::Type* original_type =
1843 GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
1844 first_element_inst->end() - 1, type_mgr);
1845
1846 if (original_type == nullptr) {
1847 return false;
1848 }
1849
1850 if (inst->type_id() != type_mgr->GetId(original_type)) {
1851 return false;
1852 }
1853
1854 if (first_element_inst->NumInOperands() == 2) {
1855 // Simplify by using the original object.
1856 inst->SetOpcode(spv::Op::OpCopyObject);
1857 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1858 return true;
1859 }
1860
1861 // Copies the original id and all indexes except for the last to the new
1862 // extract instruction.
1863 inst->SetOpcode(spv::Op::OpCompositeExtract);
1864 inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
1865 first_element_inst->end() - 1));
1866 return true;
1867 }
1868
InsertFeedingExtract()1869 FoldingRule InsertFeedingExtract() {
1870 return [](IRContext* context, Instruction* inst,
1871 const std::vector<const analysis::Constant*>&) {
1872 assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1873 "Wrong opcode. Should be OpCompositeExtract.");
1874 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1875 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1876 Instruction* cinst = def_use_mgr->GetDef(cid);
1877
1878 if (cinst->opcode() != spv::Op::OpCompositeInsert) {
1879 return false;
1880 }
1881
1882 // Find the first position where the list of insert and extract indicies
1883 // differ, if at all.
1884 uint32_t i;
1885 for (i = 1; i < inst->NumInOperands(); ++i) {
1886 if (i + 1 >= cinst->NumInOperands()) {
1887 break;
1888 }
1889
1890 if (inst->GetSingleWordInOperand(i) !=
1891 cinst->GetSingleWordInOperand(i + 1)) {
1892 break;
1893 }
1894 }
1895
1896 // We are extracting the element that was inserted.
1897 if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1898 inst->SetOpcode(spv::Op::OpCopyObject);
1899 inst->SetInOperands(
1900 {{SPV_OPERAND_TYPE_ID,
1901 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1902 return true;
1903 }
1904
1905 // Extracting the value that was inserted along with values for the base
1906 // composite. Cannot do anything.
1907 if (i == inst->NumInOperands()) {
1908 return false;
1909 }
1910
1911 // Extracting an element of the value that was inserted. Extract from
1912 // that value directly.
1913 if (i + 1 == cinst->NumInOperands()) {
1914 std::vector<Operand> operands;
1915 operands.push_back(
1916 {SPV_OPERAND_TYPE_ID,
1917 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1918 for (; i < inst->NumInOperands(); ++i) {
1919 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1920 {inst->GetSingleWordInOperand(i)}});
1921 }
1922 inst->SetInOperands(std::move(operands));
1923 return true;
1924 }
1925
1926 // Extracting a value that is disjoint from the element being inserted.
1927 // Rewrite the extract to use the composite input to the insert.
1928 std::vector<Operand> operands;
1929 operands.push_back(
1930 {SPV_OPERAND_TYPE_ID,
1931 {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1932 for (i = 1; i < inst->NumInOperands(); ++i) {
1933 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1934 {inst->GetSingleWordInOperand(i)}});
1935 }
1936 inst->SetInOperands(std::move(operands));
1937 return true;
1938 };
1939 }
1940
1941 // When a VectorShuffle is feeding an Extract, we can extract from one of the
1942 // operands of the VectorShuffle. We just need to adjust the index in the
1943 // extract instruction.
VectorShuffleFeedingExtract()1944 FoldingRule VectorShuffleFeedingExtract() {
1945 return [](IRContext* context, Instruction* inst,
1946 const std::vector<const analysis::Constant*>&) {
1947 assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1948 "Wrong opcode. Should be OpCompositeExtract.");
1949 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1950 analysis::TypeManager* type_mgr = context->get_type_mgr();
1951 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1952 Instruction* cinst = def_use_mgr->GetDef(cid);
1953
1954 if (cinst->opcode() != spv::Op::OpVectorShuffle) {
1955 return false;
1956 }
1957
1958 // Find the size of the first vector operand of the VectorShuffle
1959 Instruction* first_input =
1960 def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1961 analysis::Type* first_input_type =
1962 type_mgr->GetType(first_input->type_id());
1963 assert(first_input_type->AsVector() &&
1964 "Input to vector shuffle should be vectors.");
1965 uint32_t first_input_size = first_input_type->AsVector()->element_count();
1966
1967 // Get index of the element the vector shuffle is placing in the position
1968 // being extracted.
1969 uint32_t new_index =
1970 cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1971
1972 // Extracting an undefined value so fold this extract into an undef.
1973 const uint32_t undef_literal_value = 0xffffffff;
1974 if (new_index == undef_literal_value) {
1975 inst->SetOpcode(spv::Op::OpUndef);
1976 inst->SetInOperands({});
1977 return true;
1978 }
1979
1980 // Get the id of the of the vector the elemtent comes from, and update the
1981 // index if needed.
1982 uint32_t new_vector = 0;
1983 if (new_index < first_input_size) {
1984 new_vector = cinst->GetSingleWordInOperand(0);
1985 } else {
1986 new_vector = cinst->GetSingleWordInOperand(1);
1987 new_index -= first_input_size;
1988 }
1989
1990 // Update the extract instruction.
1991 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1992 inst->SetInOperand(1, {new_index});
1993 return true;
1994 };
1995 }
1996
1997 // When an FMix with is feeding an Extract that extracts an element whose
1998 // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1999 // operands of the FMix.
FMixFeedingExtract()2000 FoldingRule FMixFeedingExtract() {
2001 return [](IRContext* context, Instruction* inst,
2002 const std::vector<const analysis::Constant*>&) {
2003 assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2004 "Wrong opcode. Should be OpCompositeExtract.");
2005 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2006 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2007
2008 uint32_t composite_id =
2009 inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2010 Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
2011
2012 if (composite_inst->opcode() != spv::Op::OpExtInst) {
2013 return false;
2014 }
2015
2016 uint32_t inst_set_id =
2017 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2018
2019 if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
2020 inst_set_id ||
2021 composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
2022 GLSLstd450FMix) {
2023 return false;
2024 }
2025
2026 // Get the |a| for the FMix instruction.
2027 uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
2028 std::unique_ptr<Instruction> a(inst->Clone(context));
2029 a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
2030 context->get_instruction_folder().FoldInstruction(a.get());
2031
2032 if (a->opcode() != spv::Op::OpCopyObject) {
2033 return false;
2034 }
2035
2036 const analysis::Constant* a_const =
2037 const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
2038
2039 if (!a_const) {
2040 return false;
2041 }
2042
2043 bool use_x = false;
2044
2045 assert(a_const->type()->AsFloat());
2046 double element_value = a_const->GetValueAsDouble();
2047 if (element_value == 0.0) {
2048 use_x = true;
2049 } else if (element_value == 1.0) {
2050 use_x = false;
2051 } else {
2052 return false;
2053 }
2054
2055 // Get the id of the of the vector the element comes from.
2056 uint32_t new_vector = 0;
2057 if (use_x) {
2058 new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
2059 } else {
2060 new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
2061 }
2062
2063 // Update the extract instruction.
2064 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
2065 return true;
2066 };
2067 }
2068
2069 // Returns the number of elements in the composite type |type|. Returns 0 if
2070 // |type| is a scalar value. Return UINT32_MAX when the size is unknown at
2071 // compile time.
GetNumberOfElements(const analysis::Type * type)2072 uint32_t GetNumberOfElements(const analysis::Type* type) {
2073 if (auto* vector_type = type->AsVector()) {
2074 return vector_type->element_count();
2075 }
2076 if (auto* matrix_type = type->AsMatrix()) {
2077 return matrix_type->element_count();
2078 }
2079 if (auto* struct_type = type->AsStruct()) {
2080 return static_cast<uint32_t>(struct_type->element_types().size());
2081 }
2082 if (auto* array_type = type->AsArray()) {
2083 if (array_type->length_info().words[0] ==
2084 analysis::Array::LengthInfo::kConstant &&
2085 array_type->length_info().words.size() == 2) {
2086 return array_type->length_info().words[1];
2087 }
2088 return UINT32_MAX;
2089 }
2090 return 0;
2091 }
2092
2093 // Returns a map with the set of values that were inserted into an object by
2094 // the chain of OpCompositeInsertInstruction starting with |inst|.
2095 // The map will map the index to the value inserted at that index. An empty map
2096 // will be returned if the map could not be properly generated.
GetInsertedValues(Instruction * inst)2097 std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
2098 analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
2099 std::map<uint32_t, uint32_t> values_inserted;
2100 Instruction* current_inst = inst;
2101 while (current_inst->opcode() == spv::Op::OpCompositeInsert) {
2102 if (current_inst->NumInOperands() > inst->NumInOperands()) {
2103 // This is to catch the case
2104 // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
2105 // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
2106 // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
2107 // In this case we cannot do a single construct to get the matrix.
2108 uint32_t partially_inserted_element_index =
2109 current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
2110 if (values_inserted.count(partially_inserted_element_index) == 0)
2111 return {};
2112 }
2113 if (HaveSameIndexesExceptForLast(inst, current_inst)) {
2114 values_inserted.insert(
2115 {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
2116 1),
2117 current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
2118 }
2119 current_inst = def_use_mgr->GetDef(
2120 current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
2121 }
2122 return values_inserted;
2123 }
2124
2125 // Returns true of there is an entry in |values_inserted| for every element of
2126 // |Type|.
DoInsertedValuesCoverEntireObject(const analysis::Type * type,std::map<uint32_t,uint32_t> & values_inserted)2127 bool DoInsertedValuesCoverEntireObject(
2128 const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
2129 uint32_t container_size = GetNumberOfElements(type);
2130 if (container_size != values_inserted.size()) {
2131 return false;
2132 }
2133
2134 if (values_inserted.rbegin()->first >= container_size) {
2135 return false;
2136 }
2137 return true;
2138 }
2139
2140 // Returns the type of the element that immediately contains the element being
2141 // inserted by the OpCompositeInsert instruction |inst|.
GetContainerType(Instruction * inst)2142 const analysis::Type* GetContainerType(Instruction* inst) {
2143 assert(inst->opcode() == spv::Op::OpCompositeInsert);
2144 analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
2145 return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
2146 type_mgr);
2147 }
2148
2149 // Returns an OpCompositeConstruct instruction that build an object with
2150 // |type_id| out of the values in |values_inserted|. Each value will be
2151 // placed at the index corresponding to the value. The new instruction will
2152 // be placed before |insert_before|.
BuildCompositeConstruct(uint32_t type_id,const std::map<uint32_t,uint32_t> & values_inserted,Instruction * insert_before)2153 Instruction* BuildCompositeConstruct(
2154 uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
2155 Instruction* insert_before) {
2156 InstructionBuilder ir_builder(
2157 insert_before->context(), insert_before,
2158 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2159
2160 std::vector<uint32_t> ids_in_order;
2161 for (auto it : values_inserted) {
2162 ids_in_order.push_back(it.second);
2163 }
2164 Instruction* construct =
2165 ir_builder.AddCompositeConstruct(type_id, ids_in_order);
2166 return construct;
2167 }
2168
2169 // Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
2170 // object as |inst| with final index removed. If the resulting
2171 // OpCompositeInsert instruction would have no remaining indexes, the
2172 // instruction is replaced with an OpCopyObject instead.
InsertConstructedObject(Instruction * inst,const Instruction * construct)2173 void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
2174 if (inst->NumInOperands() == 3) {
2175 inst->SetOpcode(spv::Op::OpCopyObject);
2176 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
2177 } else {
2178 inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
2179 inst->RemoveOperand(inst->NumOperands() - 1);
2180 }
2181 }
2182
2183 // Replaces a series of |OpCompositeInsert| instruction that cover the entire
2184 // object with an |OpCompositeConstruct|.
CompositeInsertToCompositeConstruct(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > &)2185 bool CompositeInsertToCompositeConstruct(
2186 IRContext* context, Instruction* inst,
2187 const std::vector<const analysis::Constant*>&) {
2188 assert(inst->opcode() == spv::Op::OpCompositeInsert &&
2189 "Wrong opcode. Should be OpCompositeInsert.");
2190 if (inst->NumInOperands() < 3) return false;
2191
2192 std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
2193 const analysis::Type* container_type = GetContainerType(inst);
2194 if (container_type == nullptr) {
2195 return false;
2196 }
2197
2198 if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
2199 return false;
2200 }
2201
2202 analysis::TypeManager* type_mgr = context->get_type_mgr();
2203 Instruction* construct = BuildCompositeConstruct(
2204 type_mgr->GetId(container_type), values_inserted, inst);
2205 InsertConstructedObject(inst, construct);
2206 return true;
2207 }
2208
RedundantPhi()2209 FoldingRule RedundantPhi() {
2210 // An OpPhi instruction where all values are the same or the result of the phi
2211 // itself, can be replaced by the value itself.
2212 return [](IRContext*, Instruction* inst,
2213 const std::vector<const analysis::Constant*>&) {
2214 assert(inst->opcode() == spv::Op::OpPhi &&
2215 "Wrong opcode. Should be OpPhi.");
2216
2217 uint32_t incoming_value = 0;
2218
2219 for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
2220 uint32_t op_id = inst->GetSingleWordInOperand(i);
2221 if (op_id == inst->result_id()) {
2222 continue;
2223 }
2224
2225 if (incoming_value == 0) {
2226 incoming_value = op_id;
2227 } else if (op_id != incoming_value) {
2228 // Found two possible value. Can't simplify.
2229 return false;
2230 }
2231 }
2232
2233 if (incoming_value == 0) {
2234 // Code looks invalid. Don't do anything.
2235 return false;
2236 }
2237
2238 // We have a single incoming value. Simplify using that value.
2239 inst->SetOpcode(spv::Op::OpCopyObject);
2240 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
2241 return true;
2242 };
2243 }
2244
BitCastScalarOrVector()2245 FoldingRule BitCastScalarOrVector() {
2246 return [](IRContext* context, Instruction* inst,
2247 const std::vector<const analysis::Constant*>& constants) {
2248 assert(inst->opcode() == spv::Op::OpBitcast && constants.size() == 1);
2249 if (constants[0] == nullptr) return false;
2250
2251 const analysis::Type* type =
2252 context->get_type_mgr()->GetType(inst->type_id());
2253 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
2254 return false;
2255
2256 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2257 std::vector<uint32_t> words =
2258 GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
2259 if (words.size() == 0) return false;
2260
2261 const analysis::Constant* bitcasted_constant =
2262 ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
2263 if (!bitcasted_constant) return false;
2264
2265 auto new_feeder_id =
2266 const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
2267 ->result_id();
2268 inst->SetOpcode(spv::Op::OpCopyObject);
2269 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2270 return true;
2271 };
2272 }
2273
RedundantSelect()2274 FoldingRule RedundantSelect() {
2275 // An OpSelect instruction where both values are the same or the condition is
2276 // constant can be replaced by one of the values
2277 return [](IRContext*, Instruction* inst,
2278 const std::vector<const analysis::Constant*>& constants) {
2279 assert(inst->opcode() == spv::Op::OpSelect &&
2280 "Wrong opcode. Should be OpSelect.");
2281 assert(inst->NumInOperands() == 3);
2282 assert(constants.size() == 3);
2283
2284 uint32_t true_id = inst->GetSingleWordInOperand(1);
2285 uint32_t false_id = inst->GetSingleWordInOperand(2);
2286
2287 if (true_id == false_id) {
2288 // Both results are the same, condition doesn't matter
2289 inst->SetOpcode(spv::Op::OpCopyObject);
2290 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2291 return true;
2292 } else if (constants[0]) {
2293 const analysis::Type* type = constants[0]->type();
2294 if (type->AsBool()) {
2295 // Scalar constant value, select the corresponding value.
2296 inst->SetOpcode(spv::Op::OpCopyObject);
2297 if (constants[0]->AsNullConstant() ||
2298 !constants[0]->AsBoolConstant()->value()) {
2299 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2300 } else {
2301 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2302 }
2303 return true;
2304 } else {
2305 assert(type->AsVector());
2306 if (constants[0]->AsNullConstant()) {
2307 // All values come from false id.
2308 inst->SetOpcode(spv::Op::OpCopyObject);
2309 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2310 return true;
2311 } else {
2312 // Convert to a vector shuffle.
2313 std::vector<Operand> ops;
2314 ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
2315 ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
2316 const analysis::VectorConstant* vector_const =
2317 constants[0]->AsVectorConstant();
2318 uint32_t size =
2319 static_cast<uint32_t>(vector_const->GetComponents().size());
2320 for (uint32_t i = 0; i != size; ++i) {
2321 const analysis::Constant* component =
2322 vector_const->GetComponents()[i];
2323 if (component->AsNullConstant() ||
2324 !component->AsBoolConstant()->value()) {
2325 // Selecting from the false vector which is the second input
2326 // vector to the shuffle. Offset the index by |size|.
2327 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
2328 } else {
2329 // Selecting from true vector which is the first input vector to
2330 // the shuffle.
2331 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
2332 }
2333 }
2334
2335 inst->SetOpcode(spv::Op::OpVectorShuffle);
2336 inst->SetInOperands(std::move(ops));
2337 return true;
2338 }
2339 }
2340 }
2341
2342 return false;
2343 };
2344 }
2345
2346 enum class FloatConstantKind { Unknown, Zero, One };
2347
getFloatConstantKind(const analysis::Constant * constant)2348 FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
2349 if (constant == nullptr) {
2350 return FloatConstantKind::Unknown;
2351 }
2352
2353 assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
2354
2355 if (constant->AsNullConstant()) {
2356 return FloatConstantKind::Zero;
2357 } else if (const analysis::VectorConstant* vc =
2358 constant->AsVectorConstant()) {
2359 const std::vector<const analysis::Constant*>& components =
2360 vc->GetComponents();
2361 assert(!components.empty());
2362
2363 FloatConstantKind kind = getFloatConstantKind(components[0]);
2364
2365 for (size_t i = 1; i < components.size(); ++i) {
2366 if (getFloatConstantKind(components[i]) != kind) {
2367 return FloatConstantKind::Unknown;
2368 }
2369 }
2370
2371 return kind;
2372 } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
2373 if (fc->IsZero()) return FloatConstantKind::Zero;
2374
2375 uint32_t width = fc->type()->AsFloat()->width();
2376 if (width != 32 && width != 64) return FloatConstantKind::Unknown;
2377
2378 double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
2379
2380 if (value == 0.0) {
2381 return FloatConstantKind::Zero;
2382 } else if (value == 1.0) {
2383 return FloatConstantKind::One;
2384 } else {
2385 return FloatConstantKind::Unknown;
2386 }
2387 } else {
2388 return FloatConstantKind::Unknown;
2389 }
2390 }
2391
RedundantFAdd()2392 FoldingRule RedundantFAdd() {
2393 return [](IRContext*, Instruction* inst,
2394 const std::vector<const analysis::Constant*>& constants) {
2395 assert(inst->opcode() == spv::Op::OpFAdd &&
2396 "Wrong opcode. Should be OpFAdd.");
2397 assert(constants.size() == 2);
2398
2399 if (!inst->IsFloatingPointFoldingAllowed()) {
2400 return false;
2401 }
2402
2403 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2404 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2405
2406 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2407 inst->SetOpcode(spv::Op::OpCopyObject);
2408 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2409 {inst->GetSingleWordInOperand(
2410 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
2411 return true;
2412 }
2413
2414 return false;
2415 };
2416 }
2417
RedundantFSub()2418 FoldingRule RedundantFSub() {
2419 return [](IRContext*, Instruction* inst,
2420 const std::vector<const analysis::Constant*>& constants) {
2421 assert(inst->opcode() == spv::Op::OpFSub &&
2422 "Wrong opcode. Should be OpFSub.");
2423 assert(constants.size() == 2);
2424
2425 if (!inst->IsFloatingPointFoldingAllowed()) {
2426 return false;
2427 }
2428
2429 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2430 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2431
2432 if (kind0 == FloatConstantKind::Zero) {
2433 inst->SetOpcode(spv::Op::OpFNegate);
2434 inst->SetInOperands(
2435 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
2436 return true;
2437 }
2438
2439 if (kind1 == FloatConstantKind::Zero) {
2440 inst->SetOpcode(spv::Op::OpCopyObject);
2441 inst->SetInOperands(
2442 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2443 return true;
2444 }
2445
2446 return false;
2447 };
2448 }
2449
RedundantFMul()2450 FoldingRule RedundantFMul() {
2451 return [](IRContext*, Instruction* inst,
2452 const std::vector<const analysis::Constant*>& constants) {
2453 assert(inst->opcode() == spv::Op::OpFMul &&
2454 "Wrong opcode. Should be OpFMul.");
2455 assert(constants.size() == 2);
2456
2457 if (!inst->IsFloatingPointFoldingAllowed()) {
2458 return false;
2459 }
2460
2461 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2462 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2463
2464 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2465 inst->SetOpcode(spv::Op::OpCopyObject);
2466 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2467 {inst->GetSingleWordInOperand(
2468 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
2469 return true;
2470 }
2471
2472 if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
2473 inst->SetOpcode(spv::Op::OpCopyObject);
2474 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2475 {inst->GetSingleWordInOperand(
2476 kind0 == FloatConstantKind::One ? 1 : 0)}}});
2477 return true;
2478 }
2479
2480 return false;
2481 };
2482 }
2483
RedundantFDiv()2484 FoldingRule RedundantFDiv() {
2485 return [](IRContext*, Instruction* inst,
2486 const std::vector<const analysis::Constant*>& constants) {
2487 assert(inst->opcode() == spv::Op::OpFDiv &&
2488 "Wrong opcode. Should be OpFDiv.");
2489 assert(constants.size() == 2);
2490
2491 if (!inst->IsFloatingPointFoldingAllowed()) {
2492 return false;
2493 }
2494
2495 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2496 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2497
2498 if (kind0 == FloatConstantKind::Zero) {
2499 inst->SetOpcode(spv::Op::OpCopyObject);
2500 inst->SetInOperands(
2501 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2502 return true;
2503 }
2504
2505 if (kind1 == FloatConstantKind::One) {
2506 inst->SetOpcode(spv::Op::OpCopyObject);
2507 inst->SetInOperands(
2508 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2509 return true;
2510 }
2511
2512 return false;
2513 };
2514 }
2515
RedundantFMix()2516 FoldingRule RedundantFMix() {
2517 return [](IRContext* context, Instruction* inst,
2518 const std::vector<const analysis::Constant*>& constants) {
2519 assert(inst->opcode() == spv::Op::OpExtInst &&
2520 "Wrong opcode. Should be OpExtInst.");
2521
2522 if (!inst->IsFloatingPointFoldingAllowed()) {
2523 return false;
2524 }
2525
2526 uint32_t instSetId =
2527 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2528
2529 if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2530 inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2531 GLSLstd450FMix) {
2532 assert(constants.size() == 5);
2533
2534 FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2535
2536 if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2537 inst->SetOpcode(spv::Op::OpCopyObject);
2538 inst->SetInOperands(
2539 {{SPV_OPERAND_TYPE_ID,
2540 {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2541 ? kFMixXIdInIdx
2542 : kFMixYIdInIdx)}}});
2543 return true;
2544 }
2545 }
2546
2547 return false;
2548 };
2549 }
2550
2551 // This rule handles addition of zero for integers.
RedundantIAdd()2552 FoldingRule RedundantIAdd() {
2553 return [](IRContext* context, Instruction* inst,
2554 const std::vector<const analysis::Constant*>& constants) {
2555 assert(inst->opcode() == spv::Op::OpIAdd &&
2556 "Wrong opcode. Should be OpIAdd.");
2557
2558 uint32_t operand = std::numeric_limits<uint32_t>::max();
2559 const analysis::Type* operand_type = nullptr;
2560 if (constants[0] && constants[0]->IsZero()) {
2561 operand = inst->GetSingleWordInOperand(1);
2562 operand_type = constants[0]->type();
2563 } else if (constants[1] && constants[1]->IsZero()) {
2564 operand = inst->GetSingleWordInOperand(0);
2565 operand_type = constants[1]->type();
2566 }
2567
2568 if (operand != std::numeric_limits<uint32_t>::max()) {
2569 const analysis::Type* inst_type =
2570 context->get_type_mgr()->GetType(inst->type_id());
2571 if (inst_type->IsSame(operand_type)) {
2572 inst->SetOpcode(spv::Op::OpCopyObject);
2573 } else {
2574 inst->SetOpcode(spv::Op::OpBitcast);
2575 }
2576 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2577 return true;
2578 }
2579 return false;
2580 };
2581 }
2582
2583 // This rule look for a dot with a constant vector containing a single 1 and
2584 // the rest 0s. This is the same as doing an extract.
DotProductDoingExtract()2585 FoldingRule DotProductDoingExtract() {
2586 return [](IRContext* context, Instruction* inst,
2587 const std::vector<const analysis::Constant*>& constants) {
2588 assert(inst->opcode() == spv::Op::OpDot &&
2589 "Wrong opcode. Should be OpDot.");
2590
2591 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2592
2593 if (!inst->IsFloatingPointFoldingAllowed()) {
2594 return false;
2595 }
2596
2597 for (int i = 0; i < 2; ++i) {
2598 if (!constants[i]) {
2599 continue;
2600 }
2601
2602 const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2603 assert(vector_type && "Inputs to OpDot must be vectors.");
2604 const analysis::Float* element_type =
2605 vector_type->element_type()->AsFloat();
2606 assert(element_type && "Inputs to OpDot must be vectors of floats.");
2607 uint32_t element_width = element_type->width();
2608 if (element_width != 32 && element_width != 64) {
2609 return false;
2610 }
2611
2612 std::vector<const analysis::Constant*> components;
2613 components = constants[i]->GetVectorComponents(const_mgr);
2614
2615 constexpr uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2616
2617 uint32_t component_with_one = kNotFound;
2618 bool all_others_zero = true;
2619 for (uint32_t j = 0; j < components.size(); ++j) {
2620 const analysis::Constant* element = components[j];
2621 double value =
2622 (element_width == 32 ? element->GetFloat() : element->GetDouble());
2623 if (value == 0.0) {
2624 continue;
2625 } else if (value == 1.0) {
2626 if (component_with_one == kNotFound) {
2627 component_with_one = j;
2628 } else {
2629 component_with_one = kNotFound;
2630 break;
2631 }
2632 } else {
2633 all_others_zero = false;
2634 break;
2635 }
2636 }
2637
2638 if (!all_others_zero || component_with_one == kNotFound) {
2639 continue;
2640 }
2641
2642 std::vector<Operand> operands;
2643 operands.push_back(
2644 {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2645 operands.push_back(
2646 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2647
2648 inst->SetOpcode(spv::Op::OpCompositeExtract);
2649 inst->SetInOperands(std::move(operands));
2650 return true;
2651 }
2652 return false;
2653 };
2654 }
2655
2656 // If we are storing an undef, then we can remove the store.
2657 //
2658 // TODO: We can do something similar for OpImageWrite, but checking for volatile
2659 // is complicated. Waiting to see if it is needed.
StoringUndef()2660 FoldingRule StoringUndef() {
2661 return [](IRContext* context, Instruction* inst,
2662 const std::vector<const analysis::Constant*>&) {
2663 assert(inst->opcode() == spv::Op::OpStore &&
2664 "Wrong opcode. Should be OpStore.");
2665
2666 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2667
2668 // If this is a volatile store, the store cannot be removed.
2669 if (inst->NumInOperands() == 3) {
2670 if (inst->GetSingleWordInOperand(2) &
2671 uint32_t(spv::MemoryAccessMask::Volatile)) {
2672 return false;
2673 }
2674 }
2675
2676 uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2677 Instruction* object_inst = def_use_mgr->GetDef(object_id);
2678 if (object_inst->opcode() == spv::Op::OpUndef) {
2679 inst->ToNop();
2680 return true;
2681 }
2682 return false;
2683 };
2684 }
2685
VectorShuffleFeedingShuffle()2686 FoldingRule VectorShuffleFeedingShuffle() {
2687 return [](IRContext* context, Instruction* inst,
2688 const std::vector<const analysis::Constant*>&) {
2689 assert(inst->opcode() == spv::Op::OpVectorShuffle &&
2690 "Wrong opcode. Should be OpVectorShuffle.");
2691
2692 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2693 analysis::TypeManager* type_mgr = context->get_type_mgr();
2694
2695 Instruction* feeding_shuffle_inst =
2696 def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2697 analysis::Vector* op0_type =
2698 type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2699 uint32_t op0_length = op0_type->element_count();
2700
2701 bool feeder_is_op0 = true;
2702 if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2703 feeding_shuffle_inst =
2704 def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2705 feeder_is_op0 = false;
2706 }
2707
2708 if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2709 return false;
2710 }
2711
2712 Instruction* feeder2 =
2713 def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2714 analysis::Vector* feeder_op0_type =
2715 type_mgr->GetType(feeder2->type_id())->AsVector();
2716 uint32_t feeder_op0_length = feeder_op0_type->element_count();
2717
2718 uint32_t new_feeder_id = 0;
2719 std::vector<Operand> new_operands;
2720 new_operands.resize(
2721 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
2722 const uint32_t undef_literal = 0xffffffff;
2723 for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2724 uint32_t component_index = inst->GetSingleWordInOperand(op);
2725
2726 // Do not interpret the undefined value literal as coming from operand 1.
2727 if (component_index != undef_literal &&
2728 feeder_is_op0 == (component_index < op0_length)) {
2729 // This component comes from the feeding_shuffle_inst. Update
2730 // |component_index| to be the index into the operand of the feeder.
2731
2732 // Adjust component_index to get the index into the operands of the
2733 // feeding_shuffle_inst.
2734 if (component_index >= op0_length) {
2735 component_index -= op0_length;
2736 }
2737 component_index =
2738 feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2739
2740 // Check if we are using a component from the first or second operand of
2741 // the feeding instruction.
2742 if (component_index < feeder_op0_length) {
2743 if (new_feeder_id == 0) {
2744 // First time through, save the id of the operand the element comes
2745 // from.
2746 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2747 } else if (new_feeder_id !=
2748 feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2749 // We need both elements of the feeding_shuffle_inst, so we cannot
2750 // fold.
2751 return false;
2752 }
2753 } else if (component_index != undef_literal) {
2754 if (new_feeder_id == 0) {
2755 // First time through, save the id of the operand the element comes
2756 // from.
2757 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2758 } else if (new_feeder_id !=
2759 feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2760 // We need both elements of the feeding_shuffle_inst, so we cannot
2761 // fold.
2762 return false;
2763 }
2764 component_index -= feeder_op0_length;
2765 }
2766
2767 if (!feeder_is_op0 && component_index != undef_literal) {
2768 component_index += op0_length;
2769 }
2770 }
2771 new_operands.push_back(
2772 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2773 }
2774
2775 if (new_feeder_id == 0) {
2776 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2777 const analysis::Type* type =
2778 type_mgr->GetType(feeding_shuffle_inst->type_id());
2779 const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2780 new_feeder_id =
2781 const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2782 }
2783
2784 if (feeder_is_op0) {
2785 // If the size of the first vector operand changed then the indices
2786 // referring to the second operand need to be adjusted.
2787 Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2788 analysis::Type* new_feeder_type =
2789 type_mgr->GetType(new_feeder_inst->type_id());
2790 uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2791 int32_t adjustment = op0_length - new_op0_size;
2792
2793 if (adjustment != 0) {
2794 for (uint32_t i = 2; i < new_operands.size(); i++) {
2795 uint32_t operand = inst->GetSingleWordInOperand(i);
2796 if (operand >= op0_length && operand != undef_literal) {
2797 new_operands[i].words[0] -= adjustment;
2798 }
2799 }
2800 }
2801
2802 new_operands[0].words[0] = new_feeder_id;
2803 new_operands[1] = inst->GetInOperand(1);
2804 } else {
2805 new_operands[1].words[0] = new_feeder_id;
2806 new_operands[0] = inst->GetInOperand(0);
2807 }
2808
2809 inst->SetInOperands(std::move(new_operands));
2810 return true;
2811 };
2812 }
2813
2814 // Removes duplicate ids from the interface list of an OpEntryPoint
2815 // instruction.
RemoveRedundantOperands()2816 FoldingRule RemoveRedundantOperands() {
2817 return [](IRContext*, Instruction* inst,
2818 const std::vector<const analysis::Constant*>&) {
2819 assert(inst->opcode() == spv::Op::OpEntryPoint &&
2820 "Wrong opcode. Should be OpEntryPoint.");
2821 bool has_redundant_operand = false;
2822 std::unordered_set<uint32_t> seen_operands;
2823 std::vector<Operand> new_operands;
2824
2825 new_operands.emplace_back(inst->GetOperand(0));
2826 new_operands.emplace_back(inst->GetOperand(1));
2827 new_operands.emplace_back(inst->GetOperand(2));
2828 for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2829 if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2830 new_operands.emplace_back(inst->GetOperand(i));
2831 } else {
2832 has_redundant_operand = true;
2833 }
2834 }
2835
2836 if (!has_redundant_operand) {
2837 return false;
2838 }
2839
2840 inst->SetInOperands(std::move(new_operands));
2841 return true;
2842 };
2843 }
2844
2845 // If an image instruction's operand is a constant, updates the image operand
2846 // flag from Offset to ConstOffset.
UpdateImageOperands()2847 FoldingRule UpdateImageOperands() {
2848 return [](IRContext*, Instruction* inst,
2849 const std::vector<const analysis::Constant*>& constants) {
2850 const auto opcode = inst->opcode();
2851 (void)opcode;
2852 assert((opcode == spv::Op::OpImageSampleImplicitLod ||
2853 opcode == spv::Op::OpImageSampleExplicitLod ||
2854 opcode == spv::Op::OpImageSampleDrefImplicitLod ||
2855 opcode == spv::Op::OpImageSampleDrefExplicitLod ||
2856 opcode == spv::Op::OpImageSampleProjImplicitLod ||
2857 opcode == spv::Op::OpImageSampleProjExplicitLod ||
2858 opcode == spv::Op::OpImageSampleProjDrefImplicitLod ||
2859 opcode == spv::Op::OpImageSampleProjDrefExplicitLod ||
2860 opcode == spv::Op::OpImageFetch ||
2861 opcode == spv::Op::OpImageGather ||
2862 opcode == spv::Op::OpImageDrefGather ||
2863 opcode == spv::Op::OpImageRead || opcode == spv::Op::OpImageWrite ||
2864 opcode == spv::Op::OpImageSparseSampleImplicitLod ||
2865 opcode == spv::Op::OpImageSparseSampleExplicitLod ||
2866 opcode == spv::Op::OpImageSparseSampleDrefImplicitLod ||
2867 opcode == spv::Op::OpImageSparseSampleDrefExplicitLod ||
2868 opcode == spv::Op::OpImageSparseSampleProjImplicitLod ||
2869 opcode == spv::Op::OpImageSparseSampleProjExplicitLod ||
2870 opcode == spv::Op::OpImageSparseSampleProjDrefImplicitLod ||
2871 opcode == spv::Op::OpImageSparseSampleProjDrefExplicitLod ||
2872 opcode == spv::Op::OpImageSparseFetch ||
2873 opcode == spv::Op::OpImageSparseGather ||
2874 opcode == spv::Op::OpImageSparseDrefGather ||
2875 opcode == spv::Op::OpImageSparseRead) &&
2876 "Wrong opcode. Should be an image instruction.");
2877
2878 int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2879 if (operand_index >= 0) {
2880 auto image_operands = inst->GetSingleWordInOperand(operand_index);
2881 if (image_operands & uint32_t(spv::ImageOperandsMask::Offset)) {
2882 uint32_t offset_operand_index = operand_index + 1;
2883 if (image_operands & uint32_t(spv::ImageOperandsMask::Bias))
2884 offset_operand_index++;
2885 if (image_operands & uint32_t(spv::ImageOperandsMask::Lod))
2886 offset_operand_index++;
2887 if (image_operands & uint32_t(spv::ImageOperandsMask::Grad))
2888 offset_operand_index += 2;
2889 assert(((image_operands &
2890 uint32_t(spv::ImageOperandsMask::ConstOffset)) == 0) &&
2891 "Offset and ConstOffset may not be used together");
2892 if (offset_operand_index < inst->NumOperands()) {
2893 if (constants[offset_operand_index]) {
2894 if (constants[offset_operand_index]->IsZero()) {
2895 inst->RemoveInOperand(offset_operand_index);
2896 } else {
2897 image_operands = image_operands |
2898 uint32_t(spv::ImageOperandsMask::ConstOffset);
2899 }
2900 image_operands =
2901 image_operands & ~uint32_t(spv::ImageOperandsMask::Offset);
2902 inst->SetInOperand(operand_index, {image_operands});
2903 return true;
2904 }
2905 }
2906 }
2907 }
2908
2909 return false;
2910 };
2911 }
2912
2913 } // namespace
2914
AddFoldingRules()2915 void FoldingRules::AddFoldingRules() {
2916 // Add all folding rules to the list for the opcodes to which they apply.
2917 // Note that the order in which rules are added to the list matters. If a rule
2918 // applies to the instruction, the rest of the rules will not be attempted.
2919 // Take that into consideration.
2920 rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());
2921
2922 rules_[spv::Op::OpCompositeConstruct].push_back(
2923 CompositeExtractFeedingConstruct);
2924
2925 rules_[spv::Op::OpCompositeExtract].push_back(InsertFeedingExtract());
2926 rules_[spv::Op::OpCompositeExtract].push_back(
2927 CompositeConstructFeedingExtract);
2928 rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2929 rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract());
2930
2931 rules_[spv::Op::OpCompositeInsert].push_back(
2932 CompositeInsertToCompositeConstruct);
2933
2934 rules_[spv::Op::OpDot].push_back(DotProductDoingExtract());
2935
2936 rules_[spv::Op::OpEntryPoint].push_back(RemoveRedundantOperands());
2937
2938 rules_[spv::Op::OpFAdd].push_back(RedundantFAdd());
2939 rules_[spv::Op::OpFAdd].push_back(MergeAddNegateArithmetic());
2940 rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic());
2941 rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic());
2942 rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic());
2943 rules_[spv::Op::OpFAdd].push_back(FactorAddMuls());
2944 rules_[spv::Op::OpFAdd].push_back(MergeMulAddArithmetic);
2945
2946 rules_[spv::Op::OpFDiv].push_back(RedundantFDiv());
2947 rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv());
2948 rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic());
2949 rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
2950 rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());
2951
2952 rules_[spv::Op::OpFMul].push_back(RedundantFMul());
2953 rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
2954 rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());
2955 rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic());
2956
2957 rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic());
2958 rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic());
2959 rules_[spv::Op::OpFNegate].push_back(MergeNegateMulDivArithmetic());
2960
2961 rules_[spv::Op::OpFSub].push_back(RedundantFSub());
2962 rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic());
2963 rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic());
2964 rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic());
2965 rules_[spv::Op::OpFSub].push_back(MergeMulSubArithmetic);
2966
2967 rules_[spv::Op::OpIAdd].push_back(RedundantIAdd());
2968 rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic());
2969 rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic());
2970 rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic());
2971 rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic());
2972 rules_[spv::Op::OpIAdd].push_back(FactorAddMuls());
2973
2974 rules_[spv::Op::OpIMul].push_back(IntMultipleBy1());
2975 rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic());
2976 rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic());
2977
2978 rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic());
2979 rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic());
2980 rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic());
2981
2982 rules_[spv::Op::OpPhi].push_back(RedundantPhi());
2983
2984 rules_[spv::Op::OpSNegate].push_back(MergeNegateArithmetic());
2985 rules_[spv::Op::OpSNegate].push_back(MergeNegateMulDivArithmetic());
2986 rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic());
2987
2988 rules_[spv::Op::OpSelect].push_back(RedundantSelect());
2989
2990 rules_[spv::Op::OpStore].push_back(StoringUndef());
2991
2992 rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2993
2994 rules_[spv::Op::OpImageSampleImplicitLod].push_back(UpdateImageOperands());
2995 rules_[spv::Op::OpImageSampleExplicitLod].push_back(UpdateImageOperands());
2996 rules_[spv::Op::OpImageSampleDrefImplicitLod].push_back(
2997 UpdateImageOperands());
2998 rules_[spv::Op::OpImageSampleDrefExplicitLod].push_back(
2999 UpdateImageOperands());
3000 rules_[spv::Op::OpImageSampleProjImplicitLod].push_back(
3001 UpdateImageOperands());
3002 rules_[spv::Op::OpImageSampleProjExplicitLod].push_back(
3003 UpdateImageOperands());
3004 rules_[spv::Op::OpImageSampleProjDrefImplicitLod].push_back(
3005 UpdateImageOperands());
3006 rules_[spv::Op::OpImageSampleProjDrefExplicitLod].push_back(
3007 UpdateImageOperands());
3008 rules_[spv::Op::OpImageFetch].push_back(UpdateImageOperands());
3009 rules_[spv::Op::OpImageGather].push_back(UpdateImageOperands());
3010 rules_[spv::Op::OpImageDrefGather].push_back(UpdateImageOperands());
3011 rules_[spv::Op::OpImageRead].push_back(UpdateImageOperands());
3012 rules_[spv::Op::OpImageWrite].push_back(UpdateImageOperands());
3013 rules_[spv::Op::OpImageSparseSampleImplicitLod].push_back(
3014 UpdateImageOperands());
3015 rules_[spv::Op::OpImageSparseSampleExplicitLod].push_back(
3016 UpdateImageOperands());
3017 rules_[spv::Op::OpImageSparseSampleDrefImplicitLod].push_back(
3018 UpdateImageOperands());
3019 rules_[spv::Op::OpImageSparseSampleDrefExplicitLod].push_back(
3020 UpdateImageOperands());
3021 rules_[spv::Op::OpImageSparseSampleProjImplicitLod].push_back(
3022 UpdateImageOperands());
3023 rules_[spv::Op::OpImageSparseSampleProjExplicitLod].push_back(
3024 UpdateImageOperands());
3025 rules_[spv::Op::OpImageSparseSampleProjDrefImplicitLod].push_back(
3026 UpdateImageOperands());
3027 rules_[spv::Op::OpImageSparseSampleProjDrefExplicitLod].push_back(
3028 UpdateImageOperands());
3029 rules_[spv::Op::OpImageSparseFetch].push_back(UpdateImageOperands());
3030 rules_[spv::Op::OpImageSparseGather].push_back(UpdateImageOperands());
3031 rules_[spv::Op::OpImageSparseDrefGather].push_back(UpdateImageOperands());
3032 rules_[spv::Op::OpImageSparseRead].push_back(UpdateImageOperands());
3033
3034 FeatureManager* feature_manager = context_->get_feature_mgr();
3035 // Add rules for GLSLstd450
3036 uint32_t ext_inst_glslstd450_id =
3037 feature_manager->GetExtInstImportId_GLSLstd450();
3038 if (ext_inst_glslstd450_id != 0) {
3039 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
3040 RedundantFMix());
3041 }
3042 }
3043 } // namespace opt
3044 } // namespace spvtools
3045