• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2018 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "aco_builder.h"
26 #include "aco_ir.h"
27 
28 #include "util/half_float.h"
29 #include "util/memstream.h"
30 
31 #include <algorithm>
32 #include <array>
33 #include <vector>
34 
35 namespace aco {
36 
37 #ifndef NDEBUG
38 void
perfwarn(Program * program,bool cond,const char * msg,Instruction * instr)39 perfwarn(Program* program, bool cond, const char* msg, Instruction* instr)
40 {
41    if (cond) {
42       char* out;
43       size_t outsize;
44       struct u_memstream mem;
45       u_memstream_open(&mem, &out, &outsize);
46       FILE* const memf = u_memstream_get(&mem);
47 
48       fprintf(memf, "%s: ", msg);
49       aco_print_instr(program->gfx_level, instr, memf);
50       u_memstream_close(&mem);
51 
52       aco_perfwarn(program, out);
53       free(out);
54 
55       if (debug_flags & DEBUG_PERFWARN)
56          exit(1);
57    }
58 }
59 #endif
60 
61 /**
62  * The optimizer works in 4 phases:
63  * (1) The first pass collects information for each ssa-def,
64  *     propagates reg->reg operands of the same type, inline constants
65  *     and neg/abs input modifiers.
66  * (2) The second pass combines instructions like mad, omod, clamp and
67  *     propagates sgpr's on VALU instructions.
68  *     This pass depends on information collected in the first pass.
69  * (3) The third pass goes backwards, and selects instructions,
70  *     i.e. decides if a mad instruction is profitable and eliminates dead code.
71  * (4) The fourth pass cleans up the sequence: literals get applied and dead
72  *     instructions are removed from the sequence.
73  */
74 
75 struct mad_info {
76    aco_ptr<Instruction> add_instr;
77    uint32_t mul_temp_id;
78    uint16_t literal_mask;
79    uint16_t fp16_mask;
80 
mad_infoaco::mad_info81    mad_info(aco_ptr<Instruction> instr, uint32_t id)
82        : add_instr(std::move(instr)), mul_temp_id(id), literal_mask(0), fp16_mask(0)
83    {}
84 };
85 
86 enum Label {
87    label_vec = 1 << 0,
88    label_constant_32bit = 1 << 1,
89    /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
90     * 32-bit operations but this shouldn't cause any issues because we don't
91     * look through any conversions */
92    label_abs = 1 << 2,
93    label_neg = 1 << 3,
94    label_mul = 1 << 4,
95    label_temp = 1 << 5,
96    label_literal = 1 << 6,
97    label_mad = 1 << 7,
98    label_omod2 = 1 << 8,
99    label_omod4 = 1 << 9,
100    label_omod5 = 1 << 10,
101    label_clamp = 1 << 12,
102    label_undefined = 1 << 14,
103    label_vcc = 1 << 15,
104    label_b2f = 1 << 16,
105    label_add_sub = 1 << 17,
106    label_bitwise = 1 << 18,
107    label_minmax = 1 << 19,
108    label_vopc = 1 << 20,
109    label_uniform_bool = 1 << 21,
110    label_constant_64bit = 1 << 22,
111    label_uniform_bitwise = 1 << 23,
112    label_scc_invert = 1 << 24,
113    label_scc_needed = 1 << 26,
114    label_b2i = 1 << 27,
115    label_fcanonicalize = 1 << 28,
116    label_constant_16bit = 1 << 29,
117    label_usedef = 1 << 30,   /* generic label */
118    label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */
119    label_canonicalized = 1ull << 32,
120    label_extract = 1ull << 33,
121    label_insert = 1ull << 34,
122    label_dpp16 = 1ull << 35,
123    label_dpp8 = 1ull << 36,
124    label_f2f32 = 1ull << 37,
125    label_f2f16 = 1ull << 38,
126    label_split = 1ull << 39,
127    label_subgroup_invocation = 1ull << 40,
128 };
129 
130 static constexpr uint64_t instr_usedef_labels =
131    label_vec | label_mul | label_add_sub | label_vop3p | label_bitwise | label_uniform_bitwise |
132    label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 | label_dpp8 |
133    label_f2f32 | label_subgroup_invocation;
134 static constexpr uint64_t instr_mod_labels =
135    label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
136 
137 static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels | label_split;
138 static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f |
139                                         label_uniform_bool | label_scc_invert | label_b2i |
140                                         label_fcanonicalize;
141 static constexpr uint32_t val_labels =
142    label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal | label_mad;
143 
144 static_assert((instr_labels & temp_labels) == 0, "labels cannot intersect");
145 static_assert((instr_labels & val_labels) == 0, "labels cannot intersect");
146 static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
147 
148 struct ssa_info {
149    uint64_t label;
150    union {
151       uint32_t val;
152       Temp temp;
153       Instruction* instr;
154    };
155 
ssa_infoaco::ssa_info156    ssa_info() : label(0) {}
157 
add_labelaco::ssa_info158    void add_label(Label new_label)
159    {
160       /* Since all the instr_usedef_labels use instr for the same thing
161        * (indicating the defining instruction), there is usually no need to
162        * clear any other instr labels. */
163       if (new_label & instr_usedef_labels)
164          label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */
165 
166       if (new_label & instr_mod_labels) {
167          label &= ~instr_labels;
168          label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
169       }
170 
171       if (new_label & temp_labels) {
172          label &= ~temp_labels;
173          label &= ~(instr_labels | val_labels); /* instr, temp and val alias */
174       }
175 
176       uint32_t const_labels =
177          label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
178       if (new_label & const_labels) {
179          label &= ~val_labels | const_labels;
180          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
181       } else if (new_label & val_labels) {
182          label &= ~val_labels;
183          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
184       }
185 
186       label |= new_label;
187    }
188 
set_vecaco::ssa_info189    void set_vec(Instruction* vec)
190    {
191       add_label(label_vec);
192       instr = vec;
193    }
194 
is_vecaco::ssa_info195    bool is_vec() { return label & label_vec; }
196 
set_constantaco::ssa_info197    void set_constant(amd_gfx_level gfx_level, uint64_t constant)
198    {
199       Operand op16 = Operand::c16(constant);
200       Operand op32 = Operand::get_const(gfx_level, constant, 4);
201       add_label(label_literal);
202       val = constant;
203 
204       /* check that no upper bits are lost in case of packed 16bit constants */
205       if (gfx_level >= GFX8 && !op16.isLiteral() &&
206           op16.constantValue16(true) == ((constant >> 16) & 0xffff))
207          add_label(label_constant_16bit);
208 
209       if (!op32.isLiteral())
210          add_label(label_constant_32bit);
211 
212       if (Operand::is_constant_representable(constant, 8))
213          add_label(label_constant_64bit);
214 
215       if (label & label_constant_64bit) {
216          val = Operand::c64(constant).constantValue();
217          if (val != constant)
218             label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
219       }
220    }
221 
is_constantaco::ssa_info222    bool is_constant(unsigned bits)
223    {
224       switch (bits) {
225       case 8: return label & label_literal;
226       case 16: return label & label_constant_16bit;
227       case 32: return label & label_constant_32bit;
228       case 64: return label & label_constant_64bit;
229       }
230       return false;
231    }
232 
is_literalaco::ssa_info233    bool is_literal(unsigned bits)
234    {
235       bool is_lit = label & label_literal;
236       switch (bits) {
237       case 8: return false;
238       case 16: return is_lit && ~(label & label_constant_16bit);
239       case 32: return is_lit && ~(label & label_constant_32bit);
240       case 64: return false;
241       }
242       return false;
243    }
244 
is_constant_or_literalaco::ssa_info245    bool is_constant_or_literal(unsigned bits)
246    {
247       if (bits == 64)
248          return label & label_constant_64bit;
249       else
250          return label & label_literal;
251    }
252 
set_absaco::ssa_info253    void set_abs(Temp abs_temp)
254    {
255       add_label(label_abs);
256       temp = abs_temp;
257    }
258 
is_absaco::ssa_info259    bool is_abs() { return label & label_abs; }
260 
set_negaco::ssa_info261    void set_neg(Temp neg_temp)
262    {
263       add_label(label_neg);
264       temp = neg_temp;
265    }
266 
is_negaco::ssa_info267    bool is_neg() { return label & label_neg; }
268 
set_neg_absaco::ssa_info269    void set_neg_abs(Temp neg_abs_temp)
270    {
271       add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
272       temp = neg_abs_temp;
273    }
274 
set_mulaco::ssa_info275    void set_mul(Instruction* mul)
276    {
277       add_label(label_mul);
278       instr = mul;
279    }
280 
is_mulaco::ssa_info281    bool is_mul() { return label & label_mul; }
282 
set_tempaco::ssa_info283    void set_temp(Temp tmp)
284    {
285       add_label(label_temp);
286       temp = tmp;
287    }
288 
is_tempaco::ssa_info289    bool is_temp() { return label & label_temp; }
290 
set_madaco::ssa_info291    void set_mad(uint32_t mad_info_idx)
292    {
293       add_label(label_mad);
294       val = mad_info_idx;
295    }
296 
is_madaco::ssa_info297    bool is_mad() { return label & label_mad; }
298 
set_omod2aco::ssa_info299    void set_omod2(Instruction* mul)
300    {
301       if (label & temp_labels)
302          return;
303       add_label(label_omod2);
304       instr = mul;
305    }
306 
is_omod2aco::ssa_info307    bool is_omod2() { return label & label_omod2; }
308 
set_omod4aco::ssa_info309    void set_omod4(Instruction* mul)
310    {
311       if (label & temp_labels)
312          return;
313       add_label(label_omod4);
314       instr = mul;
315    }
316 
is_omod4aco::ssa_info317    bool is_omod4() { return label & label_omod4; }
318 
set_omod5aco::ssa_info319    void set_omod5(Instruction* mul)
320    {
321       if (label & temp_labels)
322          return;
323       add_label(label_omod5);
324       instr = mul;
325    }
326 
is_omod5aco::ssa_info327    bool is_omod5() { return label & label_omod5; }
328 
set_clampaco::ssa_info329    void set_clamp(Instruction* med3)
330    {
331       if (label & temp_labels)
332          return;
333       add_label(label_clamp);
334       instr = med3;
335    }
336 
is_clampaco::ssa_info337    bool is_clamp() { return label & label_clamp; }
338 
set_f2f16aco::ssa_info339    void set_f2f16(Instruction* conv)
340    {
341       if (label & temp_labels)
342          return;
343       add_label(label_f2f16);
344       instr = conv;
345    }
346 
is_f2f16aco::ssa_info347    bool is_f2f16() { return label & label_f2f16; }
348 
set_undefinedaco::ssa_info349    void set_undefined() { add_label(label_undefined); }
350 
is_undefinedaco::ssa_info351    bool is_undefined() { return label & label_undefined; }
352 
set_vccaco::ssa_info353    void set_vcc(Temp vcc_val)
354    {
355       add_label(label_vcc);
356       temp = vcc_val;
357    }
358 
is_vccaco::ssa_info359    bool is_vcc() { return label & label_vcc; }
360 
set_b2faco::ssa_info361    void set_b2f(Temp b2f_val)
362    {
363       add_label(label_b2f);
364       temp = b2f_val;
365    }
366 
is_b2faco::ssa_info367    bool is_b2f() { return label & label_b2f; }
368 
set_add_subaco::ssa_info369    void set_add_sub(Instruction* add_sub_instr)
370    {
371       add_label(label_add_sub);
372       instr = add_sub_instr;
373    }
374 
is_add_subaco::ssa_info375    bool is_add_sub() { return label & label_add_sub; }
376 
set_bitwiseaco::ssa_info377    void set_bitwise(Instruction* bitwise_instr)
378    {
379       add_label(label_bitwise);
380       instr = bitwise_instr;
381    }
382 
is_bitwiseaco::ssa_info383    bool is_bitwise() { return label & label_bitwise; }
384 
set_uniform_bitwiseaco::ssa_info385    void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
386 
is_uniform_bitwiseaco::ssa_info387    bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
388 
set_minmaxaco::ssa_info389    void set_minmax(Instruction* minmax_instr)
390    {
391       add_label(label_minmax);
392       instr = minmax_instr;
393    }
394 
is_minmaxaco::ssa_info395    bool is_minmax() { return label & label_minmax; }
396 
set_vopcaco::ssa_info397    void set_vopc(Instruction* vopc_instr)
398    {
399       add_label(label_vopc);
400       instr = vopc_instr;
401    }
402 
is_vopcaco::ssa_info403    bool is_vopc() { return label & label_vopc; }
404 
set_scc_neededaco::ssa_info405    void set_scc_needed() { add_label(label_scc_needed); }
406 
is_scc_neededaco::ssa_info407    bool is_scc_needed() { return label & label_scc_needed; }
408 
set_scc_invertaco::ssa_info409    void set_scc_invert(Temp scc_inv)
410    {
411       add_label(label_scc_invert);
412       temp = scc_inv;
413    }
414 
is_scc_invertaco::ssa_info415    bool is_scc_invert() { return label & label_scc_invert; }
416 
set_uniform_boolaco::ssa_info417    void set_uniform_bool(Temp uniform_bool)
418    {
419       add_label(label_uniform_bool);
420       temp = uniform_bool;
421    }
422 
is_uniform_boolaco::ssa_info423    bool is_uniform_bool() { return label & label_uniform_bool; }
424 
set_b2iaco::ssa_info425    void set_b2i(Temp b2i_val)
426    {
427       add_label(label_b2i);
428       temp = b2i_val;
429    }
430 
is_b2iaco::ssa_info431    bool is_b2i() { return label & label_b2i; }
432 
set_usedefaco::ssa_info433    void set_usedef(Instruction* label_instr)
434    {
435       add_label(label_usedef);
436       instr = label_instr;
437    }
438 
is_usedefaco::ssa_info439    bool is_usedef() { return label & label_usedef; }
440 
set_vop3paco::ssa_info441    void set_vop3p(Instruction* vop3p_instr)
442    {
443       add_label(label_vop3p);
444       instr = vop3p_instr;
445    }
446 
is_vop3paco::ssa_info447    bool is_vop3p() { return label & label_vop3p; }
448 
set_fcanonicalizeaco::ssa_info449    void set_fcanonicalize(Temp tmp)
450    {
451       add_label(label_fcanonicalize);
452       temp = tmp;
453    }
454 
is_fcanonicalizeaco::ssa_info455    bool is_fcanonicalize() { return label & label_fcanonicalize; }
456 
set_canonicalizedaco::ssa_info457    void set_canonicalized() { add_label(label_canonicalized); }
458 
is_canonicalizedaco::ssa_info459    bool is_canonicalized() { return label & label_canonicalized; }
460 
set_f2f32aco::ssa_info461    void set_f2f32(Instruction* cvt)
462    {
463       add_label(label_f2f32);
464       instr = cvt;
465    }
466 
is_f2f32aco::ssa_info467    bool is_f2f32() { return label & label_f2f32; }
468 
set_extractaco::ssa_info469    void set_extract(Instruction* extract)
470    {
471       add_label(label_extract);
472       instr = extract;
473    }
474 
is_extractaco::ssa_info475    bool is_extract() { return label & label_extract; }
476 
set_insertaco::ssa_info477    void set_insert(Instruction* insert)
478    {
479       if (label & temp_labels)
480          return;
481       add_label(label_insert);
482       instr = insert;
483    }
484 
is_insertaco::ssa_info485    bool is_insert() { return label & label_insert; }
486 
set_dpp16aco::ssa_info487    void set_dpp16(Instruction* mov)
488    {
489       add_label(label_dpp16);
490       instr = mov;
491    }
492 
set_dpp8aco::ssa_info493    void set_dpp8(Instruction* mov)
494    {
495       add_label(label_dpp8);
496       instr = mov;
497    }
498 
is_dppaco::ssa_info499    bool is_dpp() { return label & (label_dpp16 | label_dpp8); }
is_dpp16aco::ssa_info500    bool is_dpp16() { return label & label_dpp16; }
is_dpp8aco::ssa_info501    bool is_dpp8() { return label & label_dpp8; }
502 
set_splitaco::ssa_info503    void set_split(Instruction* split)
504    {
505       add_label(label_split);
506       instr = split;
507    }
508 
is_splitaco::ssa_info509    bool is_split() { return label & label_split; }
510 
set_subgroup_invocationaco::ssa_info511    void set_subgroup_invocation(Instruction* label_instr)
512    {
513       add_label(label_subgroup_invocation);
514       instr = label_instr;
515    }
516 
is_subgroup_invocationaco::ssa_info517    bool is_subgroup_invocation() { return label & label_subgroup_invocation; }
518 };
519 
520 struct opt_ctx {
521    Program* program;
522    float_mode fp_mode;
523    std::vector<aco_ptr<Instruction>> instructions;
524    ssa_info* info;
525    std::pair<uint32_t, Temp> last_literal;
526    std::vector<mad_info> mad_infos;
527    std::vector<uint16_t> uses;
528 };
529 
530 bool
can_use_VOP3(opt_ctx & ctx,const aco_ptr<Instruction> & instr)531 can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
532 {
533    if (instr->isVOP3())
534       return true;
535 
536    if (instr->isVOP3P())
537       return false;
538 
539    if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->gfx_level < GFX10)
540       return false;
541 
542    if (instr->isSDWA())
543       return false;
544 
545    if (instr->isDPP() && ctx.program->gfx_level < GFX11)
546       return false;
547 
548    return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
549           instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
550           instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
551           instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
552           instr->opcode != aco_opcode::v_permlane64_b32 &&
553           instr->opcode != aco_opcode::v_readlane_b32 &&
554           instr->opcode != aco_opcode::v_writelane_b32 &&
555           instr->opcode != aco_opcode::v_readfirstlane_b32;
556 }
557 
558 bool
pseudo_propagate_temp(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp temp,unsigned index)559 pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
560 {
561    if (instr->definitions.empty())
562       return false;
563 
564    const bool vgpr =
565       instr->opcode == aco_opcode::p_as_uniform ||
566       std::all_of(instr->definitions.begin(), instr->definitions.end(),
567                   [](const Definition& def) { return def.regClass().type() == RegType::vgpr; });
568 
569    /* don't propagate VGPRs into SGPR instructions */
570    if (temp.type() == RegType::vgpr && !vgpr)
571       return false;
572 
573    bool can_accept_sgpr =
574       ctx.program->gfx_level >= GFX9 ||
575       std::none_of(instr->definitions.begin(), instr->definitions.end(),
576                    [](const Definition& def) { return def.regClass().is_subdword(); });
577 
578    switch (instr->opcode) {
579    case aco_opcode::p_phi:
580    case aco_opcode::p_linear_phi:
581    case aco_opcode::p_parallelcopy:
582    case aco_opcode::p_create_vector:
583       if (temp.bytes() != instr->operands[index].bytes())
584          return false;
585       break;
586    case aco_opcode::p_extract_vector:
587    case aco_opcode::p_extract:
588       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
589          return false;
590       break;
591    case aco_opcode::p_split_vector: {
592       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
593          return false;
594       /* don't increase the vector size */
595       if (temp.bytes() > instr->operands[index].bytes())
596          return false;
597       /* We can decrease the vector size as smaller temporaries are only
598        * propagated by p_as_uniform instructions.
599        * If this propagation leads to invalid IR or hits the assertion below,
600        * it means that some undefined bytes within a dword are begin accessed
601        * and a bug in instruction_selection is likely. */
602       int decrease = instr->operands[index].bytes() - temp.bytes();
603       while (decrease > 0) {
604          decrease -= instr->definitions.back().bytes();
605          instr->definitions.pop_back();
606       }
607       assert(decrease == 0);
608       break;
609    }
610    case aco_opcode::p_as_uniform:
611       if (temp.regClass() == instr->definitions[0].regClass())
612          instr->opcode = aco_opcode::p_parallelcopy;
613       break;
614    default: return false;
615    }
616 
617    instr->operands[index].setTemp(temp);
618    return true;
619 }
620 
621 /* This expects the DPP modifier to be removed. */
622 bool
can_apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)623 can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
624 {
625    assert(instr->isVALU());
626    if (instr->isSDWA() && ctx.program->gfx_level < GFX9)
627       return false;
628    return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
629           instr->opcode != aco_opcode::v_readlane_b32 &&
630           instr->opcode != aco_opcode::v_readlane_b32_e64 &&
631           instr->opcode != aco_opcode::v_writelane_b32 &&
632           instr->opcode != aco_opcode::v_writelane_b32_e64 &&
633           instr->opcode != aco_opcode::v_permlane16_b32 &&
634           instr->opcode != aco_opcode::v_permlanex16_b32 &&
635           instr->opcode != aco_opcode::v_permlane64_b32 &&
636           instr->opcode != aco_opcode::v_interp_p1_f32 &&
637           instr->opcode != aco_opcode::v_interp_p2_f32 &&
638           instr->opcode != aco_opcode::v_interp_mov_f32 &&
639           instr->opcode != aco_opcode::v_interp_p1ll_f16 &&
640           instr->opcode != aco_opcode::v_interp_p1lv_f16 &&
641           instr->opcode != aco_opcode::v_interp_p2_legacy_f16 &&
642           instr->opcode != aco_opcode::v_interp_p2_f16 &&
643           instr->opcode != aco_opcode::v_interp_p10_f32_inreg &&
644           instr->opcode != aco_opcode::v_interp_p2_f32_inreg &&
645           instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg &&
646           instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg &&
647           instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg &&
648           instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg &&
649           instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 &&
650           instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 &&
651           instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
652           instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
653           instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
654           instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4;
655 }
656 
657 bool
is_operand_vgpr(Operand op)658 is_operand_vgpr(Operand op)
659 {
660    return op.isTemp() && op.getTemp().type() == RegType::vgpr;
661 }
662 
663 /* only covers special cases */
664 bool
alu_can_accept_constant(const aco_ptr<Instruction> & instr,unsigned operand)665 alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand)
666 {
667    /* Fixed operands can't accept constants because we need them
668     * to be in their fixed register.
669     */
670    assert(instr->operands.size() > operand);
671    if (instr->operands[operand].isFixed())
672       return false;
673 
674    /* SOPP instructions can't use constants. */
675    if (instr->isSOPP())
676       return false;
677 
678    switch (instr->opcode) {
679    case aco_opcode::v_mac_f32:
680    case aco_opcode::v_writelane_b32:
681    case aco_opcode::v_writelane_b32_e64:
682    case aco_opcode::v_cndmask_b32: return operand != 2;
683    case aco_opcode::s_addk_i32:
684    case aco_opcode::s_mulk_i32:
685    case aco_opcode::p_extract_vector:
686    case aco_opcode::p_split_vector:
687    case aco_opcode::v_readlane_b32:
688    case aco_opcode::v_readlane_b32_e64:
689    case aco_opcode::v_readfirstlane_b32:
690    case aco_opcode::p_extract:
691    case aco_opcode::p_insert: return operand != 0;
692    case aco_opcode::p_bpermute_readlane:
693    case aco_opcode::p_bpermute_shared_vgpr:
694    case aco_opcode::p_bpermute_permlane:
695    case aco_opcode::p_interp_gfx11:
696    case aco_opcode::p_dual_src_export_gfx11:
697    case aco_opcode::v_interp_p1_f32:
698    case aco_opcode::v_interp_p2_f32:
699    case aco_opcode::v_interp_mov_f32:
700    case aco_opcode::v_interp_p1ll_f16:
701    case aco_opcode::v_interp_p1lv_f16:
702    case aco_opcode::v_interp_p2_legacy_f16:
703    case aco_opcode::v_interp_p10_f32_inreg:
704    case aco_opcode::v_interp_p2_f32_inreg:
705    case aco_opcode::v_interp_p10_f16_f32_inreg:
706    case aco_opcode::v_interp_p2_f16_f32_inreg:
707    case aco_opcode::v_interp_p10_rtz_f16_f32_inreg:
708    case aco_opcode::v_interp_p2_rtz_f16_f32_inreg:
709    case aco_opcode::v_wmma_f32_16x16x16_f16:
710    case aco_opcode::v_wmma_f32_16x16x16_bf16:
711    case aco_opcode::v_wmma_f16_16x16x16_f16:
712    case aco_opcode::v_wmma_bf16_16x16x16_bf16:
713    case aco_opcode::v_wmma_i32_16x16x16_iu8:
714    case aco_opcode::v_wmma_i32_16x16x16_iu4: return false;
715    default: return true;
716    }
717 }
718 
719 bool
valu_can_accept_vgpr(aco_ptr<Instruction> & instr,unsigned operand)720 valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
721 {
722    if (instr->opcode == aco_opcode::v_readlane_b32 ||
723        instr->opcode == aco_opcode::v_readlane_b32_e64 ||
724        instr->opcode == aco_opcode::v_writelane_b32 ||
725        instr->opcode == aco_opcode::v_writelane_b32_e64)
726       return operand != 1;
727    if (instr->opcode == aco_opcode::v_permlane16_b32 ||
728        instr->opcode == aco_opcode::v_permlanex16_b32)
729       return operand == 0;
730    return true;
731 }
732 
733 /* check constant bus and literal limitations */
734 bool
check_vop3_operands(opt_ctx & ctx,unsigned num_operands,Operand * operands)735 check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand* operands)
736 {
737    int limit = ctx.program->gfx_level >= GFX10 ? 2 : 1;
738    Operand literal32(s1);
739    Operand literal64(s2);
740    unsigned num_sgprs = 0;
741    unsigned sgpr[] = {0, 0};
742 
743    for (unsigned i = 0; i < num_operands; i++) {
744       Operand op = operands[i];
745 
746       if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
747          /* two reads of the same SGPR count as 1 to the limit */
748          if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
749             if (num_sgprs < 2)
750                sgpr[num_sgprs++] = op.tempId();
751             limit--;
752             if (limit < 0)
753                return false;
754          }
755       } else if (op.isLiteral()) {
756          if (ctx.program->gfx_level < GFX10)
757             return false;
758 
759          if (!literal32.isUndefined() && literal32.constantValue() != op.constantValue())
760             return false;
761          if (!literal64.isUndefined() && literal64.constantValue() != op.constantValue())
762             return false;
763 
764          /* Any number of 32-bit literals counts as only 1 to the limit. Same
765           * (but separately) for 64-bit literals. */
766          if (op.size() == 1 && literal32.isUndefined()) {
767             limit--;
768             literal32 = op;
769          } else if (op.size() == 2 && literal64.isUndefined()) {
770             limit--;
771             literal64 = op;
772          }
773 
774          if (limit < 0)
775             return false;
776       }
777    }
778 
779    return true;
780 }
781 
782 bool
parse_base_offset(opt_ctx & ctx,Instruction * instr,unsigned op_index,Temp * base,uint32_t * offset,bool prevent_overflow)783 parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* base, uint32_t* offset,
784                   bool prevent_overflow)
785 {
786    Operand op = instr->operands[op_index];
787 
788    if (!op.isTemp())
789       return false;
790    Temp tmp = op.getTemp();
791    if (!ctx.info[tmp.id()].is_add_sub())
792       return false;
793 
794    Instruction* add_instr = ctx.info[tmp.id()].instr;
795 
796    unsigned mask = 0x3;
797    bool is_sub = false;
798    switch (add_instr->opcode) {
799    case aco_opcode::v_add_u32:
800    case aco_opcode::v_add_co_u32:
801    case aco_opcode::v_add_co_u32_e64:
802    case aco_opcode::s_add_i32:
803    case aco_opcode::s_add_u32: break;
804    case aco_opcode::v_sub_u32:
805    case aco_opcode::v_sub_i32:
806    case aco_opcode::v_sub_co_u32:
807    case aco_opcode::v_sub_co_u32_e64:
808    case aco_opcode::s_sub_u32:
809    case aco_opcode::s_sub_i32:
810       mask = 0x2;
811       is_sub = true;
812       break;
813    case aco_opcode::v_subrev_u32:
814    case aco_opcode::v_subrev_co_u32:
815    case aco_opcode::v_subrev_co_u32_e64:
816       mask = 0x1;
817       is_sub = true;
818       break;
819    default: return false;
820    }
821    if (prevent_overflow && !add_instr->definitions[0].isNUW())
822       return false;
823 
824    if (add_instr->usesModifiers())
825       return false;
826 
827    u_foreach_bit (i, mask) {
828       if (add_instr->operands[i].isConstant()) {
829          *offset = add_instr->operands[i].constantValue() * (uint32_t)(is_sub ? -1 : 1);
830       } else if (add_instr->operands[i].isTemp() &&
831                  ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
832          *offset = ctx.info[add_instr->operands[i].tempId()].val * (uint32_t)(is_sub ? -1 : 1);
833       } else {
834          continue;
835       }
836       if (!add_instr->operands[!i].isTemp())
837          continue;
838 
839       uint32_t offset2 = 0;
840       if (parse_base_offset(ctx, add_instr, !i, base, &offset2, prevent_overflow)) {
841          *offset += offset2;
842       } else {
843          *base = add_instr->operands[!i].getTemp();
844       }
845       return true;
846    }
847 
848    return false;
849 }
850 
851 void
skip_smem_offset_align(opt_ctx & ctx,SMEM_instruction * smem)852 skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem)
853 {
854    bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
855    if (soe && !smem->operands[1].isConstant())
856       return;
857    /* We don't need to check the constant offset because the address seems to be calculated with
858     * (offset&-4 + const_offset&-4), not (offset+const_offset)&-4.
859     */
860 
861    Operand& op = smem->operands[soe ? smem->operands.size() - 1 : 1];
862    if (!op.isTemp() || !ctx.info[op.tempId()].is_bitwise())
863       return;
864 
865    Instruction* bitwise_instr = ctx.info[op.tempId()].instr;
866    if (bitwise_instr->opcode != aco_opcode::s_and_b32)
867       return;
868 
869    if (bitwise_instr->operands[0].constantEquals(-4) &&
870        bitwise_instr->operands[1].isOfType(op.regClass().type()))
871       op.setTemp(bitwise_instr->operands[1].getTemp());
872    else if (bitwise_instr->operands[1].constantEquals(-4) &&
873             bitwise_instr->operands[0].isOfType(op.regClass().type()))
874       op.setTemp(bitwise_instr->operands[0].getTemp());
875 }
876 
877 void
smem_combine(opt_ctx & ctx,aco_ptr<Instruction> & instr)878 smem_combine(opt_ctx& ctx, aco_ptr<Instruction>& instr)
879 {
880    /* skip &-4 before offset additions: load((a + 16) & -4, 0) */
881    if (!instr->operands.empty())
882       skip_smem_offset_align(ctx, &instr->smem());
883 
884    /* propagate constants and combine additions */
885    if (!instr->operands.empty() && instr->operands[1].isTemp()) {
886       SMEM_instruction& smem = instr->smem();
887       ssa_info info = ctx.info[instr->operands[1].tempId()];
888 
889       Temp base;
890       uint32_t offset;
891       if (info.is_constant_or_literal(32) &&
892           ((ctx.program->gfx_level == GFX6 && info.val <= 0x3FF) ||
893            (ctx.program->gfx_level == GFX7 && info.val <= 0xFFFFFFFF) ||
894            (ctx.program->gfx_level >= GFX8 && info.val <= 0xFFFFF))) {
895          instr->operands[1] = Operand::c32(info.val);
896       } else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, true) &&
897                  base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->gfx_level >= GFX9 &&
898                  offset % 4u == 0) {
899          bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
900          if (soe) {
901             if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) &&
902                 ctx.info[smem.operands.back().tempId()].val == 0) {
903                smem.operands[1] = Operand::c32(offset);
904                smem.operands.back() = Operand(base);
905             }
906          } else {
907             SMEM_instruction* new_instr = create_instruction<SMEM_instruction>(
908                smem.opcode, Format::SMEM, smem.operands.size() + 1, smem.definitions.size());
909             new_instr->operands[0] = smem.operands[0];
910             new_instr->operands[1] = Operand::c32(offset);
911             if (smem.definitions.empty())
912                new_instr->operands[2] = smem.operands[2];
913             new_instr->operands.back() = Operand(base);
914             if (!smem.definitions.empty())
915                new_instr->definitions[0] = smem.definitions[0];
916             new_instr->sync = smem.sync;
917             new_instr->glc = smem.glc;
918             new_instr->dlc = smem.dlc;
919             new_instr->nv = smem.nv;
920             new_instr->disable_wqm = smem.disable_wqm;
921             instr.reset(new_instr);
922          }
923       }
924    }
925 
926    /* skip &-4 after offset additions: load(a & -4, 16) */
927    if (!instr->operands.empty())
928       skip_smem_offset_align(ctx, &instr->smem());
929 }
930 
931 Operand
get_constant_op(opt_ctx & ctx,ssa_info info,uint32_t bits)932 get_constant_op(opt_ctx& ctx, ssa_info info, uint32_t bits)
933 {
934    if (bits == 64)
935       return Operand::c32_or_c64(info.val, true);
936    return Operand::get_const(ctx.program->gfx_level, info.val, bits / 8u);
937 }
938 
939 void
propagate_constants_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned i)940 propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned i)
941 {
942    if (!info.is_constant_or_literal(32))
943       return;
944 
945    assert(instr->operands[i].isTemp());
946    unsigned bits = get_operand_size(instr, i);
947    if (info.is_constant(bits)) {
948       instr->operands[i] = get_constant_op(ctx, info, bits);
949       return;
950    }
951 
952    /* The accumulation operand of dot product instructions ignores opsel. */
953    bool cannot_use_opsel =
954       (instr->opcode == aco_opcode::v_dot4_i32_i8 || instr->opcode == aco_opcode::v_dot2_i32_i16 ||
955        instr->opcode == aco_opcode::v_dot4_i32_iu8 || instr->opcode == aco_opcode::v_dot4_u32_u8 ||
956        instr->opcode == aco_opcode::v_dot2_u32_u16) &&
957       i == 2;
958    if (cannot_use_opsel)
959       return;
960 
961    /* try to fold inline constants */
962    VALU_instruction* vop3p = &instr->valu();
963    bool opsel_lo = vop3p->opsel_lo[i];
964    bool opsel_hi = vop3p->opsel_hi[i];
965 
966    Operand const_op[2];
967    bool const_opsel[2] = {false, false};
968    for (unsigned j = 0; j < 2; j++) {
969       if ((unsigned)opsel_lo != j && (unsigned)opsel_hi != j)
970          continue; /* this half is unused */
971 
972       uint16_t val = info.val >> (j ? 16 : 0);
973       Operand op = Operand::get_const(ctx.program->gfx_level, val, bits / 8u);
974       if (bits == 32 && op.isLiteral()) /* try sign extension */
975          op = Operand::get_const(ctx.program->gfx_level, val | 0xffff0000, 4);
976       if (bits == 32 && op.isLiteral()) { /* try shifting left */
977          op = Operand::get_const(ctx.program->gfx_level, val << 16, 4);
978          const_opsel[j] = true;
979       }
980       if (op.isLiteral())
981          return;
982       const_op[j] = op;
983    }
984 
985    Operand const_lo = const_op[0];
986    Operand const_hi = const_op[1];
987    bool const_lo_opsel = const_opsel[0];
988    bool const_hi_opsel = const_opsel[1];
989 
990    if (opsel_lo == opsel_hi) {
991       /* use the single 16bit value */
992       instr->operands[i] = opsel_lo ? const_hi : const_lo;
993 
994       /* opsel must point the same for both halves */
995       opsel_lo = opsel_lo ? const_hi_opsel : const_lo_opsel;
996       opsel_hi = opsel_lo;
997    } else if (const_lo == const_hi) {
998       /* both constants are the same */
999       instr->operands[i] = const_lo;
1000 
1001       /* opsel must point the same for both halves */
1002       opsel_lo = const_lo_opsel;
1003       opsel_hi = const_lo_opsel;
1004    } else if (const_lo.constantValue16(const_lo_opsel) ==
1005               const_hi.constantValue16(!const_hi_opsel)) {
1006       instr->operands[i] = const_hi;
1007 
1008       /* redirect opsel selection */
1009       opsel_lo = opsel_lo ? const_hi_opsel : !const_hi_opsel;
1010       opsel_hi = opsel_hi ? const_hi_opsel : !const_hi_opsel;
1011    } else if (const_hi.constantValue16(const_hi_opsel) ==
1012               const_lo.constantValue16(!const_lo_opsel)) {
1013       instr->operands[i] = const_lo;
1014 
1015       /* redirect opsel selection */
1016       opsel_lo = opsel_lo ? !const_lo_opsel : const_lo_opsel;
1017       opsel_hi = opsel_hi ? !const_lo_opsel : const_lo_opsel;
1018    } else if (bits == 16 && const_lo.constantValue() == (const_hi.constantValue() ^ (1 << 15))) {
1019       assert(const_lo_opsel == false && const_hi_opsel == false);
1020 
1021       /* const_lo == -const_hi */
1022       if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
1023          return;
1024 
1025       instr->operands[i] = Operand::c16(const_lo.constantValue() & 0x7FFF);
1026       bool neg_lo = const_lo.constantValue() & (1 << 15);
1027       vop3p->neg_lo[i] ^= opsel_lo ^ neg_lo;
1028       vop3p->neg_hi[i] ^= opsel_hi ^ neg_lo;
1029 
1030       /* opsel must point to lo for both operands */
1031       opsel_lo = false;
1032       opsel_hi = false;
1033    }
1034 
1035    vop3p->opsel_lo[i] = opsel_lo;
1036    vop3p->opsel_hi[i] = opsel_hi;
1037 }
1038 
1039 bool
fixed_to_exec(Operand op)1040 fixed_to_exec(Operand op)
1041 {
1042    return op.isFixed() && op.physReg() == exec;
1043 }
1044 
1045 SubdwordSel
parse_extract(Instruction * instr)1046 parse_extract(Instruction* instr)
1047 {
1048    if (instr->opcode == aco_opcode::p_extract) {
1049       unsigned size = instr->operands[2].constantValue() / 8;
1050       unsigned offset = instr->operands[1].constantValue() * size;
1051       bool sext = instr->operands[3].constantEquals(1);
1052       return SubdwordSel(size, offset, sext);
1053    } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) {
1054       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
1055    } else if (instr->opcode == aco_opcode::p_extract_vector) {
1056       unsigned size = instr->definitions[0].bytes();
1057       unsigned offset = instr->operands[1].constantValue() * size;
1058       if (size <= 2)
1059          return SubdwordSel(size, offset, false);
1060    } else if (instr->opcode == aco_opcode::p_split_vector) {
1061       assert(instr->operands[0].bytes() == 4 && instr->definitions[1].bytes() == 2);
1062       return SubdwordSel(2, 2, false);
1063    }
1064 
1065    return SubdwordSel();
1066 }
1067 
1068 SubdwordSel
parse_insert(Instruction * instr)1069 parse_insert(Instruction* instr)
1070 {
1071    if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) &&
1072        instr->operands[1].constantEquals(0)) {
1073       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
1074    } else if (instr->opcode == aco_opcode::p_insert) {
1075       unsigned size = instr->operands[2].constantValue() / 8;
1076       unsigned offset = instr->operands[1].constantValue() * size;
1077       return SubdwordSel(size, offset, false);
1078    } else {
1079       return SubdwordSel();
1080    }
1081 }
1082 
1083 bool
can_apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1084 can_apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1085 {
1086    Temp tmp = info.instr->operands[0].getTemp();
1087    SubdwordSel sel = parse_extract(info.instr);
1088 
1089    if (!sel) {
1090       return false;
1091    } else if (sel.size() == 4) {
1092       return true;
1093    } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1094                instr->opcode == aco_opcode::v_cvt_f32_i32) &&
1095               sel.size() == 1 && !sel.sign_extend()) {
1096       return true;
1097    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1098               sel.offset() == 0 &&
1099               ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1100                (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1101       return true;
1102    } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1103               !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1104               (instr->operands[!idx].is16bit() ||
1105                instr->operands[!idx].constantValue() <= UINT16_MAX)) {
1106       return true;
1107    } else if (idx < 2 && can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1108               (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1109       if (instr->isSDWA() && instr->sdwa().sel[idx] != SubdwordSel::dword)
1110          return false;
1111       return true;
1112    } else if (instr->isVALU() && sel.size() == 2 && !instr->valu().opsel[idx] &&
1113               can_use_opsel(ctx.program->gfx_level, instr->opcode, idx)) {
1114       return true;
1115    } else if (instr->opcode == aco_opcode::p_extract) {
1116       SubdwordSel instrSel = parse_extract(instr.get());
1117 
1118       /* the outer offset must be within extracted range */
1119       if (instrSel.offset() >= sel.size())
1120          return false;
1121 
1122       /* don't remove the sign-extension when increasing the size further */
1123       if (instrSel.size() > sel.size() && !instrSel.sign_extend() && sel.sign_extend())
1124          return false;
1125 
1126       return true;
1127    }
1128 
1129    return false;
1130 }
1131 
1132 /* Combine an p_extract (or p_insert, in some cases) instruction with instr.
1133  * instr(p_extract(...)) -> instr()
1134  */
1135 void
apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1136 apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1137 {
1138    Temp tmp = info.instr->operands[0].getTemp();
1139    SubdwordSel sel = parse_extract(info.instr);
1140    assert(sel);
1141 
1142    instr->operands[idx].set16bit(false);
1143    instr->operands[idx].set24bit(false);
1144 
1145    ctx.info[tmp.id()].label &= ~label_insert;
1146 
1147    if (sel.size() == 4) {
1148       /* full dword selection */
1149    } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1150                instr->opcode == aco_opcode::v_cvt_f32_i32) &&
1151               sel.size() == 1 && !sel.sign_extend()) {
1152       switch (sel.offset()) {
1153       case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break;
1154       case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break;
1155       case 2: instr->opcode = aco_opcode::v_cvt_f32_ubyte2; break;
1156       case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break;
1157       }
1158    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1159               sel.offset() == 0 &&
1160               ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1161                (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1162       /* The undesirable upper bits are already shifted out. */
1163       return;
1164    } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1165               !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1166               (instr->operands[!idx].is16bit() ||
1167                instr->operands[!idx].constantValue() <= UINT16_MAX)) {
1168       Instruction* mad =
1169          create_instruction<VALU_instruction>(aco_opcode::v_mad_u32_u16, Format::VOP3, 3, 1);
1170       mad->definitions[0] = instr->definitions[0];
1171       mad->operands[0] = instr->operands[0];
1172       mad->operands[1] = instr->operands[1];
1173       mad->operands[2] = Operand::zero();
1174       mad->valu().opsel[idx] = sel.offset();
1175       mad->pass_flags = instr->pass_flags;
1176       instr.reset(mad);
1177    } else if (can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1178               (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1179       convert_to_SDWA(ctx.program->gfx_level, instr);
1180       instr->sdwa().sel[idx] = sel;
1181    } else if (instr->isVALU()) {
1182       if (sel.offset()) {
1183          instr->valu().opsel[idx] = true;
1184 
1185          /* VOP12C cannot use opsel with SGPRs. */
1186          if (!instr->isVOP3() && !instr->isVINTERP_INREG() &&
1187              !info.instr->operands[0].isOfType(RegType::vgpr))
1188             instr->format = asVOP3(instr->format);
1189       }
1190    } else if (instr->opcode == aco_opcode::p_extract) {
1191       SubdwordSel instrSel = parse_extract(instr.get());
1192 
1193       unsigned size = std::min(sel.size(), instrSel.size());
1194       unsigned offset = sel.offset() + instrSel.offset();
1195       unsigned sign_extend =
1196          instrSel.sign_extend() && (sel.sign_extend() || instrSel.size() <= sel.size());
1197 
1198       instr->operands[1] = Operand::c32(offset / size);
1199       instr->operands[2] = Operand::c32(size * 8u);
1200       instr->operands[3] = Operand::c32(sign_extend);
1201       return;
1202    }
1203 
1204    /* These are the only labels worth keeping at the moment. */
1205    for (Definition& def : instr->definitions) {
1206       ctx.info[def.tempId()].label &=
1207          (label_mul | label_minmax | label_usedef | label_vopc | label_f2f32 | instr_mod_labels);
1208       if (ctx.info[def.tempId()].label & instr_usedef_labels)
1209          ctx.info[def.tempId()].instr = instr.get();
1210    }
1211 }
1212 
1213 void
check_sdwa_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr)1214 check_sdwa_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1215 {
1216    for (unsigned i = 0; i < instr->operands.size(); i++) {
1217       Operand op = instr->operands[i];
1218       if (!op.isTemp())
1219          continue;
1220       ssa_info& info = ctx.info[op.tempId()];
1221       if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
1222                                 op.getTemp().type() == RegType::sgpr)) {
1223          if (!can_apply_extract(ctx, instr, i, info))
1224             info.label &= ~label_extract;
1225       }
1226    }
1227 }
1228 
1229 bool
does_fp_op_flush_denorms(opt_ctx & ctx,aco_opcode op)1230 does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
1231 {
1232    switch (op) {
1233    case aco_opcode::v_min_f32:
1234    case aco_opcode::v_max_f32:
1235    case aco_opcode::v_med3_f32:
1236    case aco_opcode::v_min3_f32:
1237    case aco_opcode::v_max3_f32:
1238    case aco_opcode::v_min_f16:
1239    case aco_opcode::v_max_f16: return ctx.program->gfx_level > GFX8;
1240    case aco_opcode::v_cndmask_b32:
1241    case aco_opcode::v_cndmask_b16:
1242    case aco_opcode::v_mov_b32:
1243    case aco_opcode::v_mov_b16: return false;
1244    default: return true;
1245    }
1246 }
1247 
1248 bool
can_eliminate_fcanonicalize(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp tmp,unsigned idx)1249 can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp, unsigned idx)
1250 {
1251    float_mode* fp = &ctx.fp_mode;
1252    if (ctx.info[tmp.id()].is_canonicalized() ||
1253        (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1254       return true;
1255 
1256    aco_opcode op = instr->opcode;
1257    return can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, idx) &&
1258           does_fp_op_flush_denorms(ctx, op);
1259 }
1260 
1261 bool
can_eliminate_and_exec(opt_ctx & ctx,Temp tmp,unsigned pass_flags)1262 can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags)
1263 {
1264    if (ctx.info[tmp.id()].is_vopc()) {
1265       Instruction* vopc_instr = ctx.info[tmp.id()].instr;
1266       /* Remove superfluous s_and when the VOPC instruction uses the same exec and thus
1267        * already produces the same result */
1268       return vopc_instr->pass_flags == pass_flags;
1269    }
1270    if (ctx.info[tmp.id()].is_bitwise()) {
1271       Instruction* instr = ctx.info[tmp.id()].instr;
1272       if (instr->operands.size() != 2 || instr->pass_flags != pass_flags)
1273          return false;
1274       if (!(instr->operands[0].isTemp() && instr->operands[1].isTemp()))
1275          return false;
1276       if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
1277          return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) ||
1278                 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1279       } else {
1280          return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) &&
1281                 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1282       }
1283    }
1284    return false;
1285 }
1286 
1287 bool
is_copy_label(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned idx)1288 is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned idx)
1289 {
1290    return info.is_temp() ||
1291           (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp, idx));
1292 }
1293 
1294 bool
is_op_canonicalized(opt_ctx & ctx,Operand op)1295 is_op_canonicalized(opt_ctx& ctx, Operand op)
1296 {
1297    float_mode* fp = &ctx.fp_mode;
1298    if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) ||
1299        (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1300       return true;
1301 
1302    if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
1303       uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
1304       if (op.bytes() == 2)
1305          return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
1306       else if (op.bytes() == 4)
1307          return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff;
1308    }
1309    return false;
1310 }
1311 
1312 bool
is_scratch_offset_valid(opt_ctx & ctx,Instruction * instr,int64_t offset0,int64_t offset1)1313 is_scratch_offset_valid(opt_ctx& ctx, Instruction* instr, int64_t offset0, int64_t offset1)
1314 {
1315    bool negative_unaligned_scratch_offset_bug = ctx.program->gfx_level == GFX10;
1316    int32_t min = ctx.program->dev.scratch_global_offset_min;
1317    int32_t max = ctx.program->dev.scratch_global_offset_max;
1318 
1319    int64_t offset = offset0 + offset1;
1320 
1321    bool has_vgpr_offset = instr && !instr->operands[0].isUndefined();
1322    if (negative_unaligned_scratch_offset_bug && has_vgpr_offset && offset < 0 && offset % 4)
1323       return false;
1324 
1325    return offset >= min && offset <= max;
1326 }
1327 
1328 bool
detect_clamp(Instruction * instr,unsigned * clamped_idx)1329 detect_clamp(Instruction* instr, unsigned* clamped_idx)
1330 {
1331    VALU_instruction& valu = instr->valu();
1332    if (valu.omod != 0 || valu.opsel != 0)
1333       return false;
1334 
1335    unsigned idx = 0;
1336    bool found_zero = false, found_one = false;
1337    bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16;
1338    for (unsigned i = 0; i < 3; i++) {
1339       if (!valu.neg[i] && instr->operands[i].constantEquals(0))
1340          found_zero = true;
1341       else if (!valu.neg[i] &&
1342                instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */
1343          found_one = true;
1344       else
1345          idx = i;
1346    }
1347    if (found_zero && found_one && instr->operands[idx].isTemp()) {
1348       *clamped_idx = idx;
1349       return true;
1350    } else {
1351       return false;
1352    }
1353 }
1354 
1355 void
label_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)1356 label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1357 {
1358    if (instr->isSALU() || instr->isVALU() || instr->isPseudo()) {
1359       ASSERTED bool all_const = false;
1360       for (Operand& op : instr->operands)
1361          all_const =
1362             all_const && (!op.isTemp() || ctx.info[op.tempId()].is_constant_or_literal(32));
1363       perfwarn(ctx.program, all_const, "All instruction operands are constant", instr.get());
1364 
1365       ASSERTED bool is_copy = instr->opcode == aco_opcode::s_mov_b32 ||
1366                               instr->opcode == aco_opcode::s_mov_b64 ||
1367                               instr->opcode == aco_opcode::v_mov_b32;
1368       perfwarn(ctx.program, is_copy && !instr->usesModifiers(), "Use p_parallelcopy instead",
1369                instr.get());
1370    }
1371 
1372    if (instr->isSMEM())
1373       smem_combine(ctx, instr);
1374 
1375    for (unsigned i = 0; i < instr->operands.size(); i++) {
1376       if (!instr->operands[i].isTemp())
1377          continue;
1378 
1379       ssa_info info = ctx.info[instr->operands[i].tempId()];
1380       /* propagate undef */
1381       if (info.is_undefined() && is_phi(instr))
1382          instr->operands[i] = Operand(instr->operands[i].regClass());
1383       /* propagate reg->reg of same type */
1384       while (info.is_temp() && info.temp.regClass() == instr->operands[i].getTemp().regClass()) {
1385          instr->operands[i].setTemp(ctx.info[instr->operands[i].tempId()].temp);
1386          info = ctx.info[info.temp.id()];
1387       }
1388 
1389       /* PSEUDO: propagate temporaries */
1390       if (instr->isPseudo()) {
1391          while (info.is_temp()) {
1392             pseudo_propagate_temp(ctx, instr, info.temp, i);
1393             info = ctx.info[info.temp.id()];
1394          }
1395       }
1396 
1397       /* SALU / PSEUDO: propagate inline constants */
1398       if (instr->isSALU() || instr->isPseudo()) {
1399          unsigned bits = get_operand_size(instr, i);
1400          if ((info.is_constant(bits) || (info.is_literal(bits) && instr->isPseudo())) &&
1401              alu_can_accept_constant(instr, i)) {
1402             instr->operands[i] = get_constant_op(ctx, info, bits);
1403             continue;
1404          }
1405       }
1406 
1407       /* VALU: propagate neg, abs & inline constants */
1408       else if (instr->isVALU()) {
1409          if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::vgpr &&
1410              valu_can_accept_vgpr(instr, i)) {
1411             instr->operands[i].setTemp(info.temp);
1412             info = ctx.info[info.temp.id()];
1413          }
1414          /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
1415          if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) &&
1416              instr->operands.size() == 1) {
1417             instr->format = withoutDPP(instr->format);
1418             instr->operands[i].setTemp(info.temp);
1419             info = ctx.info[info.temp.id()];
1420          }
1421 
1422          /* for instructions other than v_cndmask_b32, the size of the instruction should match the
1423           * operand size */
1424          bool can_use_mod =
1425             instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
1426          can_use_mod &= can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i);
1427 
1428          bool packed_math = instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
1429                             instr->opcode != aco_opcode::v_fma_mixlo_f16 &&
1430                             instr->opcode != aco_opcode::v_fma_mixhi_f16;
1431 
1432          if (instr->isSDWA())
1433             can_use_mod &= instr->sdwa().sel[i].size() == 4;
1434          else if (instr->isVOP3P())
1435             can_use_mod &= !packed_math || !info.is_abs();
1436          else
1437             can_use_mod &= instr->isDPP16() || can_use_VOP3(ctx, instr);
1438 
1439          unsigned bits = get_operand_size(instr, i);
1440          can_use_mod &= instr->operands[i].bytes() * 8 == bits;
1441 
1442          if (info.is_neg() && can_use_mod &&
1443              can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1444             instr->operands[i].setTemp(info.temp);
1445             if (!packed_math && instr->valu().abs[i]) {
1446                /* fabs(fneg(a)) -> fabs(a) */
1447             } else if (instr->opcode == aco_opcode::v_add_f32) {
1448                instr->opcode = i ? aco_opcode::v_sub_f32 : aco_opcode::v_subrev_f32;
1449             } else if (instr->opcode == aco_opcode::v_add_f16) {
1450                instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
1451             } else if (packed_math) {
1452                /* Bit size compat should ensure this. */
1453                assert(!instr->valu().opsel_lo[i] && !instr->valu().opsel_hi[i]);
1454                instr->valu().neg_lo[i] ^= true;
1455                instr->valu().neg_hi[i] ^= true;
1456             } else {
1457                if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1458                   instr->format = asVOP3(instr->format);
1459                instr->valu().neg[i] ^= true;
1460             }
1461          }
1462          if (info.is_abs() && can_use_mod &&
1463              can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1464             if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1465                instr->format = asVOP3(instr->format);
1466             instr->operands[i] = Operand(info.temp);
1467             instr->valu().abs[i] = true;
1468             continue;
1469          }
1470 
1471          if (instr->isVOP3P()) {
1472             propagate_constants_vop3p(ctx, instr, info, i);
1473             continue;
1474          }
1475 
1476          if (info.is_constant(bits) && alu_can_accept_constant(instr, i) &&
1477              (!instr->isSDWA() || ctx.program->gfx_level >= GFX9) && (!instr->isDPP() || i != 1)) {
1478             Operand op = get_constant_op(ctx, info, bits);
1479             perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2,
1480                      "v_cndmask_b32 with a constant selector", instr.get());
1481             if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
1482                 instr->opcode == aco_opcode::v_writelane_b32) {
1483                instr->format = withoutDPP(instr->format);
1484                instr->operands[i] = op;
1485                continue;
1486             } else if (!instr->isVOP3() && can_swap_operands(instr, &instr->opcode)) {
1487                instr->operands[i] = op;
1488                instr->valu().swapOperands(0, i);
1489                continue;
1490             } else if (can_use_VOP3(ctx, instr)) {
1491                instr->format = asVOP3(instr->format);
1492                instr->operands[i] = op;
1493                continue;
1494             }
1495          }
1496       }
1497 
1498       /* MUBUF: propagate constants and combine additions */
1499       else if (instr->isMUBUF()) {
1500          MUBUF_instruction& mubuf = instr->mubuf();
1501          Temp base;
1502          uint32_t offset;
1503          while (info.is_temp())
1504             info = ctx.info[info.temp.id()];
1505 
1506          /* According to AMDGPUDAGToDAGISel::SelectMUBUFScratchOffen(), vaddr
1507           * overflow for scratch accesses works only on GFX9+ and saddr overflow
1508           * never works. Since swizzling is the only thing that separates
1509           * scratch accesses and other accesses and swizzling changing how
1510           * addressing works significantly, this probably applies to swizzled
1511           * MUBUF accesses. */
1512          bool vaddr_prevent_overflow = mubuf.swizzled && ctx.program->gfx_level < GFX9;
1513 
1514          if (mubuf.offen && mubuf.idxen && i == 1 && info.is_vec() &&
1515              info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1516              info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1517              mubuf.offset + info.instr->operands[1].constantValue() < 4096) {
1518             instr->operands[1] = info.instr->operands[0];
1519             mubuf.offset += info.instr->operands[1].constantValue();
1520             mubuf.offen = false;
1521             continue;
1522          } else if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
1523                     mubuf.offset + info.val < 4096) {
1524             assert(!mubuf.idxen);
1525             instr->operands[1] = Operand(v1);
1526             mubuf.offset += info.val;
1527             mubuf.offen = false;
1528             continue;
1529          } else if (i == 2 && info.is_constant_or_literal(32) && mubuf.offset + info.val < 4096) {
1530             instr->operands[2] = Operand::c32(0);
1531             mubuf.offset += info.val;
1532             continue;
1533          } else if (mubuf.offen && i == 1 &&
1534                     parse_base_offset(ctx, instr.get(), i, &base, &offset,
1535                                       vaddr_prevent_overflow) &&
1536                     base.regClass() == v1 && mubuf.offset + offset < 4096) {
1537             assert(!mubuf.idxen);
1538             instr->operands[1].setTemp(base);
1539             mubuf.offset += offset;
1540             continue;
1541          } else if (i == 2 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1542                     base.regClass() == s1 && mubuf.offset + offset < 4096 && !mubuf.swizzled) {
1543             instr->operands[i].setTemp(base);
1544             mubuf.offset += offset;
1545             continue;
1546          }
1547       }
1548 
1549       else if (instr->isMTBUF()) {
1550          MTBUF_instruction& mtbuf = instr->mtbuf();
1551          while (info.is_temp())
1552             info = ctx.info[info.temp.id()];
1553 
1554          if (mtbuf.offen && mtbuf.idxen && i == 1 && info.is_vec() &&
1555              info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1556              info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1557              mtbuf.offset + info.instr->operands[1].constantValue() < 4096) {
1558             instr->operands[1] = info.instr->operands[0];
1559             mtbuf.offset += info.instr->operands[1].constantValue();
1560             mtbuf.offen = false;
1561             continue;
1562          }
1563       }
1564 
1565       /* SCRATCH: propagate constants and combine additions */
1566       else if (instr->isScratch()) {
1567          FLAT_instruction& scratch = instr->scratch();
1568          Temp base;
1569          uint32_t offset;
1570          while (info.is_temp())
1571             info = ctx.info[info.temp.id()];
1572 
1573          /* The hardware probably does: 'scratch_base + u2u64(saddr) + i2i64(offset)'. This means
1574           * we can't combine the addition if the unsigned addition overflows and offset is
1575           * positive. In theory, there is also issues if
1576           * 'ilt(offset, 0) && ige(saddr, 0) && ilt(saddr + offset, 0)', but that just
1577           * replaces an already out-of-bounds access with a larger one since 'saddr + offset'
1578           * would be larger than INT32_MAX.
1579           */
1580          if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1581              base.regClass() == instr->operands[i].regClass() &&
1582              is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1583             instr->operands[i].setTemp(base);
1584             scratch.offset += (int32_t)offset;
1585             continue;
1586          } else if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1587                     base.regClass() == instr->operands[i].regClass() && (int32_t)offset < 0 &&
1588                     is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1589             instr->operands[i].setTemp(base);
1590             scratch.offset += (int32_t)offset;
1591             continue;
1592          } else if (i <= 1 && info.is_constant_or_literal(32) &&
1593                     ctx.program->gfx_level >= GFX10_3 &&
1594                     is_scratch_offset_valid(ctx, NULL, scratch.offset, (int32_t)info.val)) {
1595             /* GFX10.3+ can disable both SADDR and ADDR. */
1596             instr->operands[i] = Operand(instr->operands[i].regClass());
1597             scratch.offset += (int32_t)info.val;
1598             continue;
1599          }
1600       }
1601 
1602       /* DS: combine additions */
1603       else if (instr->isDS()) {
1604 
1605          DS_instruction& ds = instr->ds();
1606          Temp base;
1607          uint32_t offset;
1608          bool has_usable_ds_offset = ctx.program->gfx_level >= GFX7;
1609          if (has_usable_ds_offset && i == 0 &&
1610              parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1611              base.regClass() == instr->operands[i].regClass() &&
1612              instr->opcode != aco_opcode::ds_swizzle_b32) {
1613             if (instr->opcode == aco_opcode::ds_write2_b32 ||
1614                 instr->opcode == aco_opcode::ds_read2_b32 ||
1615                 instr->opcode == aco_opcode::ds_write2_b64 ||
1616                 instr->opcode == aco_opcode::ds_read2_b64 ||
1617                 instr->opcode == aco_opcode::ds_write2st64_b32 ||
1618                 instr->opcode == aco_opcode::ds_read2st64_b32 ||
1619                 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1620                 instr->opcode == aco_opcode::ds_read2st64_b64) {
1621                bool is64bit = instr->opcode == aco_opcode::ds_write2_b64 ||
1622                               instr->opcode == aco_opcode::ds_read2_b64 ||
1623                               instr->opcode == aco_opcode::ds_write2st64_b64 ||
1624                               instr->opcode == aco_opcode::ds_read2st64_b64;
1625                bool st64 = instr->opcode == aco_opcode::ds_write2st64_b32 ||
1626                            instr->opcode == aco_opcode::ds_read2st64_b32 ||
1627                            instr->opcode == aco_opcode::ds_write2st64_b64 ||
1628                            instr->opcode == aco_opcode::ds_read2st64_b64;
1629                unsigned shifts = (is64bit ? 3 : 2) + (st64 ? 6 : 0);
1630                unsigned mask = BITFIELD_MASK(shifts);
1631 
1632                if ((offset & mask) == 0 && ds.offset0 + (offset >> shifts) <= 255 &&
1633                    ds.offset1 + (offset >> shifts) <= 255) {
1634                   instr->operands[i].setTemp(base);
1635                   ds.offset0 += offset >> shifts;
1636                   ds.offset1 += offset >> shifts;
1637                }
1638             } else {
1639                if (ds.offset0 + offset <= 65535) {
1640                   instr->operands[i].setTemp(base);
1641                   ds.offset0 += offset;
1642                }
1643             }
1644          }
1645       }
1646 
1647       else if (instr->isBranch()) {
1648          if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1649             /* Flip the branch instruction to get rid of the scc_invert instruction */
1650             instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
1651                                                                      : aco_opcode::p_cbranch_z;
1652             instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
1653          }
1654       }
1655    }
1656 
1657    /* if this instruction doesn't define anything, return */
1658    if (instr->definitions.empty()) {
1659       check_sdwa_extract(ctx, instr);
1660       return;
1661    }
1662 
1663    if (instr->isVALU() || instr->isVINTRP()) {
1664       if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() ||
1665           instr->opcode == aco_opcode::v_cndmask_b32) {
1666          bool canonicalized = true;
1667          if (!does_fp_op_flush_denorms(ctx, instr->opcode)) {
1668             unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size();
1669             for (unsigned i = 0; canonicalized && (i < ops); i++)
1670                canonicalized = is_op_canonicalized(ctx, instr->operands[i]);
1671          }
1672          if (canonicalized)
1673             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1674       }
1675 
1676       if (instr->isVOPC()) {
1677          ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
1678          check_sdwa_extract(ctx, instr);
1679          return;
1680       }
1681       if (instr->isVOP3P()) {
1682          ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
1683          return;
1684       }
1685    }
1686 
1687    switch (instr->opcode) {
1688    case aco_opcode::p_create_vector: {
1689       bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
1690                        instr->operands[0].regClass() == instr->definitions[0].regClass();
1691       if (copy_prop) {
1692          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1693          break;
1694       }
1695 
1696       /* expand vector operands */
1697       std::vector<Operand> ops;
1698       unsigned offset = 0;
1699       for (const Operand& op : instr->operands) {
1700          /* ensure that any expanded operands are properly aligned */
1701          bool aligned = offset % 4 == 0 || op.bytes() < 4;
1702          offset += op.bytes();
1703          if (aligned && op.isTemp() && ctx.info[op.tempId()].is_vec()) {
1704             Instruction* vec = ctx.info[op.tempId()].instr;
1705             for (const Operand& vec_op : vec->operands)
1706                ops.emplace_back(vec_op);
1707          } else {
1708             ops.emplace_back(op);
1709          }
1710       }
1711 
1712       /* combine expanded operands to new vector */
1713       if (ops.size() != instr->operands.size()) {
1714          assert(ops.size() > instr->operands.size());
1715          Definition def = instr->definitions[0];
1716          instr.reset(create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector,
1717                                                             Format::PSEUDO, ops.size(), 1));
1718          for (unsigned i = 0; i < ops.size(); i++) {
1719             if (ops[i].isTemp() && ctx.info[ops[i].tempId()].is_temp() &&
1720                 ops[i].regClass() == ctx.info[ops[i].tempId()].temp.regClass())
1721                ops[i].setTemp(ctx.info[ops[i].tempId()].temp);
1722             instr->operands[i] = ops[i];
1723          }
1724          instr->definitions[0] = def;
1725       } else {
1726          for (unsigned i = 0; i < ops.size(); i++) {
1727             assert(instr->operands[i] == ops[i]);
1728          }
1729       }
1730       ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1731 
1732       if (instr->operands.size() == 2) {
1733          /* check if this is created from split_vector */
1734          if (instr->operands[1].isTemp() && ctx.info[instr->operands[1].tempId()].is_split()) {
1735             Instruction* split = ctx.info[instr->operands[1].tempId()].instr;
1736             if (instr->operands[0].isTemp() &&
1737                 instr->operands[0].getTemp() == split->definitions[0].getTemp())
1738                ctx.info[instr->definitions[0].tempId()].set_temp(split->operands[0].getTemp());
1739          }
1740       }
1741       break;
1742    }
1743    case aco_opcode::p_split_vector: {
1744       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1745 
1746       if (info.is_constant_or_literal(32)) {
1747          uint64_t val = info.val;
1748          for (Definition def : instr->definitions) {
1749             uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
1750             ctx.info[def.tempId()].set_constant(ctx.program->gfx_level, val & mask);
1751             val >>= def.bytes() * 8u;
1752          }
1753          break;
1754       } else if (!info.is_vec()) {
1755          if (instr->definitions.size() == 2 && instr->operands[0].isTemp() &&
1756              instr->definitions[0].bytes() == instr->definitions[1].bytes()) {
1757             ctx.info[instr->definitions[1].tempId()].set_split(instr.get());
1758             if (instr->operands[0].bytes() == 4) {
1759                /* D16 subdword split */
1760                ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1761                ctx.info[instr->definitions[1].tempId()].set_extract(instr.get());
1762             }
1763          }
1764          break;
1765       }
1766 
1767       Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1768       unsigned split_offset = 0;
1769       unsigned vec_offset = 0;
1770       unsigned vec_index = 0;
1771       for (unsigned i = 0; i < instr->definitions.size();
1772            split_offset += instr->definitions[i++].bytes()) {
1773          while (vec_offset < split_offset && vec_index < vec->operands.size())
1774             vec_offset += vec->operands[vec_index++].bytes();
1775 
1776          if (vec_offset != split_offset ||
1777              vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
1778             continue;
1779 
1780          Operand vec_op = vec->operands[vec_index];
1781          if (vec_op.isConstant()) {
1782             ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->gfx_level,
1783                                                                   vec_op.constantValue64());
1784          } else if (vec_op.isUndefined()) {
1785             ctx.info[instr->definitions[i].tempId()].set_undefined();
1786          } else {
1787             assert(vec_op.isTemp());
1788             ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
1789          }
1790       }
1791       break;
1792    }
1793    case aco_opcode::p_extract_vector: { /* mov */
1794       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1795       const unsigned index = instr->operands[1].constantValue();
1796       const unsigned dst_offset = index * instr->definitions[0].bytes();
1797 
1798       if (info.is_vec()) {
1799          /* check if we index directly into a vector element */
1800          Instruction* vec = info.instr;
1801          unsigned offset = 0;
1802 
1803          for (const Operand& op : vec->operands) {
1804             if (offset < dst_offset) {
1805                offset += op.bytes();
1806                continue;
1807             } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
1808                break;
1809             }
1810             instr->operands[0] = op;
1811             break;
1812          }
1813       } else if (info.is_constant_or_literal(32)) {
1814          /* propagate constants */
1815          uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
1816          uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
1817          instr->operands[0] =
1818             Operand::get_const(ctx.program->gfx_level, val, instr->definitions[0].bytes());
1819          ;
1820       }
1821 
1822       if (instr->operands[0].bytes() != instr->definitions[0].bytes()) {
1823          if (instr->operands[0].size() != 1)
1824             break;
1825 
1826          if (index == 0)
1827             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1828          else
1829             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1830          break;
1831       }
1832 
1833       /* convert this extract into a copy instruction */
1834       instr->opcode = aco_opcode::p_parallelcopy;
1835       instr->operands.pop_back();
1836       FALLTHROUGH;
1837    }
1838    case aco_opcode::p_parallelcopy: /* propagate */
1839       if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_vec() &&
1840           instr->operands[0].regClass() != instr->definitions[0].regClass()) {
1841          /* We might not be able to copy-propagate if it's a SGPR->VGPR copy, so
1842           * duplicate the vector instead.
1843           */
1844          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1845          aco_ptr<Instruction> old_copy = std::move(instr);
1846 
1847          instr.reset(create_instruction<Pseudo_instruction>(
1848             aco_opcode::p_create_vector, Format::PSEUDO, vec->operands.size(), 1));
1849          instr->definitions[0] = old_copy->definitions[0];
1850          std::copy(vec->operands.begin(), vec->operands.end(), instr->operands.begin());
1851          for (unsigned i = 0; i < vec->operands.size(); i++) {
1852             Operand& op = instr->operands[i];
1853             if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
1854                 ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
1855                op.setTemp(ctx.info[op.tempId()].temp);
1856          }
1857          ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1858          break;
1859       }
1860       FALLTHROUGH;
1861    case aco_opcode::p_as_uniform:
1862       if (instr->definitions[0].isFixed()) {
1863          /* don't copy-propagate copies into fixed registers */
1864       } else if (instr->operands[0].isConstant()) {
1865          ctx.info[instr->definitions[0].tempId()].set_constant(
1866             ctx.program->gfx_level, instr->operands[0].constantValue64());
1867       } else if (instr->operands[0].isTemp()) {
1868          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1869          if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
1870             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1871       } else {
1872          assert(instr->operands[0].isFixed());
1873       }
1874       break;
1875    case aco_opcode::v_mov_b32:
1876       if (instr->isDPP16()) {
1877          /* anything else doesn't make sense in SSA */
1878          assert(instr->dpp16().row_mask == 0xf && instr->dpp16().bank_mask == 0xf);
1879          ctx.info[instr->definitions[0].tempId()].set_dpp16(instr.get());
1880       } else if (instr->isDPP8()) {
1881          ctx.info[instr->definitions[0].tempId()].set_dpp8(instr.get());
1882       }
1883       break;
1884    case aco_opcode::p_is_helper:
1885       if (!ctx.program->needs_wqm)
1886          ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1887       break;
1888    case aco_opcode::v_mul_f64: ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); break;
1889    case aco_opcode::v_mul_f16:
1890    case aco_opcode::v_mul_f32:
1891    case aco_opcode::v_mul_legacy_f32: { /* omod */
1892       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
1893 
1894       /* TODO: try to move the negate/abs modifier to the consumer instead */
1895       bool uses_mods = instr->usesModifiers();
1896       bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
1897 
1898       for (unsigned i = 0; i < 2; i++) {
1899          if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
1900             if (!instr->isDPP() && !instr->isSDWA() && !instr->valu().opsel &&
1901                 (instr->operands[!i].constantEquals(fp16 ? 0x3c00 : 0x3f800000) ||   /* 1.0 */
1902                  instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u))) { /* -1.0 */
1903                bool neg1 = instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u);
1904 
1905                VALU_instruction* valu = &instr->valu();
1906                if (valu->abs[!i] || valu->neg[!i] || valu->omod)
1907                   continue;
1908 
1909                bool abs = valu->abs[i];
1910                bool neg = neg1 ^ valu->neg[i];
1911                Temp other = instr->operands[i].getTemp();
1912 
1913                if (valu->clamp) {
1914                   if (!abs && !neg && other.type() == RegType::vgpr)
1915                      ctx.info[other.id()].set_clamp(instr.get());
1916                   continue;
1917                }
1918 
1919                if (abs && neg && other.type() == RegType::vgpr)
1920                   ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
1921                else if (abs && !neg && other.type() == RegType::vgpr)
1922                   ctx.info[instr->definitions[0].tempId()].set_abs(other);
1923                else if (!abs && neg && other.type() == RegType::vgpr)
1924                   ctx.info[instr->definitions[0].tempId()].set_neg(other);
1925                else if (!abs && !neg)
1926                   ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
1927             } else if (uses_mods || ((fp16 ? ctx.fp_mode.preserve_signed_zero_inf_nan16_64
1928                                            : ctx.fp_mode.preserve_signed_zero_inf_nan32) &&
1929                                      instr->opcode != aco_opcode::v_mul_legacy_f32)) {
1930                continue; /* omod uses a legacy multiplication. */
1931             } else if (instr->operands[!i].constantValue() == 0u) { /* 0.0 */
1932                ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1933             } else if ((fp16 ? ctx.fp_mode.denorm16_64 : ctx.fp_mode.denorm32) != fp_denorm_flush) {
1934                /* omod has no effect if denormals are enabled. */
1935                continue;
1936             } else if (instr->operands[!i].constantValue() ==
1937                        (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */
1938                ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
1939             } else if (instr->operands[!i].constantValue() ==
1940                        (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */
1941                ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
1942             } else if (instr->operands[!i].constantValue() ==
1943                        (fp16 ? 0x3800 : 0x3f000000)) { /* 0.5 */
1944                ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
1945             } else {
1946                continue;
1947             }
1948             break;
1949          }
1950       }
1951       break;
1952    }
1953    case aco_opcode::v_mul_lo_u16:
1954    case aco_opcode::v_mul_lo_u16_e64:
1955    case aco_opcode::v_mul_u32_u24:
1956       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1957       break;
1958    case aco_opcode::v_med3_f16:
1959    case aco_opcode::v_med3_f32: { /* clamp */
1960       unsigned idx;
1961       if (detect_clamp(instr.get(), &idx) && !instr->valu().abs && !instr->valu().neg)
1962          ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
1963       break;
1964    }
1965    case aco_opcode::v_cndmask_b32:
1966       if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(0xFFFFFFFF))
1967          ctx.info[instr->definitions[0].tempId()].set_vcc(instr->operands[2].getTemp());
1968       else if (instr->operands[0].constantEquals(0) &&
1969                instr->operands[1].constantEquals(0x3f800000u))
1970          ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
1971       else if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(1))
1972          ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
1973 
1974       break;
1975    case aco_opcode::v_cmp_lg_u32:
1976       if (instr->format == Format::VOPC && /* don't optimize VOP3 / SDWA / DPP */
1977           instr->operands[0].constantEquals(0) && instr->operands[1].isTemp() &&
1978           ctx.info[instr->operands[1].tempId()].is_vcc())
1979          ctx.info[instr->definitions[0].tempId()].set_temp(
1980             ctx.info[instr->operands[1].tempId()].temp);
1981       break;
1982    case aco_opcode::p_linear_phi: {
1983       /* lower_bool_phis() can create phis like this */
1984       bool all_same_temp = instr->operands[0].isTemp();
1985       /* this check is needed when moving uniform loop counters out of a divergent loop */
1986       if (all_same_temp)
1987          all_same_temp = instr->definitions[0].regClass() == instr->operands[0].regClass();
1988       for (unsigned i = 1; all_same_temp && (i < instr->operands.size()); i++) {
1989          if (!instr->operands[i].isTemp() ||
1990              instr->operands[i].tempId() != instr->operands[0].tempId())
1991             all_same_temp = false;
1992       }
1993       if (all_same_temp) {
1994          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1995       } else {
1996          bool all_undef = instr->operands[0].isUndefined();
1997          for (unsigned i = 1; all_undef && (i < instr->operands.size()); i++) {
1998             if (!instr->operands[i].isUndefined())
1999                all_undef = false;
2000          }
2001          if (all_undef)
2002             ctx.info[instr->definitions[0].tempId()].set_undefined();
2003       }
2004       break;
2005    }
2006    case aco_opcode::v_add_u32:
2007    case aco_opcode::v_add_co_u32:
2008    case aco_opcode::v_add_co_u32_e64:
2009    case aco_opcode::s_add_i32:
2010    case aco_opcode::s_add_u32:
2011    case aco_opcode::v_subbrev_co_u32:
2012    case aco_opcode::v_sub_u32:
2013    case aco_opcode::v_sub_i32:
2014    case aco_opcode::v_sub_co_u32:
2015    case aco_opcode::v_sub_co_u32_e64:
2016    case aco_opcode::s_sub_u32:
2017    case aco_opcode::s_sub_i32:
2018    case aco_opcode::v_subrev_u32:
2019    case aco_opcode::v_subrev_co_u32:
2020    case aco_opcode::v_subrev_co_u32_e64:
2021       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
2022       break;
2023    case aco_opcode::s_not_b32:
2024    case aco_opcode::s_not_b64:
2025       if (!instr->operands[0].isTemp()) {
2026       } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
2027          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
2028          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
2029             ctx.info[instr->operands[0].tempId()].temp);
2030       } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
2031          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
2032          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
2033             ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
2034       }
2035       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2036       break;
2037    case aco_opcode::s_and_b32:
2038    case aco_opcode::s_and_b64:
2039       if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
2040          if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
2041             /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a
2042              * uniform bool into divergent */
2043             ctx.info[instr->definitions[1].tempId()].set_temp(
2044                ctx.info[instr->operands[0].tempId()].temp);
2045             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
2046                ctx.info[instr->operands[0].tempId()].temp);
2047             break;
2048          } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
2049             /* Try to get rid of the superfluous s_and_b64, since the uniform bitwise instruction
2050              * already produces the same SCC */
2051             ctx.info[instr->definitions[1].tempId()].set_temp(
2052                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
2053             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
2054                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
2055             break;
2056          } else if ((ctx.program->stage.num_sw_stages() > 1 ||
2057                      ctx.program->stage.hw == AC_HW_NEXT_GEN_GEOMETRY_SHADER) &&
2058                     instr->pass_flags == 1) {
2059             /* In case of merged shaders, pass_flags=1 means that all lanes are active (exec=-1), so
2060              * s_and is unnecessary. */
2061             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
2062             break;
2063          }
2064       }
2065       FALLTHROUGH;
2066    case aco_opcode::s_or_b32:
2067    case aco_opcode::s_or_b64:
2068    case aco_opcode::s_xor_b32:
2069    case aco_opcode::s_xor_b64:
2070       if (std::all_of(instr->operands.begin(), instr->operands.end(),
2071                       [&ctx](const Operand& op)
2072                       {
2073                          return op.isTemp() && (ctx.info[op.tempId()].is_uniform_bool() ||
2074                                                 ctx.info[op.tempId()].is_uniform_bitwise());
2075                       })) {
2076          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
2077       }
2078       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2079       break;
2080    case aco_opcode::s_lshl_b32:
2081    case aco_opcode::v_or_b32:
2082    case aco_opcode::v_lshlrev_b32:
2083    case aco_opcode::v_bcnt_u32_b32:
2084    case aco_opcode::v_and_b32:
2085    case aco_opcode::v_xor_b32:
2086    case aco_opcode::v_not_b32:
2087       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2088       break;
2089    case aco_opcode::v_min_f32:
2090    case aco_opcode::v_min_f16:
2091    case aco_opcode::v_min_u32:
2092    case aco_opcode::v_min_i32:
2093    case aco_opcode::v_min_u16:
2094    case aco_opcode::v_min_i16:
2095    case aco_opcode::v_min_u16_e64:
2096    case aco_opcode::v_min_i16_e64:
2097    case aco_opcode::v_max_f32:
2098    case aco_opcode::v_max_f16:
2099    case aco_opcode::v_max_u32:
2100    case aco_opcode::v_max_i32:
2101    case aco_opcode::v_max_u16:
2102    case aco_opcode::v_max_i16:
2103    case aco_opcode::v_max_u16_e64:
2104    case aco_opcode::v_max_i16_e64:
2105       ctx.info[instr->definitions[0].tempId()].set_minmax(instr.get());
2106       break;
2107    case aco_opcode::s_cselect_b64:
2108    case aco_opcode::s_cselect_b32:
2109       if (instr->operands[0].constantEquals((unsigned)-1) && instr->operands[1].constantEquals(0)) {
2110          /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
2111          ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
2112       }
2113       if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
2114          /* Flip the operands to get rid of the scc_invert instruction */
2115          std::swap(instr->operands[0], instr->operands[1]);
2116          instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
2117       }
2118       break;
2119    case aco_opcode::s_mul_i32:
2120       /* Testing every uint32_t shows that 0x3f800000*n is never a denormal.
2121        * This pattern is created from a uniform nir_op_b2f. */
2122       if (instr->operands[0].constantEquals(0x3f800000u))
2123          ctx.info[instr->definitions[0].tempId()].set_canonicalized();
2124       break;
2125    case aco_opcode::p_extract: {
2126       if (instr->definitions[0].bytes() == 4) {
2127          ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2128          if (instr->operands[0].regClass() == v1 && parse_insert(instr.get()))
2129             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2130       }
2131       break;
2132    }
2133    case aco_opcode::p_insert: {
2134       if (instr->operands[0].bytes() == 4) {
2135          if (instr->operands[0].regClass() == v1)
2136             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2137          if (parse_extract(instr.get()))
2138             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2139          ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2140       }
2141       break;
2142    }
2143    case aco_opcode::ds_read_u8:
2144    case aco_opcode::ds_read_u8_d16:
2145    case aco_opcode::ds_read_u16:
2146    case aco_opcode::ds_read_u16_d16: {
2147       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2148       break;
2149    }
2150    case aco_opcode::v_mbcnt_lo_u32_b32: {
2151       if (instr->operands[0].constantEquals(-1) && instr->operands[1].constantEquals(0)) {
2152          if (ctx.program->wave_size == 32)
2153             ctx.info[instr->definitions[0].tempId()].set_subgroup_invocation(instr.get());
2154          else
2155             ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2156       }
2157       break;
2158    }
2159    case aco_opcode::v_mbcnt_hi_u32_b32:
2160    case aco_opcode::v_mbcnt_hi_u32_b32_e64: {
2161       if (instr->operands[0].constantEquals(-1) && instr->operands[1].isTemp() &&
2162           ctx.info[instr->operands[1].tempId()].is_usedef()) {
2163          Instruction* usedef_instr = ctx.info[instr->operands[1].tempId()].instr;
2164          if (usedef_instr->opcode == aco_opcode::v_mbcnt_lo_u32_b32 &&
2165              usedef_instr->operands[0].constantEquals(-1) &&
2166              usedef_instr->operands[1].constantEquals(0))
2167             ctx.info[instr->definitions[0].tempId()].set_subgroup_invocation(instr.get());
2168       }
2169       break;
2170    }
2171    case aco_opcode::v_cvt_f16_f32: {
2172       if (instr->operands[0].isTemp())
2173          ctx.info[instr->operands[0].tempId()].set_f2f16(instr.get());
2174       break;
2175    }
2176    case aco_opcode::v_cvt_f32_f16: {
2177       if (instr->operands[0].isTemp())
2178          ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
2179       break;
2180    }
2181    default: break;
2182    }
2183 
2184    /* Don't remove label_extract if we can't apply the extract to
2185     * neg/abs instructions because we'll likely combine it into another valu. */
2186    if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)))
2187       check_sdwa_extract(ctx, instr);
2188 }
2189 
2190 unsigned
original_temp_id(opt_ctx & ctx,Temp tmp)2191 original_temp_id(opt_ctx& ctx, Temp tmp)
2192 {
2193    if (ctx.info[tmp.id()].is_temp())
2194       return ctx.info[tmp.id()].temp.id();
2195    else
2196       return tmp.id();
2197 }
2198 
2199 void
decrease_op_uses_if_dead(opt_ctx & ctx,Instruction * instr)2200 decrease_op_uses_if_dead(opt_ctx& ctx, Instruction* instr)
2201 {
2202    if (is_dead(ctx.uses, instr)) {
2203       for (const Operand& op : instr->operands) {
2204          if (op.isTemp())
2205             ctx.uses[op.tempId()]--;
2206       }
2207    }
2208 }
2209 
2210 void
decrease_uses(opt_ctx & ctx,Instruction * instr)2211 decrease_uses(opt_ctx& ctx, Instruction* instr)
2212 {
2213    ctx.uses[instr->definitions[0].tempId()]--;
2214    decrease_op_uses_if_dead(ctx, instr);
2215 }
2216 
2217 Operand
copy_operand(opt_ctx & ctx,Operand op)2218 copy_operand(opt_ctx& ctx, Operand op)
2219 {
2220    if (op.isTemp())
2221       ctx.uses[op.tempId()]++;
2222    return op;
2223 }
2224 
2225 Instruction*
follow_operand(opt_ctx & ctx,Operand op,bool ignore_uses=false)2226 follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false)
2227 {
2228    if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels))
2229       return nullptr;
2230    if (!ignore_uses && ctx.uses[op.tempId()] > 1)
2231       return nullptr;
2232 
2233    Instruction* instr = ctx.info[op.tempId()].instr;
2234 
2235    if (instr->definitions.size() == 2) {
2236       assert(instr->definitions[0].isTemp() && instr->definitions[0].tempId() == op.tempId());
2237       if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2238          return nullptr;
2239    }
2240 
2241    for (Operand& operand : instr->operands) {
2242       if (fixed_to_exec(operand))
2243          return nullptr;
2244    }
2245 
2246    return instr;
2247 }
2248 
2249 /* s_or_b64(neq(a, a), neq(b, b)) -> v_cmp_u_f32(a, b)
2250  * s_and_b64(eq(a, a), eq(b, b)) -> v_cmp_o_f32(a, b) */
2251 bool
combine_ordering_test(opt_ctx & ctx,aco_ptr<Instruction> & instr)2252 combine_ordering_test(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2253 {
2254    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2255       return false;
2256    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2257       return false;
2258 
2259    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2260 
2261    bitarray8 opsel = 0;
2262    Instruction* op_instr[2];
2263    Temp op[2];
2264 
2265    unsigned bitsize = 0;
2266    for (unsigned i = 0; i < 2; i++) {
2267       op_instr[i] = follow_operand(ctx, instr->operands[i], true);
2268       if (!op_instr[i])
2269          return false;
2270 
2271       aco_opcode expected_cmp = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
2272       unsigned op_bitsize = get_cmp_bitsize(op_instr[i]->opcode);
2273 
2274       if (get_f32_cmp(op_instr[i]->opcode) != expected_cmp)
2275          return false;
2276       if (bitsize && op_bitsize != bitsize)
2277          return false;
2278       if (!op_instr[i]->operands[0].isTemp() || !op_instr[i]->operands[1].isTemp())
2279          return false;
2280 
2281       if (op_instr[i]->isSDWA() || op_instr[i]->isDPP())
2282          return false;
2283 
2284       VALU_instruction& valu = op_instr[i]->valu();
2285       if (valu.neg[0] != valu.neg[1] || valu.abs[0] != valu.abs[1] ||
2286           valu.opsel[0] != valu.opsel[1])
2287          return false;
2288       opsel[i] = valu.opsel[0];
2289 
2290       Temp op0 = op_instr[i]->operands[0].getTemp();
2291       Temp op1 = op_instr[i]->operands[1].getTemp();
2292       if (original_temp_id(ctx, op0) != original_temp_id(ctx, op1))
2293          return false;
2294 
2295       op[i] = op1;
2296       bitsize = op_bitsize;
2297    }
2298 
2299    if (op[1].type() == RegType::sgpr) {
2300       std::swap(op[0], op[1]);
2301       opsel[0].swap(opsel[1]);
2302    }
2303    unsigned num_sgprs = (op[0].type() == RegType::sgpr) + (op[1].type() == RegType::sgpr);
2304    if (num_sgprs > (ctx.program->gfx_level >= GFX10 ? 2 : 1))
2305       return false;
2306 
2307    aco_opcode new_op = aco_opcode::num_opcodes;
2308    switch (bitsize) {
2309    case 16: new_op = is_or ? aco_opcode::v_cmp_u_f16 : aco_opcode::v_cmp_o_f16; break;
2310    case 32: new_op = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32; break;
2311    case 64: new_op = is_or ? aco_opcode::v_cmp_u_f64 : aco_opcode::v_cmp_o_f64; break;
2312    }
2313    bool needs_vop3 = num_sgprs > 1 || (opsel[0] && op[0].type() != RegType::vgpr);
2314    VALU_instruction* new_instr = create_instruction<VALU_instruction>(
2315       new_op, needs_vop3 ? asVOP3(Format::VOPC) : Format::VOPC, 2, 1);
2316 
2317    new_instr->opsel = opsel;
2318    new_instr->operands[0] = copy_operand(ctx, Operand(op[0]));
2319    new_instr->operands[1] = copy_operand(ctx, Operand(op[1]));
2320    new_instr->definitions[0] = instr->definitions[0];
2321    new_instr->pass_flags = instr->pass_flags;
2322 
2323    decrease_uses(ctx, op_instr[0]);
2324    decrease_uses(ctx, op_instr[1]);
2325 
2326    ctx.info[instr->definitions[0].tempId()].label = 0;
2327    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2328 
2329    instr.reset(new_instr);
2330 
2331    return true;
2332 }
2333 
2334 /* s_or_b64(v_cmp_u_f32(a, b), cmp(a, b)) -> get_unordered(cmp)(a, b)
2335  * s_and_b64(v_cmp_o_f32(a, b), cmp(a, b)) -> get_ordered(cmp)(a, b) */
2336 bool
combine_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2337 combine_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2338 {
2339    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2340       return false;
2341    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2342       return false;
2343 
2344    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2345    aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
2346 
2347    Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2348    Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2349    if (!nan_test || !cmp)
2350       return false;
2351    if (nan_test->isSDWA() || cmp->isSDWA())
2352       return false;
2353 
2354    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2355       std::swap(nan_test, cmp);
2356    else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2357       return false;
2358 
2359    if (!is_fp_cmp(cmp->opcode) || get_cmp_bitsize(cmp->opcode) != get_cmp_bitsize(nan_test->opcode))
2360       return false;
2361 
2362    if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2363       return false;
2364    if (!cmp->operands[0].isTemp() || !cmp->operands[1].isTemp())
2365       return false;
2366 
2367    unsigned prop_cmp0 = original_temp_id(ctx, cmp->operands[0].getTemp());
2368    unsigned prop_cmp1 = original_temp_id(ctx, cmp->operands[1].getTemp());
2369    unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2370    unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2371    VALU_instruction& cmp_valu = cmp->valu();
2372    VALU_instruction& nan_valu = nan_test->valu();
2373    if ((prop_cmp0 != prop_nan0 || cmp_valu.opsel[0] != nan_valu.opsel[0]) &&
2374        (prop_cmp0 != prop_nan1 || cmp_valu.opsel[0] != nan_valu.opsel[1]))
2375       return false;
2376    if ((prop_cmp1 != prop_nan0 || cmp_valu.opsel[1] != nan_valu.opsel[0]) &&
2377        (prop_cmp1 != prop_nan1 || cmp_valu.opsel[1] != nan_valu.opsel[1]))
2378       return false;
2379    if (prop_cmp0 == prop_cmp1 && cmp_valu.opsel[0] == cmp_valu.opsel[1])
2380       return false;
2381 
2382    aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2383    VALU_instruction* new_instr = create_instruction<VALU_instruction>(
2384       new_op, cmp->isVOP3() ? asVOP3(Format::VOPC) : Format::VOPC, 2, 1);
2385    new_instr->neg = cmp_valu.neg;
2386    new_instr->abs = cmp_valu.abs;
2387    new_instr->clamp = cmp_valu.clamp;
2388    new_instr->omod = cmp_valu.omod;
2389    new_instr->opsel = cmp_valu.opsel;
2390    new_instr->operands[0] = copy_operand(ctx, cmp->operands[0]);
2391    new_instr->operands[1] = copy_operand(ctx, cmp->operands[1]);
2392    new_instr->definitions[0] = instr->definitions[0];
2393    new_instr->pass_flags = instr->pass_flags;
2394 
2395    decrease_uses(ctx, nan_test);
2396    decrease_uses(ctx, cmp);
2397 
2398    ctx.info[instr->definitions[0].tempId()].label = 0;
2399    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2400 
2401    instr.reset(new_instr);
2402 
2403    return true;
2404 }
2405 
2406 /* Optimize v_cmp of constant with subgroup invocation to a constant mask.
2407  * Ideally, we can trade v_cmp for a constant (or literal).
2408  * In a less ideal case, we trade v_cmp for a SALU instruction, which is still a win.
2409  */
2410 bool
optimize_cmp_subgroup_invocation(opt_ctx & ctx,aco_ptr<Instruction> & instr)2411 optimize_cmp_subgroup_invocation(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2412 {
2413    /* This optimization only applies to VOPC with 2 operands. */
2414    if (instr->operands.size() != 2)
2415       return false;
2416 
2417    /* Find the constant operand or return early if there isn't one. */
2418    const int const_op_idx = instr->operands[0].isConstant()   ? 0
2419                             : instr->operands[1].isConstant() ? 1
2420                                                               : -1;
2421    if (const_op_idx == -1)
2422       return false;
2423 
2424    /* Find the operand that has the subgroup invocation. */
2425    const int mbcnt_op_idx = 1 - const_op_idx;
2426    const Operand mbcnt_op = instr->operands[mbcnt_op_idx];
2427    if (!mbcnt_op.isTemp() || !ctx.info[mbcnt_op.tempId()].is_subgroup_invocation())
2428       return false;
2429 
2430    /* Adjust opcode so we don't have to care about const_op_idx below. */
2431    const aco_opcode op = const_op_idx == 0 ? get_swapped(instr->opcode) : instr->opcode;
2432    const unsigned wave_size = ctx.program->wave_size;
2433    const unsigned val = instr->operands[const_op_idx].constantValue();
2434 
2435    /* Find suitable constant bitmask corresponding to the value. */
2436    unsigned first_bit = 0, num_bits = 0;
2437    switch (op) {
2438    case aco_opcode::v_cmp_eq_u32:
2439    case aco_opcode::v_cmp_eq_i32:
2440       first_bit = val;
2441       num_bits = val >= wave_size ? 0 : 1;
2442       break;
2443    case aco_opcode::v_cmp_le_u32:
2444    case aco_opcode::v_cmp_le_i32:
2445       first_bit = 0;
2446       num_bits = val >= wave_size ? wave_size : (val + 1);
2447       break;
2448    case aco_opcode::v_cmp_lt_u32:
2449    case aco_opcode::v_cmp_lt_i32:
2450       first_bit = 0;
2451       num_bits = val >= wave_size ? wave_size : val;
2452       break;
2453    case aco_opcode::v_cmp_ge_u32:
2454    case aco_opcode::v_cmp_ge_i32:
2455       first_bit = val;
2456       num_bits = val >= wave_size ? 0 : (wave_size - val);
2457       break;
2458    case aco_opcode::v_cmp_gt_u32:
2459    case aco_opcode::v_cmp_gt_i32:
2460       first_bit = val + 1;
2461       num_bits = val >= wave_size ? 0 : (wave_size - val - 1);
2462       break;
2463    default: return false;
2464    }
2465 
2466    Instruction* cpy = NULL;
2467    const uint64_t mask = BITFIELD64_RANGE(first_bit, num_bits);
2468    if (wave_size == 64 && mask > 0x7fffffff && mask != -1ull) {
2469       /* Mask can't be represented as a 64-bit constant or literal, use s_bfm_b64. */
2470       cpy = create_instruction<SOP2_instruction>(aco_opcode::s_bfm_b64, Format::SOP2, 2, 1);
2471       cpy->operands[0] = Operand::c32(num_bits);
2472       cpy->operands[1] = Operand::c32(first_bit);
2473    } else {
2474       /* Copy mask as a literal constant. */
2475       cpy =
2476          create_instruction<Pseudo_instruction>(aco_opcode::p_parallelcopy, Format::PSEUDO, 1, 1);
2477       cpy->operands[0] = wave_size == 32 ? Operand::c32((uint32_t)mask) : Operand::c64(mask);
2478    }
2479 
2480    cpy->definitions[0] = instr->definitions[0];
2481    ctx.info[instr->definitions[0].tempId()].label = 0;
2482    decrease_uses(ctx, ctx.info[mbcnt_op.tempId()].instr);
2483    instr.reset(cpy);
2484 
2485    return true;
2486 }
2487 
2488 bool
is_operand_constant(opt_ctx & ctx,Operand op,unsigned bit_size,uint64_t * value)2489 is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value)
2490 {
2491    if (op.isConstant()) {
2492       *value = op.constantValue64();
2493       return true;
2494    } else if (op.isTemp()) {
2495       unsigned id = original_temp_id(ctx, op.getTemp());
2496       if (!ctx.info[id].is_constant_or_literal(bit_size))
2497          return false;
2498       *value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
2499       return true;
2500    }
2501    return false;
2502 }
2503 
2504 bool
is_constant_nan(uint64_t value,unsigned bit_size)2505 is_constant_nan(uint64_t value, unsigned bit_size)
2506 {
2507    if (bit_size == 16)
2508       return ((value >> 10) & 0x1f) == 0x1f && (value & 0x3ff);
2509    else if (bit_size == 32)
2510       return ((value >> 23) & 0xff) == 0xff && (value & 0x7fffff);
2511    else
2512       return ((value >> 52) & 0x7ff) == 0x7ff && (value & 0xfffffffffffff);
2513 }
2514 
2515 /* s_or_b64(v_cmp_neq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_unordered(cmp)(a, b)
2516  * s_and_b64(v_cmp_eq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_ordered(cmp)(a, b) */
2517 bool
combine_constant_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2518 combine_constant_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2519 {
2520    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2521       return false;
2522    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2523       return false;
2524 
2525    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2526 
2527    Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2528    Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2529 
2530    if (!nan_test || !cmp || nan_test->isSDWA() || cmp->isSDWA() || nan_test->isDPP() ||
2531        cmp->isDPP())
2532       return false;
2533 
2534    aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
2535    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2536       std::swap(nan_test, cmp);
2537    else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2538       return false;
2539 
2540    unsigned bit_size = get_cmp_bitsize(cmp->opcode);
2541    if (!is_fp_cmp(cmp->opcode) || get_cmp_bitsize(nan_test->opcode) != bit_size)
2542       return false;
2543 
2544    if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2545       return false;
2546    if (!cmp->operands[0].isTemp() && !cmp->operands[1].isTemp())
2547       return false;
2548 
2549    unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2550    unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2551    if (prop_nan0 != prop_nan1)
2552       return false;
2553 
2554    VALU_instruction& vop3 = nan_test->valu();
2555    if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel[0] != vop3.opsel[1])
2556       return false;
2557 
2558    int constant_operand = -1;
2559    for (unsigned i = 0; i < 2; i++) {
2560       if (cmp->operands[i].isTemp() &&
2561           original_temp_id(ctx, cmp->operands[i].getTemp()) == prop_nan0 &&
2562           cmp->valu().opsel[i] == nan_test->valu().opsel[0]) {
2563          constant_operand = !i;
2564          break;
2565       }
2566    }
2567    if (constant_operand == -1)
2568       return false;
2569 
2570    uint64_t constant_value;
2571    if (!is_operand_constant(ctx, cmp->operands[constant_operand], bit_size, &constant_value))
2572       return false;
2573    if (is_constant_nan(constant_value >> (cmp->valu().opsel[constant_operand] * 16), bit_size))
2574       return false;
2575 
2576    aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2577    Instruction* new_instr = create_instruction<VALU_instruction>(new_op, cmp->format, 2, 1);
2578    new_instr->valu().neg = cmp->valu().neg;
2579    new_instr->valu().abs = cmp->valu().abs;
2580    new_instr->valu().clamp = cmp->valu().clamp;
2581    new_instr->valu().omod = cmp->valu().omod;
2582    new_instr->valu().opsel = cmp->valu().opsel;
2583    new_instr->operands[0] = copy_operand(ctx, cmp->operands[0]);
2584    new_instr->operands[1] = copy_operand(ctx, cmp->operands[1]);
2585    new_instr->definitions[0] = instr->definitions[0];
2586    new_instr->pass_flags = instr->pass_flags;
2587 
2588    decrease_uses(ctx, nan_test);
2589    decrease_uses(ctx, cmp);
2590 
2591    ctx.info[instr->definitions[0].tempId()].label = 0;
2592    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2593 
2594    instr.reset(new_instr);
2595 
2596    return true;
2597 }
2598 
2599 /* s_not(cmp(a, b)) -> get_inverse(cmp)(a, b) */
2600 bool
combine_inverse_comparison(opt_ctx & ctx,aco_ptr<Instruction> & instr)2601 combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2602 {
2603    if (ctx.uses[instr->definitions[1].tempId()])
2604       return false;
2605    if (!instr->operands[0].isTemp() || ctx.uses[instr->operands[0].tempId()] != 1)
2606       return false;
2607 
2608    Instruction* cmp = follow_operand(ctx, instr->operands[0]);
2609    if (!cmp)
2610       return false;
2611 
2612    aco_opcode new_opcode = get_inverse(cmp->opcode);
2613    if (new_opcode == aco_opcode::num_opcodes)
2614       return false;
2615 
2616    /* Invert compare instruction and assign this instruction's definition */
2617    cmp->opcode = new_opcode;
2618    ctx.info[instr->definitions[0].tempId()] = ctx.info[cmp->definitions[0].tempId()];
2619    std::swap(instr->definitions[0], cmp->definitions[0]);
2620 
2621    ctx.uses[instr->operands[0].tempId()]--;
2622    return true;
2623 }
2624 
2625 /* op1(op2(1, 2), 0) if swap = false
2626  * op1(0, op2(1, 2)) if swap = true */
2627 bool
match_op3_for_vop3(opt_ctx & ctx,aco_opcode op1,aco_opcode op2,Instruction * op1_instr,bool swap,const char * shuffle_str,Operand operands[3],bitarray8 & neg,bitarray8 & abs,bitarray8 & opsel,bool * op1_clamp,uint8_t * op1_omod,bool * inbetween_neg,bool * inbetween_abs,bool * inbetween_opsel,bool * precise)2628 match_op3_for_vop3(opt_ctx& ctx, aco_opcode op1, aco_opcode op2, Instruction* op1_instr, bool swap,
2629                    const char* shuffle_str, Operand operands[3], bitarray8& neg, bitarray8& abs,
2630                    bitarray8& opsel, bool* op1_clamp, uint8_t* op1_omod, bool* inbetween_neg,
2631                    bool* inbetween_abs, bool* inbetween_opsel, bool* precise)
2632 {
2633    /* checks */
2634    if (op1_instr->opcode != op1)
2635       return false;
2636 
2637    Instruction* op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
2638    if (!op2_instr || op2_instr->opcode != op2)
2639       return false;
2640 
2641    VALU_instruction* op1_valu = op1_instr->isVALU() ? &op1_instr->valu() : NULL;
2642    VALU_instruction* op2_valu = op2_instr->isVALU() ? &op2_instr->valu() : NULL;
2643 
2644    if (op1_instr->isSDWA() || op2_instr->isSDWA())
2645       return false;
2646    if (op1_instr->isDPP() || op2_instr->isDPP())
2647       return false;
2648 
2649    /* don't support inbetween clamp/omod */
2650    if (op2_valu && (op2_valu->clamp || op2_valu->omod))
2651       return false;
2652 
2653    /* get operands and modifiers and check inbetween modifiers */
2654    *op1_clamp = op1_valu ? (bool)op1_valu->clamp : false;
2655    *op1_omod = op1_valu ? (unsigned)op1_valu->omod : 0u;
2656 
2657    if (inbetween_neg)
2658       *inbetween_neg = op1_valu ? op1_valu->neg[swap] : false;
2659    else if (op1_valu && op1_valu->neg[swap])
2660       return false;
2661 
2662    if (inbetween_abs)
2663       *inbetween_abs = op1_valu ? op1_valu->abs[swap] : false;
2664    else if (op1_valu && op1_valu->abs[swap])
2665       return false;
2666 
2667    if (inbetween_opsel)
2668       *inbetween_opsel = op1_valu ? op1_valu->opsel[swap] : false;
2669    else if (op1_valu && op1_valu->opsel[swap])
2670       return false;
2671 
2672    *precise = op1_instr->definitions[0].isPrecise() || op2_instr->definitions[0].isPrecise();
2673 
2674    int shuffle[3];
2675    shuffle[shuffle_str[0] - '0'] = 0;
2676    shuffle[shuffle_str[1] - '0'] = 1;
2677    shuffle[shuffle_str[2] - '0'] = 2;
2678 
2679    operands[shuffle[0]] = op1_instr->operands[!swap];
2680    neg[shuffle[0]] = op1_valu ? op1_valu->neg[!swap] : false;
2681    abs[shuffle[0]] = op1_valu ? op1_valu->abs[!swap] : false;
2682    opsel[shuffle[0]] = op1_valu ? op1_valu->opsel[!swap] : false;
2683 
2684    for (unsigned i = 0; i < 2; i++) {
2685       operands[shuffle[i + 1]] = op2_instr->operands[i];
2686       neg[shuffle[i + 1]] = op2_valu ? op2_valu->neg[i] : false;
2687       abs[shuffle[i + 1]] = op2_valu ? op2_valu->abs[i] : false;
2688       opsel[shuffle[i + 1]] = op2_valu ? op2_valu->opsel[i] : false;
2689    }
2690 
2691    /* check operands */
2692    if (!check_vop3_operands(ctx, 3, operands))
2693       return false;
2694 
2695    return true;
2696 }
2697 
2698 void
create_vop3_for_op3(opt_ctx & ctx,aco_opcode opcode,aco_ptr<Instruction> & instr,Operand operands[3],uint8_t neg,uint8_t abs,uint8_t opsel,bool clamp,unsigned omod)2699 create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>& instr,
2700                     Operand operands[3], uint8_t neg, uint8_t abs, uint8_t opsel, bool clamp,
2701                     unsigned omod)
2702 {
2703    VALU_instruction* new_instr = create_instruction<VALU_instruction>(opcode, Format::VOP3, 3, 1);
2704    new_instr->neg = neg;
2705    new_instr->abs = abs;
2706    new_instr->clamp = clamp;
2707    new_instr->omod = omod;
2708    new_instr->opsel = opsel;
2709    new_instr->operands[0] = operands[0];
2710    new_instr->operands[1] = operands[1];
2711    new_instr->operands[2] = operands[2];
2712    new_instr->definitions[0] = instr->definitions[0];
2713    new_instr->pass_flags = instr->pass_flags;
2714    ctx.info[instr->definitions[0].tempId()].label = 0;
2715 
2716    instr.reset(new_instr);
2717 }
2718 
2719 bool
combine_three_valu_op(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode op2,aco_opcode new_op,const char * shuffle,uint8_t ops)2720 combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op,
2721                       const char* shuffle, uint8_t ops)
2722 {
2723    for (unsigned swap = 0; swap < 2; swap++) {
2724       if (!((1 << swap) & ops))
2725          continue;
2726 
2727       Operand operands[3];
2728       bool clamp, precise;
2729       bitarray8 neg = 0, abs = 0, opsel = 0;
2730       uint8_t omod = 0;
2731       if (match_op3_for_vop3(ctx, instr->opcode, op2, instr.get(), swap, shuffle, operands, neg,
2732                              abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2733          ctx.uses[instr->operands[swap].tempId()]--;
2734          create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
2735          return true;
2736       }
2737    }
2738    return false;
2739 }
2740 
2741 /* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
2742 bool
combine_add_or_then_and_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr)2743 combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2744 {
2745    bool is_or = instr->opcode == aco_opcode::v_or_b32;
2746    aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
2747 
2748    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32,
2749                                       "120", 1 | 2))
2750       return true;
2751    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32,
2752                                       "120", 1 | 2))
2753       return true;
2754    if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
2755       return true;
2756    if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
2757       return true;
2758 
2759    if (instr->isSDWA() || instr->isDPP())
2760       return false;
2761 
2762    /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2763     * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2764     * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
2765     * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
2766     */
2767    for (unsigned i = 0; i < 2; i++) {
2768       Instruction* extins = follow_operand(ctx, instr->operands[i]);
2769       if (!extins)
2770          continue;
2771 
2772       aco_opcode op;
2773       Operand operands[3];
2774 
2775       if (extins->opcode == aco_opcode::p_insert &&
2776           (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
2777          op = new_op_lshl;
2778          operands[1] =
2779             Operand::c32(extins->operands[1].constantValue() * extins->operands[2].constantValue());
2780       } else if (is_or &&
2781                  (extins->opcode == aco_opcode::p_insert ||
2782                   (extins->opcode == aco_opcode::p_extract &&
2783                    extins->operands[3].constantEquals(0))) &&
2784                  extins->operands[1].constantEquals(0)) {
2785          op = aco_opcode::v_and_or_b32;
2786          operands[1] = Operand::c32(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
2787       } else {
2788          continue;
2789       }
2790 
2791       operands[0] = extins->operands[0];
2792       operands[2] = instr->operands[!i];
2793 
2794       if (!check_vop3_operands(ctx, 3, operands))
2795          continue;
2796 
2797       uint8_t neg = 0, abs = 0, opsel = 0, omod = 0;
2798       bool clamp = false;
2799       if (instr->isVOP3())
2800          clamp = instr->valu().clamp;
2801 
2802       ctx.uses[instr->operands[i].tempId()]--;
2803       create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
2804       return true;
2805    }
2806 
2807    return false;
2808 }
2809 
2810 /* v_xor(a, s_not(b)) -> v_xnor(a, b)
2811  * v_xor(a, v_not(b)) -> v_xnor(a, b)
2812  */
2813 bool
combine_xor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)2814 combine_xor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2815 {
2816    if (instr->usesModifiers())
2817       return false;
2818 
2819    for (unsigned i = 0; i < 2; i++) {
2820       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
2821       if (!op_instr ||
2822           (op_instr->opcode != aco_opcode::v_not_b32 &&
2823            op_instr->opcode != aco_opcode::s_not_b32) ||
2824           op_instr->usesModifiers() || op_instr->operands[0].isLiteral())
2825          continue;
2826 
2827       instr->opcode = aco_opcode::v_xnor_b32;
2828       instr->operands[i] = copy_operand(ctx, op_instr->operands[0]);
2829       decrease_uses(ctx, op_instr);
2830       if (instr->operands[0].isOfType(RegType::vgpr))
2831          std::swap(instr->operands[0], instr->operands[1]);
2832       if (!instr->operands[1].isOfType(RegType::vgpr))
2833          instr->format = asVOP3(instr->format);
2834 
2835       return true;
2836    }
2837 
2838    return false;
2839 }
2840 
2841 /* v_not(v_xor(a, b)) -> v_xnor(a, b) */
2842 bool
combine_not_xor(opt_ctx & ctx,aco_ptr<Instruction> & instr)2843 combine_not_xor(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2844 {
2845    if (instr->usesModifiers())
2846       return false;
2847 
2848    Instruction* op_instr = follow_operand(ctx, instr->operands[0]);
2849    if (!op_instr || op_instr->opcode != aco_opcode::v_xor_b32 || op_instr->isSDWA())
2850       return false;
2851 
2852    ctx.uses[instr->operands[0].tempId()]--;
2853    std::swap(instr->definitions[0], op_instr->definitions[0]);
2854    op_instr->opcode = aco_opcode::v_xnor_b32;
2855 
2856    return true;
2857 }
2858 
2859 bool
combine_minmax(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode opposite,aco_opcode op3src,aco_opcode minmax)2860 combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode op3src,
2861                aco_opcode minmax)
2862 {
2863    /* TODO: this can handle SDWA min/max instructions by using opsel */
2864 
2865    /* min(min(a, b), c) -> min3(a, b, c)
2866     * max(max(a, b), c) -> max3(a, b, c)
2867     * gfx11: min(-min(a, b), c) -> maxmin(-a, -b, c)
2868     * gfx11: max(-max(a, b), c) -> minmax(-a, -b, c)
2869     */
2870    for (unsigned swap = 0; swap < 2; swap++) {
2871       Operand operands[3];
2872       bool clamp, precise;
2873       bitarray8 opsel = 0, neg = 0, abs = 0;
2874       uint8_t omod = 0;
2875       bool inbetween_neg;
2876       if (match_op3_for_vop3(ctx, instr->opcode, instr->opcode, instr.get(), swap, "120", operands,
2877                              neg, abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL,
2878                              &precise) &&
2879           (!inbetween_neg ||
2880            (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2881          ctx.uses[instr->operands[swap].tempId()]--;
2882          if (inbetween_neg) {
2883             neg[0] = !neg[0];
2884             neg[1] = !neg[1];
2885             create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2886          } else {
2887             create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2888          }
2889          return true;
2890       }
2891    }
2892 
2893    /* min(-max(a, b), c) -> min3(-a, -b, c)
2894     * max(-min(a, b), c) -> max3(-a, -b, c)
2895     * gfx11: min(max(a, b), c) -> maxmin(a, b, c)
2896     * gfx11: max(min(a, b), c) -> minmax(a, b, c)
2897     */
2898    for (unsigned swap = 0; swap < 2; swap++) {
2899       Operand operands[3];
2900       bool clamp, precise;
2901       bitarray8 opsel = 0, neg = 0, abs = 0;
2902       uint8_t omod = 0;
2903       bool inbetween_neg;
2904       if (match_op3_for_vop3(ctx, instr->opcode, opposite, instr.get(), swap, "120", operands, neg,
2905                              abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL, &precise) &&
2906           (inbetween_neg ||
2907            (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2908          ctx.uses[instr->operands[swap].tempId()]--;
2909          if (inbetween_neg) {
2910             neg[0] = !neg[0];
2911             neg[1] = !neg[1];
2912             create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2913          } else {
2914             create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2915          }
2916          return true;
2917       }
2918    }
2919    return false;
2920 }
2921 
2922 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
2923  * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
2924  * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
2925  * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b)
2926  * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b)
2927  * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */
2928 bool
combine_salu_not_bitwise(opt_ctx & ctx,aco_ptr<Instruction> & instr)2929 combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2930 {
2931    /* checks */
2932    if (!instr->operands[0].isTemp())
2933       return false;
2934    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2935       return false;
2936 
2937    Instruction* op2_instr = follow_operand(ctx, instr->operands[0]);
2938    if (!op2_instr)
2939       return false;
2940    switch (op2_instr->opcode) {
2941    case aco_opcode::s_and_b32:
2942    case aco_opcode::s_or_b32:
2943    case aco_opcode::s_xor_b32:
2944    case aco_opcode::s_and_b64:
2945    case aco_opcode::s_or_b64:
2946    case aco_opcode::s_xor_b64: break;
2947    default: return false;
2948    }
2949 
2950    /* create instruction */
2951    std::swap(instr->definitions[0], op2_instr->definitions[0]);
2952    std::swap(instr->definitions[1], op2_instr->definitions[1]);
2953    ctx.uses[instr->operands[0].tempId()]--;
2954    ctx.info[op2_instr->definitions[0].tempId()].label = 0;
2955 
2956    switch (op2_instr->opcode) {
2957    case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break;
2958    case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break;
2959    case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break;
2960    case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break;
2961    case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break;
2962    case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break;
2963    default: break;
2964    }
2965 
2966    return true;
2967 }
2968 
2969 /* s_and_b32(a, s_not_b32(b)) -> s_andn2_b32(a, b)
2970  * s_or_b32(a, s_not_b32(b)) -> s_orn2_b32(a, b)
2971  * s_and_b64(a, s_not_b64(b)) -> s_andn2_b64(a, b)
2972  * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
2973 bool
combine_salu_n2(opt_ctx & ctx,aco_ptr<Instruction> & instr)2974 combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2975 {
2976    if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
2977       return false;
2978 
2979    for (unsigned i = 0; i < 2; i++) {
2980       Instruction* op2_instr = follow_operand(ctx, instr->operands[i]);
2981       if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 &&
2982                          op2_instr->opcode != aco_opcode::s_not_b64))
2983          continue;
2984       if (ctx.uses[op2_instr->definitions[1].tempId()])
2985          continue;
2986 
2987       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2988           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2989          continue;
2990 
2991       ctx.uses[instr->operands[i].tempId()]--;
2992       instr->operands[0] = instr->operands[!i];
2993       instr->operands[1] = op2_instr->operands[0];
2994       ctx.info[instr->definitions[0].tempId()].label = 0;
2995 
2996       switch (instr->opcode) {
2997       case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_andn2_b32; break;
2998       case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_orn2_b32; break;
2999       case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_andn2_b64; break;
3000       case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_orn2_b64; break;
3001       default: break;
3002       }
3003 
3004       return true;
3005    }
3006    return false;
3007 }
3008 
3009 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
3010 bool
combine_salu_lshl_add(opt_ctx & ctx,aco_ptr<Instruction> & instr)3011 combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3012 {
3013    if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
3014       return false;
3015 
3016    for (unsigned i = 0; i < 2; i++) {
3017       Instruction* op2_instr = follow_operand(ctx, instr->operands[i], true);
3018       if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
3019           ctx.uses[op2_instr->definitions[1].tempId()])
3020          continue;
3021       if (!op2_instr->operands[1].isConstant())
3022          continue;
3023 
3024       uint32_t shift = op2_instr->operands[1].constantValue();
3025       if (shift < 1 || shift > 4)
3026          continue;
3027 
3028       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
3029           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
3030          continue;
3031 
3032       instr->operands[1] = instr->operands[!i];
3033       instr->operands[0] = copy_operand(ctx, op2_instr->operands[0]);
3034       decrease_uses(ctx, op2_instr);
3035       ctx.info[instr->definitions[0].tempId()].label = 0;
3036 
3037       instr->opcode = std::array<aco_opcode, 4>{
3038          aco_opcode::s_lshl1_add_u32, aco_opcode::s_lshl2_add_u32, aco_opcode::s_lshl3_add_u32,
3039          aco_opcode::s_lshl4_add_u32}[shift - 1];
3040 
3041       return true;
3042    }
3043    return false;
3044 }
3045 
3046 /* s_abs_i32(s_sub_[iu]32(a, b)) -> s_absdiff_i32(a, b)
3047  * s_abs_i32(s_add_[iu]32(a, #b)) -> s_absdiff_i32(a, -b)
3048  */
3049 bool
combine_sabsdiff(opt_ctx & ctx,aco_ptr<Instruction> & instr)3050 combine_sabsdiff(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3051 {
3052    if (!instr->operands[0].isTemp() || !ctx.info[instr->operands[0].tempId()].is_add_sub())
3053       return false;
3054 
3055    Instruction* op_instr = follow_operand(ctx, instr->operands[0], false);
3056    if (!op_instr)
3057       return false;
3058 
3059    if (op_instr->opcode == aco_opcode::s_add_i32 || op_instr->opcode == aco_opcode::s_add_u32) {
3060       for (unsigned i = 0; i < 2; i++) {
3061          uint64_t constant;
3062          if (op_instr->operands[!i].isLiteral() ||
3063              !is_operand_constant(ctx, op_instr->operands[i], 32, &constant))
3064             continue;
3065 
3066          if (op_instr->operands[i].isTemp())
3067             ctx.uses[op_instr->operands[i].tempId()]--;
3068          op_instr->operands[0] = op_instr->operands[!i];
3069          op_instr->operands[1] = Operand::c32(-int32_t(constant));
3070          goto use_absdiff;
3071       }
3072       return false;
3073    }
3074 
3075 use_absdiff:
3076    op_instr->opcode = aco_opcode::s_absdiff_i32;
3077    std::swap(instr->definitions[0], op_instr->definitions[0]);
3078    std::swap(instr->definitions[1], op_instr->definitions[1]);
3079    ctx.uses[instr->operands[0].tempId()]--;
3080 
3081    return true;
3082 }
3083 
3084 bool
combine_add_sub_b2i(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode new_op,uint8_t ops)3085 combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
3086 {
3087    if (instr->usesModifiers())
3088       return false;
3089 
3090    for (unsigned i = 0; i < 2; i++) {
3091       if (!((1 << i) & ops))
3092          continue;
3093       if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2i() &&
3094           ctx.uses[instr->operands[i].tempId()] == 1) {
3095 
3096          aco_ptr<Instruction> new_instr;
3097          if (instr->operands[!i].isTemp() &&
3098              instr->operands[!i].getTemp().type() == RegType::vgpr) {
3099             new_instr.reset(create_instruction<VALU_instruction>(new_op, Format::VOP2, 3, 2));
3100          } else if (ctx.program->gfx_level >= GFX10 ||
3101                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3102             new_instr.reset(
3103                create_instruction<VALU_instruction>(new_op, asVOP3(Format::VOP2), 3, 2));
3104          } else {
3105             return false;
3106          }
3107          ctx.uses[instr->operands[i].tempId()]--;
3108          new_instr->definitions[0] = instr->definitions[0];
3109          if (instr->definitions.size() == 2) {
3110             new_instr->definitions[1] = instr->definitions[1];
3111          } else {
3112             new_instr->definitions[1] =
3113                Definition(ctx.program->allocateTmp(ctx.program->lane_mask));
3114             /* Make sure the uses vector is large enough and the number of
3115              * uses properly initialized to 0.
3116              */
3117             ctx.uses.push_back(0);
3118          }
3119          new_instr->operands[0] = Operand::zero();
3120          new_instr->operands[1] = instr->operands[!i];
3121          new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
3122          new_instr->pass_flags = instr->pass_flags;
3123          instr = std::move(new_instr);
3124          ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
3125          return true;
3126       }
3127    }
3128 
3129    return false;
3130 }
3131 
3132 bool
combine_add_bcnt(opt_ctx & ctx,aco_ptr<Instruction> & instr)3133 combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3134 {
3135    if (instr->usesModifiers())
3136       return false;
3137 
3138    for (unsigned i = 0; i < 2; i++) {
3139       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3140       if (op_instr && op_instr->opcode == aco_opcode::v_bcnt_u32_b32 &&
3141           !op_instr->usesModifiers() && op_instr->operands[0].isTemp() &&
3142           op_instr->operands[0].getTemp().type() == RegType::vgpr &&
3143           op_instr->operands[1].constantEquals(0)) {
3144          aco_ptr<Instruction> new_instr{
3145             create_instruction<VALU_instruction>(aco_opcode::v_bcnt_u32_b32, Format::VOP3, 2, 1)};
3146          ctx.uses[instr->operands[i].tempId()]--;
3147          new_instr->operands[0] = op_instr->operands[0];
3148          new_instr->operands[1] = instr->operands[!i];
3149          new_instr->definitions[0] = instr->definitions[0];
3150          new_instr->pass_flags = instr->pass_flags;
3151          instr = std::move(new_instr);
3152          ctx.info[instr->definitions[0].tempId()].label = 0;
3153 
3154          return true;
3155       }
3156    }
3157 
3158    return false;
3159 }
3160 
3161 bool
get_minmax_info(aco_opcode op,aco_opcode * min,aco_opcode * max,aco_opcode * min3,aco_opcode * max3,aco_opcode * med3,aco_opcode * minmax,bool * some_gfx9_only)3162 get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
3163                 aco_opcode* med3, aco_opcode* minmax, bool* some_gfx9_only)
3164 {
3165    switch (op) {
3166 #define MINMAX(type, gfx9)                                                                         \
3167    case aco_opcode::v_min_##type:                                                                  \
3168    case aco_opcode::v_max_##type:                                                                  \
3169       *min = aco_opcode::v_min_##type;                                                             \
3170       *max = aco_opcode::v_max_##type;                                                             \
3171       *med3 = aco_opcode::v_med3_##type;                                                           \
3172       *min3 = aco_opcode::v_min3_##type;                                                           \
3173       *max3 = aco_opcode::v_max3_##type;                                                           \
3174       *minmax = op == *min ? aco_opcode::v_maxmin_##type : aco_opcode::v_minmax_##type;            \
3175       *some_gfx9_only = gfx9;                                                                      \
3176       return true;
3177 #define MINMAX_INT16(type, gfx9)                                                                   \
3178    case aco_opcode::v_min_##type:                                                                  \
3179    case aco_opcode::v_max_##type:                                                                  \
3180       *min = aco_opcode::v_min_##type;                                                             \
3181       *max = aco_opcode::v_max_##type;                                                             \
3182       *med3 = aco_opcode::v_med3_##type;                                                           \
3183       *min3 = aco_opcode::v_min3_##type;                                                           \
3184       *max3 = aco_opcode::v_max3_##type;                                                           \
3185       *minmax = aco_opcode::num_opcodes;                                                           \
3186       *some_gfx9_only = gfx9;                                                                      \
3187       return true;
3188 #define MINMAX_INT16_E64(type, gfx9)                                                               \
3189    case aco_opcode::v_min_##type##_e64:                                                            \
3190    case aco_opcode::v_max_##type##_e64:                                                            \
3191       *min = aco_opcode::v_min_##type##_e64;                                                       \
3192       *max = aco_opcode::v_max_##type##_e64;                                                       \
3193       *med3 = aco_opcode::v_med3_##type;                                                           \
3194       *min3 = aco_opcode::v_min3_##type;                                                           \
3195       *max3 = aco_opcode::v_max3_##type;                                                           \
3196       *minmax = aco_opcode::num_opcodes;                                                           \
3197       *some_gfx9_only = gfx9;                                                                      \
3198       return true;
3199       MINMAX(f32, false)
3200       MINMAX(u32, false)
3201       MINMAX(i32, false)
3202       MINMAX(f16, true)
3203       MINMAX_INT16(u16, true)
3204       MINMAX_INT16(i16, true)
3205       MINMAX_INT16_E64(u16, true)
3206       MINMAX_INT16_E64(i16, true)
3207 #undef MINMAX_INT16_E64
3208 #undef MINMAX_INT16
3209 #undef MINMAX
3210    default: return false;
3211    }
3212 }
3213 
3214 /* when ub > lb:
3215  * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
3216  * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
3217  */
3218 bool
combine_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode min,aco_opcode max,aco_opcode med)3219 combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
3220               aco_opcode med)
3221 {
3222    /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
3223     * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
3224     * minVal > maxVal, which means we can always select it to a v_med3_f32 */
3225    aco_opcode other_op;
3226    if (instr->opcode == min)
3227       other_op = max;
3228    else if (instr->opcode == max)
3229       other_op = min;
3230    else
3231       return false;
3232 
3233    for (unsigned swap = 0; swap < 2; swap++) {
3234       Operand operands[3];
3235       bool clamp, precise;
3236       bitarray8 opsel = 0, neg = 0, abs = 0;
3237       uint8_t omod = 0;
3238       if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
3239                              abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
3240          /* max(min(src, upper), lower) returns upper if src is NaN, but
3241           * med3(src, lower, upper) returns lower.
3242           */
3243          if (precise && instr->opcode != min &&
3244              (min == aco_opcode::v_min_f16 || min == aco_opcode::v_min_f32))
3245             continue;
3246 
3247          int const0_idx = -1, const1_idx = -1;
3248          uint32_t const0 = 0, const1 = 0;
3249          for (int i = 0; i < 3; i++) {
3250             uint32_t val;
3251             bool hi16 = opsel & (1 << i);
3252             if (operands[i].isConstant()) {
3253                val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue();
3254             } else if (operands[i].isTemp() &&
3255                        ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
3256                val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0);
3257             } else {
3258                continue;
3259             }
3260             if (const0_idx >= 0) {
3261                const1_idx = i;
3262                const1 = val;
3263             } else {
3264                const0_idx = i;
3265                const0 = val;
3266             }
3267          }
3268          if (const0_idx < 0 || const1_idx < 0)
3269             continue;
3270 
3271          int lower_idx = const0_idx;
3272          switch (min) {
3273          case aco_opcode::v_min_f32:
3274          case aco_opcode::v_min_f16: {
3275             float const0_f, const1_f;
3276             if (min == aco_opcode::v_min_f32) {
3277                memcpy(&const0_f, &const0, 4);
3278                memcpy(&const1_f, &const1, 4);
3279             } else {
3280                const0_f = _mesa_half_to_float(const0);
3281                const1_f = _mesa_half_to_float(const1);
3282             }
3283             if (abs[const0_idx])
3284                const0_f = fabsf(const0_f);
3285             if (abs[const1_idx])
3286                const1_f = fabsf(const1_f);
3287             if (neg[const0_idx])
3288                const0_f = -const0_f;
3289             if (neg[const1_idx])
3290                const1_f = -const1_f;
3291             lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
3292             break;
3293          }
3294          case aco_opcode::v_min_u32: {
3295             lower_idx = const0 < const1 ? const0_idx : const1_idx;
3296             break;
3297          }
3298          case aco_opcode::v_min_u16:
3299          case aco_opcode::v_min_u16_e64: {
3300             lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
3301             break;
3302          }
3303          case aco_opcode::v_min_i32: {
3304             int32_t const0_i =
3305                const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
3306             int32_t const1_i =
3307                const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
3308             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
3309             break;
3310          }
3311          case aco_opcode::v_min_i16:
3312          case aco_opcode::v_min_i16_e64: {
3313             int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
3314             int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
3315             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
3316             break;
3317          }
3318          default: break;
3319          }
3320          int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
3321 
3322          if (instr->opcode == min) {
3323             if (upper_idx != 0 || lower_idx == 0)
3324                return false;
3325          } else {
3326             if (upper_idx == 0 || lower_idx != 0)
3327                return false;
3328          }
3329 
3330          ctx.uses[instr->operands[swap].tempId()]--;
3331          create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
3332 
3333          return true;
3334       }
3335    }
3336 
3337    return false;
3338 }
3339 
3340 void
apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)3341 apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3342 {
3343    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
3344                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
3345                      instr->opcode == aco_opcode::v_ashrrev_i64;
3346 
3347    /* find candidates and create the set of sgprs already read */
3348    unsigned sgpr_ids[2] = {0, 0};
3349    uint32_t operand_mask = 0;
3350    bool has_literal = false;
3351    for (unsigned i = 0; i < instr->operands.size(); i++) {
3352       if (instr->operands[i].isLiteral())
3353          has_literal = true;
3354       if (!instr->operands[i].isTemp())
3355          continue;
3356       if (instr->operands[i].getTemp().type() == RegType::sgpr) {
3357          if (instr->operands[i].tempId() != sgpr_ids[0])
3358             sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
3359       }
3360       ssa_info& info = ctx.info[instr->operands[i].tempId()];
3361       if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::sgpr)
3362          operand_mask |= 1u << i;
3363       if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr)
3364          operand_mask |= 1u << i;
3365    }
3366    unsigned max_sgprs = 1;
3367    if (ctx.program->gfx_level >= GFX10 && !is_shift64)
3368       max_sgprs = 2;
3369    if (has_literal)
3370       max_sgprs--;
3371 
3372    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
3373 
3374    /* keep on applying sgprs until there is nothing left to be done */
3375    while (operand_mask) {
3376       uint32_t sgpr_idx = 0;
3377       uint32_t sgpr_info_id = 0;
3378       uint32_t mask = operand_mask;
3379       /* choose a sgpr */
3380       while (mask) {
3381          unsigned i = u_bit_scan(&mask);
3382          uint16_t uses = ctx.uses[instr->operands[i].tempId()];
3383          if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
3384             sgpr_idx = i;
3385             sgpr_info_id = instr->operands[i].tempId();
3386          }
3387       }
3388       operand_mask &= ~(1u << sgpr_idx);
3389 
3390       ssa_info& info = ctx.info[sgpr_info_id];
3391 
3392       /* Applying two sgprs require making it VOP3, so don't do it unless it's
3393        * definitively beneficial.
3394        * TODO: this is too conservative because later the use count could be reduced to 1 */
3395       if (!info.is_extract() && num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() &&
3396           !instr->isSDWA() && instr->format != Format::VOP3P)
3397          break;
3398 
3399       Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp;
3400       bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
3401       if (new_sgpr && num_sgprs >= max_sgprs)
3402          continue;
3403 
3404       if (sgpr_idx == 0)
3405          instr->format = withoutDPP(instr->format);
3406 
3407       if (sgpr_idx == 1 && instr->isDPP())
3408          continue;
3409 
3410       if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
3411           info.is_extract()) {
3412          /* can_apply_extract() checks SGPR encoding restrictions */
3413          if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
3414             apply_extract(ctx, instr, sgpr_idx, info);
3415          else if (info.is_extract())
3416             continue;
3417          instr->operands[sgpr_idx] = Operand(sgpr);
3418       } else if (can_swap_operands(instr, &instr->opcode) && !instr->valu().opsel[sgpr_idx]) {
3419          instr->operands[sgpr_idx] = instr->operands[0];
3420          instr->operands[0] = Operand(sgpr);
3421          instr->valu().opsel[0].swap(instr->valu().opsel[sgpr_idx]);
3422          /* swap bits using a 4-entry LUT */
3423          uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
3424          operand_mask = (operand_mask & ~0x3) | swapped;
3425       } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
3426          instr->format = asVOP3(instr->format);
3427          instr->operands[sgpr_idx] = Operand(sgpr);
3428       } else {
3429          continue;
3430       }
3431 
3432       if (new_sgpr)
3433          sgpr_ids[num_sgprs++] = sgpr.id();
3434       ctx.uses[sgpr_info_id]--;
3435       ctx.uses[sgpr.id()]++;
3436 
3437       /* TODO: handle when it's a VGPR */
3438       if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
3439           ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
3440          operand_mask |= 1u << sgpr_idx;
3441    }
3442 }
3443 
3444 void
interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction> & instr)3445 interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction>& instr)
3446 {
3447    static_assert(sizeof(DPP16_instruction) == sizeof(VINTERP_inreg_instruction),
3448                  "Invalid instr cast.");
3449    instr->format = asVOP3(Format::DPP16);
3450    instr->opcode = aco_opcode::v_fma_f32;
3451    instr->dpp16().dpp_ctrl = dpp_quad_perm(2, 2, 2, 2);
3452    instr->dpp16().row_mask = 0xf;
3453    instr->dpp16().bank_mask = 0xf;
3454    instr->dpp16().bound_ctrl = 0;
3455    instr->dpp16().fetch_inactive = 1;
3456 }
3457 
3458 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
3459 bool
apply_omod_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr)3460 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3461 {
3462    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
3463        !instr_info.can_use_output_modifiers[(int)instr->opcode])
3464       return false;
3465 
3466    bool can_vop3 = can_use_VOP3(ctx, instr);
3467    bool is_mad_mix =
3468       instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16;
3469    bool needs_vop3 = !instr->isSDWA() && !instr->isVINTERP_INREG() && !is_mad_mix;
3470    if (needs_vop3 && !can_vop3)
3471       return false;
3472 
3473    /* SDWA omod is GFX9+. */
3474    bool can_use_omod =
3475       (can_vop3 || ctx.program->gfx_level >= GFX9) && !instr->isVOP3P() &&
3476       (!instr->isVINTERP_INREG() || instr->opcode == aco_opcode::v_interp_p2_f32_inreg);
3477 
3478    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3479 
3480    uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
3481    if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
3482       return false;
3483    /* if the omod/clamp instruction is dead, then the single user of this
3484     * instruction is a different instruction */
3485    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3486       return false;
3487 
3488    if (def_info.instr->definitions[0].bytes() != instr->definitions[0].bytes())
3489       return false;
3490 
3491    /* MADs/FMAs are created later, so we don't have to update the original add */
3492    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3493 
3494    if (!def_info.is_clamp() && (instr->valu().clamp || instr->valu().omod))
3495       return false;
3496 
3497    if (needs_vop3)
3498       instr->format = asVOP3(instr->format);
3499 
3500    if (!def_info.is_clamp() && instr->opcode == aco_opcode::v_interp_p2_f32_inreg)
3501       interp_p2_f32_inreg_to_fma_dpp(instr);
3502 
3503    if (def_info.is_omod2())
3504       instr->valu().omod = 1;
3505    else if (def_info.is_omod4())
3506       instr->valu().omod = 2;
3507    else if (def_info.is_omod5())
3508       instr->valu().omod = 3;
3509    else if (def_info.is_clamp())
3510       instr->valu().clamp = true;
3511 
3512    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3513    ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
3514    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3515 
3516    return true;
3517 }
3518 
3519 /* Combine an p_insert (or p_extract, in some cases) instruction with instr.
3520  * p_insert(instr(...)) -> instr_insert().
3521  */
3522 bool
apply_insert(opt_ctx & ctx,aco_ptr<Instruction> & instr)3523 apply_insert(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3524 {
3525    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1)
3526       return false;
3527 
3528    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3529    if (!def_info.is_insert())
3530       return false;
3531    /* if the insert instruction is dead, then the single user of this
3532     * instruction is a different instruction */
3533    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3534       return false;
3535 
3536    /* MADs/FMAs are created later, so we don't have to update the original add */
3537    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3538 
3539    SubdwordSel sel = parse_insert(def_info.instr);
3540    assert(sel);
3541 
3542    if (!can_use_SDWA(ctx.program->gfx_level, instr, true))
3543       return false;
3544 
3545    convert_to_SDWA(ctx.program->gfx_level, instr);
3546    if (instr->sdwa().dst_sel.size() != 4)
3547       return false;
3548    instr->sdwa().dst_sel = sel;
3549 
3550    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3551    ctx.info[instr->definitions[0].tempId()].label = 0;
3552    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3553 
3554    return true;
3555 }
3556 
3557 /* Remove superfluous extract after ds_read like so:
3558  * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN()
3559  */
3560 bool
apply_ds_extract(opt_ctx & ctx,aco_ptr<Instruction> & extract)3561 apply_ds_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract)
3562 {
3563    /* Check if p_extract has a usedef operand and is the only user. */
3564    if (!ctx.info[extract->operands[0].tempId()].is_usedef() ||
3565        ctx.uses[extract->operands[0].tempId()] > 1)
3566       return false;
3567 
3568    /* Check if the usedef is a DS instruction. */
3569    Instruction* ds = ctx.info[extract->operands[0].tempId()].instr;
3570    if (ds->format != Format::DS)
3571       return false;
3572 
3573    unsigned extract_idx = extract->operands[1].constantValue();
3574    unsigned bits_extracted = extract->operands[2].constantValue();
3575    unsigned sign_ext = extract->operands[3].constantValue();
3576    unsigned dst_bitsize = extract->definitions[0].bytes() * 8u;
3577 
3578    /* TODO: These are doable, but probably don't occur too often. */
3579    if (extract_idx || sign_ext || dst_bitsize != 32)
3580       return false;
3581 
3582    unsigned bits_loaded = 0;
3583    if (ds->opcode == aco_opcode::ds_read_u8 || ds->opcode == aco_opcode::ds_read_u8_d16)
3584       bits_loaded = 8;
3585    else if (ds->opcode == aco_opcode::ds_read_u16 || ds->opcode == aco_opcode::ds_read_u16_d16)
3586       bits_loaded = 16;
3587    else
3588       return false;
3589 
3590    /* Shrink the DS load if the extracted bit size is smaller. */
3591    bits_loaded = MIN2(bits_loaded, bits_extracted);
3592 
3593    /* Change the DS opcode so it writes the full register. */
3594    if (bits_loaded == 8)
3595       ds->opcode = aco_opcode::ds_read_u8;
3596    else if (bits_loaded == 16)
3597       ds->opcode = aco_opcode::ds_read_u16;
3598    else
3599       unreachable("Forgot to add DS opcode above.");
3600 
3601    /* The DS now produces the exact same thing as the extract, remove the extract. */
3602    std::swap(ds->definitions[0], extract->definitions[0]);
3603    ctx.uses[extract->definitions[0].tempId()] = 0;
3604    ctx.info[ds->definitions[0].tempId()].label = 0;
3605    return true;
3606 }
3607 
3608 /* v_and(a, v_subbrev_co(0, 0, vcc)) -> v_cndmask(0, a, vcc) */
3609 bool
combine_and_subbrev(opt_ctx & ctx,aco_ptr<Instruction> & instr)3610 combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3611 {
3612    if (instr->usesModifiers())
3613       return false;
3614 
3615    for (unsigned i = 0; i < 2; i++) {
3616       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3617       if (op_instr && op_instr->opcode == aco_opcode::v_subbrev_co_u32 &&
3618           op_instr->operands[0].constantEquals(0) && op_instr->operands[1].constantEquals(0) &&
3619           !op_instr->usesModifiers()) {
3620 
3621          aco_ptr<Instruction> new_instr;
3622          if (instr->operands[!i].isTemp() &&
3623              instr->operands[!i].getTemp().type() == RegType::vgpr) {
3624             new_instr.reset(
3625                create_instruction<VALU_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1));
3626          } else if (ctx.program->gfx_level >= GFX10 ||
3627                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3628             new_instr.reset(create_instruction<VALU_instruction>(aco_opcode::v_cndmask_b32,
3629                                                                  asVOP3(Format::VOP2), 3, 1));
3630          } else {
3631             return false;
3632          }
3633 
3634          new_instr->operands[0] = Operand::zero();
3635          new_instr->operands[1] = instr->operands[!i];
3636          new_instr->operands[2] = copy_operand(ctx, op_instr->operands[2]);
3637          new_instr->definitions[0] = instr->definitions[0];
3638          new_instr->pass_flags = instr->pass_flags;
3639          instr = std::move(new_instr);
3640          decrease_uses(ctx, op_instr);
3641          ctx.info[instr->definitions[0].tempId()].label = 0;
3642          return true;
3643       }
3644    }
3645 
3646    return false;
3647 }
3648 
3649 /* v_and(a, not(b)) -> v_bfi_b32(b, 0, a)
3650  * v_or(a, not(b)) -> v_bfi_b32(b, a, -1)
3651  */
3652 bool
combine_v_andor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)3653 combine_v_andor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3654 {
3655    if (instr->usesModifiers())
3656       return false;
3657 
3658    for (unsigned i = 0; i < 2; i++) {
3659       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3660       if (op_instr && !op_instr->usesModifiers() &&
3661           (op_instr->opcode == aco_opcode::v_not_b32 ||
3662            op_instr->opcode == aco_opcode::s_not_b32)) {
3663 
3664          Operand ops[3] = {
3665             op_instr->operands[0],
3666             Operand::zero(),
3667             instr->operands[!i],
3668          };
3669          if (instr->opcode == aco_opcode::v_or_b32) {
3670             ops[1] = instr->operands[!i];
3671             ops[2] = Operand::c32(-1);
3672          }
3673          if (!check_vop3_operands(ctx, 3, ops))
3674             continue;
3675 
3676          Instruction* new_instr =
3677             create_instruction<VALU_instruction>(aco_opcode::v_bfi_b32, Format::VOP3, 3, 1);
3678 
3679          if (op_instr->operands[0].isTemp())
3680             ctx.uses[op_instr->operands[0].tempId()]++;
3681          for (unsigned j = 0; j < 3; j++)
3682             new_instr->operands[j] = ops[j];
3683          new_instr->definitions[0] = instr->definitions[0];
3684          new_instr->pass_flags = instr->pass_flags;
3685          instr.reset(new_instr);
3686          decrease_uses(ctx, op_instr);
3687          ctx.info[instr->definitions[0].tempId()].label = 0;
3688          return true;
3689       }
3690    }
3691 
3692    return false;
3693 }
3694 
3695 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
3696  * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
3697  * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
3698  * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
3699  */
3700 bool
combine_add_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr,bool is_sub)3701 combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
3702 {
3703    if (instr->usesModifiers())
3704       return false;
3705 
3706    /* Substractions: start at operand 1 to avoid mixup such as
3707     * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
3708     */
3709    unsigned start_op_idx = is_sub ? 1 : 0;
3710 
3711    /* Don't allow 24-bit operands on subtraction because
3712     * v_mad_i32_i24 applies a sign extension.
3713     */
3714    bool allow_24bit = !is_sub;
3715 
3716    for (unsigned i = start_op_idx; i < 2; i++) {
3717       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3718       if (!op_instr)
3719          continue;
3720 
3721       if (op_instr->opcode != aco_opcode::s_lshl_b32 &&
3722           op_instr->opcode != aco_opcode::v_lshlrev_b32)
3723          continue;
3724 
3725       int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
3726 
3727       if (op_instr->operands[shift_op_idx].isConstant() &&
3728           ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
3729            op_instr->operands[!shift_op_idx].is16bit())) {
3730          uint32_t multiplier = 1 << (op_instr->operands[shift_op_idx].constantValue() % 32u);
3731          if (is_sub)
3732             multiplier = -multiplier;
3733          if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
3734             continue;
3735 
3736          Operand ops[3] = {
3737             op_instr->operands[!shift_op_idx],
3738             Operand::c32(multiplier),
3739             instr->operands[!i],
3740          };
3741          if (!check_vop3_operands(ctx, 3, ops))
3742             return false;
3743 
3744          ctx.uses[instr->operands[i].tempId()]--;
3745 
3746          aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : aco_opcode::v_mad_u32_u24;
3747          aco_ptr<VALU_instruction> new_instr{
3748             create_instruction<VALU_instruction>(mad_op, Format::VOP3, 3, 1)};
3749          for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
3750             new_instr->operands[op_idx] = ops[op_idx];
3751          new_instr->definitions[0] = instr->definitions[0];
3752          new_instr->pass_flags = instr->pass_flags;
3753          instr = std::move(new_instr);
3754          ctx.info[instr->definitions[0].tempId()].label = 0;
3755          return true;
3756       }
3757    }
3758 
3759    return false;
3760 }
3761 
3762 void
propagate_swizzles(VALU_instruction * instr,bool opsel_lo,bool opsel_hi)3763 propagate_swizzles(VALU_instruction* instr, bool opsel_lo, bool opsel_hi)
3764 {
3765    /* propagate swizzles which apply to a result down to the instruction's operands:
3766     * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
3767    uint8_t tmp_lo = instr->opsel_lo;
3768    uint8_t tmp_hi = instr->opsel_hi;
3769    uint8_t neg_lo = instr->neg_lo;
3770    uint8_t neg_hi = instr->neg_hi;
3771    if (opsel_lo == 1) {
3772       instr->opsel_lo = tmp_hi;
3773       instr->neg_lo = neg_hi;
3774    }
3775    if (opsel_hi == 0) {
3776       instr->opsel_hi = tmp_lo;
3777       instr->neg_hi = neg_lo;
3778    }
3779 }
3780 
3781 void
combine_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr)3782 combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3783 {
3784    VALU_instruction* vop3p = &instr->valu();
3785 
3786    /* apply clamp */
3787    if (instr->opcode == aco_opcode::v_pk_mul_f16 && instr->operands[1].constantEquals(0x3C00) &&
3788        vop3p->clamp && instr->operands[0].isTemp() && ctx.uses[instr->operands[0].tempId()] == 1 &&
3789        !vop3p->opsel_lo[1] && !vop3p->opsel_hi[1]) {
3790 
3791       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3792       if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
3793          VALU_instruction* candidate = &ctx.info[instr->operands[0].tempId()].instr->valu();
3794          candidate->clamp = true;
3795          propagate_swizzles(candidate, vop3p->opsel_lo[0], vop3p->opsel_hi[0]);
3796          instr->definitions[0].swapTemp(candidate->definitions[0]);
3797          ctx.info[candidate->definitions[0].tempId()].instr = candidate;
3798          ctx.uses[instr->definitions[0].tempId()]--;
3799          return;
3800       }
3801    }
3802 
3803    /* check for fneg modifiers */
3804    for (unsigned i = 0; i < instr->operands.size(); i++) {
3805       if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
3806          continue;
3807       Operand& op = instr->operands[i];
3808       if (!op.isTemp())
3809          continue;
3810 
3811       ssa_info& info = ctx.info[op.tempId()];
3812       if (info.is_vop3p() && info.instr->opcode == aco_opcode::v_pk_mul_f16 &&
3813           (info.instr->operands[0].constantEquals(0x3C00) ||
3814            info.instr->operands[1].constantEquals(0x3C00))) {
3815 
3816          VALU_instruction* fneg = &info.instr->valu();
3817 
3818          unsigned fneg_src = fneg->operands[0].constantEquals(0x3C00);
3819 
3820          if (fneg->opsel_lo[1 - fneg_src] || fneg->opsel_hi[1 - fneg_src])
3821             continue;
3822 
3823          Operand ops[3];
3824          for (unsigned j = 0; j < instr->operands.size(); j++)
3825             ops[j] = instr->operands[j];
3826          ops[i] = fneg->operands[fneg_src];
3827          if (!check_vop3_operands(ctx, instr->operands.size(), ops))
3828             continue;
3829 
3830          if (fneg->clamp)
3831             continue;
3832          instr->operands[i] = fneg->operands[fneg_src];
3833 
3834          /* opsel_lo/hi is either 0 or 1:
3835           * if 0 - pick selection from fneg->lo
3836           * if 1 - pick selection from fneg->hi
3837           */
3838          bool opsel_lo = vop3p->opsel_lo[i];
3839          bool opsel_hi = vop3p->opsel_hi[i];
3840          bool neg_lo = fneg->neg_lo[0] ^ fneg->neg_lo[1];
3841          bool neg_hi = fneg->neg_hi[0] ^ fneg->neg_hi[1];
3842          vop3p->neg_lo[i] ^= opsel_lo ? neg_hi : neg_lo;
3843          vop3p->neg_hi[i] ^= opsel_hi ? neg_hi : neg_lo;
3844          vop3p->opsel_lo[i] ^= opsel_lo ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3845          vop3p->opsel_hi[i] ^= opsel_hi ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3846 
3847          if (--ctx.uses[fneg->definitions[0].tempId()])
3848             ctx.uses[fneg->operands[fneg_src].tempId()]++;
3849       }
3850    }
3851 
3852    if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) {
3853       bool fadd = instr->opcode == aco_opcode::v_pk_add_f16;
3854       if (fadd && instr->definitions[0].isPrecise())
3855          return;
3856 
3857       Instruction* mul_instr = nullptr;
3858       unsigned add_op_idx = 0;
3859       bitarray8 mul_neg_lo = 0, mul_neg_hi = 0, mul_opsel_lo = 0, mul_opsel_hi = 0;
3860       uint32_t uses = UINT32_MAX;
3861 
3862       /* find the 'best' mul instruction to combine with the add */
3863       for (unsigned i = 0; i < 2; i++) {
3864          Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3865          if (!op_instr)
3866             continue;
3867 
3868          if (ctx.info[instr->operands[i].tempId()].is_vop3p()) {
3869             if (fadd) {
3870                if (op_instr->opcode != aco_opcode::v_pk_mul_f16 ||
3871                    op_instr->definitions[0].isPrecise())
3872                   continue;
3873             } else {
3874                if (op_instr->opcode != aco_opcode::v_pk_mul_lo_u16)
3875                   continue;
3876             }
3877 
3878             Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3879             if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3880                continue;
3881 
3882             /* no clamp allowed between mul and add */
3883             if (op_instr->valu().clamp)
3884                continue;
3885 
3886             mul_instr = op_instr;
3887             add_op_idx = 1 - i;
3888             uses = ctx.uses[instr->operands[i].tempId()];
3889             mul_neg_lo = mul_instr->valu().neg_lo;
3890             mul_neg_hi = mul_instr->valu().neg_hi;
3891             mul_opsel_lo = mul_instr->valu().opsel_lo;
3892             mul_opsel_hi = mul_instr->valu().opsel_hi;
3893          } else if (instr->operands[i].bytes() == 2) {
3894             if ((fadd && (op_instr->opcode != aco_opcode::v_mul_f16 ||
3895                           op_instr->definitions[0].isPrecise())) ||
3896                 (!fadd && op_instr->opcode != aco_opcode::v_mul_lo_u16 &&
3897                  op_instr->opcode != aco_opcode::v_mul_lo_u16_e64))
3898                continue;
3899 
3900             if (op_instr->valu().clamp || op_instr->valu().omod || op_instr->valu().abs)
3901                continue;
3902 
3903             if (op_instr->isDPP() || (op_instr->isSDWA() && (op_instr->sdwa().sel[0].size() < 2 ||
3904                                                              op_instr->sdwa().sel[1].size() < 2)))
3905                continue;
3906 
3907             Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3908             if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3909                continue;
3910 
3911             mul_instr = op_instr;
3912             add_op_idx = 1 - i;
3913             uses = ctx.uses[instr->operands[i].tempId()];
3914             mul_neg_lo = mul_instr->valu().neg;
3915             mul_neg_hi = mul_instr->valu().neg;
3916             if (mul_instr->isSDWA()) {
3917                for (unsigned j = 0; j < 2; j++)
3918                   mul_opsel_lo[j] = mul_instr->sdwa().sel[j].offset();
3919             } else {
3920                mul_opsel_lo = mul_instr->valu().opsel;
3921             }
3922             mul_opsel_hi = mul_opsel_lo;
3923          }
3924       }
3925 
3926       if (!mul_instr)
3927          return;
3928 
3929       /* turn mul + packed add into v_pk_fma_f16 */
3930       aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
3931       aco_ptr<VALU_instruction> fma{create_instruction<VALU_instruction>(mad, Format::VOP3P, 3, 1)};
3932       fma->operands[0] = copy_operand(ctx, mul_instr->operands[0]);
3933       fma->operands[1] = copy_operand(ctx, mul_instr->operands[1]);
3934       fma->operands[2] = instr->operands[add_op_idx];
3935       fma->clamp = vop3p->clamp;
3936       fma->neg_lo = mul_neg_lo;
3937       fma->neg_hi = mul_neg_hi;
3938       fma->opsel_lo = mul_opsel_lo;
3939       fma->opsel_hi = mul_opsel_hi;
3940       propagate_swizzles(fma.get(), vop3p->opsel_lo[1 - add_op_idx],
3941                          vop3p->opsel_hi[1 - add_op_idx]);
3942       fma->opsel_lo[2] = vop3p->opsel_lo[add_op_idx];
3943       fma->opsel_hi[2] = vop3p->opsel_hi[add_op_idx];
3944       fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];
3945       fma->neg_hi[2] = vop3p->neg_hi[add_op_idx];
3946       fma->neg_lo[1] = fma->neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
3947       fma->neg_hi[1] = fma->neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
3948       fma->definitions[0] = instr->definitions[0];
3949       fma->pass_flags = instr->pass_flags;
3950       instr = std::move(fma);
3951       ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
3952       decrease_uses(ctx, mul_instr);
3953       return;
3954    }
3955 }
3956 
3957 bool
can_use_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3958 can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3959 {
3960    if (ctx.program->gfx_level < GFX9)
3961       return false;
3962 
3963    /* v_mad_mix* on GFX9 always flushes denormals for 16-bit inputs/outputs */
3964    if (ctx.program->gfx_level == GFX9 && ctx.fp_mode.denorm16_64)
3965       return false;
3966 
3967    if (instr->valu().omod)
3968       return false;
3969 
3970    switch (instr->opcode) {
3971    case aco_opcode::v_add_f32:
3972    case aco_opcode::v_sub_f32:
3973    case aco_opcode::v_subrev_f32:
3974    case aco_opcode::v_mul_f32: return !instr->isSDWA() && !instr->isDPP();
3975    case aco_opcode::v_fma_f32:
3976       return ctx.program->dev.fused_mad_mix || !instr->definitions[0].isPrecise();
3977    case aco_opcode::v_fma_mix_f32:
3978    case aco_opcode::v_fma_mixlo_f16: return true;
3979    default: return false;
3980    }
3981 }
3982 
3983 void
to_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3984 to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3985 {
3986    ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
3987 
3988    if (instr->opcode == aco_opcode::v_fma_f32) {
3989       instr->format = (Format)((uint32_t)withoutVOP3(instr->format) | (uint32_t)(Format::VOP3P));
3990       instr->opcode = aco_opcode::v_fma_mix_f32;
3991       return;
3992    }
3993 
3994    bool is_add = instr->opcode != aco_opcode::v_mul_f32;
3995 
3996    aco_ptr<VALU_instruction> vop3p{
3997       create_instruction<VALU_instruction>(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
3998 
3999    for (unsigned i = 0; i < instr->operands.size(); i++) {
4000       vop3p->operands[is_add + i] = instr->operands[i];
4001       vop3p->neg_lo[is_add + i] = instr->valu().neg[i];
4002       vop3p->neg_hi[is_add + i] = instr->valu().abs[i];
4003    }
4004    if (instr->opcode == aco_opcode::v_mul_f32) {
4005       vop3p->operands[2] = Operand::zero();
4006       vop3p->neg_lo[2] = true;
4007    } else if (is_add) {
4008       vop3p->operands[0] = Operand::c32(0x3f800000);
4009       if (instr->opcode == aco_opcode::v_sub_f32)
4010          vop3p->neg_lo[2] ^= true;
4011       else if (instr->opcode == aco_opcode::v_subrev_f32)
4012          vop3p->neg_lo[1] ^= true;
4013    }
4014    vop3p->definitions[0] = instr->definitions[0];
4015    vop3p->clamp = instr->valu().clamp;
4016    vop3p->pass_flags = instr->pass_flags;
4017    instr = std::move(vop3p);
4018 
4019    if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
4020       ctx.info[instr->definitions[0].tempId()].instr = instr.get();
4021 }
4022 
4023 bool
combine_output_conversion(opt_ctx & ctx,aco_ptr<Instruction> & instr)4024 combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4025 {
4026    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
4027    if (!def_info.is_f2f16())
4028       return false;
4029    Instruction* conv = def_info.instr;
4030 
4031    if (!ctx.uses[conv->definitions[0].tempId()] || ctx.uses[instr->definitions[0].tempId()] != 1)
4032       return false;
4033 
4034    if (conv->usesModifiers())
4035       return false;
4036 
4037    if (instr->opcode == aco_opcode::v_interp_p2_f32_inreg)
4038       interp_p2_f32_inreg_to_fma_dpp(instr);
4039 
4040    if (!can_use_mad_mix(ctx, instr))
4041       return false;
4042 
4043    if (!instr->isVOP3P())
4044       to_mad_mix(ctx, instr);
4045 
4046    instr->opcode = aco_opcode::v_fma_mixlo_f16;
4047    instr->definitions[0].swapTemp(conv->definitions[0]);
4048    if (conv->definitions[0].isPrecise())
4049       instr->definitions[0].setPrecise(true);
4050    ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
4051    ctx.uses[conv->definitions[0].tempId()]--;
4052 
4053    return true;
4054 }
4055 
4056 void
combine_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)4057 combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4058 {
4059    if (!can_use_mad_mix(ctx, instr))
4060       return;
4061 
4062    for (unsigned i = 0; i < instr->operands.size(); i++) {
4063       if (!instr->operands[i].isTemp())
4064          continue;
4065       Temp tmp = instr->operands[i].getTemp();
4066       if (!ctx.info[tmp.id()].is_f2f32())
4067          continue;
4068 
4069       Instruction* conv = ctx.info[tmp.id()].instr;
4070       if (conv->valu().clamp || conv->valu().omod) {
4071          continue;
4072       } else if (conv->isSDWA() &&
4073                  (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2)) {
4074          continue;
4075       } else if (conv->isDPP()) {
4076          continue;
4077       }
4078 
4079       if (get_operand_size(instr, i) != 32)
4080          continue;
4081 
4082       /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect
4083        * check_vop3_operands(). */
4084       Operand op[3];
4085       for (unsigned j = 0; j < instr->operands.size(); j++)
4086          op[j] = instr->operands[j];
4087       op[i] = conv->operands[0];
4088       if (!check_vop3_operands(ctx, instr->operands.size(), op))
4089          continue;
4090       if (!conv->operands[0].isOfType(RegType::vgpr) && instr->isDPP())
4091          continue;
4092 
4093       if (!instr->isVOP3P()) {
4094          bool is_add =
4095             instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
4096          to_mad_mix(ctx, instr);
4097          i += is_add;
4098       }
4099 
4100       if (--ctx.uses[tmp.id()])
4101          ctx.uses[conv->operands[0].tempId()]++;
4102       instr->operands[i].setTemp(conv->operands[0].getTemp());
4103       if (conv->definitions[0].isPrecise())
4104          instr->definitions[0].setPrecise(true);
4105       instr->valu().opsel_hi[i] = true;
4106       if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
4107          instr->valu().opsel_lo[i] = true;
4108       else
4109          instr->valu().opsel_lo[i] = conv->valu().opsel[0];
4110       bool neg = conv->valu().neg[0];
4111       bool abs = conv->valu().abs[0];
4112       if (!instr->valu().abs[i]) {
4113          instr->valu().neg[i] ^= neg;
4114          instr->valu().abs[i] = abs;
4115       }
4116    }
4117 }
4118 
4119 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
4120 // this would mean that we'd have to fix the instruction uses while value propagation
4121 
4122 /* also returns true for inf */
4123 bool
is_pow_of_two(opt_ctx & ctx,Operand op)4124 is_pow_of_two(opt_ctx& ctx, Operand op)
4125 {
4126    if (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(op.bytes() * 8))
4127       return is_pow_of_two(ctx, get_constant_op(ctx, ctx.info[op.tempId()], op.bytes() * 8));
4128    else if (!op.isConstant())
4129       return false;
4130 
4131    uint64_t val = op.constantValue64();
4132 
4133    if (op.bytes() == 4) {
4134       uint32_t exponent = (val & 0x7f800000) >> 23;
4135       uint32_t fraction = val & 0x007fffff;
4136       return (exponent >= 127) && (fraction == 0);
4137    } else if (op.bytes() == 2) {
4138       uint32_t exponent = (val & 0x7c00) >> 10;
4139       uint32_t fraction = val & 0x03ff;
4140       return (exponent >= 15) && (fraction == 0);
4141    } else {
4142       assert(op.bytes() == 8);
4143       uint64_t exponent = (val & UINT64_C(0x7ff0000000000000)) >> 52;
4144       uint64_t fraction = val & UINT64_C(0x000fffffffffffff);
4145       return (exponent >= 1023) && (fraction == 0);
4146    }
4147 }
4148 
4149 void
combine_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4150 combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4151 {
4152    if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
4153       return;
4154 
4155    if (instr->isVALU()) {
4156       /* Apply SDWA. Do this after label_instruction() so it can remove
4157        * label_extract if not all instructions can take SDWA. */
4158       for (unsigned i = 0; i < instr->operands.size(); i++) {
4159          Operand& op = instr->operands[i];
4160          if (!op.isTemp())
4161             continue;
4162          ssa_info& info = ctx.info[op.tempId()];
4163          if (!info.is_extract())
4164             continue;
4165          /* if there are that many uses, there are likely better combinations */
4166          // TODO: delay applying extract to a point where we know better
4167          if (ctx.uses[op.tempId()] > 4) {
4168             info.label &= ~label_extract;
4169             continue;
4170          }
4171          if (info.is_extract() &&
4172              (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
4173               instr->operands[i].getTemp().type() == RegType::sgpr) &&
4174              can_apply_extract(ctx, instr, i, info)) {
4175             /* Increase use count of the extract's operand if the extract still has uses. */
4176             apply_extract(ctx, instr, i, info);
4177             if (--ctx.uses[instr->operands[i].tempId()])
4178                ctx.uses[info.instr->operands[0].tempId()]++;
4179             instr->operands[i].setTemp(info.instr->operands[0].getTemp());
4180          }
4181       }
4182 
4183       if (can_apply_sgprs(ctx, instr))
4184          apply_sgprs(ctx, instr);
4185       combine_mad_mix(ctx, instr);
4186       while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr))
4187          ;
4188       apply_insert(ctx, instr);
4189    }
4190 
4191    if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
4192        instr->opcode != aco_opcode::v_fma_mixlo_f16)
4193       return combine_vop3p(ctx, instr);
4194 
4195    if (instr->isSDWA() || instr->isDPP())
4196       return;
4197 
4198    if (instr->opcode == aco_opcode::p_extract) {
4199       ssa_info& info = ctx.info[instr->operands[0].tempId()];
4200       if (info.is_extract() && can_apply_extract(ctx, instr, 0, info)) {
4201          apply_extract(ctx, instr, 0, info);
4202          if (--ctx.uses[instr->operands[0].tempId()])
4203             ctx.uses[info.instr->operands[0].tempId()]++;
4204          instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4205       }
4206 
4207       apply_ds_extract(ctx, instr);
4208    }
4209 
4210    if (instr->isVOPC()) {
4211       if (optimize_cmp_subgroup_invocation(ctx, instr))
4212          return;
4213    }
4214 
4215    /* TODO: There are still some peephole optimizations that could be done:
4216     * - abs(a - b) -> s_absdiff_i32
4217     * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
4218     * - patterns for v_alignbit_b32 and v_alignbyte_b32
4219     * These aren't probably too interesting though.
4220     * There are also patterns for v_cmp_class_f{16,32,64}. This is difficult but
4221     * probably more useful than the previously mentioned optimizations.
4222     * The various comparison optimizations also currently only work with 32-bit
4223     * floats. */
4224 
4225    /* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
4226    if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) &&
4227        ctx.uses[instr->operands[1].tempId()] == 1) {
4228       Temp val = ctx.info[instr->definitions[0].tempId()].temp;
4229 
4230       if (!ctx.info[val.id()].is_mul())
4231          return;
4232 
4233       Instruction* mul_instr = ctx.info[val.id()].instr;
4234 
4235       if (mul_instr->operands[0].isLiteral())
4236          return;
4237       if (mul_instr->valu().clamp)
4238          return;
4239       if (mul_instr->isSDWA() || mul_instr->isDPP())
4240          return;
4241       if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
4242           ctx.fp_mode.preserve_signed_zero_inf_nan32)
4243          return;
4244       if (mul_instr->definitions[0].bytes() != instr->definitions[0].bytes())
4245          return;
4246 
4247       /* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
4248       ctx.uses[mul_instr->definitions[0].tempId()]--;
4249       Definition def = instr->definitions[0];
4250       bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg();
4251       bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
4252       uint32_t pass_flags = instr->pass_flags;
4253       Format format = mul_instr->format == Format::VOP2 ? asVOP3(Format::VOP2) : mul_instr->format;
4254       instr.reset(create_instruction<VALU_instruction>(mul_instr->opcode, format,
4255                                                        mul_instr->operands.size(), 1));
4256       std::copy(mul_instr->operands.cbegin(), mul_instr->operands.cend(), instr->operands.begin());
4257       instr->pass_flags = pass_flags;
4258       instr->definitions[0] = def;
4259       VALU_instruction& new_mul = instr->valu();
4260       VALU_instruction& mul = mul_instr->valu();
4261       new_mul.neg = mul.neg;
4262       new_mul.abs = mul.abs;
4263       new_mul.omod = mul.omod;
4264       new_mul.opsel = mul.opsel;
4265       new_mul.opsel_lo = mul.opsel_lo;
4266       new_mul.opsel_hi = mul.opsel_hi;
4267       if (is_abs) {
4268          new_mul.neg[0] = new_mul.neg[1] = false;
4269          new_mul.abs[0] = new_mul.abs[1] = true;
4270       }
4271       new_mul.neg[0] ^= is_neg;
4272       new_mul.clamp = false;
4273 
4274       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
4275       return;
4276    }
4277 
4278    /* combine mul+add -> mad */
4279    bool is_add_mix =
4280       (instr->opcode == aco_opcode::v_fma_mix_f32 ||
4281        instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
4282       !instr->valu().neg_lo[0] &&
4283       ((instr->operands[0].constantEquals(0x3f800000) && !instr->valu().opsel_hi[0]) ||
4284        (instr->operands[0].constantEquals(0x3C00) && instr->valu().opsel_hi[0] &&
4285         !instr->valu().opsel_lo[0]));
4286    bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
4287                 instr->opcode == aco_opcode::v_subrev_f32;
4288    bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
4289                 instr->opcode == aco_opcode::v_subrev_f16;
4290    bool mad64 = instr->opcode == aco_opcode::v_add_f64;
4291    if (is_add_mix || mad16 || mad32 || mad64) {
4292       Instruction* mul_instr = nullptr;
4293       unsigned add_op_idx = 0;
4294       uint32_t uses = UINT32_MAX;
4295       bool emit_fma = false;
4296       /* find the 'best' mul instruction to combine with the add */
4297       for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
4298          if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
4299             continue;
4300          ssa_info& info = ctx.info[instr->operands[i].tempId()];
4301 
4302          /* no clamp/omod allowed between mul and add */
4303          if (info.instr->isVOP3() && (info.instr->valu().clamp || info.instr->valu().omod))
4304             continue;
4305          if (info.instr->isVOP3P() && info.instr->valu().clamp)
4306             continue;
4307          /* v_fma_mix_f32/etc can't do omod */
4308          if (info.instr->isVOP3P() && instr->isVOP3() && instr->valu().omod)
4309             continue;
4310          /* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
4311          if (is_add_mix && info.instr->definitions[0].bytes() == 2)
4312             continue;
4313 
4314          if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
4315             continue;
4316 
4317          bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
4318          bool mad_mix = is_add_mix || info.instr->isVOP3P();
4319 
4320          /* Multiplication by power-of-two should never need rounding. 1/power-of-two also works,
4321           * but using fma removes denormal flushing (0xfffffe * 0.5 + 0x810001a2).
4322           */
4323          bool is_fma_precise = is_pow_of_two(ctx, info.instr->operands[0]) ||
4324                                is_pow_of_two(ctx, info.instr->operands[1]);
4325 
4326          bool has_fma = mad16 || mad64 || (legacy && ctx.program->gfx_level >= GFX10_3) ||
4327                         (mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
4328                         (mad_mix && ctx.program->dev.fused_mad_mix);
4329          bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
4330                                 : ((mad32 && ctx.program->gfx_level < GFX10_3) ||
4331                                    (mad16 && ctx.program->gfx_level <= GFX9));
4332          bool can_use_fma =
4333             has_fma &&
4334             (!(info.instr->definitions[0].isPrecise() || instr->definitions[0].isPrecise()) ||
4335              is_fma_precise);
4336          bool can_use_mad =
4337             has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
4338          if (mad_mix && legacy)
4339             continue;
4340          if (!can_use_fma && !can_use_mad)
4341             continue;
4342 
4343          unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
4344          Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
4345                           instr->operands[candidate_add_op_idx]};
4346          if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
4347              ctx.uses[instr->operands[i].tempId()] > uses)
4348             continue;
4349 
4350          if (ctx.uses[instr->operands[i].tempId()] == uses) {
4351             unsigned cur_idx = mul_instr->definitions[0].tempId();
4352             unsigned new_idx = info.instr->definitions[0].tempId();
4353             if (cur_idx > new_idx)
4354                continue;
4355          }
4356 
4357          mul_instr = info.instr;
4358          add_op_idx = candidate_add_op_idx;
4359          uses = ctx.uses[instr->operands[i].tempId()];
4360          emit_fma = !can_use_mad;
4361       }
4362 
4363       if (mul_instr) {
4364          /* turn mul+add into v_mad/v_fma */
4365          Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1],
4366                           instr->operands[add_op_idx]};
4367          ctx.uses[mul_instr->definitions[0].tempId()]--;
4368          if (ctx.uses[mul_instr->definitions[0].tempId()]) {
4369             if (op[0].isTemp())
4370                ctx.uses[op[0].tempId()]++;
4371             if (op[1].isTemp())
4372                ctx.uses[op[1].tempId()]++;
4373          }
4374 
4375          bool neg[3] = {false, false, false};
4376          bool abs[3] = {false, false, false};
4377          unsigned omod = 0;
4378          bool clamp = false;
4379          bitarray8 opsel_lo = 0;
4380          bitarray8 opsel_hi = 0;
4381          bitarray8 opsel = 0;
4382          unsigned mul_op_idx = (instr->isVOP3P() ? 3 : 1) - add_op_idx;
4383 
4384          VALU_instruction& valu_mul = mul_instr->valu();
4385          neg[0] = valu_mul.neg[0];
4386          neg[1] = valu_mul.neg[1];
4387          abs[0] = valu_mul.abs[0];
4388          abs[1] = valu_mul.abs[1];
4389          opsel_lo = valu_mul.opsel_lo & 0x3;
4390          opsel_hi = valu_mul.opsel_hi & 0x3;
4391          opsel = valu_mul.opsel & 0x3;
4392 
4393          VALU_instruction& valu = instr->valu();
4394          neg[2] = valu.neg[add_op_idx];
4395          abs[2] = valu.abs[add_op_idx];
4396          opsel_lo[2] = valu.opsel_lo[add_op_idx];
4397          opsel_hi[2] = valu.opsel_hi[add_op_idx];
4398          opsel[2] = valu.opsel[add_op_idx];
4399          opsel[3] = valu.opsel[3];
4400          omod = valu.omod;
4401          clamp = valu.clamp;
4402          /* abs of the multiplication result */
4403          if (valu.abs[mul_op_idx]) {
4404             neg[0] = false;
4405             neg[1] = false;
4406             abs[0] = true;
4407             abs[1] = true;
4408          }
4409          /* neg of the multiplication result */
4410          neg[1] ^= valu.neg[mul_op_idx];
4411 
4412          if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
4413             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
4414          else if (instr->opcode == aco_opcode::v_subrev_f32 ||
4415                   instr->opcode == aco_opcode::v_subrev_f16)
4416             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
4417 
4418          aco_ptr<Instruction> add_instr = std::move(instr);
4419          aco_ptr<VALU_instruction> mad;
4420          if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
4421             assert(!omod);
4422             assert(!opsel);
4423 
4424             aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
4425                                                                        : aco_opcode::v_fma_mix_f32;
4426             mad.reset(create_instruction<VALU_instruction>(mad_op, Format::VOP3P, 3, 1));
4427          } else {
4428             assert(!opsel_lo);
4429             assert(!opsel_hi);
4430 
4431             aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
4432             if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
4433                assert(emit_fma == (ctx.program->gfx_level >= GFX10_3));
4434                mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
4435             } else if (mad16) {
4436                mad_op = emit_fma ? (ctx.program->gfx_level == GFX8 ? aco_opcode::v_fma_legacy_f16
4437                                                                    : aco_opcode::v_fma_f16)
4438                                  : (ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_f16
4439                                                                    : aco_opcode::v_mad_f16);
4440             } else if (mad64) {
4441                mad_op = aco_opcode::v_fma_f64;
4442             }
4443 
4444             mad.reset(create_instruction<VALU_instruction>(mad_op, Format::VOP3, 3, 1));
4445          }
4446 
4447          for (unsigned i = 0; i < 3; i++) {
4448             mad->operands[i] = op[i];
4449             mad->neg[i] = neg[i];
4450             mad->abs[i] = abs[i];
4451          }
4452          mad->omod = omod;
4453          mad->clamp = clamp;
4454          mad->opsel_lo = opsel_lo;
4455          mad->opsel_hi = opsel_hi;
4456          mad->opsel = opsel;
4457          mad->definitions[0] = add_instr->definitions[0];
4458          mad->definitions[0].setPrecise(add_instr->definitions[0].isPrecise() ||
4459                                         mul_instr->definitions[0].isPrecise());
4460          mad->pass_flags = add_instr->pass_flags;
4461 
4462          instr = std::move(mad);
4463 
4464          /* mark this ssa_def to be re-checked for profitability and literals */
4465          ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
4466          ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4467          return;
4468       }
4469    }
4470    /* v_mul_f32(v_cndmask_b32(0, 1.0, cond), a) -> v_cndmask_b32(0, a, cond) */
4471    else if (((instr->opcode == aco_opcode::v_mul_f32 &&
4472               !ctx.fp_mode.preserve_signed_zero_inf_nan32) ||
4473              instr->opcode == aco_opcode::v_mul_legacy_f32) &&
4474             !instr->usesModifiers() && !ctx.fp_mode.must_flush_denorms32) {
4475       for (unsigned i = 0; i < 2; i++) {
4476          if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2f() &&
4477              ctx.uses[instr->operands[i].tempId()] == 1 && instr->operands[!i].isTemp() &&
4478              instr->operands[!i].getTemp().type() == RegType::vgpr) {
4479             ctx.uses[instr->operands[i].tempId()]--;
4480             ctx.uses[ctx.info[instr->operands[i].tempId()].temp.id()]++;
4481 
4482             aco_ptr<VALU_instruction> new_instr{
4483                create_instruction<VALU_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1)};
4484             new_instr->operands[0] = Operand::zero();
4485             new_instr->operands[1] = instr->operands[!i];
4486             new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
4487             new_instr->definitions[0] = instr->definitions[0];
4488             new_instr->pass_flags = instr->pass_flags;
4489             instr = std::move(new_instr);
4490             ctx.info[instr->definitions[0].tempId()].label = 0;
4491             return;
4492          }
4493       }
4494    } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->gfx_level >= GFX9) {
4495       if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
4496                                 1 | 2)) {
4497       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
4498                                        "012", 1 | 2)) {
4499       } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4500       } else if (combine_v_andor_not(ctx, instr)) {
4501       }
4502    } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->gfx_level >= GFX10) {
4503       if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
4504                                 1 | 2)) {
4505       } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32,
4506                                        "012", 1 | 2)) {
4507       } else if (combine_xor_not(ctx, instr)) {
4508       }
4509    } else if (instr->opcode == aco_opcode::v_not_b32 && ctx.program->gfx_level >= GFX10) {
4510       combine_not_xor(ctx, instr);
4511    } else if (instr->opcode == aco_opcode::v_add_u16) {
4512       combine_three_valu_op(
4513          ctx, instr, aco_opcode::v_mul_lo_u16,
4514          ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_u16 : aco_opcode::v_mad_u16,
4515          "120", 1 | 2);
4516    } else if (instr->opcode == aco_opcode::v_add_u16_e64) {
4517       combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16_e64, aco_opcode::v_mad_u16, "120",
4518                             1 | 2);
4519    } else if (instr->opcode == aco_opcode::v_add_u32) {
4520       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4521       } else if (combine_add_bcnt(ctx, instr)) {
4522       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4523                                        aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4524       } else if (ctx.program->gfx_level >= GFX9 && !instr->usesModifiers()) {
4525          if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120",
4526                                    1 | 2)) {
4527          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32,
4528                                           "120", 1 | 2)) {
4529          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32,
4530                                           "012", 1 | 2)) {
4531          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32,
4532                                           "012", 1 | 2)) {
4533          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32,
4534                                           "012", 1 | 2)) {
4535          } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4536          }
4537       }
4538    } else if (instr->opcode == aco_opcode::v_add_co_u32 ||
4539               instr->opcode == aco_opcode::v_add_co_u32_e64) {
4540       bool carry_out = ctx.uses[instr->definitions[1].tempId()] > 0;
4541       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4542       } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
4543       } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4544                                                      aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4545       } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
4546       }
4547    } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == aco_opcode::v_sub_co_u32 ||
4548               instr->opcode == aco_opcode::v_sub_co_u32_e64) {
4549       bool carry_out =
4550          instr->opcode != aco_opcode::v_sub_u32 && ctx.uses[instr->definitions[1].tempId()] > 0;
4551       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
4552       } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
4553       }
4554    } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
4555               instr->opcode == aco_opcode::v_subrev_co_u32 ||
4556               instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
4557       combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
4558    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->gfx_level >= GFX9) {
4559       combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120",
4560                             2);
4561    } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) &&
4562               ctx.program->gfx_level >= GFX9) {
4563       combine_salu_lshl_add(ctx, instr);
4564    } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) {
4565       if (!combine_salu_not_bitwise(ctx, instr))
4566          combine_inverse_comparison(ctx, instr);
4567    } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
4568               instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
4569       if (combine_ordering_test(ctx, instr)) {
4570       } else if (combine_comparison_ordering(ctx, instr)) {
4571       } else if (combine_constant_comparison_ordering(ctx, instr)) {
4572       } else if (combine_salu_n2(ctx, instr)) {
4573       }
4574    } else if (instr->opcode == aco_opcode::s_abs_i32) {
4575       combine_sabsdiff(ctx, instr);
4576    } else if (instr->opcode == aco_opcode::v_and_b32) {
4577       if (combine_and_subbrev(ctx, instr)) {
4578       } else if (combine_v_andor_not(ctx, instr)) {
4579       }
4580    } else if (instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) {
4581       /* set existing v_fma_f32 with label_mad so we can create v_fmamk_f32/v_fmaak_f32.
4582        * since ctx.uses[mad_info::mul_temp_id] is always 0, we don't have to worry about
4583        * select_instruction() using mad_info::add_instr.
4584        */
4585       ctx.mad_infos.emplace_back(nullptr, 0);
4586       ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4587    } else if (instr->opcode == aco_opcode::v_med3_f32 || instr->opcode == aco_opcode::v_med3_f16) {
4588       /* Optimize v_med3 to v_add so that it can be dual issued on GFX11. We start with v_med3 in
4589        * case omod can be applied.
4590        */
4591       unsigned idx;
4592       if (detect_clamp(instr.get(), &idx)) {
4593          instr->format = asVOP3(Format::VOP2);
4594          instr->operands[0] = instr->operands[idx];
4595          instr->operands[1] = Operand::zero();
4596          instr->opcode =
4597             instr->opcode == aco_opcode::v_med3_f32 ? aco_opcode::v_add_f32 : aco_opcode::v_add_f16;
4598          instr->valu().clamp = true;
4599          instr->valu().abs = (uint8_t)instr->valu().abs[idx];
4600          instr->valu().neg = (uint8_t)instr->valu().neg[idx];
4601          instr->operands.pop_back();
4602       }
4603    } else {
4604       aco_opcode min, max, min3, max3, med3, minmax;
4605       bool some_gfx9_only;
4606       if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &minmax,
4607                           &some_gfx9_only) &&
4608           (!some_gfx9_only || ctx.program->gfx_level >= GFX9)) {
4609          if (combine_minmax(ctx, instr, instr->opcode == min ? max : min,
4610                             instr->opcode == min ? min3 : max3, minmax)) {
4611          } else {
4612             combine_clamp(ctx, instr, min, max, med3);
4613          }
4614       }
4615    }
4616 }
4617 
4618 bool
to_uniform_bool_instr(opt_ctx & ctx,aco_ptr<Instruction> & instr)4619 to_uniform_bool_instr(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4620 {
4621    /* Check every operand to make sure they are suitable. */
4622    for (Operand& op : instr->operands) {
4623       if (!op.isTemp())
4624          return false;
4625       if (!ctx.info[op.tempId()].is_uniform_bool() && !ctx.info[op.tempId()].is_uniform_bitwise())
4626          return false;
4627    }
4628 
4629    switch (instr->opcode) {
4630    case aco_opcode::s_and_b32:
4631    case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_and_b32; break;
4632    case aco_opcode::s_or_b32:
4633    case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_or_b32; break;
4634    case aco_opcode::s_xor_b32:
4635    case aco_opcode::s_xor_b64: instr->opcode = aco_opcode::s_absdiff_i32; break;
4636    default:
4637       /* Don't transform other instructions. They are very unlikely to appear here. */
4638       return false;
4639    }
4640 
4641    for (Operand& op : instr->operands) {
4642       ctx.uses[op.tempId()]--;
4643 
4644       if (ctx.info[op.tempId()].is_uniform_bool()) {
4645          /* Just use the uniform boolean temp. */
4646          op.setTemp(ctx.info[op.tempId()].temp);
4647       } else if (ctx.info[op.tempId()].is_uniform_bitwise()) {
4648          /* Use the SCC definition of the predecessor instruction.
4649           * This allows the predecessor to get picked up by the same optimization (if it has no
4650           * divergent users), and it also makes sure that the current instruction will keep working
4651           * even if the predecessor won't be transformed.
4652           */
4653          Instruction* pred_instr = ctx.info[op.tempId()].instr;
4654          assert(pred_instr->definitions.size() >= 2);
4655          assert(pred_instr->definitions[1].isFixed() &&
4656                 pred_instr->definitions[1].physReg() == scc);
4657          op.setTemp(pred_instr->definitions[1].getTemp());
4658       } else {
4659          unreachable("Invalid operand on uniform bitwise instruction.");
4660       }
4661 
4662       ctx.uses[op.tempId()]++;
4663    }
4664 
4665    instr->definitions[0].setTemp(Temp(instr->definitions[0].tempId(), s1));
4666    assert(instr->operands[0].regClass() == s1);
4667    assert(instr->operands[1].regClass() == s1);
4668    return true;
4669 }
4670 
4671 void
select_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4672 select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4673 {
4674    const uint32_t threshold = 4;
4675 
4676    if (is_dead(ctx.uses, instr.get())) {
4677       instr.reset();
4678       return;
4679    }
4680 
4681    /* convert split_vector into a copy or extract_vector if only one definition is ever used */
4682    if (instr->opcode == aco_opcode::p_split_vector) {
4683       unsigned num_used = 0;
4684       unsigned idx = 0;
4685       unsigned split_offset = 0;
4686       for (unsigned i = 0, offset = 0; i < instr->definitions.size();
4687            offset += instr->definitions[i++].bytes()) {
4688          if (ctx.uses[instr->definitions[i].tempId()]) {
4689             num_used++;
4690             idx = i;
4691             split_offset = offset;
4692          }
4693       }
4694       bool done = false;
4695       if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
4696           ctx.uses[instr->operands[0].tempId()] == 1) {
4697          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
4698 
4699          unsigned off = 0;
4700          Operand op;
4701          for (Operand& vec_op : vec->operands) {
4702             if (off == split_offset) {
4703                op = vec_op;
4704                break;
4705             }
4706             off += vec_op.bytes();
4707          }
4708          if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
4709             ctx.uses[instr->operands[0].tempId()]--;
4710             for (Operand& vec_op : vec->operands) {
4711                if (vec_op.isTemp())
4712                   ctx.uses[vec_op.tempId()]--;
4713             }
4714             if (op.isTemp())
4715                ctx.uses[op.tempId()]++;
4716 
4717             aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4718                aco_opcode::p_create_vector, Format::PSEUDO, 1, 1)};
4719             extract->operands[0] = op;
4720             extract->definitions[0] = instr->definitions[idx];
4721             instr = std::move(extract);
4722 
4723             done = true;
4724          }
4725       }
4726 
4727       if (!done && num_used == 1 &&
4728           instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
4729           split_offset % instr->definitions[idx].bytes() == 0) {
4730          aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4731             aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
4732          extract->operands[0] = instr->operands[0];
4733          extract->operands[1] =
4734             Operand::c32((uint32_t)split_offset / instr->definitions[idx].bytes());
4735          extract->definitions[0] = instr->definitions[idx];
4736          instr = std::move(extract);
4737       }
4738    }
4739 
4740    mad_info* mad_info = NULL;
4741    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4742       mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
4743       /* re-check mad instructions */
4744       if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) {
4745          ctx.uses[mad_info->mul_temp_id]++;
4746          if (instr->operands[0].isTemp())
4747             ctx.uses[instr->operands[0].tempId()]--;
4748          if (instr->operands[1].isTemp())
4749             ctx.uses[instr->operands[1].tempId()]--;
4750          instr.swap(mad_info->add_instr);
4751          mad_info = NULL;
4752       }
4753       /* check literals */
4754       else if (!instr->isDPP() && !instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_f64 &&
4755                instr->opcode != aco_opcode::v_mad_legacy_f32 &&
4756                instr->opcode != aco_opcode::v_fma_legacy_f32) {
4757          /* FMA can only take literals on GFX10+ */
4758          if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
4759              ctx.program->gfx_level < GFX10)
4760             return;
4761          /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take
4762           * literals (GFX10+), these instructions don't exist.
4763           */
4764          if (instr->opcode == aco_opcode::v_fma_legacy_f16)
4765             return;
4766 
4767          uint32_t literal_mask = 0;
4768          uint32_t fp16_mask = 0;
4769          uint32_t sgpr_mask = 0;
4770          uint32_t vgpr_mask = 0;
4771          uint32_t literal_uses = UINT32_MAX;
4772          uint32_t literal_value = 0;
4773 
4774          /* Iterate in reverse to prefer v_madak/v_fmaak. */
4775          for (int i = 2; i >= 0; i--) {
4776             Operand& op = instr->operands[i];
4777             if (!op.isTemp())
4778                continue;
4779             if (ctx.info[op.tempId()].is_literal(get_operand_size(instr, i))) {
4780                uint32_t new_literal = ctx.info[op.tempId()].val;
4781                float value = uif(new_literal);
4782                uint16_t fp16_val = _mesa_float_to_half(value);
4783                bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
4784                if (_mesa_half_to_float(fp16_val) == value &&
4785                    (!is_denorm || (ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)))
4786                   fp16_mask |= 1 << i;
4787 
4788                if (!literal_mask || literal_value == new_literal) {
4789                   literal_value = new_literal;
4790                   literal_uses = MIN2(literal_uses, ctx.uses[op.tempId()]);
4791                   literal_mask |= 1 << i;
4792                   continue;
4793                }
4794             }
4795             sgpr_mask |= op.isOfType(RegType::sgpr) << i;
4796             vgpr_mask |= op.isOfType(RegType::vgpr) << i;
4797          }
4798 
4799          /* The constant bus limitations before GFX10 disallows SGPRs. */
4800          if (sgpr_mask && ctx.program->gfx_level < GFX10)
4801             literal_mask = 0;
4802 
4803          /* Encoding needs a vgpr. */
4804          if (!vgpr_mask)
4805             literal_mask = 0;
4806 
4807          /* v_madmk/v_fmamk needs a vgpr in the third source. */
4808          if (!(literal_mask & 0b100) && !(vgpr_mask & 0b100))
4809             literal_mask = 0;
4810 
4811          /* opsel with GFX11+ is the only modifier supported by fmamk/fmaak*/
4812          if (instr->valu().abs || instr->valu().neg || instr->valu().omod || instr->valu().clamp ||
4813              (instr->valu().opsel && ctx.program->gfx_level < GFX11))
4814             literal_mask = 0;
4815 
4816          if (instr->valu().opsel & ~vgpr_mask)
4817             literal_mask = 0;
4818 
4819          /* We can't use three unique fp16 literals */
4820          if (fp16_mask == 0b111)
4821             fp16_mask = 0b11;
4822 
4823          if ((instr->opcode == aco_opcode::v_fma_f32 ||
4824               (instr->opcode == aco_opcode::v_mad_f32 && !instr->definitions[0].isPrecise())) &&
4825              !instr->valu().omod && ctx.program->gfx_level >= GFX10 &&
4826              util_bitcount(fp16_mask) > std::max<uint32_t>(util_bitcount(literal_mask), 1)) {
4827             assert(ctx.program->dev.fused_mad_mix);
4828             u_foreach_bit (i, fp16_mask)
4829                ctx.uses[instr->operands[i].tempId()]--;
4830             mad_info->fp16_mask = fp16_mask;
4831             return;
4832          }
4833 
4834          /* Limit the number of literals to apply to not increase the code
4835           * size too much, but always apply literals for v_mad->v_madak
4836           * because both instructions are 64-bit and this doesn't increase
4837           * code size.
4838           * TODO: try to apply the literals earlier to lower the number of
4839           * uses below threshold
4840           */
4841          if (literal_mask && (literal_uses < threshold || (literal_mask & 0b100))) {
4842             u_foreach_bit (i, literal_mask)
4843                ctx.uses[instr->operands[i].tempId()]--;
4844             mad_info->literal_mask = literal_mask;
4845             return;
4846          }
4847       }
4848    }
4849 
4850    /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions
4851     * when it isn't beneficial */
4852    if (instr->isBranch() && instr->operands.size() && instr->operands[0].isTemp() &&
4853        instr->operands[0].isFixed() && instr->operands[0].physReg() == scc) {
4854       ctx.info[instr->operands[0].tempId()].set_scc_needed();
4855       return;
4856    } else if ((instr->opcode == aco_opcode::s_cselect_b64 ||
4857                instr->opcode == aco_opcode::s_cselect_b32) &&
4858               instr->operands[2].isTemp()) {
4859       ctx.info[instr->operands[2].tempId()].set_scc_needed();
4860    }
4861 
4862    /* check for literals */
4863    if (!instr->isSALU() && !instr->isVALU())
4864       return;
4865 
4866    /* Transform uniform bitwise boolean operations to 32-bit when there are no divergent uses. */
4867    if (instr->definitions.size() && ctx.uses[instr->definitions[0].tempId()] == 0 &&
4868        ctx.info[instr->definitions[0].tempId()].is_uniform_bitwise()) {
4869       bool transform_done = to_uniform_bool_instr(ctx, instr);
4870 
4871       if (transform_done && !ctx.info[instr->definitions[1].tempId()].is_scc_needed()) {
4872          /* Swap the two definition IDs in order to avoid overusing the SCC.
4873           * This reduces extra moves generated by RA. */
4874          uint32_t def0_id = instr->definitions[0].getTemp().id();
4875          uint32_t def1_id = instr->definitions[1].getTemp().id();
4876          instr->definitions[0].setTemp(Temp(def1_id, s1));
4877          instr->definitions[1].setTemp(Temp(def0_id, s1));
4878       }
4879 
4880       return;
4881    }
4882 
4883    /* This optimization is done late in order to be able to apply otherwise
4884     * unsafe optimizations such as the inverse comparison optimization.
4885     */
4886    if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
4887       if (instr->operands[0].isTemp() && fixed_to_exec(instr->operands[1]) &&
4888           ctx.uses[instr->operands[0].tempId()] == 1 &&
4889           ctx.uses[instr->definitions[1].tempId()] == 0 &&
4890           can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), instr->pass_flags)) {
4891          ctx.uses[instr->operands[0].tempId()]--;
4892          ctx.info[instr->operands[0].tempId()].instr->definitions[0].setTemp(
4893             instr->definitions[0].getTemp());
4894          instr.reset();
4895          return;
4896       }
4897    }
4898 
4899    /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
4900    if (instr->isVALU() && !instr->isDPP()) {
4901       for (unsigned i = 0; i < instr->operands.size(); i++) {
4902          if (!instr->operands[i].isTemp())
4903             continue;
4904          ssa_info info = ctx.info[instr->operands[i].tempId()];
4905 
4906          if (!info.is_dpp() || info.instr->pass_flags != instr->pass_flags)
4907             continue;
4908 
4909          /* We won't eliminate the DPP mov if the operand is used twice */
4910          bool op_used_twice = false;
4911          for (unsigned j = 0; j < instr->operands.size(); j++)
4912             op_used_twice |= i != j && instr->operands[i] == instr->operands[j];
4913          if (op_used_twice)
4914             continue;
4915 
4916          if (i != 0) {
4917             if (!can_swap_operands(instr, &instr->opcode, 0, i))
4918                continue;
4919             instr->valu().swapOperands(0, i);
4920          }
4921 
4922          if (!can_use_DPP(ctx.program->gfx_level, instr, info.is_dpp8()))
4923             continue;
4924 
4925          bool dpp8 = info.is_dpp8();
4926          bool input_mods = can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, 0) &&
4927                            get_operand_size(instr, 0) == 32;
4928          bool mov_uses_mods = info.instr->valu().neg[0] || info.instr->valu().abs[0];
4929          if (((dpp8 && ctx.program->gfx_level < GFX11) || !input_mods) && mov_uses_mods)
4930             continue;
4931 
4932          convert_to_DPP(ctx.program->gfx_level, instr, dpp8);
4933 
4934          if (dpp8) {
4935             DPP8_instruction* dpp = &instr->dpp8();
4936             dpp->lane_sel = info.instr->dpp8().lane_sel;
4937             dpp->fetch_inactive = info.instr->dpp8().fetch_inactive;
4938             if (mov_uses_mods)
4939                instr->format = asVOP3(instr->format);
4940          } else {
4941             DPP16_instruction* dpp = &instr->dpp16();
4942             dpp->dpp_ctrl = info.instr->dpp16().dpp_ctrl;
4943             dpp->bound_ctrl = info.instr->dpp16().bound_ctrl;
4944             dpp->fetch_inactive = info.instr->dpp16().fetch_inactive;
4945          }
4946 
4947          instr->valu().neg[0] ^= info.instr->valu().neg[0] && !instr->valu().abs[0];
4948          instr->valu().abs[0] |= info.instr->valu().abs[0];
4949 
4950          if (--ctx.uses[info.instr->definitions[0].tempId()])
4951             ctx.uses[info.instr->operands[0].tempId()]++;
4952          instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4953          break;
4954       }
4955    }
4956 
4957    /* Use v_fma_mix for f2f32/f2f16 if it has higher throughput.
4958     * Do this late to not disturb other optimizations.
4959     */
4960    if ((instr->opcode == aco_opcode::v_cvt_f32_f16 || instr->opcode == aco_opcode::v_cvt_f16_f32) &&
4961        ctx.program->gfx_level >= GFX11 && ctx.program->wave_size == 64 && !instr->valu().omod &&
4962        !instr->isDPP()) {
4963       bool is_f2f16 = instr->opcode == aco_opcode::v_cvt_f16_f32;
4964       Instruction* fma = create_instruction<VALU_instruction>(
4965          is_f2f16 ? aco_opcode::v_fma_mixlo_f16 : aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1);
4966       fma->definitions[0] = instr->definitions[0];
4967       fma->operands[0] = instr->operands[0];
4968       fma->valu().opsel_hi[0] = !is_f2f16;
4969       fma->valu().opsel_lo[0] = instr->valu().opsel[0];
4970       fma->valu().clamp = instr->valu().clamp;
4971       fma->valu().abs[0] = instr->valu().abs[0];
4972       fma->valu().neg[0] = instr->valu().neg[0];
4973       fma->operands[1] = Operand::c32(fui(1.0f));
4974       fma->operands[2] = Operand::zero();
4975       /* fma_mix is only dual issued if dst and acc type match */
4976       fma->valu().opsel_hi[2] = is_f2f16;
4977       fma->valu().neg[2] = true;
4978       instr.reset(fma);
4979       ctx.info[instr->definitions[0].tempId()].label = 0;
4980    }
4981 
4982    if (instr->isSDWA() || (instr->isVOP3() && ctx.program->gfx_level < GFX10) ||
4983        (instr->isVOP3P() && ctx.program->gfx_level < GFX10))
4984       return; /* some encodings can't ever take literals */
4985 
4986    /* we do not apply the literals yet as we don't know if it is profitable */
4987    Operand current_literal(s1);
4988 
4989    unsigned literal_id = 0;
4990    unsigned literal_uses = UINT32_MAX;
4991    Operand literal(s1);
4992    unsigned num_operands = 1;
4993    if (instr->isSALU() || (ctx.program->gfx_level >= GFX10 &&
4994                            (can_use_VOP3(ctx, instr) || instr->isVOP3P()) && !instr->isDPP()))
4995       num_operands = instr->operands.size();
4996    /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
4997    else if (instr->isVALU() && instr->operands.size() >= 3)
4998       return;
4999 
5000    unsigned sgpr_ids[2] = {0, 0};
5001    bool is_literal_sgpr = false;
5002    uint32_t mask = 0;
5003 
5004    /* choose a literal to apply */
5005    for (unsigned i = 0; i < num_operands; i++) {
5006       Operand op = instr->operands[i];
5007       unsigned bits = get_operand_size(instr, i);
5008 
5009       if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
5010           op.tempId() != sgpr_ids[0])
5011          sgpr_ids[!!sgpr_ids[0]] = op.tempId();
5012 
5013       if (op.isLiteral()) {
5014          current_literal = op;
5015          continue;
5016       } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
5017          continue;
5018       }
5019 
5020       if (!alu_can_accept_constant(instr, i))
5021          continue;
5022 
5023       if (ctx.uses[op.tempId()] < literal_uses) {
5024          is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
5025          mask = 0;
5026          literal = Operand::c32(ctx.info[op.tempId()].val);
5027          literal_uses = ctx.uses[op.tempId()];
5028          literal_id = op.tempId();
5029       }
5030 
5031       mask |= (op.tempId() == literal_id) << i;
5032    }
5033 
5034    /* don't go over the constant bus limit */
5035    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
5036                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
5037                      instr->opcode == aco_opcode::v_ashrrev_i64;
5038    unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
5039    if (ctx.program->gfx_level >= GFX10 && !is_shift64)
5040       const_bus_limit = 2;
5041 
5042    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
5043    if (num_sgprs == const_bus_limit && !is_literal_sgpr)
5044       return;
5045 
5046    if (literal_id && literal_uses < threshold &&
5047        (current_literal.isUndefined() ||
5048         (current_literal.size() == literal.size() &&
5049          current_literal.constantValue() == literal.constantValue()))) {
5050       /* mark the literal to be applied */
5051       while (mask) {
5052          unsigned i = u_bit_scan(&mask);
5053          if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
5054             ctx.uses[instr->operands[i].tempId()]--;
5055       }
5056    }
5057 }
5058 
5059 static aco_opcode
sopk_opcode_for_sopc(aco_opcode opcode)5060 sopk_opcode_for_sopc(aco_opcode opcode)
5061 {
5062 #define CTOK(op)                                                                                   \
5063    case aco_opcode::s_cmp_##op##_i32: return aco_opcode::s_cmpk_##op##_i32;                        \
5064    case aco_opcode::s_cmp_##op##_u32: return aco_opcode::s_cmpk_##op##_u32;
5065    switch (opcode) {
5066       CTOK(eq)
5067       CTOK(lg)
5068       CTOK(gt)
5069       CTOK(ge)
5070       CTOK(lt)
5071       CTOK(le)
5072    default: return aco_opcode::num_opcodes;
5073    }
5074 #undef CTOK
5075 }
5076 
5077 static bool
sopc_is_signed(aco_opcode opcode)5078 sopc_is_signed(aco_opcode opcode)
5079 {
5080 #define SOPC(op)                                                                                   \
5081    case aco_opcode::s_cmp_##op##_i32: return true;                                                 \
5082    case aco_opcode::s_cmp_##op##_u32: return false;
5083    switch (opcode) {
5084       SOPC(eq)
5085       SOPC(lg)
5086       SOPC(gt)
5087       SOPC(ge)
5088       SOPC(lt)
5089       SOPC(le)
5090    default: unreachable("Not a valid SOPC instruction.");
5091    }
5092 #undef SOPC
5093 }
5094 
5095 static aco_opcode
sopc_32_swapped(aco_opcode opcode)5096 sopc_32_swapped(aco_opcode opcode)
5097 {
5098 #define SOPC(op1, op2)                                                                             \
5099    case aco_opcode::s_cmp_##op1##_i32: return aco_opcode::s_cmp_##op2##_i32;                       \
5100    case aco_opcode::s_cmp_##op1##_u32: return aco_opcode::s_cmp_##op2##_u32;
5101    switch (opcode) {
5102       SOPC(eq, eq)
5103       SOPC(lg, lg)
5104       SOPC(gt, lt)
5105       SOPC(ge, le)
5106       SOPC(lt, gt)
5107       SOPC(le, ge)
5108    default: return aco_opcode::num_opcodes;
5109    }
5110 #undef SOPC
5111 }
5112 
5113 static void
try_convert_sopc_to_sopk(aco_ptr<Instruction> & instr)5114 try_convert_sopc_to_sopk(aco_ptr<Instruction>& instr)
5115 {
5116    if (sopk_opcode_for_sopc(instr->opcode) == aco_opcode::num_opcodes)
5117       return;
5118 
5119    if (instr->operands[0].isLiteral()) {
5120       std::swap(instr->operands[0], instr->operands[1]);
5121       instr->opcode = sopc_32_swapped(instr->opcode);
5122    }
5123 
5124    if (!instr->operands[1].isLiteral())
5125       return;
5126 
5127    if (instr->operands[0].isFixed() && instr->operands[0].physReg() >= 128)
5128       return;
5129 
5130    uint32_t value = instr->operands[1].constantValue();
5131 
5132    const uint32_t i16_mask = 0xffff8000u;
5133 
5134    bool value_is_i16 = (value & i16_mask) == 0 || (value & i16_mask) == i16_mask;
5135    bool value_is_u16 = !(value & 0xffff0000u);
5136 
5137    if (!value_is_i16 && !value_is_u16)
5138       return;
5139 
5140    if (!value_is_i16 && sopc_is_signed(instr->opcode)) {
5141       if (instr->opcode == aco_opcode::s_cmp_lg_i32)
5142          instr->opcode = aco_opcode::s_cmp_lg_u32;
5143       else if (instr->opcode == aco_opcode::s_cmp_eq_i32)
5144          instr->opcode = aco_opcode::s_cmp_eq_u32;
5145       else
5146          return;
5147    } else if (!value_is_u16 && !sopc_is_signed(instr->opcode)) {
5148       if (instr->opcode == aco_opcode::s_cmp_lg_u32)
5149          instr->opcode = aco_opcode::s_cmp_lg_i32;
5150       else if (instr->opcode == aco_opcode::s_cmp_eq_u32)
5151          instr->opcode = aco_opcode::s_cmp_eq_i32;
5152       else
5153          return;
5154    }
5155 
5156    static_assert(sizeof(SOPK_instruction) <= sizeof(SOPC_instruction),
5157                  "Invalid direct instruction cast.");
5158    instr->format = Format::SOPK;
5159    SOPK_instruction* instr_sopk = &instr->sopk();
5160 
5161    instr_sopk->imm = instr_sopk->operands[1].constantValue() & 0xffff;
5162    instr_sopk->opcode = sopk_opcode_for_sopc(instr_sopk->opcode);
5163    instr_sopk->operands.pop_back();
5164 }
5165 
5166 static void
unswizzle_vop3p_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)5167 unswizzle_vop3p_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
5168 {
5169    /* This opt is only beneficial for v_pk_fma_f16 because we can use v_pk_fmac_f16 if the
5170     * instruction doesn't use swizzles. */
5171    if (instr->opcode != aco_opcode::v_pk_fma_f16)
5172       return;
5173 
5174    VALU_instruction& vop3p = instr->valu();
5175 
5176    unsigned literal_swizzle = ~0u;
5177    for (unsigned i = 0; i < instr->operands.size(); i++) {
5178       if (!instr->operands[i].isLiteral())
5179          continue;
5180       unsigned new_swizzle = vop3p.opsel_lo[i] | (vop3p.opsel_hi[i] << 1);
5181       if (literal_swizzle != ~0u && new_swizzle != literal_swizzle)
5182          return; /* Literal swizzles conflict. */
5183       literal_swizzle = new_swizzle;
5184    }
5185 
5186    if (literal_swizzle == 0b10 || literal_swizzle == ~0u)
5187       return; /* already unswizzled */
5188 
5189    for (unsigned i = 0; i < instr->operands.size(); i++) {
5190       if (!instr->operands[i].isLiteral())
5191          continue;
5192       uint32_t literal = instr->operands[i].constantValue();
5193       literal = (literal >> (16 * (literal_swizzle & 0x1)) & 0xffff) |
5194                 (literal >> (8 * (literal_swizzle & 0x2)) << 16);
5195       instr->operands[i] = Operand::literal32(literal);
5196       vop3p.opsel_lo[i] = false;
5197       vop3p.opsel_hi[i] = true;
5198    }
5199 }
5200 
5201 void
apply_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)5202 apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
5203 {
5204    /* Cleanup Dead Instructions */
5205    if (!instr)
5206       return;
5207 
5208    /* apply literals on MAD */
5209    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
5210       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
5211       const bool madak = (info->literal_mask & 0b100);
5212       bool has_dead_literal = false;
5213       u_foreach_bit (i, info->literal_mask | info->fp16_mask)
5214          has_dead_literal |= ctx.uses[instr->operands[i].tempId()] == 0;
5215 
5216       if (has_dead_literal && info->fp16_mask) {
5217          instr->format = Format::VOP3P;
5218          instr->opcode = aco_opcode::v_fma_mix_f32;
5219 
5220          uint32_t literal = 0;
5221          bool second = false;
5222          u_foreach_bit (i, info->fp16_mask) {
5223             float value = uif(ctx.info[instr->operands[i].tempId()].val);
5224             literal |= _mesa_float_to_half(value) << (second * 16);
5225             instr->valu().opsel_lo[i] = second;
5226             instr->valu().opsel_hi[i] = true;
5227             second = true;
5228          }
5229 
5230          for (unsigned i = 0; i < 3; i++) {
5231             if (info->fp16_mask & (1 << i))
5232                instr->operands[i] = Operand::literal32(literal);
5233          }
5234 
5235          ctx.instructions.emplace_back(std::move(instr));
5236          return;
5237       }
5238 
5239       if (has_dead_literal || madak) {
5240          aco_opcode new_op = madak ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
5241          if (instr->opcode == aco_opcode::v_fma_f32)
5242             new_op = madak ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
5243          else if (instr->opcode == aco_opcode::v_mad_f16 ||
5244                   instr->opcode == aco_opcode::v_mad_legacy_f16)
5245             new_op = madak ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
5246          else if (instr->opcode == aco_opcode::v_fma_f16)
5247             new_op = madak ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
5248 
5249          uint32_t literal = ctx.info[instr->operands[ffs(info->literal_mask) - 1].tempId()].val;
5250          instr->format = Format::VOP2;
5251          instr->opcode = new_op;
5252          for (unsigned i = 0; i < 3; i++) {
5253             if (info->literal_mask & (1 << i))
5254                instr->operands[i] = Operand::literal32(literal);
5255          }
5256          if (madak) { /* add literal -> madak */
5257             if (!instr->operands[1].isOfType(RegType::vgpr))
5258                instr->valu().swapOperands(0, 1);
5259          } else { /* mul literal -> madmk */
5260             if (!(info->literal_mask & 0b10))
5261                instr->valu().swapOperands(0, 1);
5262             instr->valu().swapOperands(1, 2);
5263          }
5264          ctx.instructions.emplace_back(std::move(instr));
5265          return;
5266       }
5267    }
5268 
5269    /* apply literals on other SALU/VALU */
5270    if (instr->isSALU() || instr->isVALU()) {
5271       for (unsigned i = 0; i < instr->operands.size(); i++) {
5272          Operand op = instr->operands[i];
5273          unsigned bits = get_operand_size(instr, i);
5274          if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
5275             Operand literal = Operand::literal32(ctx.info[op.tempId()].val);
5276             instr->format = withoutDPP(instr->format);
5277             if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P)
5278                instr->format = asVOP3(instr->format);
5279             instr->operands[i] = literal;
5280          }
5281       }
5282    }
5283 
5284    if (instr->isSOPC())
5285       try_convert_sopc_to_sopk(instr);
5286 
5287    /* allow more s_addk_i32 optimizations if carry isn't used */
5288    if (instr->opcode == aco_opcode::s_add_u32 && ctx.uses[instr->definitions[1].tempId()] == 0 &&
5289        (instr->operands[0].isLiteral() || instr->operands[1].isLiteral()))
5290       instr->opcode = aco_opcode::s_add_i32;
5291 
5292    if (instr->isVOP3P())
5293       unswizzle_vop3p_literals(ctx, instr);
5294 
5295    ctx.instructions.emplace_back(std::move(instr));
5296 }
5297 
5298 void
optimize(Program * program)5299 optimize(Program* program)
5300 {
5301    opt_ctx ctx;
5302    ctx.program = program;
5303    std::vector<ssa_info> info(program->peekAllocationId());
5304    ctx.info = info.data();
5305 
5306    /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
5307    for (Block& block : program->blocks) {
5308       ctx.fp_mode = block.fp_mode;
5309       for (aco_ptr<Instruction>& instr : block.instructions)
5310          label_instruction(ctx, instr);
5311    }
5312 
5313    ctx.uses = dead_code_analysis(program);
5314 
5315    /* 2. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
5316    for (Block& block : program->blocks) {
5317       ctx.fp_mode = block.fp_mode;
5318       for (aco_ptr<Instruction>& instr : block.instructions)
5319          combine_instruction(ctx, instr);
5320    }
5321 
5322    /* 3. Top-Down DAG pass (backward) to select instructions (includes DCE) */
5323    for (auto block_rit = program->blocks.rbegin(); block_rit != program->blocks.rend();
5324         ++block_rit) {
5325       Block* block = &(*block_rit);
5326       ctx.fp_mode = block->fp_mode;
5327       for (auto instr_rit = block->instructions.rbegin(); instr_rit != block->instructions.rend();
5328            ++instr_rit)
5329          select_instruction(ctx, *instr_rit);
5330    }
5331 
5332    /* 4. Add literals to instructions */
5333    for (Block& block : program->blocks) {
5334       ctx.instructions.reserve(block.instructions.size());
5335       ctx.fp_mode = block.fp_mode;
5336       for (aco_ptr<Instruction>& instr : block.instructions)
5337          apply_literals(ctx, instr);
5338       block.instructions = std::move(ctx.instructions);
5339    }
5340 }
5341 
5342 } // namespace aco
5343