• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2018 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "aco_builder.h"
8 #include "aco_ir.h"
9 
10 #include "util/half_float.h"
11 #include "util/memstream.h"
12 
13 #include <algorithm>
14 #include <array>
15 #include <vector>
16 
17 namespace aco {
18 
19 namespace {
20 /**
21  * The optimizer works in 4 phases:
22  * (1) The first pass collects information for each ssa-def,
23  *     propagates reg->reg operands of the same type, inline constants
24  *     and neg/abs input modifiers.
25  * (2) The second pass combines instructions like mad, omod, clamp and
26  *     propagates sgpr's on VALU instructions.
27  *     This pass depends on information collected in the first pass.
28  * (3) The third pass goes backwards, and selects instructions,
29  *     i.e. decides if a mad instruction is profitable and eliminates dead code.
30  * (4) The fourth pass cleans up the sequence: literals get applied and dead
31  *     instructions are removed from the sequence.
32  */
33 
34 struct mad_info {
35    aco_ptr<Instruction> add_instr;
36    uint32_t mul_temp_id;
37    uint16_t literal_mask;
38    uint16_t fp16_mask;
39 
mad_infoaco::__anonb463357b0111::mad_info40    mad_info(aco_ptr<Instruction> instr, uint32_t id)
41        : add_instr(std::move(instr)), mul_temp_id(id), literal_mask(0), fp16_mask(0)
42    {}
43 };
44 
45 enum Label {
46    label_vec = 1 << 0,
47    label_constant_32bit = 1 << 1,
48    /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
49     * 32-bit operations but this shouldn't cause any issues because we don't
50     * look through any conversions */
51    label_abs = 1 << 2,
52    label_neg = 1 << 3,
53    label_mul = 1 << 4,
54    label_temp = 1 << 5,
55    label_literal = 1 << 6,
56    label_mad = 1 << 7,
57    label_omod2 = 1 << 8,
58    label_omod4 = 1 << 9,
59    label_omod5 = 1 << 10,
60    label_clamp = 1 << 12,
61    label_b2f = 1 << 16,
62    label_add_sub = 1 << 17,
63    label_bitwise = 1 << 18,
64    label_minmax = 1 << 19,
65    label_vopc = 1 << 20,
66    label_uniform_bool = 1 << 21,
67    label_constant_64bit = 1 << 22,
68    label_uniform_bitwise = 1 << 23,
69    label_scc_invert = 1 << 24,
70    label_scc_needed = 1 << 26,
71    label_b2i = 1 << 27,
72    label_fcanonicalize = 1 << 28,
73    label_constant_16bit = 1 << 29,
74    label_usedef = 1 << 30,   /* generic label */
75    label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */
76    label_canonicalized = 1ull << 32,
77    label_extract = 1ull << 33,
78    label_insert = 1ull << 34,
79    label_dpp16 = 1ull << 35,
80    label_dpp8 = 1ull << 36,
81    label_f2f32 = 1ull << 37,
82    label_f2f16 = 1ull << 38,
83    label_split = 1ull << 39,
84 };
85 
86 static constexpr uint64_t instr_usedef_labels =
87    label_vec | label_mul | label_add_sub | label_vop3p | label_bitwise | label_uniform_bitwise |
88    label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 | label_dpp8 |
89    label_f2f32;
90 static constexpr uint64_t instr_mod_labels =
91    label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
92 
93 static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels | label_split;
94 static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_b2f |
95                                         label_uniform_bool | label_scc_invert | label_b2i |
96                                         label_fcanonicalize;
97 static constexpr uint32_t val_labels =
98    label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal | label_mad;
99 
100 static_assert((instr_labels & temp_labels) == 0, "labels cannot intersect");
101 static_assert((instr_labels & val_labels) == 0, "labels cannot intersect");
102 static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
103 
104 struct ssa_info {
105    uint64_t label;
106    union {
107       uint32_t val;
108       Temp temp;
109       Instruction* instr;
110    };
111 
ssa_infoaco::__anonb463357b0111::ssa_info112    ssa_info() : label(0) {}
113 
add_labelaco::__anonb463357b0111::ssa_info114    void add_label(Label new_label)
115    {
116       /* Since all the instr_usedef_labels use instr for the same thing
117        * (indicating the defining instruction), there is usually no need to
118        * clear any other instr labels. */
119       if (new_label & instr_usedef_labels)
120          label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */
121 
122       if (new_label & instr_mod_labels) {
123          label &= ~instr_labels;
124          label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
125       }
126 
127       if (new_label & temp_labels) {
128          label &= ~temp_labels;
129          label &= ~(instr_labels | val_labels); /* instr, temp and val alias */
130       }
131 
132       uint32_t const_labels =
133          label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
134       if (new_label & const_labels) {
135          label &= ~val_labels | const_labels;
136          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
137       } else if (new_label & val_labels) {
138          label &= ~val_labels;
139          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
140       }
141 
142       label |= new_label;
143    }
144 
set_vecaco::__anonb463357b0111::ssa_info145    void set_vec(Instruction* vec)
146    {
147       add_label(label_vec);
148       instr = vec;
149    }
150 
is_vecaco::__anonb463357b0111::ssa_info151    bool is_vec() { return label & label_vec; }
152 
set_constantaco::__anonb463357b0111::ssa_info153    void set_constant(amd_gfx_level gfx_level, uint64_t constant)
154    {
155       Operand op16 = Operand::c16(constant);
156       Operand op32 = Operand::get_const(gfx_level, constant, 4);
157       add_label(label_literal);
158       val = constant;
159 
160       /* check that no upper bits are lost in case of packed 16bit constants */
161       if (gfx_level >= GFX8 && !op16.isLiteral() &&
162           op16.constantValue16(true) == ((constant >> 16) & 0xffff))
163          add_label(label_constant_16bit);
164 
165       if (!op32.isLiteral())
166          add_label(label_constant_32bit);
167 
168       if (Operand::is_constant_representable(constant, 8))
169          add_label(label_constant_64bit);
170 
171       if (label & label_constant_64bit) {
172          val = Operand::c64(constant).constantValue();
173          if (val != constant)
174             label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
175       }
176    }
177 
is_constantaco::__anonb463357b0111::ssa_info178    bool is_constant(unsigned bits)
179    {
180       switch (bits) {
181       case 8: return label & label_literal;
182       case 16: return label & label_constant_16bit;
183       case 32: return label & label_constant_32bit;
184       case 64: return label & label_constant_64bit;
185       }
186       return false;
187    }
188 
is_literalaco::__anonb463357b0111::ssa_info189    bool is_literal(unsigned bits)
190    {
191       bool is_lit = label & label_literal;
192       switch (bits) {
193       case 8: return false;
194       case 16: return is_lit && ~(label & label_constant_16bit);
195       case 32: return is_lit && ~(label & label_constant_32bit);
196       case 64: return false;
197       }
198       return false;
199    }
200 
is_constant_or_literalaco::__anonb463357b0111::ssa_info201    bool is_constant_or_literal(unsigned bits)
202    {
203       if (bits == 64)
204          return label & label_constant_64bit;
205       else
206          return label & label_literal;
207    }
208 
set_absaco::__anonb463357b0111::ssa_info209    void set_abs(Temp abs_temp)
210    {
211       add_label(label_abs);
212       temp = abs_temp;
213    }
214 
is_absaco::__anonb463357b0111::ssa_info215    bool is_abs() { return label & label_abs; }
216 
set_negaco::__anonb463357b0111::ssa_info217    void set_neg(Temp neg_temp)
218    {
219       add_label(label_neg);
220       temp = neg_temp;
221    }
222 
is_negaco::__anonb463357b0111::ssa_info223    bool is_neg() { return label & label_neg; }
224 
set_neg_absaco::__anonb463357b0111::ssa_info225    void set_neg_abs(Temp neg_abs_temp)
226    {
227       add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
228       temp = neg_abs_temp;
229    }
230 
set_mulaco::__anonb463357b0111::ssa_info231    void set_mul(Instruction* mul)
232    {
233       add_label(label_mul);
234       instr = mul;
235    }
236 
is_mulaco::__anonb463357b0111::ssa_info237    bool is_mul() { return label & label_mul; }
238 
set_tempaco::__anonb463357b0111::ssa_info239    void set_temp(Temp tmp)
240    {
241       add_label(label_temp);
242       temp = tmp;
243    }
244 
is_tempaco::__anonb463357b0111::ssa_info245    bool is_temp() { return label & label_temp; }
246 
set_madaco::__anonb463357b0111::ssa_info247    void set_mad(uint32_t mad_info_idx)
248    {
249       add_label(label_mad);
250       val = mad_info_idx;
251    }
252 
is_madaco::__anonb463357b0111::ssa_info253    bool is_mad() { return label & label_mad; }
254 
set_omod2aco::__anonb463357b0111::ssa_info255    void set_omod2(Instruction* mul)
256    {
257       if (label & temp_labels)
258          return;
259       add_label(label_omod2);
260       instr = mul;
261    }
262 
is_omod2aco::__anonb463357b0111::ssa_info263    bool is_omod2() { return label & label_omod2; }
264 
set_omod4aco::__anonb463357b0111::ssa_info265    void set_omod4(Instruction* mul)
266    {
267       if (label & temp_labels)
268          return;
269       add_label(label_omod4);
270       instr = mul;
271    }
272 
is_omod4aco::__anonb463357b0111::ssa_info273    bool is_omod4() { return label & label_omod4; }
274 
set_omod5aco::__anonb463357b0111::ssa_info275    void set_omod5(Instruction* mul)
276    {
277       if (label & temp_labels)
278          return;
279       add_label(label_omod5);
280       instr = mul;
281    }
282 
is_omod5aco::__anonb463357b0111::ssa_info283    bool is_omod5() { return label & label_omod5; }
284 
set_clampaco::__anonb463357b0111::ssa_info285    void set_clamp(Instruction* med3)
286    {
287       if (label & temp_labels)
288          return;
289       add_label(label_clamp);
290       instr = med3;
291    }
292 
is_clampaco::__anonb463357b0111::ssa_info293    bool is_clamp() { return label & label_clamp; }
294 
set_f2f16aco::__anonb463357b0111::ssa_info295    void set_f2f16(Instruction* conv)
296    {
297       if (label & temp_labels)
298          return;
299       add_label(label_f2f16);
300       instr = conv;
301    }
302 
is_f2f16aco::__anonb463357b0111::ssa_info303    bool is_f2f16() { return label & label_f2f16; }
304 
set_b2faco::__anonb463357b0111::ssa_info305    void set_b2f(Temp b2f_val)
306    {
307       add_label(label_b2f);
308       temp = b2f_val;
309    }
310 
is_b2faco::__anonb463357b0111::ssa_info311    bool is_b2f() { return label & label_b2f; }
312 
set_add_subaco::__anonb463357b0111::ssa_info313    void set_add_sub(Instruction* add_sub_instr)
314    {
315       add_label(label_add_sub);
316       instr = add_sub_instr;
317    }
318 
is_add_subaco::__anonb463357b0111::ssa_info319    bool is_add_sub() { return label & label_add_sub; }
320 
set_bitwiseaco::__anonb463357b0111::ssa_info321    void set_bitwise(Instruction* bitwise_instr)
322    {
323       add_label(label_bitwise);
324       instr = bitwise_instr;
325    }
326 
is_bitwiseaco::__anonb463357b0111::ssa_info327    bool is_bitwise() { return label & label_bitwise; }
328 
set_uniform_bitwiseaco::__anonb463357b0111::ssa_info329    void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
330 
is_uniform_bitwiseaco::__anonb463357b0111::ssa_info331    bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
332 
set_minmaxaco::__anonb463357b0111::ssa_info333    void set_minmax(Instruction* minmax_instr)
334    {
335       add_label(label_minmax);
336       instr = minmax_instr;
337    }
338 
is_minmaxaco::__anonb463357b0111::ssa_info339    bool is_minmax() { return label & label_minmax; }
340 
set_vopcaco::__anonb463357b0111::ssa_info341    void set_vopc(Instruction* vopc_instr)
342    {
343       add_label(label_vopc);
344       instr = vopc_instr;
345    }
346 
is_vopcaco::__anonb463357b0111::ssa_info347    bool is_vopc() { return label & label_vopc; }
348 
set_scc_neededaco::__anonb463357b0111::ssa_info349    void set_scc_needed() { add_label(label_scc_needed); }
350 
is_scc_neededaco::__anonb463357b0111::ssa_info351    bool is_scc_needed() { return label & label_scc_needed; }
352 
set_scc_invertaco::__anonb463357b0111::ssa_info353    void set_scc_invert(Temp scc_inv)
354    {
355       add_label(label_scc_invert);
356       temp = scc_inv;
357    }
358 
is_scc_invertaco::__anonb463357b0111::ssa_info359    bool is_scc_invert() { return label & label_scc_invert; }
360 
set_uniform_boolaco::__anonb463357b0111::ssa_info361    void set_uniform_bool(Temp uniform_bool)
362    {
363       add_label(label_uniform_bool);
364       temp = uniform_bool;
365    }
366 
is_uniform_boolaco::__anonb463357b0111::ssa_info367    bool is_uniform_bool() { return label & label_uniform_bool; }
368 
set_b2iaco::__anonb463357b0111::ssa_info369    void set_b2i(Temp b2i_val)
370    {
371       add_label(label_b2i);
372       temp = b2i_val;
373    }
374 
is_b2iaco::__anonb463357b0111::ssa_info375    bool is_b2i() { return label & label_b2i; }
376 
set_usedefaco::__anonb463357b0111::ssa_info377    void set_usedef(Instruction* label_instr)
378    {
379       add_label(label_usedef);
380       instr = label_instr;
381    }
382 
is_usedefaco::__anonb463357b0111::ssa_info383    bool is_usedef() { return label & label_usedef; }
384 
set_vop3paco::__anonb463357b0111::ssa_info385    void set_vop3p(Instruction* vop3p_instr)
386    {
387       add_label(label_vop3p);
388       instr = vop3p_instr;
389    }
390 
is_vop3paco::__anonb463357b0111::ssa_info391    bool is_vop3p() { return label & label_vop3p; }
392 
set_fcanonicalizeaco::__anonb463357b0111::ssa_info393    void set_fcanonicalize(Temp tmp)
394    {
395       add_label(label_fcanonicalize);
396       temp = tmp;
397    }
398 
is_fcanonicalizeaco::__anonb463357b0111::ssa_info399    bool is_fcanonicalize() { return label & label_fcanonicalize; }
400 
set_canonicalizedaco::__anonb463357b0111::ssa_info401    void set_canonicalized() { add_label(label_canonicalized); }
402 
is_canonicalizedaco::__anonb463357b0111::ssa_info403    bool is_canonicalized() { return label & label_canonicalized; }
404 
set_f2f32aco::__anonb463357b0111::ssa_info405    void set_f2f32(Instruction* cvt)
406    {
407       add_label(label_f2f32);
408       instr = cvt;
409    }
410 
is_f2f32aco::__anonb463357b0111::ssa_info411    bool is_f2f32() { return label & label_f2f32; }
412 
set_extractaco::__anonb463357b0111::ssa_info413    void set_extract(Instruction* extract)
414    {
415       add_label(label_extract);
416       instr = extract;
417    }
418 
is_extractaco::__anonb463357b0111::ssa_info419    bool is_extract() { return label & label_extract; }
420 
set_insertaco::__anonb463357b0111::ssa_info421    void set_insert(Instruction* insert)
422    {
423       if (label & temp_labels)
424          return;
425       add_label(label_insert);
426       instr = insert;
427    }
428 
is_insertaco::__anonb463357b0111::ssa_info429    bool is_insert() { return label & label_insert; }
430 
set_dpp16aco::__anonb463357b0111::ssa_info431    void set_dpp16(Instruction* mov)
432    {
433       add_label(label_dpp16);
434       instr = mov;
435    }
436 
set_dpp8aco::__anonb463357b0111::ssa_info437    void set_dpp8(Instruction* mov)
438    {
439       add_label(label_dpp8);
440       instr = mov;
441    }
442 
is_dppaco::__anonb463357b0111::ssa_info443    bool is_dpp() { return label & (label_dpp16 | label_dpp8); }
is_dpp16aco::__anonb463357b0111::ssa_info444    bool is_dpp16() { return label & label_dpp16; }
is_dpp8aco::__anonb463357b0111::ssa_info445    bool is_dpp8() { return label & label_dpp8; }
446 
set_splitaco::__anonb463357b0111::ssa_info447    void set_split(Instruction* split)
448    {
449       add_label(label_split);
450       instr = split;
451    }
452 
is_splitaco::__anonb463357b0111::ssa_info453    bool is_split() { return label & label_split; }
454 };
455 
456 struct opt_ctx {
457    Program* program;
458    float_mode fp_mode;
459    std::vector<aco_ptr<Instruction>> instructions;
460    std::vector<ssa_info> info;
461    std::pair<uint32_t, Temp> last_literal;
462    std::vector<mad_info> mad_infos;
463    std::vector<uint16_t> uses;
464 };
465 
466 bool
can_use_VOP3(opt_ctx & ctx,const aco_ptr<Instruction> & instr)467 can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
468 {
469    if (instr->isVOP3())
470       return true;
471 
472    if (instr->isVOP3P() || instr->isVINTERP_INREG())
473       return false;
474 
475    if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->gfx_level < GFX10)
476       return false;
477 
478    if (instr->isSDWA())
479       return false;
480 
481    if (instr->isDPP() && ctx.program->gfx_level < GFX11)
482       return false;
483 
484    return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
485           instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
486           instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
487           instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
488           instr->opcode != aco_opcode::v_permlane64_b32 &&
489           instr->opcode != aco_opcode::v_readlane_b32 &&
490           instr->opcode != aco_opcode::v_writelane_b32 &&
491           instr->opcode != aco_opcode::v_readfirstlane_b32;
492 }
493 
494 bool
pseudo_propagate_temp(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp temp,unsigned index)495 pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
496 {
497    if (instr->definitions.empty())
498       return false;
499 
500    const bool vgpr =
501       instr->opcode == aco_opcode::p_as_uniform ||
502       std::all_of(instr->definitions.begin(), instr->definitions.end(),
503                   [](const Definition& def) { return def.regClass().type() == RegType::vgpr; });
504 
505    /* don't propagate VGPRs into SGPR instructions */
506    if (temp.type() == RegType::vgpr && !vgpr)
507       return false;
508 
509    bool can_accept_sgpr =
510       ctx.program->gfx_level >= GFX9 ||
511       std::none_of(instr->definitions.begin(), instr->definitions.end(),
512                    [](const Definition& def) { return def.regClass().is_subdword(); });
513 
514    switch (instr->opcode) {
515    case aco_opcode::p_phi:
516    case aco_opcode::p_linear_phi:
517    case aco_opcode::p_parallelcopy:
518    case aco_opcode::p_create_vector:
519    case aco_opcode::p_start_linear_vgpr:
520       if (temp.bytes() != instr->operands[index].bytes())
521          return false;
522       break;
523    case aco_opcode::p_extract_vector:
524    case aco_opcode::p_extract:
525       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
526          return false;
527       break;
528    case aco_opcode::p_split_vector: {
529       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
530          return false;
531       /* don't increase the vector size */
532       if (temp.bytes() > instr->operands[index].bytes())
533          return false;
534       /* We can decrease the vector size as smaller temporaries are only
535        * propagated by p_as_uniform instructions.
536        * If this propagation leads to invalid IR or hits the assertion below,
537        * it means that some undefined bytes within a dword are begin accessed
538        * and a bug in instruction_selection is likely. */
539       int decrease = instr->operands[index].bytes() - temp.bytes();
540       while (decrease > 0) {
541          decrease -= instr->definitions.back().bytes();
542          instr->definitions.pop_back();
543       }
544       assert(decrease == 0);
545       break;
546    }
547    case aco_opcode::p_as_uniform:
548       if (temp.regClass() == instr->definitions[0].regClass())
549          instr->opcode = aco_opcode::p_parallelcopy;
550       break;
551    default: return false;
552    }
553 
554    instr->operands[index].setTemp(temp);
555    return true;
556 }
557 
558 /* This expects the DPP modifier to be removed. */
559 bool
can_apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)560 can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
561 {
562    assert(instr->isVALU());
563    if (instr->isSDWA() && ctx.program->gfx_level < GFX9)
564       return false;
565    return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
566           instr->opcode != aco_opcode::v_readlane_b32 &&
567           instr->opcode != aco_opcode::v_readlane_b32_e64 &&
568           instr->opcode != aco_opcode::v_writelane_b32 &&
569           instr->opcode != aco_opcode::v_writelane_b32_e64 &&
570           instr->opcode != aco_opcode::v_permlane16_b32 &&
571           instr->opcode != aco_opcode::v_permlanex16_b32 &&
572           instr->opcode != aco_opcode::v_permlane64_b32 &&
573           instr->opcode != aco_opcode::v_interp_p1_f32 &&
574           instr->opcode != aco_opcode::v_interp_p2_f32 &&
575           instr->opcode != aco_opcode::v_interp_mov_f32 &&
576           instr->opcode != aco_opcode::v_interp_p1ll_f16 &&
577           instr->opcode != aco_opcode::v_interp_p1lv_f16 &&
578           instr->opcode != aco_opcode::v_interp_p2_legacy_f16 &&
579           instr->opcode != aco_opcode::v_interp_p2_f16 &&
580           instr->opcode != aco_opcode::v_interp_p2_hi_f16 &&
581           instr->opcode != aco_opcode::v_interp_p10_f32_inreg &&
582           instr->opcode != aco_opcode::v_interp_p2_f32_inreg &&
583           instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg &&
584           instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg &&
585           instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg &&
586           instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg &&
587           instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 &&
588           instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 &&
589           instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
590           instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
591           instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
592           instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4;
593 }
594 
595 /* only covers special cases */
596 bool
alu_can_accept_constant(const aco_ptr<Instruction> & instr,unsigned operand)597 alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand)
598 {
599    /* Fixed operands can't accept constants because we need them
600     * to be in their fixed register.
601     */
602    assert(instr->operands.size() > operand);
603    if (instr->operands[operand].isFixed())
604       return false;
605 
606    /* SOPP instructions can't use constants. */
607    if (instr->isSOPP())
608       return false;
609 
610    switch (instr->opcode) {
611    case aco_opcode::v_s_exp_f16:
612    case aco_opcode::v_s_log_f16:
613    case aco_opcode::v_s_rcp_f16:
614    case aco_opcode::v_s_rsq_f16:
615    case aco_opcode::v_s_sqrt_f16:
616       /* These can't use inline constants on GFX12 but can use literals. We don't bother since they
617        * should be constant folded anyway. */
618       return false;
619    case aco_opcode::s_fmac_f16:
620    case aco_opcode::s_fmac_f32:
621    case aco_opcode::v_mac_f32:
622    case aco_opcode::v_writelane_b32:
623    case aco_opcode::v_writelane_b32_e64:
624    case aco_opcode::v_cndmask_b32: return operand != 2;
625    case aco_opcode::s_addk_i32:
626    case aco_opcode::s_mulk_i32:
627    case aco_opcode::p_extract_vector:
628    case aco_opcode::p_split_vector:
629    case aco_opcode::v_readlane_b32:
630    case aco_opcode::v_readlane_b32_e64:
631    case aco_opcode::v_readfirstlane_b32:
632    case aco_opcode::p_extract:
633    case aco_opcode::p_insert: return operand != 0;
634    case aco_opcode::p_bpermute_readlane:
635    case aco_opcode::p_bpermute_shared_vgpr:
636    case aco_opcode::p_bpermute_permlane:
637    case aco_opcode::p_interp_gfx11:
638    case aco_opcode::p_dual_src_export_gfx11:
639    case aco_opcode::v_interp_p1_f32:
640    case aco_opcode::v_interp_p2_f32:
641    case aco_opcode::v_interp_mov_f32:
642    case aco_opcode::v_interp_p1ll_f16:
643    case aco_opcode::v_interp_p1lv_f16:
644    case aco_opcode::v_interp_p2_legacy_f16:
645    case aco_opcode::v_interp_p10_f32_inreg:
646    case aco_opcode::v_interp_p2_f32_inreg:
647    case aco_opcode::v_interp_p10_f16_f32_inreg:
648    case aco_opcode::v_interp_p2_f16_f32_inreg:
649    case aco_opcode::v_interp_p10_rtz_f16_f32_inreg:
650    case aco_opcode::v_interp_p2_rtz_f16_f32_inreg:
651    case aco_opcode::v_wmma_f32_16x16x16_f16:
652    case aco_opcode::v_wmma_f32_16x16x16_bf16:
653    case aco_opcode::v_wmma_f16_16x16x16_f16:
654    case aco_opcode::v_wmma_bf16_16x16x16_bf16:
655    case aco_opcode::v_wmma_i32_16x16x16_iu8:
656    case aco_opcode::v_wmma_i32_16x16x16_iu4: return false;
657    default: return true;
658    }
659 }
660 
661 bool
valu_can_accept_vgpr(aco_ptr<Instruction> & instr,unsigned operand)662 valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
663 {
664    if (instr->opcode == aco_opcode::v_writelane_b32 ||
665        instr->opcode == aco_opcode::v_writelane_b32_e64)
666       return operand == 2;
667    if (instr->opcode == aco_opcode::v_permlane16_b32 ||
668        instr->opcode == aco_opcode::v_permlanex16_b32 ||
669        instr->opcode == aco_opcode::v_readlane_b32 ||
670        instr->opcode == aco_opcode::v_readlane_b32_e64)
671       return operand == 0;
672    return instr_info.classes[(int)instr->opcode] != instr_class::valu_pseudo_scalar_trans;
673 }
674 
675 /* check constant bus and literal limitations */
676 bool
check_vop3_operands(opt_ctx & ctx,unsigned num_operands,Operand * operands)677 check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand* operands)
678 {
679    int limit = ctx.program->gfx_level >= GFX10 ? 2 : 1;
680    Operand literal32(s1);
681    Operand literal64(s2);
682    unsigned num_sgprs = 0;
683    unsigned sgpr[] = {0, 0};
684 
685    for (unsigned i = 0; i < num_operands; i++) {
686       Operand op = operands[i];
687 
688       if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
689          /* two reads of the same SGPR count as 1 to the limit */
690          if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
691             if (num_sgprs < 2)
692                sgpr[num_sgprs++] = op.tempId();
693             limit--;
694             if (limit < 0)
695                return false;
696          }
697       } else if (op.isLiteral()) {
698          if (ctx.program->gfx_level < GFX10)
699             return false;
700 
701          if (!literal32.isUndefined() && literal32.constantValue() != op.constantValue())
702             return false;
703          if (!literal64.isUndefined() && literal64.constantValue() != op.constantValue())
704             return false;
705 
706          /* Any number of 32-bit literals counts as only 1 to the limit. Same
707           * (but separately) for 64-bit literals. */
708          if (op.size() == 1 && literal32.isUndefined()) {
709             limit--;
710             literal32 = op;
711          } else if (op.size() == 2 && literal64.isUndefined()) {
712             limit--;
713             literal64 = op;
714          }
715 
716          if (limit < 0)
717             return false;
718       }
719    }
720 
721    return true;
722 }
723 
724 bool
parse_base_offset(opt_ctx & ctx,Instruction * instr,unsigned op_index,Temp * base,uint32_t * offset,bool prevent_overflow)725 parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* base, uint32_t* offset,
726                   bool prevent_overflow)
727 {
728    Operand op = instr->operands[op_index];
729 
730    if (!op.isTemp())
731       return false;
732    Temp tmp = op.getTemp();
733    if (!ctx.info[tmp.id()].is_add_sub())
734       return false;
735 
736    Instruction* add_instr = ctx.info[tmp.id()].instr;
737 
738    unsigned mask = 0x3;
739    bool is_sub = false;
740    switch (add_instr->opcode) {
741    case aco_opcode::v_add_u32:
742    case aco_opcode::v_add_co_u32:
743    case aco_opcode::v_add_co_u32_e64:
744    case aco_opcode::s_add_i32:
745    case aco_opcode::s_add_u32: break;
746    case aco_opcode::v_sub_u32:
747    case aco_opcode::v_sub_i32:
748    case aco_opcode::v_sub_co_u32:
749    case aco_opcode::v_sub_co_u32_e64:
750    case aco_opcode::s_sub_u32:
751    case aco_opcode::s_sub_i32:
752       mask = 0x2;
753       is_sub = true;
754       break;
755    case aco_opcode::v_subrev_u32:
756    case aco_opcode::v_subrev_co_u32:
757    case aco_opcode::v_subrev_co_u32_e64:
758       mask = 0x1;
759       is_sub = true;
760       break;
761    default: return false;
762    }
763    if (prevent_overflow && !add_instr->definitions[0].isNUW())
764       return false;
765 
766    if (add_instr->usesModifiers())
767       return false;
768 
769    u_foreach_bit (i, mask) {
770       if (add_instr->operands[i].isConstant()) {
771          *offset = add_instr->operands[i].constantValue() * (uint32_t)(is_sub ? -1 : 1);
772       } else if (add_instr->operands[i].isTemp() &&
773                  ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
774          *offset = ctx.info[add_instr->operands[i].tempId()].val * (uint32_t)(is_sub ? -1 : 1);
775       } else {
776          continue;
777       }
778       if (!add_instr->operands[!i].isTemp())
779          continue;
780 
781       uint32_t offset2 = 0;
782       if (parse_base_offset(ctx, add_instr, !i, base, &offset2, prevent_overflow)) {
783          *offset += offset2;
784       } else {
785          *base = add_instr->operands[!i].getTemp();
786       }
787       return true;
788    }
789 
790    return false;
791 }
792 
793 void
skip_smem_offset_align(opt_ctx & ctx,SMEM_instruction * smem)794 skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem)
795 {
796    bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
797    if (soe && !smem->operands[1].isConstant())
798       return;
799    /* We don't need to check the constant offset because the address seems to be calculated with
800     * (offset&-4 + const_offset&-4), not (offset+const_offset)&-4.
801     */
802 
803    Operand& op = smem->operands[soe ? smem->operands.size() - 1 : 1];
804    if (!op.isTemp() || !ctx.info[op.tempId()].is_bitwise())
805       return;
806 
807    Instruction* bitwise_instr = ctx.info[op.tempId()].instr;
808    if (bitwise_instr->opcode != aco_opcode::s_and_b32)
809       return;
810 
811    if (bitwise_instr->operands[0].constantEquals(-4) &&
812        bitwise_instr->operands[1].isOfType(op.regClass().type()))
813       op.setTemp(bitwise_instr->operands[1].getTemp());
814    else if (bitwise_instr->operands[1].constantEquals(-4) &&
815             bitwise_instr->operands[0].isOfType(op.regClass().type()))
816       op.setTemp(bitwise_instr->operands[0].getTemp());
817 }
818 
819 void
smem_combine(opt_ctx & ctx,aco_ptr<Instruction> & instr)820 smem_combine(opt_ctx& ctx, aco_ptr<Instruction>& instr)
821 {
822    /* skip &-4 before offset additions: load((a + 16) & -4, 0) */
823    if (!instr->operands.empty())
824       skip_smem_offset_align(ctx, &instr->smem());
825 
826    /* propagate constants and combine additions */
827    if (!instr->operands.empty() && instr->operands[1].isTemp()) {
828       SMEM_instruction& smem = instr->smem();
829       ssa_info info = ctx.info[instr->operands[1].tempId()];
830 
831       Temp base;
832       uint32_t offset;
833       if (info.is_constant_or_literal(32) &&
834           ((ctx.program->gfx_level == GFX6 && info.val <= 0x3FF) ||
835            (ctx.program->gfx_level == GFX7 && info.val <= 0xFFFFFFFF) ||
836            (ctx.program->gfx_level >= GFX8 && info.val <= 0xFFFFF))) {
837          instr->operands[1] = Operand::c32(info.val);
838       } else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, true) &&
839                  base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->gfx_level >= GFX9 &&
840                  offset % 4u == 0) {
841          bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
842          if (soe) {
843             if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) &&
844                 ctx.info[smem.operands.back().tempId()].val == 0) {
845                smem.operands[1] = Operand::c32(offset);
846                smem.operands.back() = Operand(base);
847             }
848          } else {
849             Instruction* new_instr = create_instruction(
850                smem.opcode, Format::SMEM, smem.operands.size() + 1, smem.definitions.size());
851             new_instr->operands[0] = smem.operands[0];
852             new_instr->operands[1] = Operand::c32(offset);
853             if (smem.definitions.empty())
854                new_instr->operands[2] = smem.operands[2];
855             new_instr->operands.back() = Operand(base);
856             if (!smem.definitions.empty())
857                new_instr->definitions[0] = smem.definitions[0];
858             new_instr->smem().sync = smem.sync;
859             new_instr->smem().cache = smem.cache;
860             instr.reset(new_instr);
861          }
862       }
863    }
864 
865    /* skip &-4 after offset additions: load(a & -4, 16) */
866    if (!instr->operands.empty())
867       skip_smem_offset_align(ctx, &instr->smem());
868 }
869 
870 Operand
get_constant_op(opt_ctx & ctx,ssa_info info,uint32_t bits)871 get_constant_op(opt_ctx& ctx, ssa_info info, uint32_t bits)
872 {
873    if (bits == 64)
874       return Operand::c32_or_c64(info.val, true);
875    return Operand::get_const(ctx.program->gfx_level, info.val, bits / 8u);
876 }
877 
878 void
propagate_constants_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned i)879 propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned i)
880 {
881    if (!info.is_constant_or_literal(32))
882       return;
883 
884    assert(instr->operands[i].isTemp());
885    unsigned bits = get_operand_size(instr, i);
886    if (info.is_constant(bits)) {
887       instr->operands[i] = get_constant_op(ctx, info, bits);
888       return;
889    }
890 
891    /* The accumulation operand of dot product instructions ignores opsel. */
892    bool cannot_use_opsel =
893       (instr->opcode == aco_opcode::v_dot4_i32_i8 || instr->opcode == aco_opcode::v_dot2_i32_i16 ||
894        instr->opcode == aco_opcode::v_dot4_i32_iu8 || instr->opcode == aco_opcode::v_dot4_u32_u8 ||
895        instr->opcode == aco_opcode::v_dot2_u32_u16) &&
896       i == 2;
897    if (cannot_use_opsel)
898       return;
899 
900    /* try to fold inline constants */
901    VALU_instruction* vop3p = &instr->valu();
902    bool opsel_lo = vop3p->opsel_lo[i];
903    bool opsel_hi = vop3p->opsel_hi[i];
904 
905    Operand const_op[2];
906    bool const_opsel[2] = {false, false};
907    for (unsigned j = 0; j < 2; j++) {
908       if ((unsigned)opsel_lo != j && (unsigned)opsel_hi != j)
909          continue; /* this half is unused */
910 
911       uint16_t val = info.val >> (j ? 16 : 0);
912       Operand op = Operand::get_const(ctx.program->gfx_level, val, bits / 8u);
913       if (bits == 32 && op.isLiteral()) /* try sign extension */
914          op = Operand::get_const(ctx.program->gfx_level, val | 0xffff0000, 4);
915       if (bits == 32 && op.isLiteral()) { /* try shifting left */
916          op = Operand::get_const(ctx.program->gfx_level, val << 16, 4);
917          const_opsel[j] = true;
918       }
919       if (op.isLiteral())
920          return;
921       const_op[j] = op;
922    }
923 
924    Operand const_lo = const_op[0];
925    Operand const_hi = const_op[1];
926    bool const_lo_opsel = const_opsel[0];
927    bool const_hi_opsel = const_opsel[1];
928 
929    if (opsel_lo == opsel_hi) {
930       /* use the single 16bit value */
931       instr->operands[i] = opsel_lo ? const_hi : const_lo;
932 
933       /* opsel must point the same for both halves */
934       opsel_lo = opsel_lo ? const_hi_opsel : const_lo_opsel;
935       opsel_hi = opsel_lo;
936    } else if (const_lo == const_hi) {
937       /* both constants are the same */
938       instr->operands[i] = const_lo;
939 
940       /* opsel must point the same for both halves */
941       opsel_lo = const_lo_opsel;
942       opsel_hi = const_lo_opsel;
943    } else if (const_lo.constantValue16(const_lo_opsel) ==
944               const_hi.constantValue16(!const_hi_opsel)) {
945       instr->operands[i] = const_hi;
946 
947       /* redirect opsel selection */
948       opsel_lo = opsel_lo ? const_hi_opsel : !const_hi_opsel;
949       opsel_hi = opsel_hi ? const_hi_opsel : !const_hi_opsel;
950    } else if (const_hi.constantValue16(const_hi_opsel) ==
951               const_lo.constantValue16(!const_lo_opsel)) {
952       instr->operands[i] = const_lo;
953 
954       /* redirect opsel selection */
955       opsel_lo = opsel_lo ? !const_lo_opsel : const_lo_opsel;
956       opsel_hi = opsel_hi ? !const_lo_opsel : const_lo_opsel;
957    } else if (bits == 16 && const_lo.constantValue() == (const_hi.constantValue() ^ (1 << 15))) {
958       assert(const_lo_opsel == false && const_hi_opsel == false);
959 
960       /* const_lo == -const_hi */
961       if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
962          return;
963 
964       instr->operands[i] = Operand::c16(const_lo.constantValue() & 0x7FFF);
965       bool neg_lo = const_lo.constantValue() & (1 << 15);
966       vop3p->neg_lo[i] ^= opsel_lo ^ neg_lo;
967       vop3p->neg_hi[i] ^= opsel_hi ^ neg_lo;
968 
969       /* opsel must point to lo for both operands */
970       opsel_lo = false;
971       opsel_hi = false;
972    }
973 
974    vop3p->opsel_lo[i] = opsel_lo;
975    vop3p->opsel_hi[i] = opsel_hi;
976 }
977 
978 bool
fixed_to_exec(Operand op)979 fixed_to_exec(Operand op)
980 {
981    return op.isFixed() && op.physReg() == exec;
982 }
983 
984 SubdwordSel
parse_extract(Instruction * instr)985 parse_extract(Instruction* instr)
986 {
987    if (instr->opcode == aco_opcode::p_extract) {
988       unsigned size = instr->operands[2].constantValue() / 8;
989       unsigned offset = instr->operands[1].constantValue() * size;
990       bool sext = instr->operands[3].constantEquals(1);
991       return SubdwordSel(size, offset, sext);
992    } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) {
993       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
994    } else if (instr->opcode == aco_opcode::p_extract_vector) {
995       unsigned size = instr->definitions[0].bytes();
996       unsigned offset = instr->operands[1].constantValue() * size;
997       if (size <= 2)
998          return SubdwordSel(size, offset, false);
999    } else if (instr->opcode == aco_opcode::p_split_vector) {
1000       assert(instr->operands[0].bytes() == 4 && instr->definitions[1].bytes() == 2);
1001       return SubdwordSel(2, 2, false);
1002    }
1003 
1004    return SubdwordSel();
1005 }
1006 
1007 SubdwordSel
parse_insert(Instruction * instr)1008 parse_insert(Instruction* instr)
1009 {
1010    if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) &&
1011        instr->operands[1].constantEquals(0)) {
1012       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
1013    } else if (instr->opcode == aco_opcode::p_insert) {
1014       unsigned size = instr->operands[2].constantValue() / 8;
1015       unsigned offset = instr->operands[1].constantValue() * size;
1016       return SubdwordSel(size, offset, false);
1017    } else {
1018       return SubdwordSel();
1019    }
1020 }
1021 
1022 SubdwordSel
apply_extract_twice(SubdwordSel first,Temp first_dst,SubdwordSel second,Temp second_dst)1023 apply_extract_twice(SubdwordSel first, Temp first_dst, SubdwordSel second, Temp second_dst)
1024 {
1025    /* the outer offset must be within extracted range */
1026    if (second.offset() >= first.size())
1027       return SubdwordSel();
1028 
1029    /* don't remove the sign-extension when increasing the size further */
1030    if (second.size() > first.size() && first.sign_extend() &&
1031        !(second.sign_extend() ||
1032          (second.size() == first_dst.bytes() && second.size() == second_dst.bytes())))
1033       return SubdwordSel();
1034 
1035    unsigned size = std::min(first.size(), second.size());
1036    unsigned offset = first.offset() + second.offset();
1037    bool sign_extend = second.size() <= first.size() ? second.sign_extend() : first.sign_extend();
1038    return SubdwordSel(size, offset, sign_extend);
1039 }
1040 
1041 bool
can_apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1042 can_apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1043 {
1044    Temp tmp = info.instr->operands[0].getTemp();
1045    SubdwordSel sel = parse_extract(info.instr);
1046 
1047    if (!sel) {
1048       return false;
1049    } else if (sel.size() == instr->operands[idx].bytes() && sel.size() == tmp.bytes() &&
1050               tmp.type() == instr->operands[idx].regClass().type()) {
1051       assert(tmp.type() != RegType::sgpr); /* No sub-dword SGPR regclasses */
1052       return true;
1053    } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1054                instr->opcode == aco_opcode::v_cvt_f32_i32 ||
1055                instr->opcode == aco_opcode::v_cvt_f32_ubyte0) &&
1056               sel.size() == 1 && !sel.sign_extend() && !instr->usesModifiers()) {
1057       return true;
1058    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1059               sel.offset() == 0 && !instr->usesModifiers() &&
1060               ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1061                (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1062       return true;
1063    } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1064               !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1065               (instr->operands[!idx].is16bit() ||
1066                (instr->operands[!idx].isConstant() &&
1067                 instr->operands[!idx].constantValue() <= UINT16_MAX))) {
1068       return true;
1069    } else if (idx < 2 && can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1070               (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1071       if (instr->isSDWA()) {
1072          /* TODO: if we knew how many bytes this operand actually uses, we could have smaller
1073           * second_dst parameter and apply more sign-extended sels.
1074           */
1075          return apply_extract_twice(sel, instr->operands[idx].getTemp(), instr->sdwa().sel[idx],
1076                                     Temp(0, v1)) != SubdwordSel();
1077       }
1078       return true;
1079    } else if (instr->isVALU() && sel.size() == 2 && !instr->valu().opsel[idx] &&
1080               can_use_opsel(ctx.program->gfx_level, instr->opcode, idx)) {
1081       return true;
1082    } else if (instr->opcode == aco_opcode::s_pack_ll_b32_b16 && sel.size() == 2 &&
1083               (idx == 1 || ctx.program->gfx_level >= GFX11 || !sel.offset())) {
1084       return true;
1085    } else if (sel.size() == 2 && ((instr->opcode == aco_opcode::s_pack_lh_b32_b16 && idx == 0) ||
1086                                   (instr->opcode == aco_opcode::s_pack_hl_b32_b16 && idx == 1))) {
1087       return true;
1088    } else if (instr->opcode == aco_opcode::p_extract ||
1089               instr->opcode == aco_opcode::p_extract_vector) {
1090       if (ctx.program->gfx_level < GFX9 && !info.instr->operands[0].isOfType(RegType::vgpr) &&
1091           instr->definitions[0].regClass().is_subdword())
1092          return false;
1093 
1094       SubdwordSel instrSel = parse_extract(instr.get());
1095       return instrSel && apply_extract_twice(sel, instr->operands[idx].getTemp(), instrSel,
1096                                              instr->definitions[0].getTemp());
1097    }
1098 
1099    return false;
1100 }
1101 
1102 /* Combine an p_extract (or p_insert, in some cases) instruction with instr.
1103  * instr(p_extract(...)) -> instr()
1104  */
1105 void
apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1106 apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1107 {
1108    Temp tmp = info.instr->operands[0].getTemp();
1109    SubdwordSel sel = parse_extract(info.instr);
1110    assert(sel);
1111 
1112    instr->operands[idx].set16bit(false);
1113    instr->operands[idx].set24bit(false);
1114 
1115    ctx.info[tmp.id()].label &= ~label_insert;
1116 
1117    if (sel.size() == instr->operands[idx].bytes() && sel.size() == tmp.bytes() &&
1118        tmp.type() == instr->operands[idx].regClass().type()) {
1119       /* extract is a no-op */
1120    } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1121                instr->opcode == aco_opcode::v_cvt_f32_i32 ||
1122                instr->opcode == aco_opcode::v_cvt_f32_ubyte0) &&
1123               sel.size() == 1 && !sel.sign_extend() && !instr->usesModifiers()) {
1124       switch (sel.offset()) {
1125       case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break;
1126       case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break;
1127       case 2: instr->opcode = aco_opcode::v_cvt_f32_ubyte2; break;
1128       case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break;
1129       }
1130    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1131               sel.offset() == 0 && !instr->usesModifiers() &&
1132               ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1133                (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1134       /* The undesirable upper bits are already shifted out. */
1135       if (!instr->isVOP3() && !info.instr->operands[0].isOfType(RegType::vgpr))
1136          instr->format = asVOP3(instr->format);
1137       return;
1138    } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1139               !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1140               (instr->operands[!idx].is16bit() ||
1141                instr->operands[!idx].constantValue() <= UINT16_MAX)) {
1142       Instruction* mad = create_instruction(aco_opcode::v_mad_u32_u16, Format::VOP3, 3, 1);
1143       mad->definitions[0] = instr->definitions[0];
1144       mad->operands[0] = instr->operands[0];
1145       mad->operands[1] = instr->operands[1];
1146       mad->operands[2] = Operand::zero();
1147       mad->valu().opsel[idx] = sel.offset();
1148       mad->pass_flags = instr->pass_flags;
1149       instr.reset(mad);
1150    } else if (can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1151               (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1152       if (instr->isSDWA()) {
1153          instr->sdwa().sel[idx] = apply_extract_twice(sel, instr->operands[idx].getTemp(),
1154                                                       instr->sdwa().sel[idx], Temp(0, v1));
1155       } else {
1156          convert_to_SDWA(ctx.program->gfx_level, instr);
1157          instr->sdwa().sel[idx] = sel;
1158       }
1159    } else if (instr->isVALU()) {
1160       if (sel.offset()) {
1161          instr->valu().opsel[idx] = true;
1162 
1163          /* VOP12C cannot use opsel with SGPRs. */
1164          if (!instr->isVOP3() && !instr->isVINTERP_INREG() &&
1165              !info.instr->operands[0].isOfType(RegType::vgpr))
1166             instr->format = asVOP3(instr->format);
1167       }
1168    } else if (instr->opcode == aco_opcode::s_pack_ll_b32_b16) {
1169       if (sel.offset())
1170          instr->opcode = idx ? aco_opcode::s_pack_lh_b32_b16 : aco_opcode::s_pack_hl_b32_b16;
1171    } else if (instr->opcode == aco_opcode::s_pack_lh_b32_b16 ||
1172               instr->opcode == aco_opcode::s_pack_hl_b32_b16) {
1173       if (sel.offset())
1174          instr->opcode = aco_opcode::s_pack_hh_b32_b16;
1175    } else if (instr->opcode == aco_opcode::p_extract) {
1176       SubdwordSel instr_sel = parse_extract(instr.get());
1177       SubdwordSel new_sel = apply_extract_twice(sel, instr->operands[idx].getTemp(), instr_sel,
1178                                                 instr->definitions[0].getTemp());
1179       assert(new_sel.size() <= 2);
1180 
1181       instr->operands[1] = Operand::c32(new_sel.offset() / new_sel.size());
1182       instr->operands[2] = Operand::c32(new_sel.size() * 8u);
1183       instr->operands[3] = Operand::c32(new_sel.sign_extend());
1184       return;
1185    } else if (instr->opcode == aco_opcode::p_extract_vector) {
1186       SubdwordSel instrSel = parse_extract(instr.get());
1187       SubdwordSel new_sel = apply_extract_twice(sel, instr->operands[idx].getTemp(), instrSel,
1188                                                 instr->definitions[0].getTemp());
1189       assert(new_sel.size() <= 2);
1190 
1191       if (new_sel.size() == instr->definitions[0].bytes()) {
1192          instr->operands[1] = Operand::c32(new_sel.offset() / instr->definitions[0].bytes());
1193          return;
1194       } else {
1195          /* parse_extract() only succeeds with p_extract_vector for VGPR definitions because there
1196           * are no sub-dword SGPR regclasses. */
1197          assert(instr->definitions[0].regClass().type() != RegType::sgpr);
1198 
1199          Instruction* ext = create_instruction(aco_opcode::p_extract, Format::PSEUDO, 4, 1);
1200          ext->definitions[0] = instr->definitions[0];
1201          ext->operands[0] = instr->operands[0];
1202          ext->operands[1] = Operand::c32(new_sel.offset() / new_sel.size());
1203          ext->operands[2] = Operand::c32(new_sel.size() * 8u);
1204          ext->operands[3] = Operand::c32(new_sel.sign_extend());
1205          ext->pass_flags = instr->pass_flags;
1206          instr.reset(ext);
1207       }
1208    }
1209 
1210    /* These are the only labels worth keeping at the moment. */
1211    for (Definition& def : instr->definitions) {
1212       ctx.info[def.tempId()].label &=
1213          (label_mul | label_minmax | label_usedef | label_vopc | label_f2f32 | instr_mod_labels);
1214       if (ctx.info[def.tempId()].label & instr_usedef_labels)
1215          ctx.info[def.tempId()].instr = instr.get();
1216    }
1217 }
1218 
1219 void
check_sdwa_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr)1220 check_sdwa_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1221 {
1222    for (unsigned i = 0; i < instr->operands.size(); i++) {
1223       Operand op = instr->operands[i];
1224       if (!op.isTemp())
1225          continue;
1226       ssa_info& info = ctx.info[op.tempId()];
1227       if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
1228                                 op.getTemp().type() == RegType::sgpr)) {
1229          if (!can_apply_extract(ctx, instr, i, info))
1230             info.label &= ~label_extract;
1231       }
1232    }
1233 }
1234 
1235 bool
does_fp_op_flush_denorms(opt_ctx & ctx,aco_opcode op)1236 does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
1237 {
1238    switch (op) {
1239    case aco_opcode::v_min_f32:
1240    case aco_opcode::v_max_f32:
1241    case aco_opcode::v_med3_f32:
1242    case aco_opcode::v_min3_f32:
1243    case aco_opcode::v_max3_f32:
1244    case aco_opcode::v_min_f16:
1245    case aco_opcode::v_max_f16: return ctx.program->gfx_level > GFX8;
1246    case aco_opcode::v_cndmask_b32:
1247    case aco_opcode::v_cndmask_b16:
1248    case aco_opcode::v_mov_b32:
1249    case aco_opcode::v_mov_b16: return false;
1250    default: return true;
1251    }
1252 }
1253 
1254 bool
can_eliminate_fcanonicalize(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp tmp,unsigned idx)1255 can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp, unsigned idx)
1256 {
1257    float_mode* fp = &ctx.fp_mode;
1258    if (ctx.info[tmp.id()].is_canonicalized() ||
1259        (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1260       return true;
1261 
1262    aco_opcode op = instr->opcode;
1263    return can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, idx) &&
1264           does_fp_op_flush_denorms(ctx, op);
1265 }
1266 
1267 bool
can_eliminate_and_exec(opt_ctx & ctx,Temp tmp,unsigned pass_flags)1268 can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags)
1269 {
1270    if (ctx.info[tmp.id()].is_vopc()) {
1271       Instruction* vopc_instr = ctx.info[tmp.id()].instr;
1272       /* Remove superfluous s_and when the VOPC instruction uses the same exec and thus
1273        * already produces the same result */
1274       return vopc_instr->pass_flags == pass_flags;
1275    }
1276    if (ctx.info[tmp.id()].is_bitwise()) {
1277       Instruction* instr = ctx.info[tmp.id()].instr;
1278       if (instr->operands.size() != 2 || instr->pass_flags != pass_flags)
1279          return false;
1280       if (!(instr->operands[0].isTemp() && instr->operands[1].isTemp()))
1281          return false;
1282       if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
1283          return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) ||
1284                 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1285       } else {
1286          return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) &&
1287                 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1288       }
1289    }
1290    return false;
1291 }
1292 
1293 bool
is_copy_label(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned idx)1294 is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned idx)
1295 {
1296    return info.is_temp() ||
1297           (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp, idx));
1298 }
1299 
1300 bool
is_op_canonicalized(opt_ctx & ctx,Operand op)1301 is_op_canonicalized(opt_ctx& ctx, Operand op)
1302 {
1303    float_mode* fp = &ctx.fp_mode;
1304    if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) ||
1305        (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1306       return true;
1307 
1308    if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
1309       uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
1310       if (op.bytes() == 2)
1311          return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
1312       else if (op.bytes() == 4)
1313          return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff;
1314    }
1315    return false;
1316 }
1317 
1318 bool
is_scratch_offset_valid(opt_ctx & ctx,Instruction * instr,int64_t offset0,int64_t offset1)1319 is_scratch_offset_valid(opt_ctx& ctx, Instruction* instr, int64_t offset0, int64_t offset1)
1320 {
1321    bool negative_unaligned_scratch_offset_bug = ctx.program->gfx_level == GFX10;
1322    int32_t min = ctx.program->dev.scratch_global_offset_min;
1323    int32_t max = ctx.program->dev.scratch_global_offset_max;
1324 
1325    int64_t offset = offset0 + offset1;
1326 
1327    bool has_vgpr_offset = instr && !instr->operands[0].isUndefined();
1328    if (negative_unaligned_scratch_offset_bug && has_vgpr_offset && offset < 0 && offset % 4)
1329       return false;
1330 
1331    return offset >= min && offset <= max;
1332 }
1333 
1334 bool
detect_clamp(Instruction * instr,unsigned * clamped_idx)1335 detect_clamp(Instruction* instr, unsigned* clamped_idx)
1336 {
1337    VALU_instruction& valu = instr->valu();
1338    if (valu.omod != 0 || valu.opsel != 0)
1339       return false;
1340 
1341    unsigned idx = 0;
1342    bool found_zero = false, found_one = false;
1343    bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16;
1344    for (unsigned i = 0; i < 3; i++) {
1345       if (!valu.neg[i] && instr->operands[i].constantEquals(0))
1346          found_zero = true;
1347       else if (!valu.neg[i] &&
1348                instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */
1349          found_one = true;
1350       else
1351          idx = i;
1352    }
1353    if (found_zero && found_one && instr->operands[idx].isTemp()) {
1354       *clamped_idx = idx;
1355       return true;
1356    } else {
1357       return false;
1358    }
1359 }
1360 
1361 void
label_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)1362 label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1363 {
1364    if (instr->isSMEM())
1365       smem_combine(ctx, instr);
1366 
1367    for (unsigned i = 0; i < instr->operands.size(); i++) {
1368       if (!instr->operands[i].isTemp())
1369          continue;
1370 
1371       ssa_info info = ctx.info[instr->operands[i].tempId()];
1372       /* propagate reg->reg of same type */
1373       while (info.is_temp() && info.temp.regClass() == instr->operands[i].getTemp().regClass()) {
1374          instr->operands[i].setTemp(ctx.info[instr->operands[i].tempId()].temp);
1375          info = ctx.info[info.temp.id()];
1376       }
1377 
1378       /* PSEUDO: propagate temporaries */
1379       if (instr->isPseudo()) {
1380          while (info.is_temp()) {
1381             pseudo_propagate_temp(ctx, instr, info.temp, i);
1382             info = ctx.info[info.temp.id()];
1383          }
1384       }
1385 
1386       /* SALU / PSEUDO: propagate inline constants */
1387       if (instr->isSALU() || instr->isPseudo()) {
1388          unsigned bits = get_operand_size(instr, i);
1389          if ((info.is_constant(bits) || (info.is_literal(bits) && instr->isPseudo())) &&
1390              alu_can_accept_constant(instr, i)) {
1391             instr->operands[i] = get_constant_op(ctx, info, bits);
1392             continue;
1393          }
1394       }
1395 
1396       /* VALU: propagate neg, abs & inline constants */
1397       else if (instr->isVALU()) {
1398          if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::vgpr &&
1399              valu_can_accept_vgpr(instr, i)) {
1400             instr->operands[i].setTemp(info.temp);
1401             info = ctx.info[info.temp.id()];
1402          }
1403          /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
1404          if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) &&
1405              instr->operands.size() == 1) {
1406             instr->format = withoutDPP(instr->format);
1407             instr->operands[i].setTemp(info.temp);
1408             info = ctx.info[info.temp.id()];
1409          }
1410 
1411          /* for instructions other than v_cndmask_b32, the size of the instruction should match the
1412           * operand size */
1413          bool can_use_mod =
1414             instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
1415          can_use_mod &= can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i);
1416 
1417          bool packed_math = instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
1418                             instr->opcode != aco_opcode::v_fma_mixlo_f16 &&
1419                             instr->opcode != aco_opcode::v_fma_mixhi_f16;
1420 
1421          if (instr->isSDWA())
1422             can_use_mod &= instr->sdwa().sel[i].size() == 4;
1423          else if (instr->isVOP3P())
1424             can_use_mod &= !packed_math || !info.is_abs();
1425          else if (instr->isVINTERP_INREG())
1426             can_use_mod &= !info.is_abs();
1427          else
1428             can_use_mod &= instr->isDPP16() || can_use_VOP3(ctx, instr);
1429 
1430          unsigned bits = get_operand_size(instr, i);
1431          can_use_mod &= instr->operands[i].bytes() * 8 == bits;
1432 
1433          if (info.is_neg() && can_use_mod &&
1434              can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1435             instr->operands[i].setTemp(info.temp);
1436             if (!packed_math && instr->valu().abs[i]) {
1437                /* fabs(fneg(a)) -> fabs(a) */
1438             } else if (instr->opcode == aco_opcode::v_add_f32) {
1439                instr->opcode = i ? aco_opcode::v_sub_f32 : aco_opcode::v_subrev_f32;
1440             } else if (instr->opcode == aco_opcode::v_add_f16) {
1441                instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
1442             } else if (packed_math) {
1443                /* Bit size compat should ensure this. */
1444                assert(!instr->valu().opsel_lo[i] && !instr->valu().opsel_hi[i]);
1445                instr->valu().neg_lo[i] ^= true;
1446                instr->valu().neg_hi[i] ^= true;
1447             } else {
1448                if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1449                   instr->format = asVOP3(instr->format);
1450                instr->valu().neg[i] ^= true;
1451             }
1452          }
1453          if (info.is_abs() && can_use_mod &&
1454              can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1455             if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1456                instr->format = asVOP3(instr->format);
1457             instr->operands[i] = Operand(info.temp);
1458             instr->valu().abs[i] = true;
1459             continue;
1460          }
1461 
1462          if (instr->isVOP3P()) {
1463             propagate_constants_vop3p(ctx, instr, info, i);
1464             continue;
1465          }
1466 
1467          if (info.is_constant(bits) && alu_can_accept_constant(instr, i) &&
1468              (!instr->isSDWA() || ctx.program->gfx_level >= GFX9) && (!instr->isDPP() || i != 1)) {
1469             Operand op = get_constant_op(ctx, info, bits);
1470             if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
1471                 instr->opcode == aco_opcode::v_writelane_b32) {
1472                instr->format = withoutDPP(instr->format);
1473                instr->operands[i] = op;
1474                continue;
1475             } else if (!instr->isVOP3() && can_swap_operands(instr, &instr->opcode)) {
1476                instr->operands[i] = op;
1477                instr->valu().swapOperands(0, i);
1478                continue;
1479             } else if (can_use_VOP3(ctx, instr)) {
1480                instr->format = asVOP3(instr->format);
1481                instr->operands[i] = op;
1482                continue;
1483             }
1484          }
1485       }
1486 
1487       /* MUBUF: propagate constants and combine additions */
1488       else if (instr->isMUBUF()) {
1489          MUBUF_instruction& mubuf = instr->mubuf();
1490          Temp base;
1491          uint32_t offset;
1492          while (info.is_temp())
1493             info = ctx.info[info.temp.id()];
1494 
1495          bool swizzled = ctx.program->gfx_level >= GFX12 ? mubuf.cache.gfx12.swizzled
1496                                                          : (mubuf.cache.value & ac_swizzled);
1497          /* According to AMDGPUDAGToDAGISel::SelectMUBUFScratchOffen(), vaddr
1498           * overflow for scratch accesses works only on GFX9+ and saddr overflow
1499           * never works. Since swizzling is the only thing that separates
1500           * scratch accesses and other accesses and swizzling changing how
1501           * addressing works significantly, this probably applies to swizzled
1502           * MUBUF accesses. */
1503          bool vaddr_prevent_overflow = swizzled && ctx.program->gfx_level < GFX9;
1504 
1505          if (mubuf.offen && mubuf.idxen && i == 1 && info.is_vec() &&
1506              info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1507              info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1508              mubuf.offset + info.instr->operands[1].constantValue() < 4096) {
1509             instr->operands[1] = info.instr->operands[0];
1510             mubuf.offset += info.instr->operands[1].constantValue();
1511             mubuf.offen = false;
1512             continue;
1513          } else if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
1514                     mubuf.offset + info.val < 4096) {
1515             assert(!mubuf.idxen);
1516             instr->operands[1] = Operand(v1);
1517             mubuf.offset += info.val;
1518             mubuf.offen = false;
1519             continue;
1520          } else if (i == 2 && info.is_constant_or_literal(32) && mubuf.offset + info.val < 4096) {
1521             instr->operands[2] = Operand::c32(0);
1522             mubuf.offset += info.val;
1523             continue;
1524          } else if (mubuf.offen && i == 1 &&
1525                     parse_base_offset(ctx, instr.get(), i, &base, &offset,
1526                                       vaddr_prevent_overflow) &&
1527                     base.regClass() == v1 && mubuf.offset + offset < 4096) {
1528             assert(!mubuf.idxen);
1529             instr->operands[1].setTemp(base);
1530             mubuf.offset += offset;
1531             continue;
1532          } else if (i == 2 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1533                     base.regClass() == s1 && mubuf.offset + offset < 4096 && !swizzled) {
1534             instr->operands[i].setTemp(base);
1535             mubuf.offset += offset;
1536             continue;
1537          }
1538       }
1539 
1540       else if (instr->isMTBUF()) {
1541          MTBUF_instruction& mtbuf = instr->mtbuf();
1542          while (info.is_temp())
1543             info = ctx.info[info.temp.id()];
1544 
1545          if (mtbuf.offen && mtbuf.idxen && i == 1 && info.is_vec() &&
1546              info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1547              info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1548              mtbuf.offset + info.instr->operands[1].constantValue() < 4096) {
1549             instr->operands[1] = info.instr->operands[0];
1550             mtbuf.offset += info.instr->operands[1].constantValue();
1551             mtbuf.offen = false;
1552             continue;
1553          }
1554       }
1555 
1556       /* SCRATCH: propagate constants and combine additions */
1557       else if (instr->isScratch()) {
1558          FLAT_instruction& scratch = instr->scratch();
1559          Temp base;
1560          uint32_t offset;
1561          while (info.is_temp())
1562             info = ctx.info[info.temp.id()];
1563 
1564          /* The hardware probably does: 'scratch_base + u2u64(saddr) + i2i64(offset)'. This means
1565           * we can't combine the addition if the unsigned addition overflows and offset is
1566           * positive. In theory, there is also issues if
1567           * 'ilt(offset, 0) && ige(saddr, 0) && ilt(saddr + offset, 0)', but that just
1568           * replaces an already out-of-bounds access with a larger one since 'saddr + offset'
1569           * would be larger than INT32_MAX.
1570           */
1571          if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1572              base.regClass() == instr->operands[i].regClass() &&
1573              is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1574             instr->operands[i].setTemp(base);
1575             scratch.offset += (int32_t)offset;
1576             continue;
1577          } else if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1578                     base.regClass() == instr->operands[i].regClass() && (int32_t)offset < 0 &&
1579                     is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1580             instr->operands[i].setTemp(base);
1581             scratch.offset += (int32_t)offset;
1582             continue;
1583          } else if (i <= 1 && info.is_constant_or_literal(32) &&
1584                     ctx.program->gfx_level >= GFX10_3 &&
1585                     is_scratch_offset_valid(ctx, NULL, scratch.offset, (int32_t)info.val)) {
1586             /* GFX10.3+ can disable both SADDR and ADDR. */
1587             instr->operands[i] = Operand(instr->operands[i].regClass());
1588             scratch.offset += (int32_t)info.val;
1589             continue;
1590          }
1591       }
1592 
1593       /* DS: combine additions */
1594       else if (instr->isDS()) {
1595 
1596          DS_instruction& ds = instr->ds();
1597          Temp base;
1598          uint32_t offset;
1599          bool has_usable_ds_offset = ctx.program->gfx_level >= GFX7;
1600          if (has_usable_ds_offset && i == 0 &&
1601              parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1602              base.regClass() == instr->operands[i].regClass() &&
1603              instr->opcode != aco_opcode::ds_swizzle_b32) {
1604             if (instr->opcode == aco_opcode::ds_write2_b32 ||
1605                 instr->opcode == aco_opcode::ds_read2_b32 ||
1606                 instr->opcode == aco_opcode::ds_write2_b64 ||
1607                 instr->opcode == aco_opcode::ds_read2_b64 ||
1608                 instr->opcode == aco_opcode::ds_write2st64_b32 ||
1609                 instr->opcode == aco_opcode::ds_read2st64_b32 ||
1610                 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1611                 instr->opcode == aco_opcode::ds_read2st64_b64) {
1612                bool is64bit = instr->opcode == aco_opcode::ds_write2_b64 ||
1613                               instr->opcode == aco_opcode::ds_read2_b64 ||
1614                               instr->opcode == aco_opcode::ds_write2st64_b64 ||
1615                               instr->opcode == aco_opcode::ds_read2st64_b64;
1616                bool st64 = instr->opcode == aco_opcode::ds_write2st64_b32 ||
1617                            instr->opcode == aco_opcode::ds_read2st64_b32 ||
1618                            instr->opcode == aco_opcode::ds_write2st64_b64 ||
1619                            instr->opcode == aco_opcode::ds_read2st64_b64;
1620                unsigned shifts = (is64bit ? 3 : 2) + (st64 ? 6 : 0);
1621                unsigned mask = BITFIELD_MASK(shifts);
1622 
1623                if ((offset & mask) == 0 && ds.offset0 + (offset >> shifts) <= 255 &&
1624                    ds.offset1 + (offset >> shifts) <= 255) {
1625                   instr->operands[i].setTemp(base);
1626                   ds.offset0 += offset >> shifts;
1627                   ds.offset1 += offset >> shifts;
1628                }
1629             } else {
1630                if (ds.offset0 + offset <= 65535) {
1631                   instr->operands[i].setTemp(base);
1632                   ds.offset0 += offset;
1633                }
1634             }
1635          }
1636       }
1637 
1638       else if (instr->isBranch()) {
1639          if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1640             /* Flip the branch instruction to get rid of the scc_invert instruction */
1641             instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
1642                                                                      : aco_opcode::p_cbranch_z;
1643             instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
1644          }
1645       }
1646    }
1647 
1648    /* if this instruction doesn't define anything, return */
1649    if (instr->definitions.empty()) {
1650       check_sdwa_extract(ctx, instr);
1651       return;
1652    }
1653 
1654    if (instr->isVALU() || instr->isVINTRP()) {
1655       if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() ||
1656           instr->opcode == aco_opcode::v_cndmask_b32) {
1657          bool canonicalized = true;
1658          if (!does_fp_op_flush_denorms(ctx, instr->opcode)) {
1659             unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size();
1660             for (unsigned i = 0; canonicalized && (i < ops); i++)
1661                canonicalized = is_op_canonicalized(ctx, instr->operands[i]);
1662          }
1663          if (canonicalized)
1664             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1665       }
1666 
1667       if (instr->isVOPC()) {
1668          ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
1669          check_sdwa_extract(ctx, instr);
1670          return;
1671       }
1672       if (instr->isVOP3P()) {
1673          ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
1674          return;
1675       }
1676    }
1677 
1678    switch (instr->opcode) {
1679    case aco_opcode::p_create_vector: {
1680       bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
1681                        instr->operands[0].regClass() == instr->definitions[0].regClass();
1682       if (copy_prop) {
1683          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1684          break;
1685       }
1686 
1687       /* expand vector operands */
1688       std::vector<Operand> ops;
1689       unsigned offset = 0;
1690       for (const Operand& op : instr->operands) {
1691          /* ensure that any expanded operands are properly aligned */
1692          bool aligned = offset % 4 == 0 || op.bytes() < 4;
1693          offset += op.bytes();
1694          if (aligned && op.isTemp() && ctx.info[op.tempId()].is_vec()) {
1695             Instruction* vec = ctx.info[op.tempId()].instr;
1696             for (const Operand& vec_op : vec->operands)
1697                ops.emplace_back(vec_op);
1698          } else {
1699             ops.emplace_back(op);
1700          }
1701       }
1702 
1703       /* combine expanded operands to new vector */
1704       if (ops.size() != instr->operands.size()) {
1705          assert(ops.size() > instr->operands.size());
1706          Definition def = instr->definitions[0];
1707          instr.reset(
1708             create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, ops.size(), 1));
1709          for (unsigned i = 0; i < ops.size(); i++) {
1710             if (ops[i].isTemp() && ctx.info[ops[i].tempId()].is_temp() &&
1711                 ops[i].regClass() == ctx.info[ops[i].tempId()].temp.regClass())
1712                ops[i].setTemp(ctx.info[ops[i].tempId()].temp);
1713             instr->operands[i] = ops[i];
1714          }
1715          instr->definitions[0] = def;
1716       } else {
1717          for (unsigned i = 0; i < ops.size(); i++) {
1718             assert(instr->operands[i] == ops[i]);
1719          }
1720       }
1721       ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1722 
1723       if (instr->operands.size() == 2) {
1724          /* check if this is created from split_vector */
1725          if (instr->operands[1].isTemp() && ctx.info[instr->operands[1].tempId()].is_split()) {
1726             Instruction* split = ctx.info[instr->operands[1].tempId()].instr;
1727             if (instr->operands[0].isTemp() &&
1728                 instr->operands[0].getTemp() == split->definitions[0].getTemp())
1729                ctx.info[instr->definitions[0].tempId()].set_temp(split->operands[0].getTemp());
1730          }
1731       }
1732       break;
1733    }
1734    case aco_opcode::p_split_vector: {
1735       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1736 
1737       if (info.is_constant_or_literal(32)) {
1738          uint64_t val = info.val;
1739          for (Definition def : instr->definitions) {
1740             uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
1741             ctx.info[def.tempId()].set_constant(ctx.program->gfx_level, val & mask);
1742             val >>= def.bytes() * 8u;
1743          }
1744          break;
1745       } else if (!info.is_vec()) {
1746          if (instr->definitions.size() == 2 && instr->operands[0].isTemp() &&
1747              instr->definitions[0].bytes() == instr->definitions[1].bytes()) {
1748             ctx.info[instr->definitions[1].tempId()].set_split(instr.get());
1749             if (instr->operands[0].bytes() == 4) {
1750                /* D16 subdword split */
1751                ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1752                ctx.info[instr->definitions[1].tempId()].set_extract(instr.get());
1753             }
1754          }
1755          break;
1756       }
1757 
1758       Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1759       unsigned split_offset = 0;
1760       unsigned vec_offset = 0;
1761       unsigned vec_index = 0;
1762       for (unsigned i = 0; i < instr->definitions.size();
1763            split_offset += instr->definitions[i++].bytes()) {
1764          while (vec_offset < split_offset && vec_index < vec->operands.size())
1765             vec_offset += vec->operands[vec_index++].bytes();
1766 
1767          if (vec_offset != split_offset ||
1768              vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
1769             continue;
1770 
1771          Operand vec_op = vec->operands[vec_index];
1772          if (vec_op.isConstant()) {
1773             ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->gfx_level,
1774                                                                   vec_op.constantValue64());
1775          } else if (vec_op.isTemp()) {
1776             ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
1777          }
1778       }
1779       break;
1780    }
1781    case aco_opcode::p_extract_vector: { /* mov */
1782       const unsigned index = instr->operands[1].constantValue();
1783 
1784       if (instr->operands[0].isTemp()) {
1785          ssa_info& info = ctx.info[instr->operands[0].tempId()];
1786          const unsigned dst_offset = index * instr->definitions[0].bytes();
1787 
1788          if (info.is_vec()) {
1789             /* check if we index directly into a vector element */
1790             Instruction* vec = info.instr;
1791             unsigned offset = 0;
1792 
1793             for (const Operand& op : vec->operands) {
1794                if (offset < dst_offset) {
1795                   offset += op.bytes();
1796                   continue;
1797                } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
1798                   break;
1799                }
1800                instr->operands[0] = op;
1801                break;
1802             }
1803          } else if (info.is_constant_or_literal(32)) {
1804             /* propagate constants */
1805             uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
1806             uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
1807             instr->operands[0] =
1808                Operand::get_const(ctx.program->gfx_level, val, instr->definitions[0].bytes());
1809             ;
1810          }
1811       }
1812 
1813       if (instr->operands[0].bytes() != instr->definitions[0].bytes()) {
1814          if (instr->operands[0].size() != 1 || !instr->operands[0].isTemp())
1815             break;
1816 
1817          if (index == 0)
1818             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1819          else
1820             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1821          break;
1822       }
1823 
1824       /* convert this extract into a copy instruction */
1825       instr->opcode = aco_opcode::p_parallelcopy;
1826       instr->operands.pop_back();
1827       FALLTHROUGH;
1828    }
1829    case aco_opcode::p_parallelcopy: /* propagate */
1830       if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_vec() &&
1831           instr->operands[0].regClass() != instr->definitions[0].regClass()) {
1832          /* We might not be able to copy-propagate if it's a SGPR->VGPR copy, so
1833           * duplicate the vector instead.
1834           */
1835          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1836          aco_ptr<Instruction> old_copy = std::move(instr);
1837 
1838          instr.reset(create_instruction(aco_opcode::p_create_vector, Format::PSEUDO,
1839                                         vec->operands.size(), 1));
1840          instr->definitions[0] = old_copy->definitions[0];
1841          std::copy(vec->operands.begin(), vec->operands.end(), instr->operands.begin());
1842          for (unsigned i = 0; i < vec->operands.size(); i++) {
1843             Operand& op = instr->operands[i];
1844             if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
1845                 ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
1846                op.setTemp(ctx.info[op.tempId()].temp);
1847          }
1848          ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1849          break;
1850       }
1851       FALLTHROUGH;
1852    case aco_opcode::p_as_uniform:
1853       if (instr->definitions[0].isFixed()) {
1854          /* don't copy-propagate copies into fixed registers */
1855       } else if (instr->operands[0].isConstant()) {
1856          ctx.info[instr->definitions[0].tempId()].set_constant(
1857             ctx.program->gfx_level, instr->operands[0].constantValue64());
1858       } else if (instr->operands[0].isTemp()) {
1859          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1860          if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
1861             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1862       } else {
1863          assert(instr->operands[0].isFixed());
1864       }
1865       break;
1866    case aco_opcode::v_mov_b32:
1867       if (instr->isDPP16()) {
1868          /* anything else doesn't make sense in SSA */
1869          assert(instr->dpp16().row_mask == 0xf && instr->dpp16().bank_mask == 0xf);
1870          ctx.info[instr->definitions[0].tempId()].set_dpp16(instr.get());
1871       } else if (instr->isDPP8()) {
1872          ctx.info[instr->definitions[0].tempId()].set_dpp8(instr.get());
1873       }
1874       break;
1875    case aco_opcode::p_is_helper:
1876       if (!ctx.program->needs_wqm)
1877          ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1878       break;
1879    case aco_opcode::v_mul_f64_e64:
1880    case aco_opcode::v_mul_f64: ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); break;
1881    case aco_opcode::v_mul_f16:
1882    case aco_opcode::v_mul_f32:
1883    case aco_opcode::v_mul_legacy_f32: { /* omod */
1884       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
1885 
1886       /* TODO: try to move the negate/abs modifier to the consumer instead */
1887       bool uses_mods = instr->usesModifiers();
1888       bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
1889       unsigned denorm_mode = fp16 ? ctx.fp_mode.denorm16_64 : ctx.fp_mode.denorm32;
1890 
1891       for (unsigned i = 0; i < 2; i++) {
1892          if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
1893             if (!instr->isDPP() && !instr->isSDWA() && !instr->valu().opsel &&
1894                 (instr->operands[!i].constantEquals(fp16 ? 0x3c00 : 0x3f800000) ||   /* 1.0 */
1895                  instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u))) { /* -1.0 */
1896                bool neg1 = instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u);
1897 
1898                VALU_instruction* valu = &instr->valu();
1899                if (valu->abs[!i] || valu->neg[!i] || valu->omod)
1900                   continue;
1901 
1902                bool abs = valu->abs[i];
1903                bool neg = neg1 ^ valu->neg[i];
1904                Temp other = instr->operands[i].getTemp();
1905 
1906                if (valu->clamp) {
1907                   if (!abs && !neg && other.type() == RegType::vgpr)
1908                      ctx.info[other.id()].set_clamp(instr.get());
1909                   continue;
1910                }
1911 
1912                if (abs && neg && other.type() == RegType::vgpr)
1913                   ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
1914                else if (abs && !neg && other.type() == RegType::vgpr)
1915                   ctx.info[instr->definitions[0].tempId()].set_abs(other);
1916                else if (!abs && neg && other.type() == RegType::vgpr)
1917                   ctx.info[instr->definitions[0].tempId()].set_neg(other);
1918                else if (!abs && !neg) {
1919                   if (denorm_mode == fp_denorm_keep || ctx.info[other.id()].is_canonicalized())
1920                      ctx.info[instr->definitions[0].tempId()].set_temp(other);
1921                   else
1922                      ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
1923                }
1924             } else if (uses_mods || (instr->definitions[0].isSZPreserve() &&
1925                                      instr->opcode != aco_opcode::v_mul_legacy_f32)) {
1926                continue; /* omod uses a legacy multiplication. */
1927             } else if (instr->operands[!i].constantValue() == 0u &&
1928                        ((!instr->definitions[0].isNaNPreserve() &&
1929                          !instr->definitions[0].isInfPreserve()) ||
1930                         instr->opcode == aco_opcode::v_mul_legacy_f32)) { /* 0.0 */
1931                ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1932             } else if (denorm_mode != fp_denorm_flush) {
1933                /* omod has no effect if denormals are enabled. */
1934                continue;
1935             } else if (instr->operands[!i].constantValue() ==
1936                        (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */
1937                ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
1938             } else if (instr->operands[!i].constantValue() ==
1939                        (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */
1940                ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
1941             } else if (instr->operands[!i].constantValue() ==
1942                        (fp16 ? 0x3800 : 0x3f000000)) { /* 0.5 */
1943                ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
1944             } else {
1945                continue;
1946             }
1947             break;
1948          }
1949       }
1950       break;
1951    }
1952    case aco_opcode::v_mul_lo_u16:
1953    case aco_opcode::v_mul_lo_u16_e64:
1954    case aco_opcode::v_mul_u32_u24:
1955       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1956       break;
1957    case aco_opcode::v_med3_f16:
1958    case aco_opcode::v_med3_f32: { /* clamp */
1959       unsigned idx;
1960       if (detect_clamp(instr.get(), &idx) && !instr->valu().abs && !instr->valu().neg)
1961          ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
1962       break;
1963    }
1964    case aco_opcode::v_cndmask_b32:
1965       if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(0x3f800000u))
1966          ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
1967       else if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(1))
1968          ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
1969 
1970       break;
1971    case aco_opcode::v_add_u32:
1972    case aco_opcode::v_add_co_u32:
1973    case aco_opcode::v_add_co_u32_e64:
1974    case aco_opcode::s_add_i32:
1975    case aco_opcode::s_add_u32:
1976    case aco_opcode::v_subbrev_co_u32:
1977    case aco_opcode::v_sub_u32:
1978    case aco_opcode::v_sub_i32:
1979    case aco_opcode::v_sub_co_u32:
1980    case aco_opcode::v_sub_co_u32_e64:
1981    case aco_opcode::s_sub_u32:
1982    case aco_opcode::s_sub_i32:
1983    case aco_opcode::v_subrev_u32:
1984    case aco_opcode::v_subrev_co_u32:
1985    case aco_opcode::v_subrev_co_u32_e64:
1986       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
1987       break;
1988    case aco_opcode::s_not_b32:
1989    case aco_opcode::s_not_b64:
1990       if (!instr->operands[0].isTemp()) {
1991       } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1992          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1993          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1994             ctx.info[instr->operands[0].tempId()].temp);
1995       } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1996          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1997          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1998             ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1999       }
2000       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2001       break;
2002    case aco_opcode::s_and_b32:
2003    case aco_opcode::s_and_b64:
2004       if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
2005          if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
2006             /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a
2007              * uniform bool into divergent */
2008             ctx.info[instr->definitions[1].tempId()].set_temp(
2009                ctx.info[instr->operands[0].tempId()].temp);
2010             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
2011                ctx.info[instr->operands[0].tempId()].temp);
2012             break;
2013          } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
2014             /* Try to get rid of the superfluous s_and_b64, since the uniform bitwise instruction
2015              * already produces the same SCC */
2016             ctx.info[instr->definitions[1].tempId()].set_temp(
2017                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
2018             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
2019                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
2020             break;
2021          } else if ((ctx.program->stage.num_sw_stages() > 1 ||
2022                      ctx.program->stage.hw == AC_HW_NEXT_GEN_GEOMETRY_SHADER) &&
2023                     instr->pass_flags == 1) {
2024             /* In case of merged shaders, pass_flags=1 means that all lanes are active (exec=-1), so
2025              * s_and is unnecessary. */
2026             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
2027             break;
2028          }
2029       }
2030       FALLTHROUGH;
2031    case aco_opcode::s_or_b32:
2032    case aco_opcode::s_or_b64:
2033    case aco_opcode::s_xor_b32:
2034    case aco_opcode::s_xor_b64:
2035       if (std::all_of(instr->operands.begin(), instr->operands.end(),
2036                       [&ctx](const Operand& op)
2037                       {
2038                          return op.isTemp() && (ctx.info[op.tempId()].is_uniform_bool() ||
2039                                                 ctx.info[op.tempId()].is_uniform_bitwise());
2040                       })) {
2041          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
2042       }
2043       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2044       break;
2045    case aco_opcode::s_lshl_b32:
2046    case aco_opcode::v_or_b32:
2047    case aco_opcode::v_lshlrev_b32:
2048    case aco_opcode::v_bcnt_u32_b32:
2049    case aco_opcode::v_and_b32:
2050    case aco_opcode::v_xor_b32:
2051    case aco_opcode::v_not_b32:
2052       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2053       break;
2054    case aco_opcode::v_min_f32:
2055    case aco_opcode::v_min_f16:
2056    case aco_opcode::v_min_u32:
2057    case aco_opcode::v_min_i32:
2058    case aco_opcode::v_min_u16:
2059    case aco_opcode::v_min_i16:
2060    case aco_opcode::v_min_u16_e64:
2061    case aco_opcode::v_min_i16_e64:
2062    case aco_opcode::v_max_f32:
2063    case aco_opcode::v_max_f16:
2064    case aco_opcode::v_max_u32:
2065    case aco_opcode::v_max_i32:
2066    case aco_opcode::v_max_u16:
2067    case aco_opcode::v_max_i16:
2068    case aco_opcode::v_max_u16_e64:
2069    case aco_opcode::v_max_i16_e64:
2070       ctx.info[instr->definitions[0].tempId()].set_minmax(instr.get());
2071       break;
2072    case aco_opcode::s_cselect_b64:
2073    case aco_opcode::s_cselect_b32:
2074       if (instr->operands[0].constantEquals((unsigned)-1) && instr->operands[1].constantEquals(0)) {
2075          /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
2076          ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
2077       }
2078       if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
2079          /* Flip the operands to get rid of the scc_invert instruction */
2080          std::swap(instr->operands[0], instr->operands[1]);
2081          instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
2082       }
2083       break;
2084    case aco_opcode::s_mul_i32:
2085       /* Testing every uint32_t shows that 0x3f800000*n is never a denormal.
2086        * This pattern is created from a uniform nir_op_b2f. */
2087       if (instr->operands[0].constantEquals(0x3f800000u))
2088          ctx.info[instr->definitions[0].tempId()].set_canonicalized();
2089       break;
2090    case aco_opcode::p_extract: {
2091       if (instr->operands[0].isTemp()) {
2092          ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2093          if (instr->definitions[0].bytes() == 4 && instr->operands[0].regClass() == v1 &&
2094              parse_insert(instr.get()))
2095             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2096       }
2097       break;
2098    }
2099    case aco_opcode::p_insert: {
2100       if (instr->operands[0].isTemp()) {
2101          if (instr->operands[0].regClass() == v1)
2102             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2103          if (parse_extract(instr.get()))
2104             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2105          ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2106       }
2107       break;
2108    }
2109    case aco_opcode::ds_read_u8:
2110    case aco_opcode::ds_read_u8_d16:
2111    case aco_opcode::ds_read_u16:
2112    case aco_opcode::ds_read_u16_d16: {
2113       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2114       break;
2115    }
2116    case aco_opcode::v_cvt_f16_f32: {
2117       if (instr->operands[0].isTemp()) {
2118          ssa_info& info = ctx.info[instr->operands[0].tempId()];
2119          if (!info.is_dpp() || info.instr->pass_flags != instr->pass_flags)
2120             info.set_f2f16(instr.get());
2121       }
2122       break;
2123    }
2124    case aco_opcode::v_cvt_f32_f16: {
2125       if (instr->operands[0].isTemp())
2126          ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
2127       break;
2128    }
2129    default: break;
2130    }
2131 
2132    /* Don't remove label_extract if we can't apply the extract to
2133     * neg/abs instructions because we'll likely combine it into another valu. */
2134    if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)))
2135       check_sdwa_extract(ctx, instr);
2136 }
2137 
2138 unsigned
original_temp_id(opt_ctx & ctx,Temp tmp)2139 original_temp_id(opt_ctx& ctx, Temp tmp)
2140 {
2141    if (ctx.info[tmp.id()].is_temp())
2142       return ctx.info[tmp.id()].temp.id();
2143    else
2144       return tmp.id();
2145 }
2146 
2147 void
decrease_op_uses_if_dead(opt_ctx & ctx,Instruction * instr)2148 decrease_op_uses_if_dead(opt_ctx& ctx, Instruction* instr)
2149 {
2150    if (is_dead(ctx.uses, instr)) {
2151       for (const Operand& op : instr->operands) {
2152          if (op.isTemp())
2153             ctx.uses[op.tempId()]--;
2154       }
2155    }
2156 }
2157 
2158 void
decrease_uses(opt_ctx & ctx,Instruction * instr)2159 decrease_uses(opt_ctx& ctx, Instruction* instr)
2160 {
2161    ctx.uses[instr->definitions[0].tempId()]--;
2162    decrease_op_uses_if_dead(ctx, instr);
2163 }
2164 
2165 Operand
copy_operand(opt_ctx & ctx,Operand op)2166 copy_operand(opt_ctx& ctx, Operand op)
2167 {
2168    if (op.isTemp())
2169       ctx.uses[op.tempId()]++;
2170    return op;
2171 }
2172 
2173 Instruction*
follow_operand(opt_ctx & ctx,Operand op,bool ignore_uses=false)2174 follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false)
2175 {
2176    if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels))
2177       return nullptr;
2178    if (!ignore_uses && ctx.uses[op.tempId()] > 1)
2179       return nullptr;
2180 
2181    Instruction* instr = ctx.info[op.tempId()].instr;
2182 
2183    if (instr->definitions.size() == 2) {
2184       unsigned idx = ctx.info[op.tempId()].label & label_split ? 1 : 0;
2185       assert(instr->definitions[idx].isTemp() && instr->definitions[idx].tempId() == op.tempId());
2186       if (instr->definitions[!idx].isTemp() && ctx.uses[instr->definitions[!idx].tempId()])
2187          return nullptr;
2188    }
2189 
2190    for (Operand& operand : instr->operands) {
2191       if (fixed_to_exec(operand))
2192          return nullptr;
2193    }
2194 
2195    return instr;
2196 }
2197 
2198 bool
is_operand_constant(opt_ctx & ctx,Operand op,unsigned bit_size,uint64_t * value)2199 is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value)
2200 {
2201    if (op.isConstant()) {
2202       *value = op.constantValue64();
2203       return true;
2204    } else if (op.isTemp()) {
2205       unsigned id = original_temp_id(ctx, op.getTemp());
2206       if (!ctx.info[id].is_constant_or_literal(bit_size))
2207          return false;
2208       *value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
2209       return true;
2210    }
2211    return false;
2212 }
2213 
2214 /* s_not(cmp(a, b)) -> get_vcmp_inverse(cmp)(a, b) */
2215 bool
combine_inverse_comparison(opt_ctx & ctx,aco_ptr<Instruction> & instr)2216 combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2217 {
2218    if (ctx.uses[instr->definitions[1].tempId()])
2219       return false;
2220    if (!instr->operands[0].isTemp() || ctx.uses[instr->operands[0].tempId()] != 1)
2221       return false;
2222 
2223    Instruction* cmp = follow_operand(ctx, instr->operands[0]);
2224    if (!cmp)
2225       return false;
2226 
2227    aco_opcode new_opcode = get_vcmp_inverse(cmp->opcode);
2228    if (new_opcode == aco_opcode::num_opcodes)
2229       return false;
2230 
2231    /* Invert compare instruction and assign this instruction's definition */
2232    cmp->opcode = new_opcode;
2233    ctx.info[instr->definitions[0].tempId()] = ctx.info[cmp->definitions[0].tempId()];
2234    std::swap(instr->definitions[0], cmp->definitions[0]);
2235 
2236    ctx.uses[instr->operands[0].tempId()]--;
2237    return true;
2238 }
2239 
2240 /* op1(op2(1, 2), 0) if swap = false
2241  * op1(0, op2(1, 2)) if swap = true */
2242 bool
match_op3_for_vop3(opt_ctx & ctx,aco_opcode op1,aco_opcode op2,Instruction * op1_instr,bool swap,const char * shuffle_str,Operand operands[3],bitarray8 & neg,bitarray8 & abs,bitarray8 & opsel,bool * op1_clamp,uint8_t * op1_omod,bool * inbetween_neg,bool * inbetween_abs,bool * inbetween_opsel,bool * precise)2243 match_op3_for_vop3(opt_ctx& ctx, aco_opcode op1, aco_opcode op2, Instruction* op1_instr, bool swap,
2244                    const char* shuffle_str, Operand operands[3], bitarray8& neg, bitarray8& abs,
2245                    bitarray8& opsel, bool* op1_clamp, uint8_t* op1_omod, bool* inbetween_neg,
2246                    bool* inbetween_abs, bool* inbetween_opsel, bool* precise)
2247 {
2248    /* checks */
2249    if (op1_instr->opcode != op1)
2250       return false;
2251 
2252    Instruction* op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
2253    if (!op2_instr || op2_instr->opcode != op2)
2254       return false;
2255 
2256    VALU_instruction* op1_valu = op1_instr->isVALU() ? &op1_instr->valu() : NULL;
2257    VALU_instruction* op2_valu = op2_instr->isVALU() ? &op2_instr->valu() : NULL;
2258 
2259    if (op1_instr->isSDWA() || op2_instr->isSDWA())
2260       return false;
2261    if (op1_instr->isDPP() || op2_instr->isDPP())
2262       return false;
2263 
2264    /* don't support inbetween clamp/omod */
2265    if (op2_valu && (op2_valu->clamp || op2_valu->omod))
2266       return false;
2267 
2268    /* get operands and modifiers and check inbetween modifiers */
2269    *op1_clamp = op1_valu ? (bool)op1_valu->clamp : false;
2270    *op1_omod = op1_valu ? (unsigned)op1_valu->omod : 0u;
2271 
2272    if (inbetween_neg)
2273       *inbetween_neg = op1_valu ? op1_valu->neg[swap] : false;
2274    else if (op1_valu && op1_valu->neg[swap])
2275       return false;
2276 
2277    if (inbetween_abs)
2278       *inbetween_abs = op1_valu ? op1_valu->abs[swap] : false;
2279    else if (op1_valu && op1_valu->abs[swap])
2280       return false;
2281 
2282    if (inbetween_opsel)
2283       *inbetween_opsel = op1_valu ? op1_valu->opsel[swap] : false;
2284    else if (op1_valu && op1_valu->opsel[swap])
2285       return false;
2286 
2287    *precise = op1_instr->definitions[0].isPrecise() || op2_instr->definitions[0].isPrecise();
2288 
2289    int shuffle[3];
2290    shuffle[shuffle_str[0] - '0'] = 0;
2291    shuffle[shuffle_str[1] - '0'] = 1;
2292    shuffle[shuffle_str[2] - '0'] = 2;
2293 
2294    operands[shuffle[0]] = op1_instr->operands[!swap];
2295    neg[shuffle[0]] = op1_valu ? op1_valu->neg[!swap] : false;
2296    abs[shuffle[0]] = op1_valu ? op1_valu->abs[!swap] : false;
2297    opsel[shuffle[0]] = op1_valu ? op1_valu->opsel[!swap] : false;
2298 
2299    for (unsigned i = 0; i < 2; i++) {
2300       operands[shuffle[i + 1]] = op2_instr->operands[i];
2301       neg[shuffle[i + 1]] = op2_valu ? op2_valu->neg[i] : false;
2302       abs[shuffle[i + 1]] = op2_valu ? op2_valu->abs[i] : false;
2303       opsel[shuffle[i + 1]] = op2_valu ? op2_valu->opsel[i] : false;
2304    }
2305 
2306    /* check operands */
2307    if (!check_vop3_operands(ctx, 3, operands))
2308       return false;
2309 
2310    return true;
2311 }
2312 
2313 void
create_vop3_for_op3(opt_ctx & ctx,aco_opcode opcode,aco_ptr<Instruction> & instr,Operand operands[3],uint8_t neg,uint8_t abs,uint8_t opsel,bool clamp,unsigned omod)2314 create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>& instr,
2315                     Operand operands[3], uint8_t neg, uint8_t abs, uint8_t opsel, bool clamp,
2316                     unsigned omod)
2317 {
2318    Instruction* new_instr = create_instruction(opcode, Format::VOP3, 3, 1);
2319    new_instr->valu().neg = neg;
2320    new_instr->valu().abs = abs;
2321    new_instr->valu().clamp = clamp;
2322    new_instr->valu().omod = omod;
2323    new_instr->valu().opsel = opsel;
2324    new_instr->operands[0] = operands[0];
2325    new_instr->operands[1] = operands[1];
2326    new_instr->operands[2] = operands[2];
2327    new_instr->definitions[0] = instr->definitions[0];
2328    new_instr->pass_flags = instr->pass_flags;
2329    ctx.info[instr->definitions[0].tempId()].label = 0;
2330 
2331    instr.reset(new_instr);
2332 }
2333 
2334 bool
combine_three_valu_op(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode op2,aco_opcode new_op,const char * shuffle,uint8_t ops)2335 combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op,
2336                       const char* shuffle, uint8_t ops)
2337 {
2338    for (unsigned swap = 0; swap < 2; swap++) {
2339       if (!((1 << swap) & ops))
2340          continue;
2341 
2342       Operand operands[3];
2343       bool clamp, precise;
2344       bitarray8 neg = 0, abs = 0, opsel = 0;
2345       uint8_t omod = 0;
2346       if (match_op3_for_vop3(ctx, instr->opcode, op2, instr.get(), swap, shuffle, operands, neg,
2347                              abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2348          ctx.uses[instr->operands[swap].tempId()]--;
2349          create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
2350          return true;
2351       }
2352    }
2353    return false;
2354 }
2355 
2356 /* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
2357 bool
combine_add_or_then_and_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr)2358 combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2359 {
2360    bool is_or = instr->opcode == aco_opcode::v_or_b32;
2361    aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
2362 
2363    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32,
2364                                       "120", 1 | 2))
2365       return true;
2366    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32,
2367                                       "120", 1 | 2))
2368       return true;
2369    if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
2370       return true;
2371    if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
2372       return true;
2373 
2374    if (instr->isSDWA() || instr->isDPP())
2375       return false;
2376 
2377    /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2378     * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2379     * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
2380     * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
2381     */
2382    for (unsigned i = 0; i < 2; i++) {
2383       Instruction* extins = follow_operand(ctx, instr->operands[i]);
2384       if (!extins)
2385          continue;
2386 
2387       aco_opcode op;
2388       Operand operands[3];
2389 
2390       if (extins->opcode == aco_opcode::p_insert &&
2391           (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
2392          op = new_op_lshl;
2393          operands[1] =
2394             Operand::c32(extins->operands[1].constantValue() * extins->operands[2].constantValue());
2395       } else if (is_or &&
2396                  (extins->opcode == aco_opcode::p_insert ||
2397                   (extins->opcode == aco_opcode::p_extract &&
2398                    extins->operands[3].constantEquals(0))) &&
2399                  extins->operands[1].constantEquals(0)) {
2400          op = aco_opcode::v_and_or_b32;
2401          operands[1] = Operand::c32(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
2402       } else {
2403          continue;
2404       }
2405 
2406       operands[0] = extins->operands[0];
2407       operands[2] = instr->operands[!i];
2408 
2409       if (!check_vop3_operands(ctx, 3, operands))
2410          continue;
2411 
2412       uint8_t neg = 0, abs = 0, opsel = 0, omod = 0;
2413       bool clamp = false;
2414       if (instr->isVOP3())
2415          clamp = instr->valu().clamp;
2416 
2417       ctx.uses[instr->operands[i].tempId()]--;
2418       create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
2419       return true;
2420    }
2421 
2422    return false;
2423 }
2424 
2425 /* v_xor(a, s_not(b)) -> v_xnor(a, b)
2426  * v_xor(a, v_not(b)) -> v_xnor(a, b)
2427  */
2428 bool
combine_xor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)2429 combine_xor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2430 {
2431    if (instr->usesModifiers())
2432       return false;
2433 
2434    for (unsigned i = 0; i < 2; i++) {
2435       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
2436       if (!op_instr ||
2437           (op_instr->opcode != aco_opcode::v_not_b32 &&
2438            op_instr->opcode != aco_opcode::s_not_b32) ||
2439           op_instr->usesModifiers() || op_instr->operands[0].isLiteral())
2440          continue;
2441 
2442       instr->opcode = aco_opcode::v_xnor_b32;
2443       instr->operands[i] = copy_operand(ctx, op_instr->operands[0]);
2444       decrease_uses(ctx, op_instr);
2445       if (instr->operands[0].isOfType(RegType::vgpr))
2446          std::swap(instr->operands[0], instr->operands[1]);
2447       if (!instr->operands[1].isOfType(RegType::vgpr))
2448          instr->format = asVOP3(instr->format);
2449 
2450       return true;
2451    }
2452 
2453    return false;
2454 }
2455 
2456 /* v_not(v_xor(a, b)) -> v_xnor(a, b) */
2457 bool
combine_not_xor(opt_ctx & ctx,aco_ptr<Instruction> & instr)2458 combine_not_xor(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2459 {
2460    if (instr->usesModifiers())
2461       return false;
2462 
2463    Instruction* op_instr = follow_operand(ctx, instr->operands[0]);
2464    if (!op_instr || op_instr->opcode != aco_opcode::v_xor_b32 || op_instr->isSDWA())
2465       return false;
2466 
2467    ctx.uses[instr->operands[0].tempId()]--;
2468    std::swap(instr->definitions[0], op_instr->definitions[0]);
2469    op_instr->opcode = aco_opcode::v_xnor_b32;
2470    ctx.info[op_instr->definitions[0].tempId()].label = 0;
2471 
2472    return true;
2473 }
2474 
2475 bool
combine_minmax(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode opposite,aco_opcode op3src,aco_opcode minmax)2476 combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode op3src,
2477                aco_opcode minmax)
2478 {
2479    /* TODO: this can handle SDWA min/max instructions by using opsel */
2480 
2481    /* min(min(a, b), c) -> min3(a, b, c)
2482     * max(max(a, b), c) -> max3(a, b, c)
2483     * gfx11: min(-min(a, b), c) -> maxmin(-a, -b, c)
2484     * gfx11: max(-max(a, b), c) -> minmax(-a, -b, c)
2485     */
2486    for (unsigned swap = 0; swap < 2; swap++) {
2487       Operand operands[3];
2488       bool clamp, precise;
2489       bitarray8 opsel = 0, neg = 0, abs = 0;
2490       uint8_t omod = 0;
2491       bool inbetween_neg;
2492       if (match_op3_for_vop3(ctx, instr->opcode, instr->opcode, instr.get(), swap, "120", operands,
2493                              neg, abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL,
2494                              &precise) &&
2495           (!inbetween_neg ||
2496            (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2497          ctx.uses[instr->operands[swap].tempId()]--;
2498          if (inbetween_neg) {
2499             neg[0] = !neg[0];
2500             neg[1] = !neg[1];
2501             create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2502          } else {
2503             create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2504          }
2505          return true;
2506       }
2507    }
2508 
2509    /* min(-max(a, b), c) -> min3(-a, -b, c)
2510     * max(-min(a, b), c) -> max3(-a, -b, c)
2511     * gfx11: min(max(a, b), c) -> maxmin(a, b, c)
2512     * gfx11: max(min(a, b), c) -> minmax(a, b, c)
2513     */
2514    for (unsigned swap = 0; swap < 2; swap++) {
2515       Operand operands[3];
2516       bool clamp, precise;
2517       bitarray8 opsel = 0, neg = 0, abs = 0;
2518       uint8_t omod = 0;
2519       bool inbetween_neg;
2520       if (match_op3_for_vop3(ctx, instr->opcode, opposite, instr.get(), swap, "120", operands, neg,
2521                              abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL, &precise) &&
2522           (inbetween_neg ||
2523            (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2524          ctx.uses[instr->operands[swap].tempId()]--;
2525          if (inbetween_neg) {
2526             neg[0] = !neg[0];
2527             neg[1] = !neg[1];
2528             create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2529          } else {
2530             create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2531          }
2532          return true;
2533       }
2534    }
2535    return false;
2536 }
2537 
2538 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
2539  * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
2540  * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
2541  * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b)
2542  * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b)
2543  * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */
2544 bool
combine_salu_not_bitwise(opt_ctx & ctx,aco_ptr<Instruction> & instr)2545 combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2546 {
2547    /* checks */
2548    if (!instr->operands[0].isTemp())
2549       return false;
2550    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2551       return false;
2552 
2553    Instruction* op2_instr = follow_operand(ctx, instr->operands[0]);
2554    if (!op2_instr)
2555       return false;
2556    switch (op2_instr->opcode) {
2557    case aco_opcode::s_and_b32:
2558    case aco_opcode::s_or_b32:
2559    case aco_opcode::s_xor_b32:
2560    case aco_opcode::s_and_b64:
2561    case aco_opcode::s_or_b64:
2562    case aco_opcode::s_xor_b64: break;
2563    default: return false;
2564    }
2565 
2566    /* create instruction */
2567    std::swap(instr->definitions[0], op2_instr->definitions[0]);
2568    std::swap(instr->definitions[1], op2_instr->definitions[1]);
2569    ctx.uses[instr->operands[0].tempId()]--;
2570    ctx.info[op2_instr->definitions[0].tempId()].label = 0;
2571 
2572    switch (op2_instr->opcode) {
2573    case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break;
2574    case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break;
2575    case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break;
2576    case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break;
2577    case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break;
2578    case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break;
2579    default: break;
2580    }
2581 
2582    return true;
2583 }
2584 
2585 /* s_and_b32(a, s_not_b32(b)) -> s_andn2_b32(a, b)
2586  * s_or_b32(a, s_not_b32(b)) -> s_orn2_b32(a, b)
2587  * s_and_b64(a, s_not_b64(b)) -> s_andn2_b64(a, b)
2588  * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
2589 bool
combine_salu_n2(opt_ctx & ctx,aco_ptr<Instruction> & instr)2590 combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2591 {
2592    if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
2593       return false;
2594 
2595    for (unsigned i = 0; i < 2; i++) {
2596       Instruction* op2_instr = follow_operand(ctx, instr->operands[i]);
2597       if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 &&
2598                          op2_instr->opcode != aco_opcode::s_not_b64))
2599          continue;
2600       if (ctx.uses[op2_instr->definitions[1].tempId()])
2601          continue;
2602 
2603       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2604           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2605          continue;
2606 
2607       ctx.uses[instr->operands[i].tempId()]--;
2608       instr->operands[0] = instr->operands[!i];
2609       instr->operands[1] = op2_instr->operands[0];
2610       ctx.info[instr->definitions[0].tempId()].label = 0;
2611 
2612       switch (instr->opcode) {
2613       case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_andn2_b32; break;
2614       case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_orn2_b32; break;
2615       case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_andn2_b64; break;
2616       case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_orn2_b64; break;
2617       default: break;
2618       }
2619 
2620       return true;
2621    }
2622    return false;
2623 }
2624 
2625 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
2626 bool
combine_salu_lshl_add(opt_ctx & ctx,aco_ptr<Instruction> & instr)2627 combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2628 {
2629    if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
2630       return false;
2631 
2632    for (unsigned i = 0; i < 2; i++) {
2633       Instruction* op2_instr = follow_operand(ctx, instr->operands[i], true);
2634       if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
2635           ctx.uses[op2_instr->definitions[1].tempId()])
2636          continue;
2637       if (!op2_instr->operands[1].isConstant())
2638          continue;
2639 
2640       uint32_t shift = op2_instr->operands[1].constantValue();
2641       if (shift < 1 || shift > 4)
2642          continue;
2643 
2644       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2645           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2646          continue;
2647 
2648       instr->operands[1] = instr->operands[!i];
2649       instr->operands[0] = copy_operand(ctx, op2_instr->operands[0]);
2650       decrease_uses(ctx, op2_instr);
2651       ctx.info[instr->definitions[0].tempId()].label = 0;
2652 
2653       instr->opcode = std::array<aco_opcode, 4>{
2654          aco_opcode::s_lshl1_add_u32, aco_opcode::s_lshl2_add_u32, aco_opcode::s_lshl3_add_u32,
2655          aco_opcode::s_lshl4_add_u32}[shift - 1];
2656 
2657       return true;
2658    }
2659    return false;
2660 }
2661 
2662 /* s_abs_i32(s_sub_[iu]32(a, b)) -> s_absdiff_i32(a, b)
2663  * s_abs_i32(s_add_[iu]32(a, #b)) -> s_absdiff_i32(a, -b)
2664  */
2665 bool
combine_sabsdiff(opt_ctx & ctx,aco_ptr<Instruction> & instr)2666 combine_sabsdiff(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2667 {
2668    if (!instr->operands[0].isTemp() || !ctx.info[instr->operands[0].tempId()].is_add_sub())
2669       return false;
2670 
2671    Instruction* op_instr = follow_operand(ctx, instr->operands[0], false);
2672    if (!op_instr)
2673       return false;
2674 
2675    if (op_instr->opcode == aco_opcode::s_add_i32 || op_instr->opcode == aco_opcode::s_add_u32) {
2676       for (unsigned i = 0; i < 2; i++) {
2677          uint64_t constant;
2678          if (op_instr->operands[!i].isLiteral() ||
2679              !is_operand_constant(ctx, op_instr->operands[i], 32, &constant))
2680             continue;
2681 
2682          if (op_instr->operands[i].isTemp())
2683             ctx.uses[op_instr->operands[i].tempId()]--;
2684          op_instr->operands[0] = op_instr->operands[!i];
2685          op_instr->operands[1] = Operand::c32(-int32_t(constant));
2686          goto use_absdiff;
2687       }
2688       return false;
2689    }
2690 
2691 use_absdiff:
2692    op_instr->opcode = aco_opcode::s_absdiff_i32;
2693    std::swap(instr->definitions[0], op_instr->definitions[0]);
2694    std::swap(instr->definitions[1], op_instr->definitions[1]);
2695    ctx.uses[instr->operands[0].tempId()]--;
2696    ctx.info[op_instr->definitions[0].tempId()].label = 0;
2697 
2698    return true;
2699 }
2700 
2701 bool
combine_add_sub_b2i(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode new_op,uint8_t ops)2702 combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
2703 {
2704    if (instr->usesModifiers())
2705       return false;
2706 
2707    for (unsigned i = 0; i < 2; i++) {
2708       if (!((1 << i) & ops))
2709          continue;
2710       if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2i() &&
2711           ctx.uses[instr->operands[i].tempId()] == 1) {
2712 
2713          aco_ptr<Instruction> new_instr;
2714          if (instr->operands[!i].isTemp() &&
2715              instr->operands[!i].getTemp().type() == RegType::vgpr) {
2716             new_instr.reset(create_instruction(new_op, Format::VOP2, 3, 2));
2717          } else if (ctx.program->gfx_level >= GFX10 ||
2718                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
2719             new_instr.reset(create_instruction(new_op, asVOP3(Format::VOP2), 3, 2));
2720          } else {
2721             return false;
2722          }
2723          ctx.uses[instr->operands[i].tempId()]--;
2724          new_instr->definitions[0] = instr->definitions[0];
2725          if (instr->definitions.size() == 2) {
2726             new_instr->definitions[1] = instr->definitions[1];
2727          } else {
2728             new_instr->definitions[1] =
2729                Definition(ctx.program->allocateTmp(ctx.program->lane_mask));
2730             /* Make sure the uses vector is large enough and the number of
2731              * uses properly initialized to 0.
2732              */
2733             ctx.uses.push_back(0);
2734             ctx.info.push_back(ssa_info{});
2735          }
2736          new_instr->operands[0] = Operand::zero();
2737          new_instr->operands[1] = instr->operands[!i];
2738          new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
2739          new_instr->pass_flags = instr->pass_flags;
2740          instr = std::move(new_instr);
2741          ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
2742          return true;
2743       }
2744    }
2745 
2746    return false;
2747 }
2748 
2749 bool
combine_add_bcnt(opt_ctx & ctx,aco_ptr<Instruction> & instr)2750 combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2751 {
2752    if (instr->usesModifiers())
2753       return false;
2754 
2755    for (unsigned i = 0; i < 2; i++) {
2756       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
2757       if (op_instr && op_instr->opcode == aco_opcode::v_bcnt_u32_b32 &&
2758           !op_instr->usesModifiers() && op_instr->operands[0].isTemp() &&
2759           op_instr->operands[0].getTemp().type() == RegType::vgpr &&
2760           op_instr->operands[1].constantEquals(0)) {
2761          aco_ptr<Instruction> new_instr{
2762             create_instruction(aco_opcode::v_bcnt_u32_b32, Format::VOP3, 2, 1)};
2763          ctx.uses[instr->operands[i].tempId()]--;
2764          new_instr->operands[0] = op_instr->operands[0];
2765          new_instr->operands[1] = instr->operands[!i];
2766          new_instr->definitions[0] = instr->definitions[0];
2767          new_instr->pass_flags = instr->pass_flags;
2768          instr = std::move(new_instr);
2769          ctx.info[instr->definitions[0].tempId()].label = 0;
2770 
2771          return true;
2772       }
2773    }
2774 
2775    return false;
2776 }
2777 
2778 bool
get_minmax_info(aco_opcode op,aco_opcode * min,aco_opcode * max,aco_opcode * min3,aco_opcode * max3,aco_opcode * med3,aco_opcode * minmax,bool * some_gfx9_only)2779 get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
2780                 aco_opcode* med3, aco_opcode* minmax, bool* some_gfx9_only)
2781 {
2782    switch (op) {
2783 #define MINMAX(type, gfx9)                                                                         \
2784    case aco_opcode::v_min_##type:                                                                  \
2785    case aco_opcode::v_max_##type:                                                                  \
2786       *min = aco_opcode::v_min_##type;                                                             \
2787       *max = aco_opcode::v_max_##type;                                                             \
2788       *med3 = aco_opcode::v_med3_##type;                                                           \
2789       *min3 = aco_opcode::v_min3_##type;                                                           \
2790       *max3 = aco_opcode::v_max3_##type;                                                           \
2791       *minmax = op == *min ? aco_opcode::v_maxmin_##type : aco_opcode::v_minmax_##type;            \
2792       *some_gfx9_only = gfx9;                                                                      \
2793       return true;
2794 #define MINMAX_INT16(type, gfx9)                                                                   \
2795    case aco_opcode::v_min_##type:                                                                  \
2796    case aco_opcode::v_max_##type:                                                                  \
2797       *min = aco_opcode::v_min_##type;                                                             \
2798       *max = aco_opcode::v_max_##type;                                                             \
2799       *med3 = aco_opcode::v_med3_##type;                                                           \
2800       *min3 = aco_opcode::v_min3_##type;                                                           \
2801       *max3 = aco_opcode::v_max3_##type;                                                           \
2802       *minmax = aco_opcode::num_opcodes;                                                           \
2803       *some_gfx9_only = gfx9;                                                                      \
2804       return true;
2805 #define MINMAX_INT16_E64(type, gfx9)                                                               \
2806    case aco_opcode::v_min_##type##_e64:                                                            \
2807    case aco_opcode::v_max_##type##_e64:                                                            \
2808       *min = aco_opcode::v_min_##type##_e64;                                                       \
2809       *max = aco_opcode::v_max_##type##_e64;                                                       \
2810       *med3 = aco_opcode::v_med3_##type;                                                           \
2811       *min3 = aco_opcode::v_min3_##type;                                                           \
2812       *max3 = aco_opcode::v_max3_##type;                                                           \
2813       *minmax = aco_opcode::num_opcodes;                                                           \
2814       *some_gfx9_only = gfx9;                                                                      \
2815       return true;
2816       MINMAX(f32, false)
2817       MINMAX(u32, false)
2818       MINMAX(i32, false)
2819       MINMAX(f16, true)
2820       MINMAX_INT16(u16, true)
2821       MINMAX_INT16(i16, true)
2822       MINMAX_INT16_E64(u16, true)
2823       MINMAX_INT16_E64(i16, true)
2824 #undef MINMAX_INT16_E64
2825 #undef MINMAX_INT16
2826 #undef MINMAX
2827    default: return false;
2828    }
2829 }
2830 
2831 /* when ub > lb:
2832  * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2833  * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2834  */
2835 bool
combine_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode min,aco_opcode max,aco_opcode med)2836 combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
2837               aco_opcode med)
2838 {
2839    /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
2840     * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
2841     * minVal > maxVal, which means we can always select it to a v_med3_f32 */
2842    aco_opcode other_op;
2843    if (instr->opcode == min)
2844       other_op = max;
2845    else if (instr->opcode == max)
2846       other_op = min;
2847    else
2848       return false;
2849 
2850    for (unsigned swap = 0; swap < 2; swap++) {
2851       Operand operands[3];
2852       bool clamp, precise;
2853       bitarray8 opsel = 0, neg = 0, abs = 0;
2854       uint8_t omod = 0;
2855       if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
2856                              abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2857          /* max(min(src, upper), lower) returns upper if src is NaN, but
2858           * med3(src, lower, upper) returns lower.
2859           */
2860          if (precise && instr->opcode != min &&
2861              (min == aco_opcode::v_min_f16 || min == aco_opcode::v_min_f32))
2862             continue;
2863 
2864          int const0_idx = -1, const1_idx = -1;
2865          uint32_t const0 = 0, const1 = 0;
2866          for (int i = 0; i < 3; i++) {
2867             uint32_t val;
2868             bool hi16 = opsel & (1 << i);
2869             if (operands[i].isConstant()) {
2870                val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue();
2871             } else if (operands[i].isTemp() &&
2872                        ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
2873                val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0);
2874             } else {
2875                continue;
2876             }
2877             if (const0_idx >= 0) {
2878                const1_idx = i;
2879                const1 = val;
2880             } else {
2881                const0_idx = i;
2882                const0 = val;
2883             }
2884          }
2885          if (const0_idx < 0 || const1_idx < 0)
2886             continue;
2887 
2888          int lower_idx = const0_idx;
2889          switch (min) {
2890          case aco_opcode::v_min_f32:
2891          case aco_opcode::v_min_f16: {
2892             float const0_f, const1_f;
2893             if (min == aco_opcode::v_min_f32) {
2894                memcpy(&const0_f, &const0, 4);
2895                memcpy(&const1_f, &const1, 4);
2896             } else {
2897                const0_f = _mesa_half_to_float(const0);
2898                const1_f = _mesa_half_to_float(const1);
2899             }
2900             if (abs[const0_idx])
2901                const0_f = fabsf(const0_f);
2902             if (abs[const1_idx])
2903                const1_f = fabsf(const1_f);
2904             if (neg[const0_idx])
2905                const0_f = -const0_f;
2906             if (neg[const1_idx])
2907                const1_f = -const1_f;
2908             lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
2909             break;
2910          }
2911          case aco_opcode::v_min_u32: {
2912             lower_idx = const0 < const1 ? const0_idx : const1_idx;
2913             break;
2914          }
2915          case aco_opcode::v_min_u16:
2916          case aco_opcode::v_min_u16_e64: {
2917             lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
2918             break;
2919          }
2920          case aco_opcode::v_min_i32: {
2921             int32_t const0_i =
2922                const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
2923             int32_t const1_i =
2924                const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
2925             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2926             break;
2927          }
2928          case aco_opcode::v_min_i16:
2929          case aco_opcode::v_min_i16_e64: {
2930             int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
2931             int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
2932             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2933             break;
2934          }
2935          default: break;
2936          }
2937          int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
2938 
2939          if (instr->opcode == min) {
2940             if (upper_idx != 0 || lower_idx == 0)
2941                return false;
2942          } else {
2943             if (upper_idx == 0 || lower_idx != 0)
2944                return false;
2945          }
2946 
2947          ctx.uses[instr->operands[swap].tempId()]--;
2948          create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
2949 
2950          return true;
2951       }
2952    }
2953 
2954    return false;
2955 }
2956 
2957 void
apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)2958 apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2959 {
2960    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64_e64 ||
2961                      instr->opcode == aco_opcode::v_lshlrev_b64 ||
2962                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
2963                      instr->opcode == aco_opcode::v_ashrrev_i64;
2964 
2965    /* find candidates and create the set of sgprs already read */
2966    unsigned sgpr_ids[2] = {0, 0};
2967    uint32_t operand_mask = 0;
2968    bool has_literal = false;
2969    for (unsigned i = 0; i < instr->operands.size(); i++) {
2970       if (instr->operands[i].isLiteral())
2971          has_literal = true;
2972       if (!instr->operands[i].isTemp())
2973          continue;
2974       if (instr->operands[i].getTemp().type() == RegType::sgpr) {
2975          if (instr->operands[i].tempId() != sgpr_ids[0])
2976             sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
2977       }
2978       ssa_info& info = ctx.info[instr->operands[i].tempId()];
2979       if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::sgpr)
2980          operand_mask |= 1u << i;
2981       if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr)
2982          operand_mask |= 1u << i;
2983    }
2984    unsigned max_sgprs = 1;
2985    if (ctx.program->gfx_level >= GFX10 && !is_shift64)
2986       max_sgprs = 2;
2987    if (has_literal)
2988       max_sgprs--;
2989 
2990    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
2991 
2992    /* keep on applying sgprs until there is nothing left to be done */
2993    while (operand_mask) {
2994       uint32_t sgpr_idx = 0;
2995       uint32_t sgpr_info_id = 0;
2996       uint32_t mask = operand_mask;
2997       /* choose a sgpr */
2998       while (mask) {
2999          unsigned i = u_bit_scan(&mask);
3000          uint16_t uses = ctx.uses[instr->operands[i].tempId()];
3001          if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
3002             sgpr_idx = i;
3003             sgpr_info_id = instr->operands[i].tempId();
3004          }
3005       }
3006       operand_mask &= ~(1u << sgpr_idx);
3007 
3008       ssa_info& info = ctx.info[sgpr_info_id];
3009 
3010       Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp;
3011       bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
3012       if (new_sgpr && num_sgprs >= max_sgprs)
3013          continue;
3014 
3015       if (sgpr_idx == 0)
3016          instr->format = withoutDPP(instr->format);
3017 
3018       if (sgpr_idx == 1 && instr->isDPP())
3019          continue;
3020 
3021       if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
3022           info.is_extract()) {
3023          /* can_apply_extract() checks SGPR encoding restrictions */
3024          if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
3025             apply_extract(ctx, instr, sgpr_idx, info);
3026          else if (info.is_extract())
3027             continue;
3028          instr->operands[sgpr_idx] = Operand(sgpr);
3029       } else if (can_swap_operands(instr, &instr->opcode) && !instr->valu().opsel[sgpr_idx]) {
3030          instr->operands[sgpr_idx] = instr->operands[0];
3031          instr->operands[0] = Operand(sgpr);
3032          instr->valu().opsel[0].swap(instr->valu().opsel[sgpr_idx]);
3033          /* swap bits using a 4-entry LUT */
3034          uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
3035          operand_mask = (operand_mask & ~0x3) | swapped;
3036       } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
3037          instr->format = asVOP3(instr->format);
3038          instr->operands[sgpr_idx] = Operand(sgpr);
3039       } else {
3040          continue;
3041       }
3042 
3043       if (new_sgpr)
3044          sgpr_ids[num_sgprs++] = sgpr.id();
3045       ctx.uses[sgpr_info_id]--;
3046       ctx.uses[sgpr.id()]++;
3047 
3048       /* TODO: handle when it's a VGPR */
3049       if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
3050           ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
3051          operand_mask |= 1u << sgpr_idx;
3052    }
3053 }
3054 
3055 bool
interp_can_become_fma(opt_ctx & ctx,aco_ptr<Instruction> & instr)3056 interp_can_become_fma(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3057 {
3058    if (instr->opcode != aco_opcode::v_interp_p2_f32_inreg)
3059       return false;
3060 
3061    instr->opcode = aco_opcode::v_fma_f32;
3062    instr->format = Format::VOP3;
3063    bool dpp_allowed = can_use_DPP(ctx.program->gfx_level, instr, false);
3064    instr->opcode = aco_opcode::v_interp_p2_f32_inreg;
3065    instr->format = Format::VINTERP_INREG;
3066 
3067    return dpp_allowed;
3068 }
3069 
3070 void
interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction> & instr)3071 interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction>& instr)
3072 {
3073    static_assert(sizeof(DPP16_instruction) == sizeof(VINTERP_inreg_instruction),
3074                  "Invalid instr cast.");
3075    instr->format = asVOP3(Format::DPP16);
3076    instr->opcode = aco_opcode::v_fma_f32;
3077    instr->dpp16().dpp_ctrl = dpp_quad_perm(2, 2, 2, 2);
3078    instr->dpp16().row_mask = 0xf;
3079    instr->dpp16().bank_mask = 0xf;
3080    instr->dpp16().bound_ctrl = 0;
3081    instr->dpp16().fetch_inactive = 1;
3082 }
3083 
3084 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
3085 bool
apply_omod_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr)3086 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3087 {
3088    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
3089        !instr_info.can_use_output_modifiers[(int)instr->opcode])
3090       return false;
3091 
3092    bool can_vop3 = can_use_VOP3(ctx, instr);
3093    bool is_mad_mix =
3094       instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16;
3095    bool needs_vop3 = !instr->isSDWA() && !instr->isVINTERP_INREG() && !is_mad_mix;
3096    if (needs_vop3 && !can_vop3)
3097       return false;
3098 
3099    /* SDWA omod is GFX9+. */
3100    bool can_use_omod = (can_vop3 || ctx.program->gfx_level >= GFX9) && !instr->isVOP3P() &&
3101                        (!instr->isVINTERP_INREG() || interp_can_become_fma(ctx, instr));
3102 
3103    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3104 
3105    uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
3106    if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
3107       return false;
3108    /* if the omod/clamp instruction is dead, then the single user of this
3109     * instruction is a different instruction */
3110    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3111       return false;
3112 
3113    if (def_info.instr->definitions[0].bytes() != instr->definitions[0].bytes())
3114       return false;
3115 
3116    /* MADs/FMAs are created later, so we don't have to update the original add */
3117    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3118 
3119    if (!def_info.is_clamp() && (instr->valu().clamp || instr->valu().omod))
3120       return false;
3121 
3122    if (needs_vop3)
3123       instr->format = asVOP3(instr->format);
3124 
3125    if (!def_info.is_clamp() && instr->opcode == aco_opcode::v_interp_p2_f32_inreg)
3126       interp_p2_f32_inreg_to_fma_dpp(instr);
3127 
3128    if (def_info.is_omod2())
3129       instr->valu().omod = 1;
3130    else if (def_info.is_omod4())
3131       instr->valu().omod = 2;
3132    else if (def_info.is_omod5())
3133       instr->valu().omod = 3;
3134    else if (def_info.is_clamp())
3135       instr->valu().clamp = true;
3136 
3137    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3138    ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
3139    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3140 
3141    return true;
3142 }
3143 
3144 /* Combine an p_insert (or p_extract, in some cases) instruction with instr.
3145  * p_insert(instr(...)) -> instr_insert().
3146  */
3147 bool
apply_insert(opt_ctx & ctx,aco_ptr<Instruction> & instr)3148 apply_insert(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3149 {
3150    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1)
3151       return false;
3152 
3153    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3154    if (!def_info.is_insert())
3155       return false;
3156    /* if the insert instruction is dead, then the single user of this
3157     * instruction is a different instruction */
3158    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3159       return false;
3160 
3161    /* MADs/FMAs are created later, so we don't have to update the original add */
3162    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3163 
3164    SubdwordSel sel = parse_insert(def_info.instr);
3165    assert(sel);
3166 
3167    if (!can_use_SDWA(ctx.program->gfx_level, instr, true))
3168       return false;
3169 
3170    convert_to_SDWA(ctx.program->gfx_level, instr);
3171    if (instr->sdwa().dst_sel.size() != 4)
3172       return false;
3173    instr->sdwa().dst_sel = sel;
3174 
3175    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3176    ctx.info[instr->definitions[0].tempId()].label = 0;
3177    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3178 
3179    return true;
3180 }
3181 
3182 /* Remove superfluous extract after ds_read like so:
3183  * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN()
3184  */
3185 bool
apply_ds_extract(opt_ctx & ctx,aco_ptr<Instruction> & extract)3186 apply_ds_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract)
3187 {
3188    /* Check if p_extract has a usedef operand and is the only user. */
3189    if (!ctx.info[extract->operands[0].tempId()].is_usedef() ||
3190        ctx.uses[extract->operands[0].tempId()] > 1)
3191       return false;
3192 
3193    /* Check if the usedef is a DS instruction. */
3194    Instruction* ds = ctx.info[extract->operands[0].tempId()].instr;
3195    if (ds->format != Format::DS)
3196       return false;
3197 
3198    unsigned extract_idx = extract->operands[1].constantValue();
3199    unsigned bits_extracted = extract->operands[2].constantValue();
3200    unsigned sign_ext = extract->operands[3].constantValue();
3201    unsigned dst_bitsize = extract->definitions[0].bytes() * 8u;
3202 
3203    /* TODO: These are doable, but probably don't occur too often. */
3204    if (extract_idx || sign_ext || dst_bitsize != 32)
3205       return false;
3206 
3207    unsigned bits_loaded = 0;
3208    if (ds->opcode == aco_opcode::ds_read_u8 || ds->opcode == aco_opcode::ds_read_u8_d16)
3209       bits_loaded = 8;
3210    else if (ds->opcode == aco_opcode::ds_read_u16 || ds->opcode == aco_opcode::ds_read_u16_d16)
3211       bits_loaded = 16;
3212    else
3213       return false;
3214 
3215    /* Shrink the DS load if the extracted bit size is smaller. */
3216    bits_loaded = MIN2(bits_loaded, bits_extracted);
3217 
3218    /* Change the DS opcode so it writes the full register. */
3219    if (bits_loaded == 8)
3220       ds->opcode = aco_opcode::ds_read_u8;
3221    else if (bits_loaded == 16)
3222       ds->opcode = aco_opcode::ds_read_u16;
3223    else
3224       unreachable("Forgot to add DS opcode above.");
3225 
3226    /* The DS now produces the exact same thing as the extract, remove the extract. */
3227    std::swap(ds->definitions[0], extract->definitions[0]);
3228    ctx.uses[extract->definitions[0].tempId()] = 0;
3229    ctx.info[ds->definitions[0].tempId()].label = 0;
3230    return true;
3231 }
3232 
3233 /* v_and(a, v_subbrev_co(0, 0, vcc)) -> v_cndmask(0, a, vcc) */
3234 bool
combine_and_subbrev(opt_ctx & ctx,aco_ptr<Instruction> & instr)3235 combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3236 {
3237    if (instr->usesModifiers())
3238       return false;
3239 
3240    for (unsigned i = 0; i < 2; i++) {
3241       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3242       if (op_instr && op_instr->opcode == aco_opcode::v_subbrev_co_u32 &&
3243           op_instr->operands[0].constantEquals(0) && op_instr->operands[1].constantEquals(0) &&
3244           !op_instr->usesModifiers()) {
3245 
3246          aco_ptr<Instruction> new_instr;
3247          if (instr->operands[!i].isTemp() &&
3248              instr->operands[!i].getTemp().type() == RegType::vgpr) {
3249             new_instr.reset(create_instruction(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1));
3250          } else if (ctx.program->gfx_level >= GFX10 ||
3251                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3252             new_instr.reset(
3253                create_instruction(aco_opcode::v_cndmask_b32, asVOP3(Format::VOP2), 3, 1));
3254          } else {
3255             return false;
3256          }
3257 
3258          new_instr->operands[0] = Operand::zero();
3259          new_instr->operands[1] = instr->operands[!i];
3260          new_instr->operands[2] = copy_operand(ctx, op_instr->operands[2]);
3261          new_instr->definitions[0] = instr->definitions[0];
3262          new_instr->pass_flags = instr->pass_flags;
3263          instr = std::move(new_instr);
3264          decrease_uses(ctx, op_instr);
3265          ctx.info[instr->definitions[0].tempId()].label = 0;
3266          return true;
3267       }
3268    }
3269 
3270    return false;
3271 }
3272 
3273 /* v_and(a, not(b)) -> v_bfi_b32(b, 0, a)
3274  * v_or(a, not(b)) -> v_bfi_b32(b, a, -1)
3275  */
3276 bool
combine_v_andor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)3277 combine_v_andor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3278 {
3279    if (instr->usesModifiers())
3280       return false;
3281 
3282    for (unsigned i = 0; i < 2; i++) {
3283       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3284       if (op_instr && !op_instr->usesModifiers() &&
3285           (op_instr->opcode == aco_opcode::v_not_b32 ||
3286            op_instr->opcode == aco_opcode::s_not_b32)) {
3287 
3288          Operand ops[3] = {
3289             op_instr->operands[0],
3290             Operand::zero(),
3291             instr->operands[!i],
3292          };
3293          if (instr->opcode == aco_opcode::v_or_b32) {
3294             ops[1] = instr->operands[!i];
3295             ops[2] = Operand::c32(-1);
3296          }
3297          if (!check_vop3_operands(ctx, 3, ops))
3298             continue;
3299 
3300          Instruction* new_instr = create_instruction(aco_opcode::v_bfi_b32, Format::VOP3, 3, 1);
3301 
3302          if (op_instr->operands[0].isTemp())
3303             ctx.uses[op_instr->operands[0].tempId()]++;
3304          for (unsigned j = 0; j < 3; j++)
3305             new_instr->operands[j] = ops[j];
3306          new_instr->definitions[0] = instr->definitions[0];
3307          new_instr->pass_flags = instr->pass_flags;
3308          instr.reset(new_instr);
3309          decrease_uses(ctx, op_instr);
3310          ctx.info[instr->definitions[0].tempId()].label = 0;
3311          return true;
3312       }
3313    }
3314 
3315    return false;
3316 }
3317 
3318 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
3319  * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
3320  * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
3321  * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
3322  */
3323 bool
combine_add_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr,bool is_sub)3324 combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
3325 {
3326    if (instr->usesModifiers())
3327       return false;
3328 
3329    /* Substractions: start at operand 1 to avoid mixup such as
3330     * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
3331     */
3332    unsigned start_op_idx = is_sub ? 1 : 0;
3333 
3334    /* Don't allow 24-bit operands on subtraction because
3335     * v_mad_i32_i24 applies a sign extension.
3336     */
3337    bool allow_24bit = !is_sub;
3338 
3339    for (unsigned i = start_op_idx; i < 2; i++) {
3340       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3341       if (!op_instr)
3342          continue;
3343 
3344       if (op_instr->opcode != aco_opcode::s_lshl_b32 &&
3345           op_instr->opcode != aco_opcode::v_lshlrev_b32)
3346          continue;
3347 
3348       int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
3349 
3350       if (op_instr->operands[shift_op_idx].isConstant() &&
3351           ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
3352            op_instr->operands[!shift_op_idx].is16bit())) {
3353          uint32_t multiplier = 1 << (op_instr->operands[shift_op_idx].constantValue() % 32u);
3354          if (is_sub)
3355             multiplier = -multiplier;
3356          if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
3357             continue;
3358 
3359          Operand ops[3] = {
3360             op_instr->operands[!shift_op_idx],
3361             Operand::c32(multiplier),
3362             instr->operands[!i],
3363          };
3364          if (!check_vop3_operands(ctx, 3, ops))
3365             return false;
3366 
3367          ctx.uses[instr->operands[i].tempId()]--;
3368 
3369          aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : aco_opcode::v_mad_u32_u24;
3370          aco_ptr<Instruction> new_instr{create_instruction(mad_op, Format::VOP3, 3, 1)};
3371          for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
3372             new_instr->operands[op_idx] = ops[op_idx];
3373          new_instr->definitions[0] = instr->definitions[0];
3374          new_instr->pass_flags = instr->pass_flags;
3375          instr = std::move(new_instr);
3376          ctx.info[instr->definitions[0].tempId()].label = 0;
3377          return true;
3378       }
3379    }
3380 
3381    return false;
3382 }
3383 
3384 void
propagate_swizzles(VALU_instruction * instr,bool opsel_lo,bool opsel_hi)3385 propagate_swizzles(VALU_instruction* instr, bool opsel_lo, bool opsel_hi)
3386 {
3387    /* propagate swizzles which apply to a result down to the instruction's operands:
3388     * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
3389    uint8_t tmp_lo = instr->opsel_lo;
3390    uint8_t tmp_hi = instr->opsel_hi;
3391    uint8_t neg_lo = instr->neg_lo;
3392    uint8_t neg_hi = instr->neg_hi;
3393    if (opsel_lo == 1) {
3394       instr->opsel_lo = tmp_hi;
3395       instr->neg_lo = neg_hi;
3396    }
3397    if (opsel_hi == 0) {
3398       instr->opsel_hi = tmp_lo;
3399       instr->neg_hi = neg_lo;
3400    }
3401 }
3402 
3403 void
combine_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr)3404 combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3405 {
3406    VALU_instruction* vop3p = &instr->valu();
3407 
3408    /* apply clamp */
3409    if (instr->opcode == aco_opcode::v_pk_mul_f16 && instr->operands[1].constantEquals(0x3C00) &&
3410        vop3p->clamp && instr->operands[0].isTemp() && ctx.uses[instr->operands[0].tempId()] == 1 &&
3411        !vop3p->opsel_lo[1] && !vop3p->opsel_hi[1]) {
3412 
3413       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3414       if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
3415          VALU_instruction* candidate = &ctx.info[instr->operands[0].tempId()].instr->valu();
3416          candidate->clamp = true;
3417          propagate_swizzles(candidate, vop3p->opsel_lo[0], vop3p->opsel_hi[0]);
3418          instr->definitions[0].swapTemp(candidate->definitions[0]);
3419          ctx.info[candidate->definitions[0].tempId()].instr = candidate;
3420          ctx.uses[instr->definitions[0].tempId()]--;
3421          return;
3422       }
3423    }
3424 
3425    /* check for fneg modifiers */
3426    for (unsigned i = 0; i < instr->operands.size(); i++) {
3427       if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
3428          continue;
3429       Operand& op = instr->operands[i];
3430       if (!op.isTemp())
3431          continue;
3432 
3433       ssa_info& info = ctx.info[op.tempId()];
3434       if (info.is_vop3p() && info.instr->opcode == aco_opcode::v_pk_mul_f16 &&
3435           (info.instr->operands[0].constantEquals(0x3C00) ||
3436            info.instr->operands[1].constantEquals(0x3C00))) {
3437 
3438          VALU_instruction* fneg = &info.instr->valu();
3439 
3440          unsigned fneg_src = fneg->operands[0].constantEquals(0x3C00);
3441 
3442          if (fneg->opsel_lo[1 - fneg_src] || fneg->opsel_hi[1 - fneg_src])
3443             continue;
3444 
3445          Operand ops[3];
3446          for (unsigned j = 0; j < instr->operands.size(); j++)
3447             ops[j] = instr->operands[j];
3448          ops[i] = fneg->operands[fneg_src];
3449          if (!check_vop3_operands(ctx, instr->operands.size(), ops))
3450             continue;
3451 
3452          if (fneg->clamp)
3453             continue;
3454          instr->operands[i] = fneg->operands[fneg_src];
3455 
3456          /* opsel_lo/hi is either 0 or 1:
3457           * if 0 - pick selection from fneg->lo
3458           * if 1 - pick selection from fneg->hi
3459           */
3460          bool opsel_lo = vop3p->opsel_lo[i];
3461          bool opsel_hi = vop3p->opsel_hi[i];
3462          bool neg_lo = fneg->neg_lo[0] ^ fneg->neg_lo[1];
3463          bool neg_hi = fneg->neg_hi[0] ^ fneg->neg_hi[1];
3464          vop3p->neg_lo[i] ^= opsel_lo ? neg_hi : neg_lo;
3465          vop3p->neg_hi[i] ^= opsel_hi ? neg_hi : neg_lo;
3466          vop3p->opsel_lo[i] ^= opsel_lo ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3467          vop3p->opsel_hi[i] ^= opsel_hi ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3468 
3469          if (--ctx.uses[fneg->definitions[0].tempId()])
3470             ctx.uses[fneg->operands[fneg_src].tempId()]++;
3471       }
3472    }
3473 
3474    if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) {
3475       bool fadd = instr->opcode == aco_opcode::v_pk_add_f16;
3476       if (fadd && instr->definitions[0].isPrecise())
3477          return;
3478       if (!fadd && instr->valu().clamp)
3479          return;
3480 
3481       Instruction* mul_instr = nullptr;
3482       unsigned add_op_idx = 0;
3483       bitarray8 mul_neg_lo = 0, mul_neg_hi = 0, mul_opsel_lo = 0, mul_opsel_hi = 0;
3484       uint32_t uses = UINT32_MAX;
3485 
3486       /* find the 'best' mul instruction to combine with the add */
3487       for (unsigned i = 0; i < 2; i++) {
3488          Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3489          if (!op_instr)
3490             continue;
3491 
3492          if (ctx.info[instr->operands[i].tempId()].is_vop3p()) {
3493             if (fadd) {
3494                if (op_instr->opcode != aco_opcode::v_pk_mul_f16 ||
3495                    op_instr->definitions[0].isPrecise())
3496                   continue;
3497             } else {
3498                if (op_instr->opcode != aco_opcode::v_pk_mul_lo_u16)
3499                   continue;
3500             }
3501 
3502             Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3503             if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3504                continue;
3505 
3506             /* no clamp allowed between mul and add */
3507             if (op_instr->valu().clamp)
3508                continue;
3509 
3510             mul_instr = op_instr;
3511             add_op_idx = 1 - i;
3512             uses = ctx.uses[instr->operands[i].tempId()];
3513             mul_neg_lo = mul_instr->valu().neg_lo;
3514             mul_neg_hi = mul_instr->valu().neg_hi;
3515             mul_opsel_lo = mul_instr->valu().opsel_lo;
3516             mul_opsel_hi = mul_instr->valu().opsel_hi;
3517          } else if (instr->operands[i].bytes() == 2) {
3518             if ((fadd && (op_instr->opcode != aco_opcode::v_mul_f16 ||
3519                           op_instr->definitions[0].isPrecise())) ||
3520                 (!fadd && op_instr->opcode != aco_opcode::v_mul_lo_u16 &&
3521                  op_instr->opcode != aco_opcode::v_mul_lo_u16_e64))
3522                continue;
3523 
3524             if (op_instr->valu().clamp || op_instr->valu().omod || op_instr->valu().abs)
3525                continue;
3526 
3527             if (op_instr->isDPP() || (op_instr->isSDWA() && (op_instr->sdwa().sel[0].size() < 2 ||
3528                                                              op_instr->sdwa().sel[1].size() < 2)))
3529                continue;
3530 
3531             Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3532             if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3533                continue;
3534 
3535             mul_instr = op_instr;
3536             add_op_idx = 1 - i;
3537             uses = ctx.uses[instr->operands[i].tempId()];
3538             mul_neg_lo = mul_instr->valu().neg;
3539             mul_neg_hi = mul_instr->valu().neg;
3540             if (mul_instr->isSDWA()) {
3541                for (unsigned j = 0; j < 2; j++)
3542                   mul_opsel_lo[j] = mul_instr->sdwa().sel[j].offset();
3543             } else {
3544                mul_opsel_lo = mul_instr->valu().opsel;
3545             }
3546             mul_opsel_hi = mul_opsel_lo;
3547          }
3548       }
3549 
3550       if (!mul_instr)
3551          return;
3552 
3553       /* turn mul + packed add into v_pk_fma_f16 */
3554       aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
3555       aco_ptr<Instruction> fma{create_instruction(mad, Format::VOP3P, 3, 1)};
3556       fma->operands[0] = copy_operand(ctx, mul_instr->operands[0]);
3557       fma->operands[1] = copy_operand(ctx, mul_instr->operands[1]);
3558       fma->operands[2] = instr->operands[add_op_idx];
3559       fma->valu().clamp = vop3p->clamp;
3560       fma->valu().neg_lo = mul_neg_lo;
3561       fma->valu().neg_hi = mul_neg_hi;
3562       fma->valu().opsel_lo = mul_opsel_lo;
3563       fma->valu().opsel_hi = mul_opsel_hi;
3564       propagate_swizzles(&fma->valu(), vop3p->opsel_lo[1 - add_op_idx],
3565                          vop3p->opsel_hi[1 - add_op_idx]);
3566       fma->valu().opsel_lo[2] = vop3p->opsel_lo[add_op_idx];
3567       fma->valu().opsel_hi[2] = vop3p->opsel_hi[add_op_idx];
3568       fma->valu().neg_lo[2] = vop3p->neg_lo[add_op_idx];
3569       fma->valu().neg_hi[2] = vop3p->neg_hi[add_op_idx];
3570       fma->valu().neg_lo[1] = fma->valu().neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
3571       fma->valu().neg_hi[1] = fma->valu().neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
3572       fma->definitions[0] = instr->definitions[0];
3573       fma->pass_flags = instr->pass_flags;
3574       instr = std::move(fma);
3575       ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
3576       decrease_uses(ctx, mul_instr);
3577       return;
3578    }
3579 }
3580 
3581 bool
can_use_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3582 can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3583 {
3584    if (ctx.program->gfx_level < GFX9)
3585       return false;
3586 
3587    /* v_mad_mix* on GFX9 always flushes denormals for 16-bit inputs/outputs */
3588    if (ctx.program->gfx_level == GFX9 && ctx.fp_mode.denorm16_64)
3589       return false;
3590 
3591    if (instr->valu().omod)
3592       return false;
3593 
3594    switch (instr->opcode) {
3595    case aco_opcode::v_add_f32:
3596    case aco_opcode::v_sub_f32:
3597    case aco_opcode::v_subrev_f32:
3598    case aco_opcode::v_mul_f32: return !instr->isSDWA() && !instr->isDPP();
3599    case aco_opcode::v_fma_f32:
3600       return ctx.program->dev.fused_mad_mix || !instr->definitions[0].isPrecise();
3601    case aco_opcode::v_fma_mix_f32:
3602    case aco_opcode::v_fma_mixlo_f16: return true;
3603    default: return false;
3604    }
3605 }
3606 
3607 void
to_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3608 to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3609 {
3610    ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
3611 
3612    if (instr->opcode == aco_opcode::v_fma_f32) {
3613       instr->format = (Format)((uint32_t)withoutVOP3(instr->format) | (uint32_t)(Format::VOP3P));
3614       instr->opcode = aco_opcode::v_fma_mix_f32;
3615       return;
3616    }
3617 
3618    bool is_add = instr->opcode != aco_opcode::v_mul_f32;
3619 
3620    aco_ptr<Instruction> vop3p{create_instruction(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
3621 
3622    for (unsigned i = 0; i < instr->operands.size(); i++) {
3623       vop3p->operands[is_add + i] = instr->operands[i];
3624       vop3p->valu().neg_lo[is_add + i] = instr->valu().neg[i];
3625       vop3p->valu().neg_hi[is_add + i] = instr->valu().abs[i];
3626    }
3627    if (instr->opcode == aco_opcode::v_mul_f32) {
3628       vop3p->operands[2] = Operand::zero();
3629       vop3p->valu().neg_lo[2] = true;
3630    } else if (is_add) {
3631       vop3p->operands[0] = Operand::c32(0x3f800000);
3632       if (instr->opcode == aco_opcode::v_sub_f32)
3633          vop3p->valu().neg_lo[2] ^= true;
3634       else if (instr->opcode == aco_opcode::v_subrev_f32)
3635          vop3p->valu().neg_lo[1] ^= true;
3636    }
3637    vop3p->definitions[0] = instr->definitions[0];
3638    vop3p->valu().clamp = instr->valu().clamp;
3639    vop3p->pass_flags = instr->pass_flags;
3640    instr = std::move(vop3p);
3641 
3642    if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
3643       ctx.info[instr->definitions[0].tempId()].instr = instr.get();
3644 }
3645 
3646 bool
combine_output_conversion(opt_ctx & ctx,aco_ptr<Instruction> & instr)3647 combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3648 {
3649    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3650    if (!def_info.is_f2f16())
3651       return false;
3652    Instruction* conv = def_info.instr;
3653 
3654    if (!ctx.uses[conv->definitions[0].tempId()] || ctx.uses[instr->definitions[0].tempId()] != 1)
3655       return false;
3656 
3657    if (conv->usesModifiers())
3658       return false;
3659 
3660    if (interp_can_become_fma(ctx, instr))
3661       interp_p2_f32_inreg_to_fma_dpp(instr);
3662 
3663    if (!can_use_mad_mix(ctx, instr))
3664       return false;
3665 
3666    if (!instr->isVOP3P())
3667       to_mad_mix(ctx, instr);
3668 
3669    instr->opcode = aco_opcode::v_fma_mixlo_f16;
3670    instr->definitions[0].swapTemp(conv->definitions[0]);
3671    if (conv->definitions[0].isPrecise())
3672       instr->definitions[0].setPrecise(true);
3673    ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
3674    ctx.uses[conv->definitions[0].tempId()]--;
3675 
3676    return true;
3677 }
3678 
3679 void
combine_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3680 combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3681 {
3682    if (!can_use_mad_mix(ctx, instr))
3683       return;
3684 
3685    for (unsigned i = 0; i < instr->operands.size(); i++) {
3686       if (!instr->operands[i].isTemp())
3687          continue;
3688       Temp tmp = instr->operands[i].getTemp();
3689       if (!ctx.info[tmp.id()].is_f2f32())
3690          continue;
3691 
3692       Instruction* conv = ctx.info[tmp.id()].instr;
3693       if (conv->valu().clamp || conv->valu().omod) {
3694          continue;
3695       } else if (conv->isSDWA() &&
3696                  (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2)) {
3697          continue;
3698       } else if (conv->isDPP()) {
3699          continue;
3700       }
3701 
3702       if (get_operand_size(instr, i) != 32)
3703          continue;
3704 
3705       /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect
3706        * check_vop3_operands(). */
3707       Operand op[3];
3708       for (unsigned j = 0; j < instr->operands.size(); j++)
3709          op[j] = instr->operands[j];
3710       op[i] = conv->operands[0];
3711       if (!check_vop3_operands(ctx, instr->operands.size(), op))
3712          continue;
3713       if (!conv->operands[0].isOfType(RegType::vgpr) && instr->isDPP())
3714          continue;
3715 
3716       if (!instr->isVOP3P()) {
3717          bool is_add =
3718             instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
3719          to_mad_mix(ctx, instr);
3720          i += is_add;
3721       }
3722 
3723       if (--ctx.uses[tmp.id()])
3724          ctx.uses[conv->operands[0].tempId()]++;
3725       instr->operands[i].setTemp(conv->operands[0].getTemp());
3726       if (conv->definitions[0].isPrecise())
3727          instr->definitions[0].setPrecise(true);
3728       instr->valu().opsel_hi[i] = true;
3729       if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
3730          instr->valu().opsel_lo[i] = true;
3731       else
3732          instr->valu().opsel_lo[i] = conv->valu().opsel[0];
3733       bool neg = conv->valu().neg[0];
3734       bool abs = conv->valu().abs[0];
3735       if (!instr->valu().abs[i]) {
3736          instr->valu().neg[i] ^= neg;
3737          instr->valu().abs[i] = abs;
3738       }
3739    }
3740 }
3741 
3742 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
3743 // this would mean that we'd have to fix the instruction uses while value propagation
3744 
3745 /* also returns true for inf */
3746 bool
is_pow_of_two(opt_ctx & ctx,Operand op)3747 is_pow_of_two(opt_ctx& ctx, Operand op)
3748 {
3749    if (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(op.bytes() * 8))
3750       return is_pow_of_two(ctx, get_constant_op(ctx, ctx.info[op.tempId()], op.bytes() * 8));
3751    else if (!op.isConstant())
3752       return false;
3753 
3754    uint64_t val = op.constantValue64();
3755 
3756    if (op.bytes() == 4) {
3757       uint32_t exponent = (val & 0x7f800000) >> 23;
3758       uint32_t fraction = val & 0x007fffff;
3759       return (exponent >= 127) && (fraction == 0);
3760    } else if (op.bytes() == 2) {
3761       uint32_t exponent = (val & 0x7c00) >> 10;
3762       uint32_t fraction = val & 0x03ff;
3763       return (exponent >= 15) && (fraction == 0);
3764    } else {
3765       assert(op.bytes() == 8);
3766       uint64_t exponent = (val & UINT64_C(0x7ff0000000000000)) >> 52;
3767       uint64_t fraction = val & UINT64_C(0x000fffffffffffff);
3768       return (exponent >= 1023) && (fraction == 0);
3769    }
3770 }
3771 
3772 void
combine_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)3773 combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3774 {
3775    if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
3776       return;
3777 
3778    if (instr->isVALU() || instr->isSALU()) {
3779       /* Apply SDWA. Do this after label_instruction() so it can remove
3780        * label_extract if not all instructions can take SDWA. */
3781       for (unsigned i = 0; i < instr->operands.size(); i++) {
3782          Operand& op = instr->operands[i];
3783          if (!op.isTemp())
3784             continue;
3785          ssa_info& info = ctx.info[op.tempId()];
3786          if (!info.is_extract())
3787             continue;
3788          /* if there are that many uses, there are likely better combinations */
3789          // TODO: delay applying extract to a point where we know better
3790          if (ctx.uses[op.tempId()] > 4) {
3791             info.label &= ~label_extract;
3792             continue;
3793          }
3794          if (info.is_extract() &&
3795              (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
3796               instr->operands[i].getTemp().type() == RegType::sgpr) &&
3797              can_apply_extract(ctx, instr, i, info)) {
3798             /* Increase use count of the extract's operand if the extract still has uses. */
3799             apply_extract(ctx, instr, i, info);
3800             if (--ctx.uses[instr->operands[i].tempId()])
3801                ctx.uses[info.instr->operands[0].tempId()]++;
3802             instr->operands[i].setTemp(info.instr->operands[0].getTemp());
3803          }
3804       }
3805    }
3806 
3807    if (instr->isVALU()) {
3808       if (can_apply_sgprs(ctx, instr))
3809          apply_sgprs(ctx, instr);
3810       combine_mad_mix(ctx, instr);
3811       while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr))
3812          ;
3813       apply_insert(ctx, instr);
3814    }
3815 
3816    if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
3817        instr->opcode != aco_opcode::v_fma_mixlo_f16)
3818       return combine_vop3p(ctx, instr);
3819 
3820    if (instr->isSDWA() || instr->isDPP())
3821       return;
3822 
3823    if (instr->opcode == aco_opcode::p_extract || instr->opcode == aco_opcode::p_extract_vector) {
3824       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3825       if (info.is_extract() && can_apply_extract(ctx, instr, 0, info)) {
3826          apply_extract(ctx, instr, 0, info);
3827          if (--ctx.uses[instr->operands[0].tempId()])
3828             ctx.uses[info.instr->operands[0].tempId()]++;
3829          instr->operands[0].setTemp(info.instr->operands[0].getTemp());
3830       }
3831 
3832       if (instr->opcode == aco_opcode::p_extract)
3833          apply_ds_extract(ctx, instr);
3834    }
3835 
3836    /* TODO: There are still some peephole optimizations that could be done:
3837     * - abs(a - b) -> s_absdiff_i32
3838     * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
3839     * - patterns for v_alignbit_b32 and v_alignbyte_b32
3840     * These aren't probably too interesting though.
3841     * There are also patterns for v_cmp_class_f{16,32,64}. This is difficult but
3842     * probably more useful than the previously mentioned optimizations.
3843     * The various comparison optimizations also currently only work with 32-bit
3844     * floats. */
3845 
3846    /* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
3847    if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) &&
3848        ctx.uses[instr->operands[1].tempId()] == 1) {
3849       Temp val = ctx.info[instr->definitions[0].tempId()].temp;
3850 
3851       if (!ctx.info[val.id()].is_mul())
3852          return;
3853 
3854       Instruction* mul_instr = ctx.info[val.id()].instr;
3855 
3856       if (mul_instr->operands[0].isLiteral())
3857          return;
3858       if (mul_instr->valu().clamp)
3859          return;
3860       if (mul_instr->isSDWA() || mul_instr->isDPP())
3861          return;
3862       if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
3863           mul_instr->definitions[0].isSZPreserve())
3864          return;
3865       if (mul_instr->definitions[0].bytes() != instr->definitions[0].bytes())
3866          return;
3867 
3868       /* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
3869       ctx.uses[mul_instr->definitions[0].tempId()]--;
3870       Definition def = instr->definitions[0];
3871       bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg();
3872       bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
3873       uint32_t pass_flags = instr->pass_flags;
3874       Format format = mul_instr->format == Format::VOP2 ? asVOP3(Format::VOP2) : mul_instr->format;
3875       instr.reset(create_instruction(mul_instr->opcode, format, mul_instr->operands.size(), 1));
3876       std::copy(mul_instr->operands.cbegin(), mul_instr->operands.cend(), instr->operands.begin());
3877       instr->pass_flags = pass_flags;
3878       instr->definitions[0] = def;
3879       VALU_instruction& new_mul = instr->valu();
3880       VALU_instruction& mul = mul_instr->valu();
3881       new_mul.neg = mul.neg;
3882       new_mul.abs = mul.abs;
3883       new_mul.omod = mul.omod;
3884       new_mul.opsel = mul.opsel;
3885       new_mul.opsel_lo = mul.opsel_lo;
3886       new_mul.opsel_hi = mul.opsel_hi;
3887       if (is_abs) {
3888          new_mul.neg[0] = new_mul.neg[1] = false;
3889          new_mul.abs[0] = new_mul.abs[1] = true;
3890       }
3891       new_mul.neg[0] ^= is_neg;
3892       new_mul.clamp = false;
3893 
3894       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
3895       return;
3896    }
3897 
3898    /* combine mul+add -> mad */
3899    bool is_add_mix =
3900       (instr->opcode == aco_opcode::v_fma_mix_f32 ||
3901        instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
3902       !instr->valu().neg_lo[0] &&
3903       ((instr->operands[0].constantEquals(0x3f800000) && !instr->valu().opsel_hi[0]) ||
3904        (instr->operands[0].constantEquals(0x3C00) && instr->valu().opsel_hi[0] &&
3905         !instr->valu().opsel_lo[0]));
3906    bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
3907                 instr->opcode == aco_opcode::v_subrev_f32;
3908    bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
3909                 instr->opcode == aco_opcode::v_subrev_f16;
3910    bool mad64 =
3911       instr->opcode == aco_opcode::v_add_f64_e64 || instr->opcode == aco_opcode::v_add_f64;
3912    if (is_add_mix || mad16 || mad32 || mad64) {
3913       Instruction* mul_instr = nullptr;
3914       unsigned add_op_idx = 0;
3915       uint32_t uses = UINT32_MAX;
3916       bool emit_fma = false;
3917       /* find the 'best' mul instruction to combine with the add */
3918       for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
3919          if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
3920             continue;
3921          ssa_info& info = ctx.info[instr->operands[i].tempId()];
3922 
3923          /* no clamp/omod allowed between mul and add */
3924          if (info.instr->isVOP3() && (info.instr->valu().clamp || info.instr->valu().omod))
3925             continue;
3926          if (info.instr->isVOP3P() && info.instr->valu().clamp)
3927             continue;
3928          /* v_fma_mix_f32/etc can't do omod */
3929          if (info.instr->isVOP3P() && instr->isVOP3() && instr->valu().omod)
3930             continue;
3931          /* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
3932          if (is_add_mix && info.instr->definitions[0].bytes() == 2)
3933             continue;
3934 
3935          if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
3936             continue;
3937 
3938          bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
3939          bool mad_mix = is_add_mix || info.instr->isVOP3P();
3940 
3941          /* Multiplication by power-of-two should never need rounding. 1/power-of-two also works,
3942           * but using fma removes denormal flushing (0xfffffe * 0.5 + 0x810001a2).
3943           */
3944          bool is_fma_precise = is_pow_of_two(ctx, info.instr->operands[0]) ||
3945                                is_pow_of_two(ctx, info.instr->operands[1]);
3946 
3947          bool has_fma = mad16 || mad64 || (legacy && ctx.program->gfx_level >= GFX10_3) ||
3948                         (mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
3949                         (mad_mix && ctx.program->dev.fused_mad_mix);
3950          bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
3951                                 : ((mad32 && ctx.program->gfx_level < GFX10_3) ||
3952                                    (mad16 && ctx.program->gfx_level <= GFX9));
3953          bool can_use_fma =
3954             has_fma &&
3955             (!(info.instr->definitions[0].isPrecise() || instr->definitions[0].isPrecise()) ||
3956              is_fma_precise);
3957          bool can_use_mad =
3958             has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
3959          if (mad_mix && legacy)
3960             continue;
3961          if (!can_use_fma && !can_use_mad)
3962             continue;
3963 
3964          unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
3965          Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
3966                           instr->operands[candidate_add_op_idx]};
3967          if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
3968              ctx.uses[instr->operands[i].tempId()] > uses)
3969             continue;
3970 
3971          if (ctx.uses[instr->operands[i].tempId()] == uses) {
3972             unsigned cur_idx = mul_instr->definitions[0].tempId();
3973             unsigned new_idx = info.instr->definitions[0].tempId();
3974             if (cur_idx > new_idx)
3975                continue;
3976          }
3977 
3978          mul_instr = info.instr;
3979          add_op_idx = candidate_add_op_idx;
3980          uses = ctx.uses[instr->operands[i].tempId()];
3981          emit_fma = !can_use_mad;
3982       }
3983 
3984       if (mul_instr) {
3985          /* turn mul+add into v_mad/v_fma */
3986          Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1],
3987                           instr->operands[add_op_idx]};
3988          ctx.uses[mul_instr->definitions[0].tempId()]--;
3989          if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3990             if (op[0].isTemp())
3991                ctx.uses[op[0].tempId()]++;
3992             if (op[1].isTemp())
3993                ctx.uses[op[1].tempId()]++;
3994          }
3995 
3996          bool neg[3] = {false, false, false};
3997          bool abs[3] = {false, false, false};
3998          unsigned omod = 0;
3999          bool clamp = false;
4000          bitarray8 opsel_lo = 0;
4001          bitarray8 opsel_hi = 0;
4002          bitarray8 opsel = 0;
4003          unsigned mul_op_idx = (instr->isVOP3P() ? 3 : 1) - add_op_idx;
4004 
4005          VALU_instruction& valu_mul = mul_instr->valu();
4006          neg[0] = valu_mul.neg[0];
4007          neg[1] = valu_mul.neg[1];
4008          abs[0] = valu_mul.abs[0];
4009          abs[1] = valu_mul.abs[1];
4010          opsel_lo = valu_mul.opsel_lo & 0x3;
4011          opsel_hi = valu_mul.opsel_hi & 0x3;
4012          opsel = valu_mul.opsel & 0x3;
4013 
4014          VALU_instruction& valu = instr->valu();
4015          neg[2] = valu.neg[add_op_idx];
4016          abs[2] = valu.abs[add_op_idx];
4017          opsel_lo[2] = valu.opsel_lo[add_op_idx];
4018          opsel_hi[2] = valu.opsel_hi[add_op_idx];
4019          opsel[2] = valu.opsel[add_op_idx];
4020          opsel[3] = valu.opsel[3];
4021          omod = valu.omod;
4022          clamp = valu.clamp;
4023          /* abs of the multiplication result */
4024          if (valu.abs[mul_op_idx]) {
4025             neg[0] = false;
4026             neg[1] = false;
4027             abs[0] = true;
4028             abs[1] = true;
4029          }
4030          /* neg of the multiplication result */
4031          neg[1] ^= valu.neg[mul_op_idx];
4032 
4033          if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
4034             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
4035          else if (instr->opcode == aco_opcode::v_subrev_f32 ||
4036                   instr->opcode == aco_opcode::v_subrev_f16)
4037             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
4038 
4039          aco_ptr<Instruction> add_instr = std::move(instr);
4040          aco_ptr<Instruction> mad;
4041          if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
4042             assert(!omod);
4043             assert(!opsel);
4044 
4045             aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
4046                                                                        : aco_opcode::v_fma_mix_f32;
4047             mad.reset(create_instruction(mad_op, Format::VOP3P, 3, 1));
4048          } else {
4049             assert(!opsel_lo);
4050             assert(!opsel_hi);
4051 
4052             aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
4053             if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
4054                assert(emit_fma == (ctx.program->gfx_level >= GFX10_3));
4055                mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
4056             } else if (mad16) {
4057                mad_op = emit_fma ? (ctx.program->gfx_level == GFX8 ? aco_opcode::v_fma_legacy_f16
4058                                                                    : aco_opcode::v_fma_f16)
4059                                  : (ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_f16
4060                                                                    : aco_opcode::v_mad_f16);
4061             } else if (mad64) {
4062                mad_op = aco_opcode::v_fma_f64;
4063             }
4064 
4065             mad.reset(create_instruction(mad_op, Format::VOP3, 3, 1));
4066          }
4067 
4068          for (unsigned i = 0; i < 3; i++) {
4069             mad->operands[i] = op[i];
4070             mad->valu().neg[i] = neg[i];
4071             mad->valu().abs[i] = abs[i];
4072          }
4073          mad->valu().omod = omod;
4074          mad->valu().clamp = clamp;
4075          mad->valu().opsel_lo = opsel_lo;
4076          mad->valu().opsel_hi = opsel_hi;
4077          mad->valu().opsel = opsel;
4078          mad->definitions[0] = add_instr->definitions[0];
4079          mad->definitions[0].setPrecise(add_instr->definitions[0].isPrecise() ||
4080                                         mul_instr->definitions[0].isPrecise());
4081          mad->pass_flags = add_instr->pass_flags;
4082 
4083          instr = std::move(mad);
4084 
4085          /* mark this ssa_def to be re-checked for profitability and literals */
4086          ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
4087          ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4088          return;
4089       }
4090    }
4091    /* v_mul_f32(v_cndmask_b32(0, 1.0, cond), a) -> v_cndmask_b32(0, a, cond) */
4092    else if (((instr->opcode == aco_opcode::v_mul_f32 && !instr->definitions[0].isNaNPreserve() &&
4093               !instr->definitions[0].isInfPreserve()) ||
4094              (instr->opcode == aco_opcode::v_mul_legacy_f32 &&
4095               !instr->definitions[0].isSZPreserve())) &&
4096             !instr->usesModifiers() && !ctx.fp_mode.must_flush_denorms32) {
4097       for (unsigned i = 0; i < 2; i++) {
4098          if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2f() &&
4099              ctx.uses[instr->operands[i].tempId()] == 1 && instr->operands[!i].isTemp() &&
4100              instr->operands[!i].getTemp().type() == RegType::vgpr) {
4101             ctx.uses[instr->operands[i].tempId()]--;
4102             ctx.uses[ctx.info[instr->operands[i].tempId()].temp.id()]++;
4103 
4104             aco_ptr<Instruction> new_instr{
4105                create_instruction(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1)};
4106             new_instr->operands[0] = Operand::zero();
4107             new_instr->operands[1] = instr->operands[!i];
4108             new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
4109             new_instr->definitions[0] = instr->definitions[0];
4110             new_instr->pass_flags = instr->pass_flags;
4111             instr = std::move(new_instr);
4112             ctx.info[instr->definitions[0].tempId()].label = 0;
4113             return;
4114          }
4115       }
4116    } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->gfx_level >= GFX9) {
4117       if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
4118                                 1 | 2)) {
4119       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
4120                                        "012", 1 | 2)) {
4121       } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4122       } else if (combine_v_andor_not(ctx, instr)) {
4123       }
4124    } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->gfx_level >= GFX10) {
4125       if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
4126                                 1 | 2)) {
4127       } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32,
4128                                        "012", 1 | 2)) {
4129       } else if (combine_xor_not(ctx, instr)) {
4130       }
4131    } else if (instr->opcode == aco_opcode::v_not_b32 && ctx.program->gfx_level >= GFX10) {
4132       combine_not_xor(ctx, instr);
4133    } else if (instr->opcode == aco_opcode::v_add_u16 && !instr->valu().clamp) {
4134       combine_three_valu_op(
4135          ctx, instr, aco_opcode::v_mul_lo_u16,
4136          ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_u16 : aco_opcode::v_mad_u16,
4137          "120", 1 | 2);
4138    } else if (instr->opcode == aco_opcode::v_add_u16_e64 && !instr->valu().clamp) {
4139       combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16_e64, aco_opcode::v_mad_u16, "120",
4140                             1 | 2);
4141    } else if (instr->opcode == aco_opcode::v_add_u32 && !instr->usesModifiers()) {
4142       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4143       } else if (combine_add_bcnt(ctx, instr)) {
4144       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4145                                        aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4146       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_i32_i24,
4147                                        aco_opcode::v_mad_i32_i24, "120", 1 | 2)) {
4148       } else if (ctx.program->gfx_level >= GFX9) {
4149          if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120",
4150                                    1 | 2)) {
4151          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32,
4152                                           "120", 1 | 2)) {
4153          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32,
4154                                           "012", 1 | 2)) {
4155          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32,
4156                                           "012", 1 | 2)) {
4157          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32,
4158                                           "012", 1 | 2)) {
4159          } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4160          }
4161       }
4162    } else if ((instr->opcode == aco_opcode::v_add_co_u32 ||
4163                instr->opcode == aco_opcode::v_add_co_u32_e64) &&
4164               !instr->usesModifiers()) {
4165       bool carry_out = ctx.uses[instr->definitions[1].tempId()] > 0;
4166       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4167       } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
4168       } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4169                                                      aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4170       } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_i32_i24,
4171                                                      aco_opcode::v_mad_i32_i24, "120", 1 | 2)) {
4172       } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
4173       }
4174    } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == aco_opcode::v_sub_co_u32 ||
4175               instr->opcode == aco_opcode::v_sub_co_u32_e64) {
4176       bool carry_out =
4177          instr->opcode != aco_opcode::v_sub_u32 && ctx.uses[instr->definitions[1].tempId()] > 0;
4178       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
4179       } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
4180       }
4181    } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
4182               instr->opcode == aco_opcode::v_subrev_co_u32 ||
4183               instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
4184       combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
4185    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->gfx_level >= GFX9) {
4186       combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120",
4187                             2);
4188    } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) &&
4189               ctx.program->gfx_level >= GFX9) {
4190       combine_salu_lshl_add(ctx, instr);
4191    } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) {
4192       if (!combine_salu_not_bitwise(ctx, instr))
4193          combine_inverse_comparison(ctx, instr);
4194    } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
4195               instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
4196       combine_salu_n2(ctx, instr);
4197    } else if (instr->opcode == aco_opcode::s_abs_i32) {
4198       combine_sabsdiff(ctx, instr);
4199    } else if (instr->opcode == aco_opcode::v_and_b32) {
4200       if (combine_and_subbrev(ctx, instr)) {
4201       } else if (combine_v_andor_not(ctx, instr)) {
4202       }
4203    } else if (instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) {
4204       /* set existing v_fma_f32 with label_mad so we can create v_fmamk_f32/v_fmaak_f32.
4205        * since ctx.uses[mad_info::mul_temp_id] is always 0, we don't have to worry about
4206        * select_instruction() using mad_info::add_instr.
4207        */
4208       ctx.mad_infos.emplace_back(nullptr, 0);
4209       ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4210    } else if (instr->opcode == aco_opcode::v_med3_f32 || instr->opcode == aco_opcode::v_med3_f16) {
4211       /* Optimize v_med3 to v_add so that it can be dual issued on GFX11. We start with v_med3 in
4212        * case omod can be applied.
4213        */
4214       unsigned idx;
4215       if (detect_clamp(instr.get(), &idx)) {
4216          instr->format = asVOP3(Format::VOP2);
4217          instr->operands[0] = instr->operands[idx];
4218          instr->operands[1] = Operand::zero();
4219          instr->opcode =
4220             instr->opcode == aco_opcode::v_med3_f32 ? aco_opcode::v_add_f32 : aco_opcode::v_add_f16;
4221          instr->valu().clamp = true;
4222          instr->valu().abs = (uint8_t)instr->valu().abs[idx];
4223          instr->valu().neg = (uint8_t)instr->valu().neg[idx];
4224          instr->operands.pop_back();
4225       }
4226    } else {
4227       aco_opcode min, max, min3, max3, med3, minmax;
4228       bool some_gfx9_only;
4229       if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &minmax,
4230                           &some_gfx9_only) &&
4231           (!some_gfx9_only || ctx.program->gfx_level >= GFX9)) {
4232          if (combine_minmax(ctx, instr, instr->opcode == min ? max : min,
4233                             instr->opcode == min ? min3 : max3, minmax)) {
4234          } else {
4235             combine_clamp(ctx, instr, min, max, med3);
4236          }
4237       }
4238    }
4239 }
4240 
4241 struct remat_entry {
4242    Instruction* instr;
4243    uint32_t block;
4244 };
4245 
4246 inline bool
is_constant(Instruction * instr)4247 is_constant(Instruction* instr)
4248 {
4249    if (instr->opcode != aco_opcode::p_parallelcopy || instr->operands.size() != 1)
4250       return false;
4251 
4252    return instr->operands[0].isConstant() && instr->definitions[0].isTemp();
4253 }
4254 
4255 void
remat_constants_instr(opt_ctx & ctx,aco::map<Temp,remat_entry> & constants,Instruction * instr,uint32_t block_idx)4256 remat_constants_instr(opt_ctx& ctx, aco::map<Temp, remat_entry>& constants, Instruction* instr,
4257                       uint32_t block_idx)
4258 {
4259    for (Operand& op : instr->operands) {
4260       if (!op.isTemp())
4261          continue;
4262 
4263       auto it = constants.find(op.getTemp());
4264       if (it == constants.end())
4265          continue;
4266 
4267       /* Check if we already emitted the same constant in this block. */
4268       if (it->second.block != block_idx) {
4269          /* Rematerialize the constant. */
4270          Builder bld(ctx.program, &ctx.instructions);
4271          Operand const_op = it->second.instr->operands[0];
4272          it->second.instr = bld.copy(bld.def(op.regClass()), const_op);
4273          it->second.block = block_idx;
4274          ctx.uses.push_back(0);
4275          ctx.info.push_back(ctx.info[op.tempId()]);
4276       }
4277 
4278       /* Use the rematerialized constant and update information about latest use. */
4279       if (op.getTemp() != it->second.instr->definitions[0].getTemp()) {
4280          ctx.uses[op.tempId()]--;
4281          op.setTemp(it->second.instr->definitions[0].getTemp());
4282          ctx.uses[op.tempId()]++;
4283       }
4284    }
4285 }
4286 
4287 /**
4288  * This pass implements a simple constant rematerialization.
4289  * As common subexpression elimination (CSE) might increase the live-ranges
4290  * of loaded constants over large distances, this pass splits the live-ranges
4291  * again by re-emitting constants in every basic block.
4292  */
4293 void
rematerialize_constants(opt_ctx & ctx)4294 rematerialize_constants(opt_ctx& ctx)
4295 {
4296    aco::monotonic_buffer_resource memory(1024);
4297    aco::map<Temp, remat_entry> constants(memory);
4298 
4299    for (Block& block : ctx.program->blocks) {
4300       if (block.logical_idom == -1)
4301          continue;
4302 
4303       if (block.logical_idom == (int)block.index)
4304          constants.clear();
4305 
4306       ctx.instructions.reserve(block.instructions.size());
4307 
4308       for (aco_ptr<Instruction>& instr : block.instructions) {
4309          if (is_dead(ctx.uses, instr.get()))
4310             continue;
4311 
4312          if (is_constant(instr.get())) {
4313             Temp tmp = instr->definitions[0].getTemp();
4314             constants[tmp] = {instr.get(), block.index};
4315          } else if (!is_phi(instr)) {
4316             remat_constants_instr(ctx, constants, instr.get(), block.index);
4317          }
4318 
4319          ctx.instructions.emplace_back(instr.release());
4320       }
4321 
4322       block.instructions = std::move(ctx.instructions);
4323    }
4324 }
4325 
4326 bool
to_uniform_bool_instr(opt_ctx & ctx,aco_ptr<Instruction> & instr)4327 to_uniform_bool_instr(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4328 {
4329    /* Check every operand to make sure they are suitable. */
4330    for (Operand& op : instr->operands) {
4331       if (!op.isTemp())
4332          return false;
4333       if (!ctx.info[op.tempId()].is_uniform_bool() && !ctx.info[op.tempId()].is_uniform_bitwise())
4334          return false;
4335    }
4336 
4337    switch (instr->opcode) {
4338    case aco_opcode::s_and_b32:
4339    case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_and_b32; break;
4340    case aco_opcode::s_or_b32:
4341    case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_or_b32; break;
4342    case aco_opcode::s_xor_b32:
4343    case aco_opcode::s_xor_b64: instr->opcode = aco_opcode::s_absdiff_i32; break;
4344    default:
4345       /* Don't transform other instructions. They are very unlikely to appear here. */
4346       return false;
4347    }
4348 
4349    for (Operand& op : instr->operands) {
4350       ctx.uses[op.tempId()]--;
4351 
4352       if (ctx.info[op.tempId()].is_uniform_bool()) {
4353          /* Just use the uniform boolean temp. */
4354          op.setTemp(ctx.info[op.tempId()].temp);
4355       } else if (ctx.info[op.tempId()].is_uniform_bitwise()) {
4356          /* Use the SCC definition of the predecessor instruction.
4357           * This allows the predecessor to get picked up by the same optimization (if it has no
4358           * divergent users), and it also makes sure that the current instruction will keep working
4359           * even if the predecessor won't be transformed.
4360           */
4361          Instruction* pred_instr = ctx.info[op.tempId()].instr;
4362          assert(pred_instr->definitions.size() >= 2);
4363          assert(pred_instr->definitions[1].isFixed() &&
4364                 pred_instr->definitions[1].physReg() == scc);
4365          op.setTemp(pred_instr->definitions[1].getTemp());
4366       } else {
4367          unreachable("Invalid operand on uniform bitwise instruction.");
4368       }
4369 
4370       ctx.uses[op.tempId()]++;
4371    }
4372 
4373    instr->definitions[0].setTemp(Temp(instr->definitions[0].tempId(), s1));
4374    ctx.program->temp_rc[instr->definitions[0].tempId()] = s1;
4375    assert(instr->operands[0].regClass() == s1);
4376    assert(instr->operands[1].regClass() == s1);
4377    return true;
4378 }
4379 
4380 void
select_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4381 select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4382 {
4383    const uint32_t threshold = 4;
4384 
4385    if (is_dead(ctx.uses, instr.get())) {
4386       instr.reset();
4387       return;
4388    }
4389 
4390    /* convert split_vector into a copy or extract_vector if only one definition is ever used */
4391    if (instr->opcode == aco_opcode::p_split_vector) {
4392       unsigned num_used = 0;
4393       unsigned idx = 0;
4394       unsigned split_offset = 0;
4395       for (unsigned i = 0, offset = 0; i < instr->definitions.size();
4396            offset += instr->definitions[i++].bytes()) {
4397          if (ctx.uses[instr->definitions[i].tempId()]) {
4398             num_used++;
4399             idx = i;
4400             split_offset = offset;
4401          }
4402       }
4403       bool done = false;
4404       if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
4405           ctx.uses[instr->operands[0].tempId()] == 1) {
4406          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
4407 
4408          unsigned off = 0;
4409          Operand op;
4410          for (Operand& vec_op : vec->operands) {
4411             if (off == split_offset) {
4412                op = vec_op;
4413                break;
4414             }
4415             off += vec_op.bytes();
4416          }
4417          if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
4418             ctx.uses[instr->operands[0].tempId()]--;
4419             for (Operand& vec_op : vec->operands) {
4420                if (vec_op.isTemp())
4421                   ctx.uses[vec_op.tempId()]--;
4422             }
4423             if (op.isTemp())
4424                ctx.uses[op.tempId()]++;
4425 
4426             aco_ptr<Instruction> copy{
4427                create_instruction(aco_opcode::p_parallelcopy, Format::PSEUDO, 1, 1)};
4428             copy->operands[0] = op;
4429             copy->definitions[0] = instr->definitions[idx];
4430             instr = std::move(copy);
4431 
4432             done = true;
4433          }
4434       }
4435 
4436       if (!done && num_used == 1 &&
4437           instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
4438           split_offset % instr->definitions[idx].bytes() == 0) {
4439          aco_ptr<Instruction> extract{
4440             create_instruction(aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
4441          extract->operands[0] = instr->operands[0];
4442          extract->operands[1] =
4443             Operand::c32((uint32_t)split_offset / instr->definitions[idx].bytes());
4444          extract->definitions[0] = instr->definitions[idx];
4445          instr = std::move(extract);
4446       }
4447    }
4448 
4449    mad_info* mad_info = NULL;
4450    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4451       mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
4452       /* re-check mad instructions */
4453       if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) {
4454          ctx.uses[mad_info->mul_temp_id]++;
4455          if (instr->operands[0].isTemp())
4456             ctx.uses[instr->operands[0].tempId()]--;
4457          if (instr->operands[1].isTemp())
4458             ctx.uses[instr->operands[1].tempId()]--;
4459          instr.swap(mad_info->add_instr);
4460          mad_info = NULL;
4461       }
4462       /* check literals */
4463       else if (!instr->isDPP() && !instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_f64 &&
4464                instr->opcode != aco_opcode::v_mad_legacy_f32 &&
4465                instr->opcode != aco_opcode::v_fma_legacy_f32) {
4466          /* FMA can only take literals on GFX10+ */
4467          if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
4468              ctx.program->gfx_level < GFX10)
4469             return;
4470          /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take
4471           * literals (GFX10+), these instructions don't exist.
4472           */
4473          if (instr->opcode == aco_opcode::v_fma_legacy_f16)
4474             return;
4475 
4476          uint32_t literal_mask = 0;
4477          uint32_t fp16_mask = 0;
4478          uint32_t sgpr_mask = 0;
4479          uint32_t vgpr_mask = 0;
4480          uint32_t literal_uses = UINT32_MAX;
4481          uint32_t literal_value = 0;
4482 
4483          /* Iterate in reverse to prefer v_madak/v_fmaak. */
4484          for (int i = 2; i >= 0; i--) {
4485             Operand& op = instr->operands[i];
4486             if (!op.isTemp())
4487                continue;
4488             if (ctx.info[op.tempId()].is_literal(get_operand_size(instr, i))) {
4489                uint32_t new_literal = ctx.info[op.tempId()].val;
4490                float value = uif(new_literal);
4491                uint16_t fp16_val = _mesa_float_to_half(value);
4492                bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
4493                if (_mesa_half_to_float(fp16_val) == value &&
4494                    (!is_denorm || (ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)))
4495                   fp16_mask |= 1 << i;
4496 
4497                if (!literal_mask || literal_value == new_literal) {
4498                   literal_value = new_literal;
4499                   literal_uses = MIN2(literal_uses, ctx.uses[op.tempId()]);
4500                   literal_mask |= 1 << i;
4501                   continue;
4502                }
4503             }
4504             sgpr_mask |= op.isOfType(RegType::sgpr) << i;
4505             vgpr_mask |= op.isOfType(RegType::vgpr) << i;
4506          }
4507 
4508          /* The constant bus limitations before GFX10 disallows SGPRs. */
4509          if (sgpr_mask && ctx.program->gfx_level < GFX10)
4510             literal_mask = 0;
4511 
4512          /* Encoding needs a vgpr. */
4513          if (!vgpr_mask)
4514             literal_mask = 0;
4515 
4516          /* v_madmk/v_fmamk needs a vgpr in the third source. */
4517          if (!(literal_mask & 0b100) && !(vgpr_mask & 0b100))
4518             literal_mask = 0;
4519 
4520          /* opsel with GFX11+ is the only modifier supported by fmamk/fmaak*/
4521          if (instr->valu().abs || instr->valu().neg || instr->valu().omod || instr->valu().clamp ||
4522              (instr->valu().opsel && ctx.program->gfx_level < GFX11))
4523             literal_mask = 0;
4524 
4525          if (instr->valu().opsel & ~vgpr_mask)
4526             literal_mask = 0;
4527 
4528          /* We can't use three unique fp16 literals */
4529          if (fp16_mask == 0b111)
4530             fp16_mask = 0b11;
4531 
4532          if ((instr->opcode == aco_opcode::v_fma_f32 ||
4533               (instr->opcode == aco_opcode::v_mad_f32 && !instr->definitions[0].isPrecise())) &&
4534              !instr->valu().omod && ctx.program->gfx_level >= GFX10 &&
4535              util_bitcount(fp16_mask) > std::max<uint32_t>(util_bitcount(literal_mask), 1)) {
4536             assert(ctx.program->dev.fused_mad_mix);
4537             u_foreach_bit (i, fp16_mask)
4538                ctx.uses[instr->operands[i].tempId()]--;
4539             mad_info->fp16_mask = fp16_mask;
4540             return;
4541          }
4542 
4543          /* Limit the number of literals to apply to not increase the code
4544           * size too much, but always apply literals for v_mad->v_madak
4545           * because both instructions are 64-bit and this doesn't increase
4546           * code size.
4547           * TODO: try to apply the literals earlier to lower the number of
4548           * uses below threshold
4549           */
4550          if (literal_mask && (literal_uses < threshold || (literal_mask & 0b100))) {
4551             u_foreach_bit (i, literal_mask)
4552                ctx.uses[instr->operands[i].tempId()]--;
4553             mad_info->literal_mask = literal_mask;
4554             return;
4555          }
4556       }
4557    }
4558 
4559    /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions
4560     * when it isn't beneficial */
4561    if (instr->isBranch() && instr->operands.size() && instr->operands[0].isTemp() &&
4562        instr->operands[0].isFixed() && instr->operands[0].physReg() == scc) {
4563       ctx.info[instr->operands[0].tempId()].set_scc_needed();
4564       return;
4565    } else if ((instr->opcode == aco_opcode::s_cselect_b64 ||
4566                instr->opcode == aco_opcode::s_cselect_b32) &&
4567               instr->operands[2].isTemp()) {
4568       ctx.info[instr->operands[2].tempId()].set_scc_needed();
4569    }
4570 
4571    /* check for literals */
4572    if (!instr->isSALU() && !instr->isVALU())
4573       return;
4574 
4575    /* Transform uniform bitwise boolean operations to 32-bit when there are no divergent uses. */
4576    if (instr->definitions.size() && ctx.uses[instr->definitions[0].tempId()] == 0 &&
4577        ctx.info[instr->definitions[0].tempId()].is_uniform_bitwise()) {
4578       bool transform_done = to_uniform_bool_instr(ctx, instr);
4579 
4580       if (transform_done && !ctx.info[instr->definitions[1].tempId()].is_scc_needed()) {
4581          /* Swap the two definition IDs in order to avoid overusing the SCC.
4582           * This reduces extra moves generated by RA. */
4583          uint32_t def0_id = instr->definitions[0].getTemp().id();
4584          uint32_t def1_id = instr->definitions[1].getTemp().id();
4585          instr->definitions[0].setTemp(Temp(def1_id, s1));
4586          instr->definitions[1].setTemp(Temp(def0_id, s1));
4587       }
4588 
4589       return;
4590    }
4591 
4592    /* This optimization is done late in order to be able to apply otherwise
4593     * unsafe optimizations such as the inverse comparison optimization.
4594     */
4595    if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
4596       if (instr->operands[0].isTemp() && fixed_to_exec(instr->operands[1]) &&
4597           ctx.uses[instr->operands[0].tempId()] == 1 &&
4598           ctx.uses[instr->definitions[1].tempId()] == 0 &&
4599           can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), instr->pass_flags)) {
4600          ctx.uses[instr->operands[0].tempId()]--;
4601          ctx.info[instr->operands[0].tempId()].instr->definitions[0].setTemp(
4602             instr->definitions[0].getTemp());
4603          instr.reset();
4604          return;
4605       }
4606    }
4607 
4608    /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
4609    if (instr->isVALU() && !instr->isDPP()) {
4610       for (unsigned i = 0; i < instr->operands.size(); i++) {
4611          if (!instr->operands[i].isTemp())
4612             continue;
4613          ssa_info info = ctx.info[instr->operands[i].tempId()];
4614 
4615          if (!info.is_dpp() || info.instr->pass_flags != instr->pass_flags)
4616             continue;
4617 
4618          /* We won't eliminate the DPP mov if the operand is used twice */
4619          bool op_used_twice = false;
4620          for (unsigned j = 0; j < instr->operands.size(); j++)
4621             op_used_twice |= i != j && instr->operands[i] == instr->operands[j];
4622          if (op_used_twice)
4623             continue;
4624 
4625          if (i != 0) {
4626             if (!can_swap_operands(instr, &instr->opcode, 0, i))
4627                continue;
4628             instr->valu().swapOperands(0, i);
4629          }
4630 
4631          if (!can_use_DPP(ctx.program->gfx_level, instr, info.is_dpp8()))
4632             continue;
4633 
4634          bool dpp8 = info.is_dpp8();
4635          bool input_mods = can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, 0) &&
4636                            get_operand_size(instr, 0) == 32;
4637          bool mov_uses_mods = info.instr->valu().neg[0] || info.instr->valu().abs[0];
4638          if (((dpp8 && ctx.program->gfx_level < GFX11) || !input_mods) && mov_uses_mods)
4639             continue;
4640 
4641          convert_to_DPP(ctx.program->gfx_level, instr, dpp8);
4642 
4643          if (dpp8) {
4644             DPP8_instruction* dpp = &instr->dpp8();
4645             dpp->lane_sel = info.instr->dpp8().lane_sel;
4646             dpp->fetch_inactive = info.instr->dpp8().fetch_inactive;
4647             if (mov_uses_mods)
4648                instr->format = asVOP3(instr->format);
4649          } else {
4650             DPP16_instruction* dpp = &instr->dpp16();
4651             dpp->dpp_ctrl = info.instr->dpp16().dpp_ctrl;
4652             dpp->bound_ctrl = info.instr->dpp16().bound_ctrl;
4653             dpp->fetch_inactive = info.instr->dpp16().fetch_inactive;
4654          }
4655 
4656          instr->valu().neg[0] ^= info.instr->valu().neg[0] && !instr->valu().abs[0];
4657          instr->valu().abs[0] |= info.instr->valu().abs[0];
4658 
4659          if (--ctx.uses[info.instr->definitions[0].tempId()])
4660             ctx.uses[info.instr->operands[0].tempId()]++;
4661          instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4662          break;
4663       }
4664    }
4665 
4666    /* Use v_fma_mix for f2f32/f2f16 if it has higher throughput.
4667     * Do this late to not disturb other optimizations.
4668     */
4669    if ((instr->opcode == aco_opcode::v_cvt_f32_f16 || instr->opcode == aco_opcode::v_cvt_f16_f32) &&
4670        ctx.program->gfx_level >= GFX11 && ctx.program->wave_size == 64 && !instr->valu().omod &&
4671        !instr->isDPP()) {
4672       bool is_f2f16 = instr->opcode == aco_opcode::v_cvt_f16_f32;
4673       Instruction* fma = create_instruction(
4674          is_f2f16 ? aco_opcode::v_fma_mixlo_f16 : aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1);
4675       fma->definitions[0] = instr->definitions[0];
4676       fma->operands[0] = instr->operands[0];
4677       fma->valu().opsel_hi[0] = !is_f2f16;
4678       fma->valu().opsel_lo[0] = instr->valu().opsel[0];
4679       fma->valu().clamp = instr->valu().clamp;
4680       fma->valu().abs[0] = instr->valu().abs[0];
4681       fma->valu().neg[0] = instr->valu().neg[0];
4682       fma->operands[1] = Operand::c32(fui(1.0f));
4683       fma->operands[2] = Operand::zero();
4684       fma->valu().neg[2] = true;
4685       instr.reset(fma);
4686       ctx.info[instr->definitions[0].tempId()].label = 0;
4687    }
4688 
4689    if (instr->isSDWA() || (instr->isVOP3() && ctx.program->gfx_level < GFX10) ||
4690        (instr->isVOP3P() && ctx.program->gfx_level < GFX10))
4691       return; /* some encodings can't ever take literals */
4692 
4693    /* we do not apply the literals yet as we don't know if it is profitable */
4694    Operand current_literal(s1);
4695 
4696    unsigned literal_id = 0;
4697    unsigned literal_uses = UINT32_MAX;
4698    Operand literal(s1);
4699    unsigned num_operands = 1;
4700    if (instr->isSALU() || (ctx.program->gfx_level >= GFX10 &&
4701                            (can_use_VOP3(ctx, instr) || instr->isVOP3P()) && !instr->isDPP()))
4702       num_operands = instr->operands.size();
4703    /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
4704    else if (instr->isVALU() && instr->operands.size() >= 3)
4705       return;
4706 
4707    unsigned sgpr_ids[2] = {0, 0};
4708    bool is_literal_sgpr = false;
4709    uint32_t mask = 0;
4710 
4711    /* choose a literal to apply */
4712    for (unsigned i = 0; i < num_operands; i++) {
4713       Operand op = instr->operands[i];
4714       unsigned bits = get_operand_size(instr, i);
4715 
4716       if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
4717           op.tempId() != sgpr_ids[0])
4718          sgpr_ids[!!sgpr_ids[0]] = op.tempId();
4719 
4720       if (op.isLiteral()) {
4721          current_literal = op;
4722          continue;
4723       } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
4724          continue;
4725       }
4726 
4727       if (!alu_can_accept_constant(instr, i))
4728          continue;
4729 
4730       if (ctx.uses[op.tempId()] < literal_uses) {
4731          is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
4732          mask = 0;
4733          literal = Operand::c32(ctx.info[op.tempId()].val);
4734          literal_uses = ctx.uses[op.tempId()];
4735          literal_id = op.tempId();
4736       }
4737 
4738       mask |= (op.tempId() == literal_id) << i;
4739    }
4740 
4741    /* don't go over the constant bus limit */
4742    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64_e64 ||
4743                      instr->opcode == aco_opcode::v_lshlrev_b64 ||
4744                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
4745                      instr->opcode == aco_opcode::v_ashrrev_i64;
4746    unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
4747    if (ctx.program->gfx_level >= GFX10 && !is_shift64)
4748       const_bus_limit = 2;
4749 
4750    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
4751    if (num_sgprs == const_bus_limit && !is_literal_sgpr)
4752       return;
4753 
4754    if (literal_id && literal_uses < threshold &&
4755        (current_literal.isUndefined() ||
4756         (current_literal.size() == literal.size() &&
4757          current_literal.constantValue() == literal.constantValue()))) {
4758       /* mark the literal to be applied */
4759       while (mask) {
4760          unsigned i = u_bit_scan(&mask);
4761          if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
4762             ctx.uses[instr->operands[i].tempId()]--;
4763       }
4764    }
4765 }
4766 
4767 static aco_opcode
sopk_opcode_for_sopc(aco_opcode opcode)4768 sopk_opcode_for_sopc(aco_opcode opcode)
4769 {
4770 #define CTOK(op)                                                                                   \
4771    case aco_opcode::s_cmp_##op##_i32: return aco_opcode::s_cmpk_##op##_i32;                        \
4772    case aco_opcode::s_cmp_##op##_u32: return aco_opcode::s_cmpk_##op##_u32;
4773    switch (opcode) {
4774       CTOK(eq)
4775       CTOK(lg)
4776       CTOK(gt)
4777       CTOK(ge)
4778       CTOK(lt)
4779       CTOK(le)
4780    default: return aco_opcode::num_opcodes;
4781    }
4782 #undef CTOK
4783 }
4784 
4785 static bool
sopc_is_signed(aco_opcode opcode)4786 sopc_is_signed(aco_opcode opcode)
4787 {
4788 #define SOPC(op)                                                                                   \
4789    case aco_opcode::s_cmp_##op##_i32: return true;                                                 \
4790    case aco_opcode::s_cmp_##op##_u32: return false;
4791    switch (opcode) {
4792       SOPC(eq)
4793       SOPC(lg)
4794       SOPC(gt)
4795       SOPC(ge)
4796       SOPC(lt)
4797       SOPC(le)
4798    default: unreachable("Not a valid SOPC instruction.");
4799    }
4800 #undef SOPC
4801 }
4802 
4803 static aco_opcode
sopc_32_swapped(aco_opcode opcode)4804 sopc_32_swapped(aco_opcode opcode)
4805 {
4806 #define SOPC(op1, op2)                                                                             \
4807    case aco_opcode::s_cmp_##op1##_i32: return aco_opcode::s_cmp_##op2##_i32;                       \
4808    case aco_opcode::s_cmp_##op1##_u32: return aco_opcode::s_cmp_##op2##_u32;
4809    switch (opcode) {
4810       SOPC(eq, eq)
4811       SOPC(lg, lg)
4812       SOPC(gt, lt)
4813       SOPC(ge, le)
4814       SOPC(lt, gt)
4815       SOPC(le, ge)
4816    default: return aco_opcode::num_opcodes;
4817    }
4818 #undef SOPC
4819 }
4820 
4821 static void
try_convert_sopc_to_sopk(aco_ptr<Instruction> & instr)4822 try_convert_sopc_to_sopk(aco_ptr<Instruction>& instr)
4823 {
4824    if (sopk_opcode_for_sopc(instr->opcode) == aco_opcode::num_opcodes)
4825       return;
4826 
4827    if (instr->operands[0].isLiteral()) {
4828       std::swap(instr->operands[0], instr->operands[1]);
4829       instr->opcode = sopc_32_swapped(instr->opcode);
4830    }
4831 
4832    if (!instr->operands[1].isLiteral())
4833       return;
4834 
4835    if (instr->operands[0].isFixed() && instr->operands[0].physReg() >= 128)
4836       return;
4837 
4838    uint32_t value = instr->operands[1].constantValue();
4839 
4840    const uint32_t i16_mask = 0xffff8000u;
4841 
4842    bool value_is_i16 = (value & i16_mask) == 0 || (value & i16_mask) == i16_mask;
4843    bool value_is_u16 = !(value & 0xffff0000u);
4844 
4845    if (!value_is_i16 && !value_is_u16)
4846       return;
4847 
4848    if (!value_is_i16 && sopc_is_signed(instr->opcode)) {
4849       if (instr->opcode == aco_opcode::s_cmp_lg_i32)
4850          instr->opcode = aco_opcode::s_cmp_lg_u32;
4851       else if (instr->opcode == aco_opcode::s_cmp_eq_i32)
4852          instr->opcode = aco_opcode::s_cmp_eq_u32;
4853       else
4854          return;
4855    } else if (!value_is_u16 && !sopc_is_signed(instr->opcode)) {
4856       if (instr->opcode == aco_opcode::s_cmp_lg_u32)
4857          instr->opcode = aco_opcode::s_cmp_lg_i32;
4858       else if (instr->opcode == aco_opcode::s_cmp_eq_u32)
4859          instr->opcode = aco_opcode::s_cmp_eq_i32;
4860       else
4861          return;
4862    }
4863 
4864    instr->format = Format::SOPK;
4865    SALU_instruction* instr_sopk = &instr->salu();
4866 
4867    instr_sopk->imm = instr_sopk->operands[1].constantValue() & 0xffff;
4868    instr_sopk->opcode = sopk_opcode_for_sopc(instr_sopk->opcode);
4869    instr_sopk->operands.pop_back();
4870 }
4871 
4872 static void
opt_fma_mix_acc(opt_ctx & ctx,aco_ptr<Instruction> & instr)4873 opt_fma_mix_acc(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4874 {
4875    /* fma_mix is only dual issued on gfx11 if dst and acc type match */
4876    bool f2f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16;
4877 
4878    if (instr->valu().opsel_hi[2] == f2f16 || instr->isDPP())
4879       return;
4880 
4881    bool is_add = false;
4882    for (unsigned i = 0; i < 2; i++) {
4883       uint32_t one = instr->valu().opsel_hi[i] ? 0x3800 : 0x3f800000;
4884       is_add = instr->operands[i].constantEquals(one) && !instr->valu().neg[i] &&
4885                !instr->valu().opsel_lo[i];
4886       if (is_add) {
4887          instr->valu().swapOperands(0, i);
4888          break;
4889       }
4890    }
4891 
4892    if (is_add && instr->valu().opsel_hi[1] == f2f16) {
4893       instr->valu().swapOperands(1, 2);
4894       return;
4895    }
4896 
4897    unsigned literal_count = instr->operands[0].isLiteral() + instr->operands[1].isLiteral() +
4898                             instr->operands[2].isLiteral();
4899 
4900    if (!f2f16 || literal_count > 1)
4901       return;
4902 
4903    /* try to convert constant operand to fp16 */
4904    for (unsigned i = 2 - is_add; i < 3; i++) {
4905       if (!instr->operands[i].isConstant())
4906          continue;
4907 
4908       float value = uif(instr->operands[i].constantValue());
4909       uint16_t fp16_val = _mesa_float_to_half(value);
4910       bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
4911 
4912       if (_mesa_half_to_float(fp16_val) != value ||
4913           (is_denorm && !(ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)))
4914          continue;
4915 
4916       instr->valu().swapOperands(i, 2);
4917 
4918       Operand op16 = Operand::c16(fp16_val);
4919       assert(!op16.isLiteral() || instr->operands[2].isLiteral());
4920 
4921       instr->operands[2] = op16;
4922       instr->valu().opsel_lo[2] = false;
4923       instr->valu().opsel_hi[2] = true;
4924       return;
4925    }
4926 }
4927 
4928 void
apply_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)4929 apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4930 {
4931    /* Cleanup Dead Instructions */
4932    if (!instr)
4933       return;
4934 
4935    /* apply literals on MAD */
4936    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4937       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
4938       const bool madak = (info->literal_mask & 0b100);
4939       bool has_dead_literal = false;
4940       u_foreach_bit (i, info->literal_mask | info->fp16_mask)
4941          has_dead_literal |= ctx.uses[instr->operands[i].tempId()] == 0;
4942 
4943       if (has_dead_literal && info->fp16_mask) {
4944          instr->format = Format::VOP3P;
4945          instr->opcode = aco_opcode::v_fma_mix_f32;
4946 
4947          uint32_t literal = 0;
4948          bool second = false;
4949          u_foreach_bit (i, info->fp16_mask) {
4950             float value = uif(ctx.info[instr->operands[i].tempId()].val);
4951             literal |= _mesa_float_to_half(value) << (second * 16);
4952             instr->valu().opsel_lo[i] = second;
4953             instr->valu().opsel_hi[i] = true;
4954             second = true;
4955          }
4956 
4957          for (unsigned i = 0; i < 3; i++) {
4958             if (info->fp16_mask & (1 << i))
4959                instr->operands[i] = Operand::literal32(literal);
4960          }
4961 
4962          ctx.instructions.emplace_back(std::move(instr));
4963          return;
4964       }
4965 
4966       if (has_dead_literal || madak) {
4967          aco_opcode new_op = madak ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
4968          if (instr->opcode == aco_opcode::v_fma_f32)
4969             new_op = madak ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
4970          else if (instr->opcode == aco_opcode::v_mad_f16 ||
4971                   instr->opcode == aco_opcode::v_mad_legacy_f16)
4972             new_op = madak ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
4973          else if (instr->opcode == aco_opcode::v_fma_f16)
4974             new_op = madak ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
4975 
4976          uint32_t literal = ctx.info[instr->operands[ffs(info->literal_mask) - 1].tempId()].val;
4977          instr->format = Format::VOP2;
4978          instr->opcode = new_op;
4979          for (unsigned i = 0; i < 3; i++) {
4980             if (info->literal_mask & (1 << i))
4981                instr->operands[i] = Operand::literal32(literal);
4982          }
4983          if (madak) { /* add literal -> madak */
4984             if (!instr->operands[1].isOfType(RegType::vgpr))
4985                instr->valu().swapOperands(0, 1);
4986          } else { /* mul literal -> madmk */
4987             if (!(info->literal_mask & 0b10))
4988                instr->valu().swapOperands(0, 1);
4989             instr->valu().swapOperands(1, 2);
4990          }
4991          ctx.instructions.emplace_back(std::move(instr));
4992          return;
4993       }
4994    }
4995 
4996    /* apply literals on other SALU/VALU */
4997    if (instr->isSALU() || instr->isVALU()) {
4998       for (unsigned i = 0; i < instr->operands.size(); i++) {
4999          Operand op = instr->operands[i];
5000          unsigned bits = get_operand_size(instr, i);
5001          if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
5002             Operand literal = Operand::literal32(ctx.info[op.tempId()].val);
5003             instr->format = withoutDPP(instr->format);
5004             if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P)
5005                instr->format = asVOP3(instr->format);
5006             instr->operands[i] = literal;
5007          }
5008       }
5009    }
5010 
5011    if (instr->isSOPC() && ctx.program->gfx_level < GFX12)
5012       try_convert_sopc_to_sopk(instr);
5013 
5014    if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || instr->opcode == aco_opcode::v_fma_mix_f32)
5015       opt_fma_mix_acc(ctx, instr);
5016 
5017    ctx.instructions.emplace_back(std::move(instr));
5018 }
5019 
5020 } /* end namespace */
5021 
5022 void
optimize(Program * program)5023 optimize(Program* program)
5024 {
5025    opt_ctx ctx;
5026    ctx.program = program;
5027    ctx.info = std::vector<ssa_info>(program->peekAllocationId());
5028 
5029    /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
5030    for (Block& block : program->blocks) {
5031       ctx.fp_mode = block.fp_mode;
5032       for (aco_ptr<Instruction>& instr : block.instructions)
5033          label_instruction(ctx, instr);
5034    }
5035 
5036    ctx.uses = dead_code_analysis(program);
5037 
5038    /* 2. Rematerialize constants in every block. */
5039    rematerialize_constants(ctx);
5040 
5041    /* 3. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
5042    for (Block& block : program->blocks) {
5043       ctx.fp_mode = block.fp_mode;
5044       for (aco_ptr<Instruction>& instr : block.instructions)
5045          combine_instruction(ctx, instr);
5046    }
5047 
5048    /* 4. Top-Down DAG pass (backward) to select instructions (includes DCE) */
5049    for (auto block_rit = program->blocks.rbegin(); block_rit != program->blocks.rend();
5050         ++block_rit) {
5051       Block* block = &(*block_rit);
5052       ctx.fp_mode = block->fp_mode;
5053       for (auto instr_rit = block->instructions.rbegin(); instr_rit != block->instructions.rend();
5054            ++instr_rit)
5055          select_instruction(ctx, *instr_rit);
5056    }
5057 
5058    /* 5. Add literals to instructions */
5059    for (Block& block : program->blocks) {
5060       ctx.instructions.reserve(block.instructions.size());
5061       ctx.fp_mode = block.fp_mode;
5062       for (aco_ptr<Instruction>& instr : block.instructions)
5063          apply_literals(ctx, instr);
5064       block.instructions = std::move(ctx.instructions);
5065    }
5066 }
5067 
5068 } // namespace aco
5069