• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2018 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "aco_builder.h"
26 #include "aco_ir.h"
27 
28 #include "util/half_float.h"
29 #include "util/memstream.h"
30 
31 #include <algorithm>
32 #include <array>
33 #include <vector>
34 
35 namespace aco {
36 
37 #ifndef NDEBUG
38 void
perfwarn(Program * program,bool cond,const char * msg,Instruction * instr)39 perfwarn(Program* program, bool cond, const char* msg, Instruction* instr)
40 {
41    if (cond) {
42       char* out;
43       size_t outsize;
44       struct u_memstream mem;
45       u_memstream_open(&mem, &out, &outsize);
46       FILE* const memf = u_memstream_get(&mem);
47 
48       fprintf(memf, "%s: ", msg);
49       aco_print_instr(instr, memf);
50       u_memstream_close(&mem);
51 
52       aco_perfwarn(program, out);
53       free(out);
54 
55       if (debug_flags & DEBUG_PERFWARN)
56          exit(1);
57    }
58 }
59 #endif
60 
61 /**
62  * The optimizer works in 4 phases:
63  * (1) The first pass collects information for each ssa-def,
64  *     propagates reg->reg operands of the same type, inline constants
65  *     and neg/abs input modifiers.
66  * (2) The second pass combines instructions like mad, omod, clamp and
67  *     propagates sgpr's on VALU instructions.
68  *     This pass depends on information collected in the first pass.
69  * (3) The third pass goes backwards, and selects instructions,
70  *     i.e. decides if a mad instruction is profitable and eliminates dead code.
71  * (4) The fourth pass cleans up the sequence: literals get applied and dead
72  *     instructions are removed from the sequence.
73  */
74 
75 struct mad_info {
76    aco_ptr<Instruction> add_instr;
77    uint32_t mul_temp_id;
78    uint16_t literal_idx;
79    bool check_literal;
80 
mad_infoaco::mad_info81    mad_info(aco_ptr<Instruction> instr, uint32_t id)
82        : add_instr(std::move(instr)), mul_temp_id(id), literal_idx(0), check_literal(false)
83    {}
84 };
85 
86 enum Label {
87    label_vec = 1 << 0,
88    label_constant_32bit = 1 << 1,
89    /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
90     * 32-bit operations but this shouldn't cause any issues because we don't
91     * look through any conversions */
92    label_abs = 1 << 2,
93    label_neg = 1 << 3,
94    label_mul = 1 << 4,
95    label_temp = 1 << 5,
96    label_literal = 1 << 6,
97    label_mad = 1 << 7,
98    label_omod2 = 1 << 8,
99    label_omod4 = 1 << 9,
100    label_omod5 = 1 << 10,
101    label_clamp = 1 << 12,
102    label_undefined = 1 << 14,
103    label_vcc = 1 << 15,
104    label_b2f = 1 << 16,
105    label_add_sub = 1 << 17,
106    label_bitwise = 1 << 18,
107    label_minmax = 1 << 19,
108    label_vopc = 1 << 20,
109    label_uniform_bool = 1 << 21,
110    label_constant_64bit = 1 << 22,
111    label_uniform_bitwise = 1 << 23,
112    label_scc_invert = 1 << 24,
113    label_scc_needed = 1 << 26,
114    label_b2i = 1 << 27,
115    label_fcanonicalize = 1 << 28,
116    label_constant_16bit = 1 << 29,
117    label_usedef = 1 << 30,   /* generic label */
118    label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */
119    label_canonicalized = 1ull << 32,
120    label_extract = 1ull << 33,
121    label_insert = 1ull << 34,
122    label_dpp16 = 1ull << 35,
123    label_dpp8 = 1ull << 36,
124    label_f2f32 = 1ull << 37,
125    label_f2f16 = 1ull << 38,
126    label_split = 1ull << 39,
127 };
128 
129 static constexpr uint64_t instr_usedef_labels =
130    label_vec | label_mul | label_mad | label_add_sub | label_vop3p | label_bitwise |
131    label_uniform_bitwise | label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 |
132    label_dpp8 | label_f2f32;
133 static constexpr uint64_t instr_mod_labels =
134    label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
135 
136 static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels | label_split;
137 static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f |
138                                         label_uniform_bool | label_scc_invert | label_b2i |
139                                         label_fcanonicalize;
140 static constexpr uint32_t val_labels =
141    label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal;
142 
143 static_assert((instr_labels & temp_labels) == 0, "labels cannot intersect");
144 static_assert((instr_labels & val_labels) == 0, "labels cannot intersect");
145 static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
146 
147 struct ssa_info {
148    uint64_t label;
149    union {
150       uint32_t val;
151       Temp temp;
152       Instruction* instr;
153    };
154 
ssa_infoaco::ssa_info155    ssa_info() : label(0) {}
156 
add_labelaco::ssa_info157    void add_label(Label new_label)
158    {
159       /* Since all the instr_usedef_labels use instr for the same thing
160        * (indicating the defining instruction), there is usually no need to
161        * clear any other instr labels. */
162       if (new_label & instr_usedef_labels)
163          label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */
164 
165       if (new_label & instr_mod_labels) {
166          label &= ~instr_labels;
167          label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
168       }
169 
170       if (new_label & temp_labels) {
171          label &= ~temp_labels;
172          label &= ~(instr_labels | val_labels); /* instr, temp and val alias */
173       }
174 
175       uint32_t const_labels =
176          label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
177       if (new_label & const_labels) {
178          label &= ~val_labels | const_labels;
179          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
180       } else if (new_label & val_labels) {
181          label &= ~val_labels;
182          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
183       }
184 
185       label |= new_label;
186    }
187 
set_vecaco::ssa_info188    void set_vec(Instruction* vec)
189    {
190       add_label(label_vec);
191       instr = vec;
192    }
193 
is_vecaco::ssa_info194    bool is_vec() { return label & label_vec; }
195 
set_constantaco::ssa_info196    void set_constant(amd_gfx_level gfx_level, uint64_t constant)
197    {
198       Operand op16 = Operand::c16(constant);
199       Operand op32 = Operand::get_const(gfx_level, constant, 4);
200       add_label(label_literal);
201       val = constant;
202 
203       /* check that no upper bits are lost in case of packed 16bit constants */
204       if (gfx_level >= GFX8 && !op16.isLiteral() &&
205           op16.constantValue16(true) == ((constant >> 16) & 0xffff))
206          add_label(label_constant_16bit);
207 
208       if (!op32.isLiteral())
209          add_label(label_constant_32bit);
210 
211       if (Operand::is_constant_representable(constant, 8))
212          add_label(label_constant_64bit);
213 
214       if (label & label_constant_64bit) {
215          val = Operand::c64(constant).constantValue();
216          if (val != constant)
217             label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
218       }
219    }
220 
is_constantaco::ssa_info221    bool is_constant(unsigned bits)
222    {
223       switch (bits) {
224       case 8: return label & label_literal;
225       case 16: return label & label_constant_16bit;
226       case 32: return label & label_constant_32bit;
227       case 64: return label & label_constant_64bit;
228       }
229       return false;
230    }
231 
is_literalaco::ssa_info232    bool is_literal(unsigned bits)
233    {
234       bool is_lit = label & label_literal;
235       switch (bits) {
236       case 8: return false;
237       case 16: return is_lit && ~(label & label_constant_16bit);
238       case 32: return is_lit && ~(label & label_constant_32bit);
239       case 64: return false;
240       }
241       return false;
242    }
243 
is_constant_or_literalaco::ssa_info244    bool is_constant_or_literal(unsigned bits)
245    {
246       if (bits == 64)
247          return label & label_constant_64bit;
248       else
249          return label & label_literal;
250    }
251 
set_absaco::ssa_info252    void set_abs(Temp abs_temp)
253    {
254       add_label(label_abs);
255       temp = abs_temp;
256    }
257 
is_absaco::ssa_info258    bool is_abs() { return label & label_abs; }
259 
set_negaco::ssa_info260    void set_neg(Temp neg_temp)
261    {
262       add_label(label_neg);
263       temp = neg_temp;
264    }
265 
is_negaco::ssa_info266    bool is_neg() { return label & label_neg; }
267 
set_neg_absaco::ssa_info268    void set_neg_abs(Temp neg_abs_temp)
269    {
270       add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
271       temp = neg_abs_temp;
272    }
273 
set_mulaco::ssa_info274    void set_mul(Instruction* mul)
275    {
276       add_label(label_mul);
277       instr = mul;
278    }
279 
is_mulaco::ssa_info280    bool is_mul() { return label & label_mul; }
281 
set_tempaco::ssa_info282    void set_temp(Temp tmp)
283    {
284       add_label(label_temp);
285       temp = tmp;
286    }
287 
is_tempaco::ssa_info288    bool is_temp() { return label & label_temp; }
289 
set_madaco::ssa_info290    void set_mad(Instruction* mad, uint32_t mad_info_idx)
291    {
292       add_label(label_mad);
293       mad->pass_flags = mad_info_idx;
294       instr = mad;
295    }
296 
is_madaco::ssa_info297    bool is_mad() { return label & label_mad; }
298 
set_omod2aco::ssa_info299    void set_omod2(Instruction* mul)
300    {
301       add_label(label_omod2);
302       instr = mul;
303    }
304 
is_omod2aco::ssa_info305    bool is_omod2() { return label & label_omod2; }
306 
set_omod4aco::ssa_info307    void set_omod4(Instruction* mul)
308    {
309       add_label(label_omod4);
310       instr = mul;
311    }
312 
is_omod4aco::ssa_info313    bool is_omod4() { return label & label_omod4; }
314 
set_omod5aco::ssa_info315    void set_omod5(Instruction* mul)
316    {
317       add_label(label_omod5);
318       instr = mul;
319    }
320 
is_omod5aco::ssa_info321    bool is_omod5() { return label & label_omod5; }
322 
set_clampaco::ssa_info323    void set_clamp(Instruction* med3)
324    {
325       add_label(label_clamp);
326       instr = med3;
327    }
328 
is_clampaco::ssa_info329    bool is_clamp() { return label & label_clamp; }
330 
set_f2f16aco::ssa_info331    void set_f2f16(Instruction* conv)
332    {
333       add_label(label_f2f16);
334       instr = conv;
335    }
336 
is_f2f16aco::ssa_info337    bool is_f2f16() { return label & label_f2f16; }
338 
set_undefinedaco::ssa_info339    void set_undefined() { add_label(label_undefined); }
340 
is_undefinedaco::ssa_info341    bool is_undefined() { return label & label_undefined; }
342 
set_vccaco::ssa_info343    void set_vcc(Temp vcc_val)
344    {
345       add_label(label_vcc);
346       temp = vcc_val;
347    }
348 
is_vccaco::ssa_info349    bool is_vcc() { return label & label_vcc; }
350 
set_b2faco::ssa_info351    void set_b2f(Temp b2f_val)
352    {
353       add_label(label_b2f);
354       temp = b2f_val;
355    }
356 
is_b2faco::ssa_info357    bool is_b2f() { return label & label_b2f; }
358 
set_add_subaco::ssa_info359    void set_add_sub(Instruction* add_sub_instr)
360    {
361       add_label(label_add_sub);
362       instr = add_sub_instr;
363    }
364 
is_add_subaco::ssa_info365    bool is_add_sub() { return label & label_add_sub; }
366 
set_bitwiseaco::ssa_info367    void set_bitwise(Instruction* bitwise_instr)
368    {
369       add_label(label_bitwise);
370       instr = bitwise_instr;
371    }
372 
is_bitwiseaco::ssa_info373    bool is_bitwise() { return label & label_bitwise; }
374 
set_uniform_bitwiseaco::ssa_info375    void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
376 
is_uniform_bitwiseaco::ssa_info377    bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
378 
set_minmaxaco::ssa_info379    void set_minmax(Instruction* minmax_instr)
380    {
381       add_label(label_minmax);
382       instr = minmax_instr;
383    }
384 
is_minmaxaco::ssa_info385    bool is_minmax() { return label & label_minmax; }
386 
set_vopcaco::ssa_info387    void set_vopc(Instruction* vopc_instr)
388    {
389       add_label(label_vopc);
390       instr = vopc_instr;
391    }
392 
is_vopcaco::ssa_info393    bool is_vopc() { return label & label_vopc; }
394 
set_scc_neededaco::ssa_info395    void set_scc_needed() { add_label(label_scc_needed); }
396 
is_scc_neededaco::ssa_info397    bool is_scc_needed() { return label & label_scc_needed; }
398 
set_scc_invertaco::ssa_info399    void set_scc_invert(Temp scc_inv)
400    {
401       add_label(label_scc_invert);
402       temp = scc_inv;
403    }
404 
is_scc_invertaco::ssa_info405    bool is_scc_invert() { return label & label_scc_invert; }
406 
set_uniform_boolaco::ssa_info407    void set_uniform_bool(Temp uniform_bool)
408    {
409       add_label(label_uniform_bool);
410       temp = uniform_bool;
411    }
412 
is_uniform_boolaco::ssa_info413    bool is_uniform_bool() { return label & label_uniform_bool; }
414 
set_b2iaco::ssa_info415    void set_b2i(Temp b2i_val)
416    {
417       add_label(label_b2i);
418       temp = b2i_val;
419    }
420 
is_b2iaco::ssa_info421    bool is_b2i() { return label & label_b2i; }
422 
set_usedefaco::ssa_info423    void set_usedef(Instruction* label_instr)
424    {
425       add_label(label_usedef);
426       instr = label_instr;
427    }
428 
is_usedefaco::ssa_info429    bool is_usedef() { return label & label_usedef; }
430 
set_vop3paco::ssa_info431    void set_vop3p(Instruction* vop3p_instr)
432    {
433       add_label(label_vop3p);
434       instr = vop3p_instr;
435    }
436 
is_vop3paco::ssa_info437    bool is_vop3p() { return label & label_vop3p; }
438 
set_fcanonicalizeaco::ssa_info439    void set_fcanonicalize(Temp tmp)
440    {
441       add_label(label_fcanonicalize);
442       temp = tmp;
443    }
444 
is_fcanonicalizeaco::ssa_info445    bool is_fcanonicalize() { return label & label_fcanonicalize; }
446 
set_canonicalizedaco::ssa_info447    void set_canonicalized() { add_label(label_canonicalized); }
448 
is_canonicalizedaco::ssa_info449    bool is_canonicalized() { return label & label_canonicalized; }
450 
set_f2f32aco::ssa_info451    void set_f2f32(Instruction* cvt)
452    {
453       add_label(label_f2f32);
454       instr = cvt;
455    }
456 
is_f2f32aco::ssa_info457    bool is_f2f32() { return label & label_f2f32; }
458 
set_extractaco::ssa_info459    void set_extract(Instruction* extract)
460    {
461       add_label(label_extract);
462       instr = extract;
463    }
464 
is_extractaco::ssa_info465    bool is_extract() { return label & label_extract; }
466 
set_insertaco::ssa_info467    void set_insert(Instruction* insert)
468    {
469       add_label(label_insert);
470       instr = insert;
471    }
472 
is_insertaco::ssa_info473    bool is_insert() { return label & label_insert; }
474 
set_dpp16aco::ssa_info475    void set_dpp16(Instruction* mov)
476    {
477       add_label(label_dpp16);
478       instr = mov;
479    }
480 
set_dpp8aco::ssa_info481    void set_dpp8(Instruction* mov)
482    {
483       add_label(label_dpp8);
484       instr = mov;
485    }
486 
is_dppaco::ssa_info487    bool is_dpp() { return label & (label_dpp16 | label_dpp8); }
is_dpp16aco::ssa_info488    bool is_dpp16() { return label & label_dpp16; }
is_dpp8aco::ssa_info489    bool is_dpp8() { return label & label_dpp8; }
490 
set_splitaco::ssa_info491    void set_split(Instruction* split)
492    {
493       add_label(label_split);
494       instr = split;
495    }
496 
is_splitaco::ssa_info497    bool is_split() { return label & label_split; }
498 };
499 
500 struct opt_ctx {
501    Program* program;
502    float_mode fp_mode;
503    std::vector<aco_ptr<Instruction>> instructions;
504    ssa_info* info;
505    std::pair<uint32_t, Temp> last_literal;
506    std::vector<mad_info> mad_infos;
507    std::vector<uint16_t> uses;
508 };
509 
510 bool
can_use_VOP3(opt_ctx & ctx,const aco_ptr<Instruction> & instr)511 can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
512 {
513    if (instr->isVOP3())
514       return true;
515 
516    if (instr->isVOP3P())
517       return false;
518 
519    if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->gfx_level < GFX10)
520       return false;
521 
522    if (instr->isDPP() || instr->isSDWA())
523       return false;
524 
525    return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
526           instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
527           instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
528           instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
529           instr->opcode != aco_opcode::v_readlane_b32 &&
530           instr->opcode != aco_opcode::v_writelane_b32 &&
531           instr->opcode != aco_opcode::v_readfirstlane_b32;
532 }
533 
534 bool
pseudo_propagate_temp(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp temp,unsigned index)535 pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
536 {
537    if (instr->definitions.empty())
538       return false;
539 
540    const bool vgpr =
541       instr->opcode == aco_opcode::p_as_uniform ||
542       std::all_of(instr->definitions.begin(), instr->definitions.end(),
543                   [](const Definition& def) { return def.regClass().type() == RegType::vgpr; });
544 
545    /* don't propagate VGPRs into SGPR instructions */
546    if (temp.type() == RegType::vgpr && !vgpr)
547       return false;
548 
549    bool can_accept_sgpr =
550       ctx.program->gfx_level >= GFX9 ||
551       std::none_of(instr->definitions.begin(), instr->definitions.end(),
552                    [](const Definition& def) { return def.regClass().is_subdword(); });
553 
554    switch (instr->opcode) {
555    case aco_opcode::p_phi:
556    case aco_opcode::p_linear_phi:
557    case aco_opcode::p_parallelcopy:
558    case aco_opcode::p_create_vector:
559       if (temp.bytes() != instr->operands[index].bytes())
560          return false;
561       break;
562    case aco_opcode::p_extract_vector:
563    case aco_opcode::p_extract:
564       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
565          return false;
566       break;
567    case aco_opcode::p_split_vector: {
568       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
569          return false;
570       /* don't increase the vector size */
571       if (temp.bytes() > instr->operands[index].bytes())
572          return false;
573       /* We can decrease the vector size as smaller temporaries are only
574        * propagated by p_as_uniform instructions.
575        * If this propagation leads to invalid IR or hits the assertion below,
576        * it means that some undefined bytes within a dword are begin accessed
577        * and a bug in instruction_selection is likely. */
578       int decrease = instr->operands[index].bytes() - temp.bytes();
579       while (decrease > 0) {
580          decrease -= instr->definitions.back().bytes();
581          instr->definitions.pop_back();
582       }
583       assert(decrease == 0);
584       break;
585    }
586    case aco_opcode::p_as_uniform:
587       if (temp.regClass() == instr->definitions[0].regClass())
588          instr->opcode = aco_opcode::p_parallelcopy;
589       break;
590    default: return false;
591    }
592 
593    instr->operands[index].setTemp(temp);
594    return true;
595 }
596 
597 /* This expects the DPP modifier to be removed. */
598 bool
can_apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)599 can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
600 {
601    if (instr->isSDWA() && ctx.program->gfx_level < GFX9)
602       return false;
603    return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
604           instr->opcode != aco_opcode::v_readlane_b32 &&
605           instr->opcode != aco_opcode::v_readlane_b32_e64 &&
606           instr->opcode != aco_opcode::v_writelane_b32 &&
607           instr->opcode != aco_opcode::v_writelane_b32_e64 &&
608           instr->opcode != aco_opcode::v_permlane16_b32 &&
609           instr->opcode != aco_opcode::v_permlanex16_b32;
610 }
611 
612 void
to_VOP3(opt_ctx & ctx,aco_ptr<Instruction> & instr)613 to_VOP3(opt_ctx& ctx, aco_ptr<Instruction>& instr)
614 {
615    if (instr->isVOP3())
616       return;
617 
618    aco_ptr<Instruction> tmp = std::move(instr);
619    Format format = asVOP3(tmp->format);
620    instr.reset(create_instruction<VOP3_instruction>(tmp->opcode, format, tmp->operands.size(),
621                                                     tmp->definitions.size()));
622    std::copy(tmp->operands.cbegin(), tmp->operands.cend(), instr->operands.begin());
623    for (unsigned i = 0; i < instr->definitions.size(); i++) {
624       instr->definitions[i] = tmp->definitions[i];
625       if (instr->definitions[i].isTemp()) {
626          ssa_info& info = ctx.info[instr->definitions[i].tempId()];
627          if (info.label & instr_usedef_labels && info.instr == tmp.get())
628             info.instr = instr.get();
629       }
630    }
631    /* we don't need to update any instr_mod_labels because they either haven't
632     * been applied yet or this instruction isn't dead and so they've been ignored */
633 
634    instr->pass_flags = tmp->pass_flags;
635 }
636 
637 bool
is_operand_vgpr(Operand op)638 is_operand_vgpr(Operand op)
639 {
640    return op.isTemp() && op.getTemp().type() == RegType::vgpr;
641 }
642 
643 void
to_SDWA(opt_ctx & ctx,aco_ptr<Instruction> & instr)644 to_SDWA(opt_ctx& ctx, aco_ptr<Instruction>& instr)
645 {
646    aco_ptr<Instruction> tmp = convert_to_SDWA(ctx.program->gfx_level, instr);
647    if (!tmp)
648       return;
649 
650    for (unsigned i = 0; i < instr->definitions.size(); i++) {
651       ssa_info& info = ctx.info[instr->definitions[i].tempId()];
652       if (info.label & instr_labels && info.instr == tmp.get())
653          info.instr = instr.get();
654    }
655 }
656 
657 /* only covers special cases */
658 bool
alu_can_accept_constant(aco_opcode opcode,unsigned operand)659 alu_can_accept_constant(aco_opcode opcode, unsigned operand)
660 {
661    switch (opcode) {
662    case aco_opcode::v_interp_p2_f32:
663    case aco_opcode::v_mac_f32:
664    case aco_opcode::v_writelane_b32:
665    case aco_opcode::v_writelane_b32_e64:
666    case aco_opcode::v_cndmask_b32: return operand != 2;
667    case aco_opcode::s_addk_i32:
668    case aco_opcode::s_mulk_i32:
669    case aco_opcode::p_wqm:
670    case aco_opcode::p_extract_vector:
671    case aco_opcode::p_split_vector:
672    case aco_opcode::v_readlane_b32:
673    case aco_opcode::v_readlane_b32_e64:
674    case aco_opcode::v_readfirstlane_b32:
675    case aco_opcode::p_extract:
676    case aco_opcode::p_insert: return operand != 0;
677    default: return true;
678    }
679 }
680 
681 bool
valu_can_accept_vgpr(aco_ptr<Instruction> & instr,unsigned operand)682 valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
683 {
684    if (instr->opcode == aco_opcode::v_readlane_b32 ||
685        instr->opcode == aco_opcode::v_readlane_b32_e64 ||
686        instr->opcode == aco_opcode::v_writelane_b32 ||
687        instr->opcode == aco_opcode::v_writelane_b32_e64)
688       return operand != 1;
689    if (instr->opcode == aco_opcode::v_permlane16_b32 ||
690        instr->opcode == aco_opcode::v_permlanex16_b32)
691       return operand == 0;
692    return true;
693 }
694 
695 /* check constant bus and literal limitations */
696 bool
check_vop3_operands(opt_ctx & ctx,unsigned num_operands,Operand * operands)697 check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand* operands)
698 {
699    int limit = ctx.program->gfx_level >= GFX10 ? 2 : 1;
700    Operand literal32(s1);
701    Operand literal64(s2);
702    unsigned num_sgprs = 0;
703    unsigned sgpr[] = {0, 0};
704 
705    for (unsigned i = 0; i < num_operands; i++) {
706       Operand op = operands[i];
707 
708       if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
709          /* two reads of the same SGPR count as 1 to the limit */
710          if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
711             if (num_sgprs < 2)
712                sgpr[num_sgprs++] = op.tempId();
713             limit--;
714             if (limit < 0)
715                return false;
716          }
717       } else if (op.isLiteral()) {
718          if (ctx.program->gfx_level < GFX10)
719             return false;
720 
721          if (!literal32.isUndefined() && literal32.constantValue() != op.constantValue())
722             return false;
723          if (!literal64.isUndefined() && literal64.constantValue() != op.constantValue())
724             return false;
725 
726          /* Any number of 32-bit literals counts as only 1 to the limit. Same
727           * (but separately) for 64-bit literals. */
728          if (op.size() == 1 && literal32.isUndefined()) {
729             limit--;
730             literal32 = op;
731          } else if (op.size() == 2 && literal64.isUndefined()) {
732             limit--;
733             literal64 = op;
734          }
735 
736          if (limit < 0)
737             return false;
738       }
739    }
740 
741    return true;
742 }
743 
744 bool
parse_base_offset(opt_ctx & ctx,Instruction * instr,unsigned op_index,Temp * base,uint32_t * offset,bool prevent_overflow)745 parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* base, uint32_t* offset,
746                   bool prevent_overflow)
747 {
748    Operand op = instr->operands[op_index];
749 
750    if (!op.isTemp())
751       return false;
752    Temp tmp = op.getTemp();
753    if (!ctx.info[tmp.id()].is_add_sub())
754       return false;
755 
756    Instruction* add_instr = ctx.info[tmp.id()].instr;
757 
758    unsigned mask = 0x3;
759    bool is_sub = false;
760    switch (add_instr->opcode) {
761    case aco_opcode::v_add_u32:
762    case aco_opcode::v_add_co_u32:
763    case aco_opcode::v_add_co_u32_e64:
764    case aco_opcode::s_add_i32:
765    case aco_opcode::s_add_u32: break;
766    case aco_opcode::v_sub_u32:
767    case aco_opcode::v_sub_i32:
768    case aco_opcode::v_sub_co_u32:
769    case aco_opcode::v_sub_co_u32_e64:
770    case aco_opcode::s_sub_u32:
771    case aco_opcode::s_sub_i32:
772       mask = 0x2;
773       is_sub = true;
774       break;
775    case aco_opcode::v_subrev_u32:
776    case aco_opcode::v_subrev_co_u32:
777    case aco_opcode::v_subrev_co_u32_e64:
778       mask = 0x1;
779       is_sub = true;
780       break;
781    default: return false;
782    }
783    if (prevent_overflow && !add_instr->definitions[0].isNUW())
784       return false;
785 
786    if (add_instr->usesModifiers())
787       return false;
788 
789    u_foreach_bit (i, mask) {
790       if (add_instr->operands[i].isConstant()) {
791          *offset = add_instr->operands[i].constantValue() * (uint32_t)(is_sub ? -1 : 1);
792       } else if (add_instr->operands[i].isTemp() &&
793                  ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
794          *offset = ctx.info[add_instr->operands[i].tempId()].val * (uint32_t)(is_sub ? -1 : 1);
795       } else {
796          continue;
797       }
798       if (!add_instr->operands[!i].isTemp())
799          continue;
800 
801       uint32_t offset2 = 0;
802       if (parse_base_offset(ctx, add_instr, !i, base, &offset2, prevent_overflow)) {
803          *offset += offset2;
804       } else {
805          *base = add_instr->operands[!i].getTemp();
806       }
807       return true;
808    }
809 
810    return false;
811 }
812 
813 void
skip_smem_offset_align(opt_ctx & ctx,SMEM_instruction * smem)814 skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem)
815 {
816    bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
817    if (soe && !smem->operands[1].isConstant())
818       return;
819    /* We don't need to check the constant offset because the address seems to be calculated with
820     * (offset&-4 + const_offset&-4), not (offset+const_offset)&-4.
821     */
822 
823    Operand& op = smem->operands[soe ? smem->operands.size() - 1 : 1];
824    if (!op.isTemp() || !ctx.info[op.tempId()].is_bitwise())
825       return;
826 
827    Instruction* bitwise_instr = ctx.info[op.tempId()].instr;
828    if (bitwise_instr->opcode != aco_opcode::s_and_b32)
829       return;
830 
831    if (bitwise_instr->operands[0].constantEquals(-4) &&
832        bitwise_instr->operands[1].isOfType(op.regClass().type()))
833       op.setTemp(bitwise_instr->operands[1].getTemp());
834    else if (bitwise_instr->operands[1].constantEquals(-4) &&
835             bitwise_instr->operands[0].isOfType(op.regClass().type()))
836       op.setTemp(bitwise_instr->operands[0].getTemp());
837 }
838 
839 void
smem_combine(opt_ctx & ctx,aco_ptr<Instruction> & instr)840 smem_combine(opt_ctx& ctx, aco_ptr<Instruction>& instr)
841 {
842    /* skip &-4 before offset additions: load((a + 16) & -4, 0) */
843    if (!instr->operands.empty())
844       skip_smem_offset_align(ctx, &instr->smem());
845 
846    /* propagate constants and combine additions */
847    if (!instr->operands.empty() && instr->operands[1].isTemp()) {
848       SMEM_instruction& smem = instr->smem();
849       ssa_info info = ctx.info[instr->operands[1].tempId()];
850 
851       Temp base;
852       uint32_t offset;
853       bool prevent_overflow = smem.operands[0].size() > 2 || smem.prevent_overflow;
854       if (info.is_constant_or_literal(32) &&
855           ((ctx.program->gfx_level == GFX6 && info.val <= 0x3FF) ||
856            (ctx.program->gfx_level == GFX7 && info.val <= 0xFFFFFFFF) ||
857            (ctx.program->gfx_level >= GFX8 && info.val <= 0xFFFFF))) {
858          instr->operands[1] = Operand::c32(info.val);
859       } else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, prevent_overflow) &&
860                  base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->gfx_level >= GFX9 &&
861                  offset % 4u == 0) {
862          bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
863          if (soe) {
864             if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) &&
865                 ctx.info[smem.operands.back().tempId()].val == 0) {
866                smem.operands[1] = Operand::c32(offset);
867                smem.operands.back() = Operand(base);
868             }
869          } else {
870             SMEM_instruction* new_instr = create_instruction<SMEM_instruction>(
871                smem.opcode, Format::SMEM, smem.operands.size() + 1, smem.definitions.size());
872             new_instr->operands[0] = smem.operands[0];
873             new_instr->operands[1] = Operand::c32(offset);
874             if (smem.definitions.empty())
875                new_instr->operands[2] = smem.operands[2];
876             new_instr->operands.back() = Operand(base);
877             if (!smem.definitions.empty())
878                new_instr->definitions[0] = smem.definitions[0];
879             new_instr->sync = smem.sync;
880             new_instr->glc = smem.glc;
881             new_instr->dlc = smem.dlc;
882             new_instr->nv = smem.nv;
883             new_instr->disable_wqm = smem.disable_wqm;
884             instr.reset(new_instr);
885          }
886       }
887    }
888 
889    /* skip &-4 after offset additions: load(a & -4, 16) */
890    if (!instr->operands.empty())
891       skip_smem_offset_align(ctx, &instr->smem());
892 }
893 
894 unsigned
get_operand_size(aco_ptr<Instruction> & instr,unsigned index)895 get_operand_size(aco_ptr<Instruction>& instr, unsigned index)
896 {
897    if (instr->isPseudo())
898       return instr->operands[index].bytes() * 8u;
899    else if (instr->opcode == aco_opcode::v_mad_u64_u32 ||
900             instr->opcode == aco_opcode::v_mad_i64_i32)
901       return index == 2 ? 64 : 32;
902    else if (instr->opcode == aco_opcode::v_fma_mix_f32 ||
903             instr->opcode == aco_opcode::v_fma_mixlo_f16)
904       return instr->vop3p().opsel_hi & (1u << index) ? 16 : 32;
905    else if (instr->isVALU() || instr->isSALU())
906       return instr_info.operand_size[(int)instr->opcode];
907    else
908       return 0;
909 }
910 
911 Operand
get_constant_op(opt_ctx & ctx,ssa_info info,uint32_t bits)912 get_constant_op(opt_ctx& ctx, ssa_info info, uint32_t bits)
913 {
914    if (bits == 64)
915       return Operand::c32_or_c64(info.val, true);
916    return Operand::get_const(ctx.program->gfx_level, info.val, bits / 8u);
917 }
918 
919 void
propagate_constants_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned i)920 propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned i)
921 {
922    if (!info.is_constant_or_literal(32))
923       return;
924 
925    assert(instr->operands[i].isTemp());
926    unsigned bits = get_operand_size(instr, i);
927    if (info.is_constant(bits)) {
928       instr->operands[i] = get_constant_op(ctx, info, bits);
929       return;
930    }
931 
932    /* The accumulation operand of dot product instructions ignores opsel. */
933    bool cannot_use_opsel =
934       (instr->opcode == aco_opcode::v_dot4_i32_i8 || instr->opcode == aco_opcode::v_dot2_i32_i16 ||
935        instr->opcode == aco_opcode::v_dot4_u32_u8 || instr->opcode == aco_opcode::v_dot2_u32_u16) &&
936       i == 2;
937    if (cannot_use_opsel)
938       return;
939 
940    /* try to fold inline constants */
941    VOP3P_instruction* vop3p = &instr->vop3p();
942    bool opsel_lo = (vop3p->opsel_lo >> i) & 1;
943    bool opsel_hi = (vop3p->opsel_hi >> i) & 1;
944 
945    Operand const_op[2];
946    bool const_opsel[2] = {false, false};
947    for (unsigned j = 0; j < 2; j++) {
948       if ((unsigned)opsel_lo != j && (unsigned)opsel_hi != j)
949          continue; /* this half is unused */
950 
951       uint16_t val = info.val >> (j ? 16 : 0);
952       Operand op = Operand::get_const(ctx.program->gfx_level, val, bits / 8u);
953       if (bits == 32 && op.isLiteral()) /* try sign extension */
954          op = Operand::get_const(ctx.program->gfx_level, val | 0xffff0000, 4);
955       if (bits == 32 && op.isLiteral()) { /* try shifting left */
956          op = Operand::get_const(ctx.program->gfx_level, val << 16, 4);
957          const_opsel[j] = true;
958       }
959       if (op.isLiteral())
960          return;
961       const_op[j] = op;
962    }
963 
964    Operand const_lo = const_op[0];
965    Operand const_hi = const_op[1];
966    bool const_lo_opsel = const_opsel[0];
967    bool const_hi_opsel = const_opsel[1];
968 
969    if (opsel_lo == opsel_hi) {
970       /* use the single 16bit value */
971       instr->operands[i] = opsel_lo ? const_hi : const_lo;
972 
973       /* opsel must point the same for both halves */
974       opsel_lo = opsel_lo ? const_hi_opsel : const_lo_opsel;
975       opsel_hi = opsel_lo;
976    } else if (const_lo == const_hi) {
977       /* both constants are the same */
978       instr->operands[i] = const_lo;
979 
980       /* opsel must point the same for both halves */
981       opsel_lo = const_lo_opsel;
982       opsel_hi = const_lo_opsel;
983    } else if (const_lo.constantValue16(const_lo_opsel) ==
984               const_hi.constantValue16(!const_hi_opsel)) {
985       instr->operands[i] = const_hi;
986 
987       /* redirect opsel selection */
988       opsel_lo = opsel_lo ? const_hi_opsel : !const_hi_opsel;
989       opsel_hi = opsel_hi ? const_hi_opsel : !const_hi_opsel;
990    } else if (const_hi.constantValue16(const_hi_opsel) ==
991               const_lo.constantValue16(!const_lo_opsel)) {
992       instr->operands[i] = const_lo;
993 
994       /* redirect opsel selection */
995       opsel_lo = opsel_lo ? !const_lo_opsel : const_lo_opsel;
996       opsel_hi = opsel_hi ? !const_lo_opsel : const_lo_opsel;
997    } else if (bits == 16 && const_lo.constantValue() == (const_hi.constantValue() ^ (1 << 15))) {
998       assert(const_lo_opsel == false && const_hi_opsel == false);
999 
1000       /* const_lo == -const_hi */
1001       if (!instr_info.can_use_input_modifiers[(int)instr->opcode])
1002          return;
1003 
1004       instr->operands[i] = Operand::c16(const_lo.constantValue() & 0x7FFF);
1005       bool neg_lo = const_lo.constantValue() & (1 << 15);
1006       vop3p->neg_lo[i] ^= opsel_lo ^ neg_lo;
1007       vop3p->neg_hi[i] ^= opsel_hi ^ neg_lo;
1008 
1009       /* opsel must point to lo for both operands */
1010       opsel_lo = false;
1011       opsel_hi = false;
1012    }
1013 
1014    vop3p->opsel_lo = opsel_lo ? (vop3p->opsel_lo | (1 << i)) : (vop3p->opsel_lo & ~(1 << i));
1015    vop3p->opsel_hi = opsel_hi ? (vop3p->opsel_hi | (1 << i)) : (vop3p->opsel_hi & ~(1 << i));
1016 }
1017 
1018 bool
fixed_to_exec(Operand op)1019 fixed_to_exec(Operand op)
1020 {
1021    return op.isFixed() && op.physReg() == exec;
1022 }
1023 
1024 SubdwordSel
parse_extract(Instruction * instr)1025 parse_extract(Instruction* instr)
1026 {
1027    if (instr->opcode == aco_opcode::p_extract) {
1028       unsigned size = instr->operands[2].constantValue() / 8;
1029       unsigned offset = instr->operands[1].constantValue() * size;
1030       bool sext = instr->operands[3].constantEquals(1);
1031       return SubdwordSel(size, offset, sext);
1032    } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) {
1033       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
1034    } else if (instr->opcode == aco_opcode::p_extract_vector) {
1035       unsigned size = instr->definitions[0].bytes();
1036       unsigned offset = instr->operands[1].constantValue() * size;
1037       if (size <= 2)
1038          return SubdwordSel(size, offset, false);
1039    } else if (instr->opcode == aco_opcode::p_split_vector) {
1040       assert(instr->operands[0].bytes() == 4 && instr->definitions[1].bytes() == 2);
1041       return SubdwordSel(2, 2, false);
1042    }
1043 
1044    return SubdwordSel();
1045 }
1046 
1047 SubdwordSel
parse_insert(Instruction * instr)1048 parse_insert(Instruction* instr)
1049 {
1050    if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) &&
1051        instr->operands[1].constantEquals(0)) {
1052       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
1053    } else if (instr->opcode == aco_opcode::p_insert) {
1054       unsigned size = instr->operands[2].constantValue() / 8;
1055       unsigned offset = instr->operands[1].constantValue() * size;
1056       return SubdwordSel(size, offset, false);
1057    } else {
1058       return SubdwordSel();
1059    }
1060 }
1061 
1062 bool
can_apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1063 can_apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1064 {
1065    if (idx >= 2)
1066       return false;
1067 
1068    Temp tmp = info.instr->operands[0].getTemp();
1069    SubdwordSel sel = parse_extract(info.instr);
1070 
1071    if (!sel) {
1072       return false;
1073    } else if (sel.size() == 4) {
1074       return true;
1075    } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel.size() == 1 && !sel.sign_extend()) {
1076       return true;
1077    } else if (can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1078               (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1079       if (instr->isSDWA() && instr->sdwa().sel[idx] != SubdwordSel::dword)
1080          return false;
1081       return true;
1082    } else if (instr->isVOP3() && sel.size() == 2 &&
1083               can_use_opsel(ctx.program->gfx_level, instr->opcode, idx) &&
1084               !(instr->vop3().opsel & (1 << idx))) {
1085       return true;
1086    } else if (instr->opcode == aco_opcode::p_extract) {
1087       SubdwordSel instrSel = parse_extract(instr.get());
1088 
1089       /* the outer offset must be within extracted range */
1090       if (instrSel.offset() >= sel.size())
1091          return false;
1092 
1093       /* don't remove the sign-extension when increasing the size further */
1094       if (instrSel.size() > sel.size() && !instrSel.sign_extend() && sel.sign_extend())
1095          return false;
1096 
1097       return true;
1098    }
1099 
1100    return false;
1101 }
1102 
1103 /* Combine an p_extract (or p_insert, in some cases) instruction with instr.
1104  * instr(p_extract(...)) -> instr()
1105  */
1106 void
apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1107 apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1108 {
1109    Temp tmp = info.instr->operands[0].getTemp();
1110    SubdwordSel sel = parse_extract(info.instr);
1111    assert(sel);
1112 
1113    instr->operands[idx].set16bit(false);
1114    instr->operands[idx].set24bit(false);
1115 
1116    ctx.info[tmp.id()].label &= ~label_insert;
1117 
1118    if (sel.size() == 4) {
1119       /* full dword selection */
1120    } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel.size() == 1 && !sel.sign_extend()) {
1121       switch (sel.offset()) {
1122       case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break;
1123       case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break;
1124       case 2: instr->opcode = aco_opcode::v_cvt_f32_ubyte2; break;
1125       case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break;
1126       }
1127    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1128               sel.offset() == 0 &&
1129               ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1130                (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1131       /* The undesireable upper bits are already shifted out. */
1132       return;
1133    } else if (can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1134               (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1135       to_SDWA(ctx, instr);
1136       static_cast<SDWA_instruction*>(instr.get())->sel[idx] = sel;
1137    } else if (instr->isVOP3()) {
1138       if (sel.offset())
1139          instr->vop3().opsel |= 1 << idx;
1140    } else if (instr->opcode == aco_opcode::p_extract) {
1141       SubdwordSel instrSel = parse_extract(instr.get());
1142 
1143       unsigned size = std::min(sel.size(), instrSel.size());
1144       unsigned offset = sel.offset() + instrSel.offset();
1145       unsigned sign_extend =
1146          instrSel.sign_extend() && (sel.sign_extend() || instrSel.size() <= sel.size());
1147 
1148       instr->operands[1] = Operand::c32(offset / size);
1149       instr->operands[2] = Operand::c32(size * 8u);
1150       instr->operands[3] = Operand::c32(sign_extend);
1151       return;
1152    }
1153 
1154    /* Output modifier, label_vopc and label_f2f32 seem to be the only one worth keeping at the
1155     * moment
1156     */
1157    for (Definition& def : instr->definitions)
1158       ctx.info[def.tempId()].label &= (label_vopc | label_f2f32 | instr_mod_labels);
1159 }
1160 
1161 void
check_sdwa_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr)1162 check_sdwa_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1163 {
1164    for (unsigned i = 0; i < instr->operands.size(); i++) {
1165       Operand op = instr->operands[i];
1166       if (!op.isTemp())
1167          continue;
1168       ssa_info& info = ctx.info[op.tempId()];
1169       if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
1170                                 op.getTemp().type() == RegType::sgpr)) {
1171          if (!can_apply_extract(ctx, instr, i, info))
1172             info.label &= ~label_extract;
1173       }
1174    }
1175 }
1176 
1177 bool
does_fp_op_flush_denorms(opt_ctx & ctx,aco_opcode op)1178 does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
1179 {
1180    if (ctx.program->gfx_level <= GFX8) {
1181       switch (op) {
1182       case aco_opcode::v_min_f32:
1183       case aco_opcode::v_max_f32:
1184       case aco_opcode::v_med3_f32:
1185       case aco_opcode::v_min3_f32:
1186       case aco_opcode::v_max3_f32:
1187       case aco_opcode::v_min_f16:
1188       case aco_opcode::v_max_f16: return false;
1189       default: break;
1190       }
1191    }
1192    return op != aco_opcode::v_cndmask_b32;
1193 }
1194 
1195 bool
can_eliminate_fcanonicalize(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp tmp)1196 can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp)
1197 {
1198    float_mode* fp = &ctx.fp_mode;
1199    if (ctx.info[tmp.id()].is_canonicalized() ||
1200        (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1201       return true;
1202 
1203    aco_opcode op = instr->opcode;
1204    return instr_info.can_use_input_modifiers[(int)op] && does_fp_op_flush_denorms(ctx, op);
1205 }
1206 
1207 bool
can_eliminate_and_exec(opt_ctx & ctx,Temp tmp,unsigned pass_flags)1208 can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags)
1209 {
1210    if (ctx.info[tmp.id()].is_vopc()) {
1211       Instruction* vopc_instr = ctx.info[tmp.id()].instr;
1212       /* Remove superfluous s_and when the VOPC instruction uses the same exec and thus
1213        * already produces the same result */
1214       return vopc_instr->pass_flags == pass_flags;
1215    }
1216    if (ctx.info[tmp.id()].is_bitwise()) {
1217       Instruction* instr = ctx.info[tmp.id()].instr;
1218       if (instr->operands.size() != 2 || instr->pass_flags != pass_flags)
1219          return false;
1220       if (!(instr->operands[0].isTemp() && instr->operands[1].isTemp()))
1221          return false;
1222       return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) &&
1223              can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1224    }
1225    return false;
1226 }
1227 
1228 bool
is_copy_label(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info)1229 is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info)
1230 {
1231    return info.is_temp() ||
1232           (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp));
1233 }
1234 
1235 bool
is_op_canonicalized(opt_ctx & ctx,Operand op)1236 is_op_canonicalized(opt_ctx& ctx, Operand op)
1237 {
1238    float_mode* fp = &ctx.fp_mode;
1239    if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) ||
1240        (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1241       return true;
1242 
1243    if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
1244       uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
1245       if (op.bytes() == 2)
1246          return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
1247       else if (op.bytes() == 4)
1248          return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff;
1249    }
1250    return false;
1251 }
1252 
1253 bool
is_scratch_offset_valid(opt_ctx & ctx,Instruction * instr,int32_t offset)1254 is_scratch_offset_valid(opt_ctx& ctx, Instruction* instr, int32_t offset)
1255 {
1256    bool negative_unaligned_scratch_offset_bug = ctx.program->gfx_level == GFX10;
1257    int32_t min = ctx.program->dev.scratch_global_offset_min;
1258    int32_t max = ctx.program->dev.scratch_global_offset_max;
1259 
1260    bool has_vgpr_offset = instr && !instr->operands[0].isUndefined();
1261    if (negative_unaligned_scratch_offset_bug && has_vgpr_offset && offset < 0 && offset % 4)
1262       return false;
1263 
1264    return offset >= min && offset <= max;
1265 }
1266 
1267 void
label_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)1268 label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1269 {
1270    if (instr->isSALU() || instr->isVALU() || instr->isPseudo()) {
1271       ASSERTED bool all_const = false;
1272       for (Operand& op : instr->operands)
1273          all_const =
1274             all_const && (!op.isTemp() || ctx.info[op.tempId()].is_constant_or_literal(32));
1275       perfwarn(ctx.program, all_const, "All instruction operands are constant", instr.get());
1276 
1277       ASSERTED bool is_copy = instr->opcode == aco_opcode::s_mov_b32 ||
1278                               instr->opcode == aco_opcode::s_mov_b64 ||
1279                               instr->opcode == aco_opcode::v_mov_b32;
1280       perfwarn(ctx.program, is_copy && !instr->usesModifiers(), "Use p_parallelcopy instead",
1281                instr.get());
1282    }
1283 
1284    if (instr->isSMEM())
1285       smem_combine(ctx, instr);
1286 
1287    for (unsigned i = 0; i < instr->operands.size(); i++) {
1288       if (!instr->operands[i].isTemp())
1289          continue;
1290 
1291       ssa_info info = ctx.info[instr->operands[i].tempId()];
1292       /* propagate undef */
1293       if (info.is_undefined() && is_phi(instr))
1294          instr->operands[i] = Operand(instr->operands[i].regClass());
1295       /* propagate reg->reg of same type */
1296       while (info.is_temp() && info.temp.regClass() == instr->operands[i].getTemp().regClass()) {
1297          instr->operands[i].setTemp(ctx.info[instr->operands[i].tempId()].temp);
1298          info = ctx.info[info.temp.id()];
1299       }
1300 
1301       /* PSEUDO: propagate temporaries */
1302       if (instr->isPseudo()) {
1303          while (info.is_temp()) {
1304             pseudo_propagate_temp(ctx, instr, info.temp, i);
1305             info = ctx.info[info.temp.id()];
1306          }
1307       }
1308 
1309       /* SALU / PSEUDO: propagate inline constants */
1310       if (instr->isSALU() || instr->isPseudo()) {
1311          unsigned bits = get_operand_size(instr, i);
1312          if ((info.is_constant(bits) || (info.is_literal(bits) && instr->isPseudo())) &&
1313              !instr->operands[i].isFixed() && alu_can_accept_constant(instr->opcode, i)) {
1314             instr->operands[i] = get_constant_op(ctx, info, bits);
1315             continue;
1316          }
1317       }
1318 
1319       /* VALU: propagate neg, abs & inline constants */
1320       else if (instr->isVALU()) {
1321          if (is_copy_label(ctx, instr, info) && info.temp.type() == RegType::vgpr &&
1322              valu_can_accept_vgpr(instr, i)) {
1323             instr->operands[i].setTemp(info.temp);
1324             info = ctx.info[info.temp.id()];
1325          }
1326          /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
1327          if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) &&
1328              instr->operands.size() == 1) {
1329             instr->format = withoutDPP(instr->format);
1330             instr->operands[i].setTemp(info.temp);
1331             info = ctx.info[info.temp.id()];
1332          }
1333 
1334          /* for instructions other than v_cndmask_b32, the size of the instruction should match the
1335           * operand size */
1336          unsigned can_use_mod =
1337             instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
1338          can_use_mod = can_use_mod && instr_info.can_use_input_modifiers[(int)instr->opcode];
1339 
1340          if (instr->isSDWA())
1341             can_use_mod = can_use_mod && instr->sdwa().sel[i].size() == 4;
1342          else
1343             can_use_mod = can_use_mod && (instr->isDPP16() || can_use_VOP3(ctx, instr));
1344 
1345          unsigned bits = get_operand_size(instr, i);
1346          bool mod_bitsize_compat = instr->operands[i].bytes() * 8 == bits;
1347 
1348          if (info.is_neg() && instr->opcode == aco_opcode::v_add_f32 && mod_bitsize_compat) {
1349             instr->opcode = i ? aco_opcode::v_sub_f32 : aco_opcode::v_subrev_f32;
1350             instr->operands[i].setTemp(info.temp);
1351          } else if (info.is_neg() && instr->opcode == aco_opcode::v_add_f16 && mod_bitsize_compat) {
1352             instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
1353             instr->operands[i].setTemp(info.temp);
1354          } else if (info.is_neg() && can_use_mod && mod_bitsize_compat &&
1355                     can_eliminate_fcanonicalize(ctx, instr, info.temp)) {
1356             if (!instr->isDPP() && !instr->isSDWA())
1357                to_VOP3(ctx, instr);
1358             instr->operands[i].setTemp(info.temp);
1359             if (instr->isDPP16() && !instr->dpp16().abs[i])
1360                instr->dpp16().neg[i] = true;
1361             else if (instr->isSDWA() && !instr->sdwa().abs[i])
1362                instr->sdwa().neg[i] = true;
1363             else if (instr->isVOP3() && !instr->vop3().abs[i])
1364                instr->vop3().neg[i] = true;
1365          }
1366          if (info.is_abs() && can_use_mod && mod_bitsize_compat &&
1367              can_eliminate_fcanonicalize(ctx, instr, info.temp)) {
1368             if (!instr->isDPP() && !instr->isSDWA())
1369                to_VOP3(ctx, instr);
1370             instr->operands[i] = Operand(info.temp);
1371             if (instr->isDPP16())
1372                instr->dpp16().abs[i] = true;
1373             else if (instr->isSDWA())
1374                instr->sdwa().abs[i] = true;
1375             else
1376                instr->vop3().abs[i] = true;
1377             continue;
1378          }
1379 
1380          if (instr->isVOP3P()) {
1381             propagate_constants_vop3p(ctx, instr, info, i);
1382             continue;
1383          }
1384 
1385          if (info.is_constant(bits) && alu_can_accept_constant(instr->opcode, i) &&
1386              (!instr->isSDWA() || ctx.program->gfx_level >= GFX9)) {
1387             Operand op = get_constant_op(ctx, info, bits);
1388             perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2,
1389                      "v_cndmask_b32 with a constant selector", instr.get());
1390             if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
1391                 instr->opcode == aco_opcode::v_writelane_b32) {
1392                instr->format = withoutDPP(instr->format);
1393                instr->operands[i] = op;
1394                continue;
1395             } else if (!instr->isVOP3() && can_swap_operands(instr, &instr->opcode)) {
1396                instr->operands[i] = instr->operands[0];
1397                instr->operands[0] = op;
1398                continue;
1399             } else if (can_use_VOP3(ctx, instr)) {
1400                to_VOP3(ctx, instr);
1401                instr->operands[i] = op;
1402                continue;
1403             }
1404          }
1405       }
1406 
1407       /* MUBUF: propagate constants and combine additions */
1408       else if (instr->isMUBUF()) {
1409          MUBUF_instruction& mubuf = instr->mubuf();
1410          Temp base;
1411          uint32_t offset;
1412          while (info.is_temp())
1413             info = ctx.info[info.temp.id()];
1414 
1415          /* According to AMDGPUDAGToDAGISel::SelectMUBUFScratchOffen(), vaddr
1416           * overflow for scratch accesses works only on GFX9+ and saddr overflow
1417           * never works. Since swizzling is the only thing that separates
1418           * scratch accesses and other accesses and swizzling changing how
1419           * addressing works significantly, this probably applies to swizzled
1420           * MUBUF accesses. */
1421          bool vaddr_prevent_overflow = mubuf.swizzled && ctx.program->gfx_level < GFX9;
1422 
1423          if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
1424              mubuf.offset + info.val < 4096) {
1425             assert(!mubuf.idxen);
1426             instr->operands[1] = Operand(v1);
1427             mubuf.offset += info.val;
1428             mubuf.offen = false;
1429             continue;
1430          } else if (i == 2 && info.is_constant_or_literal(32) && mubuf.offset + info.val < 4096) {
1431             instr->operands[2] = Operand::c32(0);
1432             mubuf.offset += info.val;
1433             continue;
1434          } else if (mubuf.offen && i == 1 &&
1435                     parse_base_offset(ctx, instr.get(), i, &base, &offset,
1436                                       vaddr_prevent_overflow) &&
1437                     base.regClass() == v1 && mubuf.offset + offset < 4096) {
1438             assert(!mubuf.idxen);
1439             instr->operands[1].setTemp(base);
1440             mubuf.offset += offset;
1441             continue;
1442          } else if (i == 2 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1443                     base.regClass() == s1 && mubuf.offset + offset < 4096) {
1444             instr->operands[i].setTemp(base);
1445             mubuf.offset += offset;
1446             continue;
1447          }
1448       }
1449 
1450       /* SCRATCH: propagate constants and combine additions */
1451       else if (instr->isScratch()) {
1452          FLAT_instruction& scratch = instr->scratch();
1453          Temp base;
1454          uint32_t offset;
1455          while (info.is_temp())
1456             info = ctx.info[info.temp.id()];
1457 
1458          if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1459              base.regClass() == instr->operands[i].regClass() &&
1460              is_scratch_offset_valid(ctx, instr.get(), scratch.offset + (int32_t)offset)) {
1461             instr->operands[i].setTemp(base);
1462             scratch.offset += (int32_t)offset;
1463             continue;
1464          } else if (i <= 1 && info.is_constant_or_literal(32) &&
1465                     ctx.program->gfx_level >= GFX10_3 &&
1466                     is_scratch_offset_valid(ctx, NULL, scratch.offset + (int32_t)info.val)) {
1467             /* GFX10.3+ can disable both SADDR and ADDR. */
1468             instr->operands[i] = Operand(instr->operands[i].regClass());
1469             scratch.offset += (int32_t)info.val;
1470             continue;
1471          }
1472       }
1473 
1474       /* DS: combine additions */
1475       else if (instr->isDS()) {
1476 
1477          DS_instruction& ds = instr->ds();
1478          Temp base;
1479          uint32_t offset;
1480          bool has_usable_ds_offset = ctx.program->gfx_level >= GFX7;
1481          if (has_usable_ds_offset && i == 0 &&
1482              parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1483              base.regClass() == instr->operands[i].regClass() &&
1484              instr->opcode != aco_opcode::ds_swizzle_b32) {
1485             if (instr->opcode == aco_opcode::ds_write2_b32 ||
1486                 instr->opcode == aco_opcode::ds_read2_b32 ||
1487                 instr->opcode == aco_opcode::ds_write2_b64 ||
1488                 instr->opcode == aco_opcode::ds_read2_b64 ||
1489                 instr->opcode == aco_opcode::ds_write2st64_b32 ||
1490                 instr->opcode == aco_opcode::ds_read2st64_b32 ||
1491                 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1492                 instr->opcode == aco_opcode::ds_read2st64_b64) {
1493                bool is64bit = instr->opcode == aco_opcode::ds_write2_b64 ||
1494                               instr->opcode == aco_opcode::ds_read2_b64 ||
1495                               instr->opcode == aco_opcode::ds_write2st64_b64 ||
1496                               instr->opcode == aco_opcode::ds_read2st64_b64;
1497                bool st64 = instr->opcode == aco_opcode::ds_write2st64_b32 ||
1498                            instr->opcode == aco_opcode::ds_read2st64_b32 ||
1499                            instr->opcode == aco_opcode::ds_write2st64_b64 ||
1500                            instr->opcode == aco_opcode::ds_read2st64_b64;
1501                unsigned shifts = (is64bit ? 3 : 2) + (st64 ? 6 : 0);
1502                unsigned mask = BITFIELD_MASK(shifts);
1503 
1504                if ((offset & mask) == 0 && ds.offset0 + (offset >> shifts) <= 255 &&
1505                    ds.offset1 + (offset >> shifts) <= 255) {
1506                   instr->operands[i].setTemp(base);
1507                   ds.offset0 += offset >> shifts;
1508                   ds.offset1 += offset >> shifts;
1509                }
1510             } else {
1511                if (ds.offset0 + offset <= 65535) {
1512                   instr->operands[i].setTemp(base);
1513                   ds.offset0 += offset;
1514                }
1515             }
1516          }
1517       }
1518 
1519       else if (instr->isBranch()) {
1520          if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1521             /* Flip the branch instruction to get rid of the scc_invert instruction */
1522             instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
1523                                                                      : aco_opcode::p_cbranch_z;
1524             instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
1525          }
1526       }
1527    }
1528 
1529    /* if this instruction doesn't define anything, return */
1530    if (instr->definitions.empty()) {
1531       check_sdwa_extract(ctx, instr);
1532       return;
1533    }
1534 
1535    if (instr->isVALU() || instr->isVINTRP()) {
1536       if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() ||
1537           instr->opcode == aco_opcode::v_cndmask_b32) {
1538          bool canonicalized = true;
1539          if (!does_fp_op_flush_denorms(ctx, instr->opcode)) {
1540             unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size();
1541             for (unsigned i = 0; canonicalized && (i < ops); i++)
1542                canonicalized = is_op_canonicalized(ctx, instr->operands[i]);
1543          }
1544          if (canonicalized)
1545             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1546       }
1547 
1548       if (instr->isVOPC()) {
1549          ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
1550          check_sdwa_extract(ctx, instr);
1551          return;
1552       }
1553       if (instr->isVOP3P()) {
1554          ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
1555          return;
1556       }
1557    }
1558 
1559    switch (instr->opcode) {
1560    case aco_opcode::p_create_vector: {
1561       bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
1562                        instr->operands[0].regClass() == instr->definitions[0].regClass();
1563       if (copy_prop) {
1564          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1565          break;
1566       }
1567 
1568       /* expand vector operands */
1569       std::vector<Operand> ops;
1570       unsigned offset = 0;
1571       for (const Operand& op : instr->operands) {
1572          /* ensure that any expanded operands are properly aligned */
1573          bool aligned = offset % 4 == 0 || op.bytes() < 4;
1574          offset += op.bytes();
1575          if (aligned && op.isTemp() && ctx.info[op.tempId()].is_vec()) {
1576             Instruction* vec = ctx.info[op.tempId()].instr;
1577             for (const Operand& vec_op : vec->operands)
1578                ops.emplace_back(vec_op);
1579          } else {
1580             ops.emplace_back(op);
1581          }
1582       }
1583 
1584       /* combine expanded operands to new vector */
1585       if (ops.size() != instr->operands.size()) {
1586          assert(ops.size() > instr->operands.size());
1587          Definition def = instr->definitions[0];
1588          instr.reset(create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector,
1589                                                             Format::PSEUDO, ops.size(), 1));
1590          for (unsigned i = 0; i < ops.size(); i++) {
1591             if (ops[i].isTemp() && ctx.info[ops[i].tempId()].is_temp() &&
1592                 ops[i].regClass() == ctx.info[ops[i].tempId()].temp.regClass())
1593                ops[i].setTemp(ctx.info[ops[i].tempId()].temp);
1594             instr->operands[i] = ops[i];
1595          }
1596          instr->definitions[0] = def;
1597       } else {
1598          for (unsigned i = 0; i < ops.size(); i++) {
1599             assert(instr->operands[i] == ops[i]);
1600          }
1601       }
1602       ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1603 
1604       if (instr->operands.size() == 2) {
1605          /* check if this is created from split_vector */
1606          if (instr->operands[1].isTemp() && ctx.info[instr->operands[1].tempId()].is_split()) {
1607             Instruction* split = ctx.info[instr->operands[1].tempId()].instr;
1608             if (instr->operands[0].isTemp() &&
1609                 instr->operands[0].getTemp() == split->definitions[0].getTemp())
1610                ctx.info[instr->definitions[0].tempId()].set_temp(split->operands[0].getTemp());
1611          }
1612       }
1613       break;
1614    }
1615    case aco_opcode::p_split_vector: {
1616       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1617 
1618       if (info.is_constant_or_literal(32)) {
1619          uint64_t val = info.val;
1620          for (Definition def : instr->definitions) {
1621             uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
1622             ctx.info[def.tempId()].set_constant(ctx.program->gfx_level, val & mask);
1623             val >>= def.bytes() * 8u;
1624          }
1625          break;
1626       } else if (!info.is_vec()) {
1627          if (instr->definitions.size() == 2 && instr->operands[0].isTemp() &&
1628              instr->definitions[0].bytes() == instr->definitions[1].bytes()) {
1629             ctx.info[instr->definitions[1].tempId()].set_split(instr.get());
1630             if (instr->operands[0].bytes() == 4) {
1631                /* D16 subdword split */
1632                ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1633                ctx.info[instr->definitions[1].tempId()].set_extract(instr.get());
1634             }
1635          }
1636          break;
1637       }
1638 
1639       Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1640       unsigned split_offset = 0;
1641       unsigned vec_offset = 0;
1642       unsigned vec_index = 0;
1643       for (unsigned i = 0; i < instr->definitions.size();
1644            split_offset += instr->definitions[i++].bytes()) {
1645          while (vec_offset < split_offset && vec_index < vec->operands.size())
1646             vec_offset += vec->operands[vec_index++].bytes();
1647 
1648          if (vec_offset != split_offset ||
1649              vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
1650             continue;
1651 
1652          Operand vec_op = vec->operands[vec_index];
1653          if (vec_op.isConstant()) {
1654             ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->gfx_level,
1655                                                                   vec_op.constantValue64());
1656          } else if (vec_op.isUndefined()) {
1657             ctx.info[instr->definitions[i].tempId()].set_undefined();
1658          } else {
1659             assert(vec_op.isTemp());
1660             ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
1661          }
1662       }
1663       break;
1664    }
1665    case aco_opcode::p_extract_vector: { /* mov */
1666       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1667       const unsigned index = instr->operands[1].constantValue();
1668       const unsigned dst_offset = index * instr->definitions[0].bytes();
1669 
1670       if (info.is_vec()) {
1671          /* check if we index directly into a vector element */
1672          Instruction* vec = info.instr;
1673          unsigned offset = 0;
1674 
1675          for (const Operand& op : vec->operands) {
1676             if (offset < dst_offset) {
1677                offset += op.bytes();
1678                continue;
1679             } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
1680                break;
1681             }
1682             instr->operands[0] = op;
1683             break;
1684          }
1685       } else if (info.is_constant_or_literal(32)) {
1686          /* propagate constants */
1687          uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
1688          uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
1689          instr->operands[0] =
1690             Operand::get_const(ctx.program->gfx_level, val, instr->definitions[0].bytes());
1691          ;
1692       }
1693 
1694       if (instr->operands[0].bytes() != instr->definitions[0].bytes()) {
1695          if (instr->operands[0].size() != 1)
1696             break;
1697 
1698          if (index == 0)
1699             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1700          else
1701             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1702          break;
1703       }
1704 
1705       /* convert this extract into a copy instruction */
1706       instr->opcode = aco_opcode::p_parallelcopy;
1707       instr->operands.pop_back();
1708       FALLTHROUGH;
1709    }
1710    case aco_opcode::p_parallelcopy: /* propagate */
1711       if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_vec() &&
1712           instr->operands[0].regClass() != instr->definitions[0].regClass()) {
1713          /* We might not be able to copy-propagate if it's a SGPR->VGPR copy, so
1714           * duplicate the vector instead.
1715           */
1716          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1717          aco_ptr<Instruction> old_copy = std::move(instr);
1718 
1719          instr.reset(create_instruction<Pseudo_instruction>(
1720             aco_opcode::p_create_vector, Format::PSEUDO, vec->operands.size(), 1));
1721          instr->definitions[0] = old_copy->definitions[0];
1722          std::copy(vec->operands.begin(), vec->operands.end(), instr->operands.begin());
1723          for (unsigned i = 0; i < vec->operands.size(); i++) {
1724             Operand& op = instr->operands[i];
1725             if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
1726                 ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
1727                op.setTemp(ctx.info[op.tempId()].temp);
1728          }
1729          ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1730          break;
1731       }
1732       FALLTHROUGH;
1733    case aco_opcode::p_as_uniform:
1734       if (instr->definitions[0].isFixed()) {
1735          /* don't copy-propagate copies into fixed registers */
1736       } else if (instr->usesModifiers()) {
1737          // TODO
1738       } else if (instr->operands[0].isConstant()) {
1739          ctx.info[instr->definitions[0].tempId()].set_constant(
1740             ctx.program->gfx_level, instr->operands[0].constantValue64());
1741       } else if (instr->operands[0].isTemp()) {
1742          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1743          if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
1744             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1745       } else {
1746          assert(instr->operands[0].isFixed());
1747       }
1748       break;
1749    case aco_opcode::v_mov_b32:
1750       if (instr->isDPP16()) {
1751          /* anything else doesn't make sense in SSA */
1752          assert(instr->dpp16().row_mask == 0xf && instr->dpp16().bank_mask == 0xf);
1753          ctx.info[instr->definitions[0].tempId()].set_dpp16(instr.get());
1754       } else if (instr->isDPP8()) {
1755          ctx.info[instr->definitions[0].tempId()].set_dpp8(instr.get());
1756       }
1757       break;
1758    case aco_opcode::p_is_helper:
1759       if (!ctx.program->needs_wqm)
1760          ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1761       break;
1762    case aco_opcode::v_mul_f64: ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); break;
1763    case aco_opcode::v_mul_f16:
1764    case aco_opcode::v_mul_f32:
1765    case aco_opcode::v_mul_legacy_f32: { /* omod */
1766       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
1767 
1768       /* TODO: try to move the negate/abs modifier to the consumer instead */
1769       bool uses_mods = instr->usesModifiers();
1770       bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
1771 
1772       for (unsigned i = 0; i < 2; i++) {
1773          if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
1774             if (!instr->isDPP() && !instr->isSDWA() &&
1775                 (instr->operands[!i].constantEquals(fp16 ? 0x3c00 : 0x3f800000) ||   /* 1.0 */
1776                  instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u))) { /* -1.0 */
1777                bool neg1 = instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u);
1778 
1779                VOP3_instruction* vop3 = instr->isVOP3() ? &instr->vop3() : NULL;
1780                if (vop3 && (vop3->abs[!i] || vop3->neg[!i] || vop3->clamp || vop3->omod))
1781                   continue;
1782 
1783                bool abs = vop3 && vop3->abs[i];
1784                bool neg = neg1 ^ (vop3 && vop3->neg[i]);
1785 
1786                Temp other = instr->operands[i].getTemp();
1787                if (abs && neg && other.type() == RegType::vgpr)
1788                   ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
1789                else if (abs && !neg && other.type() == RegType::vgpr)
1790                   ctx.info[instr->definitions[0].tempId()].set_abs(other);
1791                else if (!abs && neg && other.type() == RegType::vgpr)
1792                   ctx.info[instr->definitions[0].tempId()].set_neg(other);
1793                else if (!abs && !neg)
1794                   ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
1795             } else if (uses_mods) {
1796                continue;
1797             } else if (instr->operands[!i].constantValue() ==
1798                        (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */
1799                ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
1800             } else if (instr->operands[!i].constantValue() ==
1801                        (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */
1802                ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
1803             } else if (instr->operands[!i].constantValue() ==
1804                        (fp16 ? 0x3800 : 0x3f000000)) { /* 0.5 */
1805                ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
1806             } else if (instr->operands[!i].constantValue() == 0u &&
1807                        (!(fp16 ? ctx.fp_mode.preserve_signed_zero_inf_nan16_64
1808                                : ctx.fp_mode.preserve_signed_zero_inf_nan32) ||
1809                         instr->opcode == aco_opcode::v_mul_legacy_f32)) { /* 0.0 */
1810                ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1811             } else {
1812                continue;
1813             }
1814             break;
1815          }
1816       }
1817       break;
1818    }
1819    case aco_opcode::v_mul_lo_u16:
1820    case aco_opcode::v_mul_lo_u16_e64:
1821    case aco_opcode::v_mul_u32_u24:
1822       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1823       break;
1824    case aco_opcode::v_med3_f16:
1825    case aco_opcode::v_med3_f32: { /* clamp */
1826       VOP3_instruction& vop3 = instr->vop3();
1827       if (vop3.abs[0] || vop3.abs[1] || vop3.abs[2] || vop3.neg[0] || vop3.neg[1] || vop3.neg[2] ||
1828           vop3.omod != 0 || vop3.opsel != 0)
1829          break;
1830 
1831       unsigned idx = 0;
1832       bool found_zero = false, found_one = false;
1833       bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16;
1834       for (unsigned i = 0; i < 3; i++) {
1835          if (instr->operands[i].constantEquals(0))
1836             found_zero = true;
1837          else if (instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */
1838             found_one = true;
1839          else
1840             idx = i;
1841       }
1842       if (found_zero && found_one && instr->operands[idx].isTemp())
1843          ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
1844       break;
1845    }
1846    case aco_opcode::v_cndmask_b32:
1847       if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(0xFFFFFFFF))
1848          ctx.info[instr->definitions[0].tempId()].set_vcc(instr->operands[2].getTemp());
1849       else if (instr->operands[0].constantEquals(0) &&
1850                instr->operands[1].constantEquals(0x3f800000u))
1851          ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
1852       else if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(1))
1853          ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
1854 
1855       break;
1856    case aco_opcode::v_cmp_lg_u32:
1857       if (instr->format == Format::VOPC && /* don't optimize VOP3 / SDWA / DPP */
1858           instr->operands[0].constantEquals(0) && instr->operands[1].isTemp() &&
1859           ctx.info[instr->operands[1].tempId()].is_vcc())
1860          ctx.info[instr->definitions[0].tempId()].set_temp(
1861             ctx.info[instr->operands[1].tempId()].temp);
1862       break;
1863    case aco_opcode::p_linear_phi: {
1864       /* lower_bool_phis() can create phis like this */
1865       bool all_same_temp = instr->operands[0].isTemp();
1866       /* this check is needed when moving uniform loop counters out of a divergent loop */
1867       if (all_same_temp)
1868          all_same_temp = instr->definitions[0].regClass() == instr->operands[0].regClass();
1869       for (unsigned i = 1; all_same_temp && (i < instr->operands.size()); i++) {
1870          if (!instr->operands[i].isTemp() ||
1871              instr->operands[i].tempId() != instr->operands[0].tempId())
1872             all_same_temp = false;
1873       }
1874       if (all_same_temp) {
1875          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1876       } else {
1877          bool all_undef = instr->operands[0].isUndefined();
1878          for (unsigned i = 1; all_undef && (i < instr->operands.size()); i++) {
1879             if (!instr->operands[i].isUndefined())
1880                all_undef = false;
1881          }
1882          if (all_undef)
1883             ctx.info[instr->definitions[0].tempId()].set_undefined();
1884       }
1885       break;
1886    }
1887    case aco_opcode::v_add_u32:
1888    case aco_opcode::v_add_co_u32:
1889    case aco_opcode::v_add_co_u32_e64:
1890    case aco_opcode::s_add_i32:
1891    case aco_opcode::s_add_u32:
1892    case aco_opcode::v_subbrev_co_u32:
1893    case aco_opcode::v_sub_u32:
1894    case aco_opcode::v_sub_i32:
1895    case aco_opcode::v_sub_co_u32:
1896    case aco_opcode::v_sub_co_u32_e64:
1897    case aco_opcode::s_sub_u32:
1898    case aco_opcode::s_sub_i32:
1899    case aco_opcode::v_subrev_u32:
1900    case aco_opcode::v_subrev_co_u32:
1901    case aco_opcode::v_subrev_co_u32_e64:
1902       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
1903       break;
1904    case aco_opcode::s_not_b32:
1905    case aco_opcode::s_not_b64:
1906       if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1907          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1908          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1909             ctx.info[instr->operands[0].tempId()].temp);
1910       } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1911          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1912          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1913             ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1914       }
1915       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1916       break;
1917    case aco_opcode::s_and_b32:
1918    case aco_opcode::s_and_b64:
1919       if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
1920          if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1921             /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a
1922              * uniform bool into divergent */
1923             ctx.info[instr->definitions[1].tempId()].set_temp(
1924                ctx.info[instr->operands[0].tempId()].temp);
1925             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1926                ctx.info[instr->operands[0].tempId()].temp);
1927             break;
1928          } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1929             /* Try to get rid of the superfluous s_and_b64, since the uniform bitwise instruction
1930              * already produces the same SCC */
1931             ctx.info[instr->definitions[1].tempId()].set_temp(
1932                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1933             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1934                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1935             break;
1936          } else if ((ctx.program->stage.num_sw_stages() > 1 ||
1937                      ctx.program->stage.hw == HWStage::NGG) &&
1938                     instr->pass_flags == 1) {
1939             /* In case of merged shaders, pass_flags=1 means that all lanes are active (exec=-1), so
1940              * s_and is unnecessary. */
1941             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1942             break;
1943          } else if (can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), instr->pass_flags)) {
1944             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1945             break;
1946          }
1947       }
1948       FALLTHROUGH;
1949    case aco_opcode::s_or_b32:
1950    case aco_opcode::s_or_b64:
1951    case aco_opcode::s_xor_b32:
1952    case aco_opcode::s_xor_b64:
1953       if (std::all_of(instr->operands.begin(), instr->operands.end(),
1954                       [&ctx](const Operand& op)
1955                       {
1956                          return op.isTemp() && (ctx.info[op.tempId()].is_uniform_bool() ||
1957                                                 ctx.info[op.tempId()].is_uniform_bitwise());
1958                       })) {
1959          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1960       }
1961       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1962       break;
1963    case aco_opcode::s_lshl_b32:
1964    case aco_opcode::v_or_b32:
1965    case aco_opcode::v_lshlrev_b32:
1966    case aco_opcode::v_bcnt_u32_b32:
1967    case aco_opcode::v_and_b32:
1968    case aco_opcode::v_xor_b32:
1969       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1970       break;
1971    case aco_opcode::v_min_f32:
1972    case aco_opcode::v_min_f16:
1973    case aco_opcode::v_min_u32:
1974    case aco_opcode::v_min_i32:
1975    case aco_opcode::v_min_u16:
1976    case aco_opcode::v_min_i16:
1977    case aco_opcode::v_min_u16_e64:
1978    case aco_opcode::v_min_i16_e64:
1979    case aco_opcode::v_max_f32:
1980    case aco_opcode::v_max_f16:
1981    case aco_opcode::v_max_u32:
1982    case aco_opcode::v_max_i32:
1983    case aco_opcode::v_max_u16:
1984    case aco_opcode::v_max_i16:
1985    case aco_opcode::v_max_u16_e64:
1986    case aco_opcode::v_max_i16_e64:
1987       ctx.info[instr->definitions[0].tempId()].set_minmax(instr.get());
1988       break;
1989    case aco_opcode::s_cselect_b64:
1990    case aco_opcode::s_cselect_b32:
1991       if (instr->operands[0].constantEquals((unsigned)-1) && instr->operands[1].constantEquals(0)) {
1992          /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
1993          ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
1994       }
1995       if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
1996          /* Flip the operands to get rid of the scc_invert instruction */
1997          std::swap(instr->operands[0], instr->operands[1]);
1998          instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
1999       }
2000       break;
2001    case aco_opcode::p_wqm:
2002       if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
2003          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
2004       }
2005       break;
2006    case aco_opcode::s_mul_i32:
2007       /* Testing every uint32_t shows that 0x3f800000*n is never a denormal.
2008        * This pattern is created from a uniform nir_op_b2f. */
2009       if (instr->operands[0].constantEquals(0x3f800000u))
2010          ctx.info[instr->definitions[0].tempId()].set_canonicalized();
2011       break;
2012    case aco_opcode::p_extract: {
2013       if (instr->definitions[0].bytes() == 4) {
2014          ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2015          if (instr->operands[0].regClass() == v1 && parse_insert(instr.get()))
2016             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2017       }
2018       break;
2019    }
2020    case aco_opcode::p_insert: {
2021       if (instr->operands[0].bytes() == 4) {
2022          if (instr->operands[0].regClass() == v1)
2023             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2024          if (parse_extract(instr.get()))
2025             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2026          ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2027       }
2028       break;
2029    }
2030    case aco_opcode::ds_read_u8:
2031    case aco_opcode::ds_read_u8_d16:
2032    case aco_opcode::ds_read_u16:
2033    case aco_opcode::ds_read_u16_d16: {
2034       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2035       break;
2036    }
2037    case aco_opcode::v_cvt_f16_f32: {
2038       if (instr->operands[0].isTemp())
2039          ctx.info[instr->operands[0].tempId()].set_f2f16(instr.get());
2040       break;
2041    }
2042    case aco_opcode::v_cvt_f32_f16: {
2043       if (instr->operands[0].isTemp())
2044          ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
2045       break;
2046    }
2047    default: break;
2048    }
2049 
2050    /* Don't remove label_extract if we can't apply the extract to
2051     * neg/abs instructions because we'll likely combine it into another valu. */
2052    if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)))
2053       check_sdwa_extract(ctx, instr);
2054 }
2055 
2056 unsigned
original_temp_id(opt_ctx & ctx,Temp tmp)2057 original_temp_id(opt_ctx& ctx, Temp tmp)
2058 {
2059    if (ctx.info[tmp.id()].is_temp())
2060       return ctx.info[tmp.id()].temp.id();
2061    else
2062       return tmp.id();
2063 }
2064 
2065 void
decrease_uses(opt_ctx & ctx,Instruction * instr)2066 decrease_uses(opt_ctx& ctx, Instruction* instr)
2067 {
2068    if (!--ctx.uses[instr->definitions[0].tempId()]) {
2069       for (const Operand& op : instr->operands) {
2070          if (op.isTemp())
2071             ctx.uses[op.tempId()]--;
2072       }
2073    }
2074 }
2075 
2076 Instruction*
follow_operand(opt_ctx & ctx,Operand op,bool ignore_uses=false)2077 follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false)
2078 {
2079    if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels))
2080       return nullptr;
2081    if (!ignore_uses && ctx.uses[op.tempId()] > 1)
2082       return nullptr;
2083 
2084    Instruction* instr = ctx.info[op.tempId()].instr;
2085 
2086    if (instr->definitions.size() == 2) {
2087       assert(instr->definitions[0].isTemp() && instr->definitions[0].tempId() == op.tempId());
2088       if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2089          return nullptr;
2090    }
2091 
2092    return instr;
2093 }
2094 
2095 /* s_or_b64(neq(a, a), neq(b, b)) -> v_cmp_u_f32(a, b)
2096  * s_and_b64(eq(a, a), eq(b, b)) -> v_cmp_o_f32(a, b) */
2097 bool
combine_ordering_test(opt_ctx & ctx,aco_ptr<Instruction> & instr)2098 combine_ordering_test(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2099 {
2100    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2101       return false;
2102    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2103       return false;
2104 
2105    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2106 
2107    bool neg[2] = {false, false};
2108    bool abs[2] = {false, false};
2109    uint8_t opsel = 0;
2110    Instruction* op_instr[2];
2111    Temp op[2];
2112 
2113    unsigned bitsize = 0;
2114    for (unsigned i = 0; i < 2; i++) {
2115       op_instr[i] = follow_operand(ctx, instr->operands[i], true);
2116       if (!op_instr[i])
2117          return false;
2118 
2119       aco_opcode expected_cmp = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
2120       unsigned op_bitsize = get_cmp_bitsize(op_instr[i]->opcode);
2121 
2122       if (get_f32_cmp(op_instr[i]->opcode) != expected_cmp)
2123          return false;
2124       if (bitsize && op_bitsize != bitsize)
2125          return false;
2126       if (!op_instr[i]->operands[0].isTemp() || !op_instr[i]->operands[1].isTemp())
2127          return false;
2128 
2129       if (op_instr[i]->isVOP3()) {
2130          VOP3_instruction& vop3 = op_instr[i]->vop3();
2131          if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel == 1 ||
2132              vop3.opsel == 2)
2133             return false;
2134          neg[i] = vop3.neg[0];
2135          abs[i] = vop3.abs[0];
2136          opsel |= (vop3.opsel & 1) << i;
2137       } else if (op_instr[i]->isSDWA()) {
2138          return false;
2139       }
2140 
2141       Temp op0 = op_instr[i]->operands[0].getTemp();
2142       Temp op1 = op_instr[i]->operands[1].getTemp();
2143       if (original_temp_id(ctx, op0) != original_temp_id(ctx, op1))
2144          return false;
2145 
2146       op[i] = op1;
2147       bitsize = op_bitsize;
2148    }
2149 
2150    if (op[1].type() == RegType::sgpr)
2151       std::swap(op[0], op[1]);
2152    unsigned num_sgprs = (op[0].type() == RegType::sgpr) + (op[1].type() == RegType::sgpr);
2153    if (num_sgprs > (ctx.program->gfx_level >= GFX10 ? 2 : 1))
2154       return false;
2155 
2156    ctx.uses[op[0].id()]++;
2157    ctx.uses[op[1].id()]++;
2158    decrease_uses(ctx, op_instr[0]);
2159    decrease_uses(ctx, op_instr[1]);
2160 
2161    aco_opcode new_op = aco_opcode::num_opcodes;
2162    switch (bitsize) {
2163    case 16: new_op = is_or ? aco_opcode::v_cmp_u_f16 : aco_opcode::v_cmp_o_f16; break;
2164    case 32: new_op = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32; break;
2165    case 64: new_op = is_or ? aco_opcode::v_cmp_u_f64 : aco_opcode::v_cmp_o_f64; break;
2166    }
2167    Instruction* new_instr;
2168    if (neg[0] || neg[1] || abs[0] || abs[1] || opsel || num_sgprs > 1) {
2169       VOP3_instruction* vop3 =
2170          create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2171       for (unsigned i = 0; i < 2; i++) {
2172          vop3->neg[i] = neg[i];
2173          vop3->abs[i] = abs[i];
2174       }
2175       vop3->opsel = opsel;
2176       new_instr = static_cast<Instruction*>(vop3);
2177    } else {
2178       new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2179    }
2180    new_instr->operands[0] = Operand(op[0]);
2181    new_instr->operands[1] = Operand(op[1]);
2182    new_instr->definitions[0] = instr->definitions[0];
2183 
2184    ctx.info[instr->definitions[0].tempId()].label = 0;
2185    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2186 
2187    instr.reset(new_instr);
2188 
2189    return true;
2190 }
2191 
2192 /* s_or_b64(v_cmp_u_f32(a, b), cmp(a, b)) -> get_unordered(cmp)(a, b)
2193  * s_and_b64(v_cmp_o_f32(a, b), cmp(a, b)) -> get_ordered(cmp)(a, b) */
2194 bool
combine_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2195 combine_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2196 {
2197    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2198       return false;
2199    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2200       return false;
2201 
2202    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2203    aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
2204 
2205    Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2206    Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2207    if (!nan_test || !cmp)
2208       return false;
2209    if (nan_test->isSDWA() || cmp->isSDWA())
2210       return false;
2211 
2212    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2213       std::swap(nan_test, cmp);
2214    else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2215       return false;
2216 
2217    if (!is_cmp(cmp->opcode) || get_cmp_bitsize(cmp->opcode) != get_cmp_bitsize(nan_test->opcode))
2218       return false;
2219 
2220    if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2221       return false;
2222    if (!cmp->operands[0].isTemp() || !cmp->operands[1].isTemp())
2223       return false;
2224 
2225    unsigned prop_cmp0 = original_temp_id(ctx, cmp->operands[0].getTemp());
2226    unsigned prop_cmp1 = original_temp_id(ctx, cmp->operands[1].getTemp());
2227    unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2228    unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2229    if (prop_cmp0 != prop_nan0 && prop_cmp0 != prop_nan1)
2230       return false;
2231    if (prop_cmp1 != prop_nan0 && prop_cmp1 != prop_nan1)
2232       return false;
2233 
2234    ctx.uses[cmp->operands[0].tempId()]++;
2235    ctx.uses[cmp->operands[1].tempId()]++;
2236    decrease_uses(ctx, nan_test);
2237    decrease_uses(ctx, cmp);
2238 
2239    aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2240    Instruction* new_instr;
2241    if (cmp->isVOP3()) {
2242       VOP3_instruction* new_vop3 =
2243          create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2244       VOP3_instruction& cmp_vop3 = cmp->vop3();
2245       memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2246       memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2247       new_vop3->clamp = cmp_vop3.clamp;
2248       new_vop3->omod = cmp_vop3.omod;
2249       new_vop3->opsel = cmp_vop3.opsel;
2250       new_instr = new_vop3;
2251    } else {
2252       new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2253    }
2254    new_instr->operands[0] = cmp->operands[0];
2255    new_instr->operands[1] = cmp->operands[1];
2256    new_instr->definitions[0] = instr->definitions[0];
2257 
2258    ctx.info[instr->definitions[0].tempId()].label = 0;
2259    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2260 
2261    instr.reset(new_instr);
2262 
2263    return true;
2264 }
2265 
2266 bool
is_operand_constant(opt_ctx & ctx,Operand op,unsigned bit_size,uint64_t * value)2267 is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value)
2268 {
2269    if (op.isConstant()) {
2270       *value = op.constantValue64();
2271       return true;
2272    } else if (op.isTemp()) {
2273       unsigned id = original_temp_id(ctx, op.getTemp());
2274       if (!ctx.info[id].is_constant_or_literal(bit_size))
2275          return false;
2276       *value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
2277       return true;
2278    }
2279    return false;
2280 }
2281 
2282 bool
is_constant_nan(uint64_t value,unsigned bit_size)2283 is_constant_nan(uint64_t value, unsigned bit_size)
2284 {
2285    if (bit_size == 16)
2286       return ((value >> 10) & 0x1f) == 0x1f && (value & 0x3ff);
2287    else if (bit_size == 32)
2288       return ((value >> 23) & 0xff) == 0xff && (value & 0x7fffff);
2289    else
2290       return ((value >> 52) & 0x7ff) == 0x7ff && (value & 0xfffffffffffff);
2291 }
2292 
2293 /* s_or_b64(v_cmp_neq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_unordered(cmp)(a, b)
2294  * s_and_b64(v_cmp_eq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_ordered(cmp)(a, b) */
2295 bool
combine_constant_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2296 combine_constant_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2297 {
2298    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2299       return false;
2300    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2301       return false;
2302 
2303    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2304 
2305    Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2306    Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2307 
2308    if (!nan_test || !cmp || nan_test->isSDWA() || cmp->isSDWA())
2309       return false;
2310    if (nan_test->isSDWA() || cmp->isSDWA())
2311       return false;
2312 
2313    aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
2314    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2315       std::swap(nan_test, cmp);
2316    else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2317       return false;
2318 
2319    unsigned bit_size = get_cmp_bitsize(cmp->opcode);
2320    if (!is_cmp(cmp->opcode) || get_cmp_bitsize(nan_test->opcode) != bit_size)
2321       return false;
2322 
2323    if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2324       return false;
2325    if (!cmp->operands[0].isTemp() && !cmp->operands[1].isTemp())
2326       return false;
2327 
2328    unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2329    unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2330    if (prop_nan0 != prop_nan1)
2331       return false;
2332 
2333    if (nan_test->isVOP3()) {
2334       VOP3_instruction& vop3 = nan_test->vop3();
2335       if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel == 1 ||
2336           vop3.opsel == 2)
2337          return false;
2338    }
2339 
2340    int constant_operand = -1;
2341    for (unsigned i = 0; i < 2; i++) {
2342       if (cmp->operands[i].isTemp() &&
2343           original_temp_id(ctx, cmp->operands[i].getTemp()) == prop_nan0) {
2344          constant_operand = !i;
2345          break;
2346       }
2347    }
2348    if (constant_operand == -1)
2349       return false;
2350 
2351    uint64_t constant_value;
2352    if (!is_operand_constant(ctx, cmp->operands[constant_operand], bit_size, &constant_value))
2353       return false;
2354    if (is_constant_nan(constant_value, bit_size))
2355       return false;
2356 
2357    if (cmp->operands[0].isTemp())
2358       ctx.uses[cmp->operands[0].tempId()]++;
2359    if (cmp->operands[1].isTemp())
2360       ctx.uses[cmp->operands[1].tempId()]++;
2361    decrease_uses(ctx, nan_test);
2362    decrease_uses(ctx, cmp);
2363 
2364    aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2365    Instruction* new_instr;
2366    if (cmp->isVOP3()) {
2367       VOP3_instruction* new_vop3 =
2368          create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2369       VOP3_instruction& cmp_vop3 = cmp->vop3();
2370       memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2371       memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2372       new_vop3->clamp = cmp_vop3.clamp;
2373       new_vop3->omod = cmp_vop3.omod;
2374       new_vop3->opsel = cmp_vop3.opsel;
2375       new_instr = new_vop3;
2376    } else {
2377       new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2378    }
2379    new_instr->operands[0] = cmp->operands[0];
2380    new_instr->operands[1] = cmp->operands[1];
2381    new_instr->definitions[0] = instr->definitions[0];
2382 
2383    ctx.info[instr->definitions[0].tempId()].label = 0;
2384    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2385 
2386    instr.reset(new_instr);
2387 
2388    return true;
2389 }
2390 
2391 /* s_andn2(exec, cmp(a, b)) -> get_inverse(cmp)(a, b) */
2392 bool
combine_inverse_comparison(opt_ctx & ctx,aco_ptr<Instruction> & instr)2393 combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2394 {
2395    if (!instr->operands[0].isFixed() || instr->operands[0].physReg() != exec)
2396       return false;
2397    if (ctx.uses[instr->definitions[1].tempId()])
2398       return false;
2399 
2400    Instruction* cmp = follow_operand(ctx, instr->operands[1]);
2401    if (!cmp)
2402       return false;
2403 
2404    aco_opcode new_opcode = get_inverse(cmp->opcode);
2405    if (new_opcode == aco_opcode::num_opcodes)
2406       return false;
2407 
2408    if (cmp->operands[0].isTemp())
2409       ctx.uses[cmp->operands[0].tempId()]++;
2410    if (cmp->operands[1].isTemp())
2411       ctx.uses[cmp->operands[1].tempId()]++;
2412    decrease_uses(ctx, cmp);
2413 
2414    /* This creates a new instruction instead of modifying the existing
2415     * comparison so that the comparison is done with the correct exec mask. */
2416    Instruction* new_instr;
2417    if (cmp->isVOP3()) {
2418       VOP3_instruction* new_vop3 =
2419          create_instruction<VOP3_instruction>(new_opcode, asVOP3(Format::VOPC), 2, 1);
2420       VOP3_instruction& cmp_vop3 = cmp->vop3();
2421       memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2422       memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2423       new_vop3->clamp = cmp_vop3.clamp;
2424       new_vop3->omod = cmp_vop3.omod;
2425       new_vop3->opsel = cmp_vop3.opsel;
2426       new_instr = new_vop3;
2427    } else if (cmp->isSDWA()) {
2428       SDWA_instruction* new_sdwa = create_instruction<SDWA_instruction>(
2429          new_opcode, (Format)((uint16_t)Format::SDWA | (uint16_t)Format::VOPC), 2, 1);
2430       SDWA_instruction& cmp_sdwa = cmp->sdwa();
2431       memcpy(new_sdwa->abs, cmp_sdwa.abs, sizeof(new_sdwa->abs));
2432       memcpy(new_sdwa->sel, cmp_sdwa.sel, sizeof(new_sdwa->sel));
2433       memcpy(new_sdwa->neg, cmp_sdwa.neg, sizeof(new_sdwa->neg));
2434       new_sdwa->dst_sel = cmp_sdwa.dst_sel;
2435       new_sdwa->clamp = cmp_sdwa.clamp;
2436       new_sdwa->omod = cmp_sdwa.omod;
2437       new_instr = new_sdwa;
2438    } else if (cmp->isDPP16()) {
2439       DPP16_instruction* new_dpp = create_instruction<DPP16_instruction>(
2440          new_opcode, (Format)((uint16_t)Format::DPP16 | (uint16_t)Format::VOPC), 2, 1);
2441       DPP16_instruction& cmp_dpp = cmp->dpp16();
2442       memcpy(new_dpp->abs, cmp_dpp.abs, sizeof(new_dpp->abs));
2443       memcpy(new_dpp->neg, cmp_dpp.neg, sizeof(new_dpp->neg));
2444       new_dpp->dpp_ctrl = cmp_dpp.dpp_ctrl;
2445       new_dpp->row_mask = cmp_dpp.row_mask;
2446       new_dpp->bank_mask = cmp_dpp.bank_mask;
2447       new_dpp->bound_ctrl = cmp_dpp.bound_ctrl;
2448       new_instr = new_dpp;
2449    } else if (cmp->isDPP8()) {
2450       DPP8_instruction* new_dpp = create_instruction<DPP8_instruction>(
2451          new_opcode, (Format)((uint16_t)Format::DPP8 | (uint16_t)Format::VOPC), 2, 1);
2452       DPP8_instruction& cmp_dpp = cmp->dpp8();
2453       memcpy(new_dpp->lane_sel, cmp_dpp.lane_sel, sizeof(new_dpp->lane_sel));
2454       new_instr = new_dpp;
2455    } else {
2456       new_instr = create_instruction<VOPC_instruction>(new_opcode, Format::VOPC, 2, 1);
2457    }
2458    new_instr->operands[0] = cmp->operands[0];
2459    new_instr->operands[1] = cmp->operands[1];
2460    new_instr->definitions[0] = instr->definitions[0];
2461 
2462    ctx.info[instr->definitions[0].tempId()].label = 0;
2463    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2464 
2465    instr.reset(new_instr);
2466 
2467    return true;
2468 }
2469 
2470 /* op1(op2(1, 2), 0) if swap = false
2471  * op1(0, op2(1, 2)) if swap = true */
2472 bool
match_op3_for_vop3(opt_ctx & ctx,aco_opcode op1,aco_opcode op2,Instruction * op1_instr,bool swap,const char * shuffle_str,Operand operands[3],bool neg[3],bool abs[3],uint8_t * opsel,bool * op1_clamp,uint8_t * op1_omod,bool * inbetween_neg,bool * inbetween_abs,bool * inbetween_opsel,bool * precise)2473 match_op3_for_vop3(opt_ctx& ctx, aco_opcode op1, aco_opcode op2, Instruction* op1_instr, bool swap,
2474                    const char* shuffle_str, Operand operands[3], bool neg[3], bool abs[3],
2475                    uint8_t* opsel, bool* op1_clamp, uint8_t* op1_omod, bool* inbetween_neg,
2476                    bool* inbetween_abs, bool* inbetween_opsel, bool* precise)
2477 {
2478    /* checks */
2479    if (op1_instr->opcode != op1)
2480       return false;
2481 
2482    Instruction* op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
2483    if (!op2_instr || op2_instr->opcode != op2)
2484       return false;
2485    if (fixed_to_exec(op2_instr->operands[0]) || fixed_to_exec(op2_instr->operands[1]))
2486       return false;
2487 
2488    VOP3_instruction* op1_vop3 = op1_instr->isVOP3() ? &op1_instr->vop3() : NULL;
2489    VOP3_instruction* op2_vop3 = op2_instr->isVOP3() ? &op2_instr->vop3() : NULL;
2490 
2491    if (op1_instr->isSDWA() || op2_instr->isSDWA())
2492       return false;
2493    if (op1_instr->isDPP() || op2_instr->isDPP())
2494       return false;
2495 
2496    /* don't support inbetween clamp/omod */
2497    if (op2_vop3 && (op2_vop3->clamp || op2_vop3->omod))
2498       return false;
2499 
2500    /* get operands and modifiers and check inbetween modifiers */
2501    *op1_clamp = op1_vop3 ? op1_vop3->clamp : false;
2502    *op1_omod = op1_vop3 ? op1_vop3->omod : 0u;
2503 
2504    if (inbetween_neg)
2505       *inbetween_neg = op1_vop3 ? op1_vop3->neg[swap] : false;
2506    else if (op1_vop3 && op1_vop3->neg[swap])
2507       return false;
2508 
2509    if (inbetween_abs)
2510       *inbetween_abs = op1_vop3 ? op1_vop3->abs[swap] : false;
2511    else if (op1_vop3 && op1_vop3->abs[swap])
2512       return false;
2513 
2514    if (inbetween_opsel)
2515       *inbetween_opsel = op1_vop3 ? op1_vop3->opsel & (1 << (unsigned)swap) : false;
2516    else if (op1_vop3 && op1_vop3->opsel & (1 << (unsigned)swap))
2517       return false;
2518 
2519    *precise = op1_instr->definitions[0].isPrecise() || op2_instr->definitions[0].isPrecise();
2520 
2521    int shuffle[3];
2522    shuffle[shuffle_str[0] - '0'] = 0;
2523    shuffle[shuffle_str[1] - '0'] = 1;
2524    shuffle[shuffle_str[2] - '0'] = 2;
2525 
2526    operands[shuffle[0]] = op1_instr->operands[!swap];
2527    neg[shuffle[0]] = op1_vop3 ? op1_vop3->neg[!swap] : false;
2528    abs[shuffle[0]] = op1_vop3 ? op1_vop3->abs[!swap] : false;
2529    if (op1_vop3 && (op1_vop3->opsel & (1 << (unsigned)!swap)))
2530       *opsel |= 1 << shuffle[0];
2531 
2532    for (unsigned i = 0; i < 2; i++) {
2533       operands[shuffle[i + 1]] = op2_instr->operands[i];
2534       neg[shuffle[i + 1]] = op2_vop3 ? op2_vop3->neg[i] : false;
2535       abs[shuffle[i + 1]] = op2_vop3 ? op2_vop3->abs[i] : false;
2536       if (op2_vop3 && op2_vop3->opsel & (1 << i))
2537          *opsel |= 1 << shuffle[i + 1];
2538    }
2539 
2540    /* check operands */
2541    if (!check_vop3_operands(ctx, 3, operands))
2542       return false;
2543 
2544    return true;
2545 }
2546 
2547 void
create_vop3_for_op3(opt_ctx & ctx,aco_opcode opcode,aco_ptr<Instruction> & instr,Operand operands[3],bool neg[3],bool abs[3],uint8_t opsel,bool clamp,unsigned omod)2548 create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>& instr,
2549                     Operand operands[3], bool neg[3], bool abs[3], uint8_t opsel, bool clamp,
2550                     unsigned omod)
2551 {
2552    VOP3_instruction* new_instr = create_instruction<VOP3_instruction>(opcode, Format::VOP3, 3, 1);
2553    memcpy(new_instr->abs, abs, sizeof(bool[3]));
2554    memcpy(new_instr->neg, neg, sizeof(bool[3]));
2555    new_instr->clamp = clamp;
2556    new_instr->omod = omod;
2557    new_instr->opsel = opsel;
2558    new_instr->operands[0] = operands[0];
2559    new_instr->operands[1] = operands[1];
2560    new_instr->operands[2] = operands[2];
2561    new_instr->definitions[0] = instr->definitions[0];
2562    ctx.info[instr->definitions[0].tempId()].label = 0;
2563 
2564    instr.reset(new_instr);
2565 }
2566 
2567 bool
combine_three_valu_op(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode op2,aco_opcode new_op,const char * shuffle,uint8_t ops)2568 combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op,
2569                       const char* shuffle, uint8_t ops)
2570 {
2571    for (unsigned swap = 0; swap < 2; swap++) {
2572       if (!((1 << swap) & ops))
2573          continue;
2574 
2575       Operand operands[3];
2576       bool neg[3], abs[3], clamp, precise;
2577       uint8_t opsel = 0, omod = 0;
2578       if (match_op3_for_vop3(ctx, instr->opcode, op2, instr.get(), swap, shuffle, operands, neg,
2579                              abs, &opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2580          ctx.uses[instr->operands[swap].tempId()]--;
2581          create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
2582          return true;
2583       }
2584    }
2585    return false;
2586 }
2587 
2588 /* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
2589 bool
combine_add_or_then_and_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr)2590 combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2591 {
2592    bool is_or = instr->opcode == aco_opcode::v_or_b32;
2593    aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
2594 
2595    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32,
2596                                       "120", 1 | 2))
2597       return true;
2598    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32,
2599                                       "120", 1 | 2))
2600       return true;
2601    if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
2602       return true;
2603    if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
2604       return true;
2605 
2606    if (instr->isSDWA() || instr->isDPP())
2607       return false;
2608 
2609    /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2610     * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2611     * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
2612     * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
2613     */
2614    for (unsigned i = 0; i < 2; i++) {
2615       Instruction* extins = follow_operand(ctx, instr->operands[i]);
2616       if (!extins)
2617          continue;
2618 
2619       aco_opcode op;
2620       Operand operands[3];
2621 
2622       if (extins->opcode == aco_opcode::p_insert &&
2623           (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
2624          op = new_op_lshl;
2625          operands[1] =
2626             Operand::c32(extins->operands[1].constantValue() * extins->operands[2].constantValue());
2627       } else if (is_or &&
2628                  (extins->opcode == aco_opcode::p_insert ||
2629                   (extins->opcode == aco_opcode::p_extract &&
2630                    extins->operands[3].constantEquals(0))) &&
2631                  extins->operands[1].constantEquals(0)) {
2632          op = aco_opcode::v_and_or_b32;
2633          operands[1] = Operand::c32(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
2634       } else {
2635          continue;
2636       }
2637 
2638       operands[0] = extins->operands[0];
2639       operands[2] = instr->operands[!i];
2640 
2641       if (!check_vop3_operands(ctx, 3, operands))
2642          continue;
2643 
2644       bool neg[3] = {}, abs[3] = {};
2645       uint8_t opsel = 0, omod = 0;
2646       bool clamp = false;
2647       if (instr->isVOP3())
2648          clamp = instr->vop3().clamp;
2649 
2650       ctx.uses[instr->operands[i].tempId()]--;
2651       create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
2652       return true;
2653    }
2654 
2655    return false;
2656 }
2657 
2658 bool
combine_minmax(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode opposite,aco_opcode minmax3)2659 combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode minmax3)
2660 {
2661    /* TODO: this can handle SDWA min/max instructions by using opsel */
2662    if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2))
2663       return true;
2664 
2665    /* min(-max(a, b), c) -> min3(c, -a, -b) *
2666     * max(-min(a, b), c) -> max3(c, -a, -b) */
2667    for (unsigned swap = 0; swap < 2; swap++) {
2668       Operand operands[3];
2669       bool neg[3], abs[3], clamp, precise;
2670       uint8_t opsel = 0, omod = 0;
2671       bool inbetween_neg;
2672       if (match_op3_for_vop3(ctx, instr->opcode, opposite, instr.get(), swap, "012", operands, neg,
2673                              abs, &opsel, &clamp, &omod, &inbetween_neg, NULL, NULL, &precise) &&
2674           inbetween_neg) {
2675          ctx.uses[instr->operands[swap].tempId()]--;
2676          neg[1] = !neg[1];
2677          neg[2] = !neg[2];
2678          create_vop3_for_op3(ctx, minmax3, instr, operands, neg, abs, opsel, clamp, omod);
2679          return true;
2680       }
2681    }
2682    return false;
2683 }
2684 
2685 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
2686  * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
2687  * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
2688  * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b)
2689  * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b)
2690  * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */
2691 bool
combine_salu_not_bitwise(opt_ctx & ctx,aco_ptr<Instruction> & instr)2692 combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2693 {
2694    /* checks */
2695    if (!instr->operands[0].isTemp())
2696       return false;
2697    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2698       return false;
2699 
2700    Instruction* op2_instr = follow_operand(ctx, instr->operands[0]);
2701    if (!op2_instr)
2702       return false;
2703    switch (op2_instr->opcode) {
2704    case aco_opcode::s_and_b32:
2705    case aco_opcode::s_or_b32:
2706    case aco_opcode::s_xor_b32:
2707    case aco_opcode::s_and_b64:
2708    case aco_opcode::s_or_b64:
2709    case aco_opcode::s_xor_b64: break;
2710    default: return false;
2711    }
2712 
2713    /* create instruction */
2714    std::swap(instr->definitions[0], op2_instr->definitions[0]);
2715    std::swap(instr->definitions[1], op2_instr->definitions[1]);
2716    ctx.uses[instr->operands[0].tempId()]--;
2717    ctx.info[op2_instr->definitions[0].tempId()].label = 0;
2718 
2719    switch (op2_instr->opcode) {
2720    case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break;
2721    case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break;
2722    case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break;
2723    case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break;
2724    case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break;
2725    case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break;
2726    default: break;
2727    }
2728 
2729    return true;
2730 }
2731 
2732 /* s_and_b32(a, s_not_b32(b)) -> s_andn2_b32(a, b)
2733  * s_or_b32(a, s_not_b32(b)) -> s_orn2_b32(a, b)
2734  * s_and_b64(a, s_not_b64(b)) -> s_andn2_b64(a, b)
2735  * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
2736 bool
combine_salu_n2(opt_ctx & ctx,aco_ptr<Instruction> & instr)2737 combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2738 {
2739    if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
2740       return false;
2741 
2742    for (unsigned i = 0; i < 2; i++) {
2743       Instruction* op2_instr = follow_operand(ctx, instr->operands[i]);
2744       if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 &&
2745                          op2_instr->opcode != aco_opcode::s_not_b64))
2746          continue;
2747       if (ctx.uses[op2_instr->definitions[1].tempId()] || fixed_to_exec(op2_instr->operands[0]))
2748          continue;
2749 
2750       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2751           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2752          continue;
2753 
2754       ctx.uses[instr->operands[i].tempId()]--;
2755       instr->operands[0] = instr->operands[!i];
2756       instr->operands[1] = op2_instr->operands[0];
2757       ctx.info[instr->definitions[0].tempId()].label = 0;
2758 
2759       switch (instr->opcode) {
2760       case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_andn2_b32; break;
2761       case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_orn2_b32; break;
2762       case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_andn2_b64; break;
2763       case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_orn2_b64; break;
2764       default: break;
2765       }
2766 
2767       return true;
2768    }
2769    return false;
2770 }
2771 
2772 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
2773 bool
combine_salu_lshl_add(opt_ctx & ctx,aco_ptr<Instruction> & instr)2774 combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2775 {
2776    if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
2777       return false;
2778 
2779    for (unsigned i = 0; i < 2; i++) {
2780       Instruction* op2_instr = follow_operand(ctx, instr->operands[i], true);
2781       if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
2782           ctx.uses[op2_instr->definitions[1].tempId()])
2783          continue;
2784       if (!op2_instr->operands[1].isConstant() || fixed_to_exec(op2_instr->operands[0]))
2785          continue;
2786 
2787       uint32_t shift = op2_instr->operands[1].constantValue();
2788       if (shift < 1 || shift > 4)
2789          continue;
2790 
2791       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2792           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2793          continue;
2794 
2795       ctx.uses[instr->operands[i].tempId()]--;
2796       instr->operands[1] = instr->operands[!i];
2797       instr->operands[0] = op2_instr->operands[0];
2798       ctx.info[instr->definitions[0].tempId()].label = 0;
2799 
2800       instr->opcode = std::array<aco_opcode, 4>{
2801          aco_opcode::s_lshl1_add_u32, aco_opcode::s_lshl2_add_u32, aco_opcode::s_lshl3_add_u32,
2802          aco_opcode::s_lshl4_add_u32}[shift - 1];
2803 
2804       return true;
2805    }
2806    return false;
2807 }
2808 
2809 bool
combine_add_sub_b2i(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode new_op,uint8_t ops)2810 combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
2811 {
2812    if (instr->usesModifiers())
2813       return false;
2814 
2815    for (unsigned i = 0; i < 2; i++) {
2816       if (!((1 << i) & ops))
2817          continue;
2818       if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2i() &&
2819           ctx.uses[instr->operands[i].tempId()] == 1) {
2820 
2821          aco_ptr<Instruction> new_instr;
2822          if (instr->operands[!i].isTemp() &&
2823              instr->operands[!i].getTemp().type() == RegType::vgpr) {
2824             new_instr.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 2));
2825          } else if (ctx.program->gfx_level >= GFX10 ||
2826                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
2827             new_instr.reset(
2828                create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOP2), 3, 2));
2829          } else {
2830             return false;
2831          }
2832          ctx.uses[instr->operands[i].tempId()]--;
2833          new_instr->definitions[0] = instr->definitions[0];
2834          if (instr->definitions.size() == 2) {
2835             new_instr->definitions[1] = instr->definitions[1];
2836          } else {
2837             new_instr->definitions[1] =
2838                Definition(ctx.program->allocateTmp(ctx.program->lane_mask));
2839             /* Make sure the uses vector is large enough and the number of
2840              * uses properly initialized to 0.
2841              */
2842             ctx.uses.push_back(0);
2843          }
2844          new_instr->operands[0] = Operand::zero();
2845          new_instr->operands[1] = instr->operands[!i];
2846          new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
2847          instr = std::move(new_instr);
2848          ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
2849          return true;
2850       }
2851    }
2852 
2853    return false;
2854 }
2855 
2856 bool
combine_add_bcnt(opt_ctx & ctx,aco_ptr<Instruction> & instr)2857 combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2858 {
2859    if (instr->usesModifiers())
2860       return false;
2861 
2862    for (unsigned i = 0; i < 2; i++) {
2863       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
2864       if (op_instr && op_instr->opcode == aco_opcode::v_bcnt_u32_b32 &&
2865           !op_instr->usesModifiers() && op_instr->operands[0].isTemp() &&
2866           op_instr->operands[0].getTemp().type() == RegType::vgpr &&
2867           op_instr->operands[1].constantEquals(0)) {
2868          aco_ptr<Instruction> new_instr{
2869             create_instruction<VOP3_instruction>(aco_opcode::v_bcnt_u32_b32, Format::VOP3, 2, 1)};
2870          ctx.uses[instr->operands[i].tempId()]--;
2871          new_instr->operands[0] = op_instr->operands[0];
2872          new_instr->operands[1] = instr->operands[!i];
2873          new_instr->definitions[0] = instr->definitions[0];
2874          instr = std::move(new_instr);
2875          ctx.info[instr->definitions[0].tempId()].label = 0;
2876 
2877          return true;
2878       }
2879    }
2880 
2881    return false;
2882 }
2883 
2884 bool
get_minmax_info(aco_opcode op,aco_opcode * min,aco_opcode * max,aco_opcode * min3,aco_opcode * max3,aco_opcode * med3,bool * some_gfx9_only)2885 get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
2886                 aco_opcode* med3, bool* some_gfx9_only)
2887 {
2888    switch (op) {
2889 #define MINMAX(type, gfx9)                                                                         \
2890    case aco_opcode::v_min_##type:                                                                  \
2891    case aco_opcode::v_max_##type:                                                                  \
2892       *min = aco_opcode::v_min_##type;                                                             \
2893       *max = aco_opcode::v_max_##type;                                                             \
2894       *med3 = aco_opcode::v_med3_##type;                                                           \
2895       *min3 = aco_opcode::v_min3_##type;                                                           \
2896       *max3 = aco_opcode::v_max3_##type;                                                           \
2897       *some_gfx9_only = gfx9;                                                                      \
2898       return true;
2899 #define MINMAX_E64(type, gfx9)                                                                     \
2900    case aco_opcode::v_min_##type##_e64:                                                            \
2901    case aco_opcode::v_max_##type##_e64:                                                            \
2902       *min = aco_opcode::v_min_##type##_e64;                                                       \
2903       *max = aco_opcode::v_max_##type##_e64;                                                       \
2904       *med3 = aco_opcode::v_med3_##type;                                                           \
2905       *min3 = aco_opcode::v_min3_##type;                                                           \
2906       *max3 = aco_opcode::v_max3_##type;                                                           \
2907       *some_gfx9_only = gfx9;                                                                      \
2908       return true;
2909       MINMAX(f32, false)
2910       MINMAX(u32, false)
2911       MINMAX(i32, false)
2912       MINMAX(f16, true)
2913       MINMAX(u16, true)
2914       MINMAX(i16, true)
2915       MINMAX_E64(u16, true)
2916       MINMAX_E64(i16, true)
2917 #undef MINMAX_E64
2918 #undef MINMAX
2919    default: return false;
2920    }
2921 }
2922 
2923 /* when ub > lb:
2924  * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2925  * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2926  */
2927 bool
combine_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode min,aco_opcode max,aco_opcode med)2928 combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
2929               aco_opcode med)
2930 {
2931    /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
2932     * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
2933     * minVal > maxVal, which means we can always select it to a v_med3_f32 */
2934    aco_opcode other_op;
2935    if (instr->opcode == min)
2936       other_op = max;
2937    else if (instr->opcode == max)
2938       other_op = min;
2939    else
2940       return false;
2941 
2942    for (unsigned swap = 0; swap < 2; swap++) {
2943       Operand operands[3];
2944       bool neg[3], abs[3], clamp, precise;
2945       uint8_t opsel = 0, omod = 0;
2946       if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
2947                              abs, &opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2948          /* max(min(src, upper), lower) returns upper if src is NaN, but
2949           * med3(src, lower, upper) returns lower.
2950           */
2951          if (precise && instr->opcode != min &&
2952              (min == aco_opcode::v_min_f16 || min == aco_opcode::v_min_f32))
2953             continue;
2954 
2955          int const0_idx = -1, const1_idx = -1;
2956          uint32_t const0 = 0, const1 = 0;
2957          for (int i = 0; i < 3; i++) {
2958             uint32_t val;
2959             bool hi16 = opsel & (1 << i);
2960             if (operands[i].isConstant()) {
2961                val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue();
2962             } else if (operands[i].isTemp() &&
2963                        ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
2964                val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0);
2965             } else {
2966                continue;
2967             }
2968             if (const0_idx >= 0) {
2969                const1_idx = i;
2970                const1 = val;
2971             } else {
2972                const0_idx = i;
2973                const0 = val;
2974             }
2975          }
2976          if (const0_idx < 0 || const1_idx < 0)
2977             continue;
2978 
2979          int lower_idx = const0_idx;
2980          switch (min) {
2981          case aco_opcode::v_min_f32:
2982          case aco_opcode::v_min_f16: {
2983             float const0_f, const1_f;
2984             if (min == aco_opcode::v_min_f32) {
2985                memcpy(&const0_f, &const0, 4);
2986                memcpy(&const1_f, &const1, 4);
2987             } else {
2988                const0_f = _mesa_half_to_float(const0);
2989                const1_f = _mesa_half_to_float(const1);
2990             }
2991             if (abs[const0_idx])
2992                const0_f = fabsf(const0_f);
2993             if (abs[const1_idx])
2994                const1_f = fabsf(const1_f);
2995             if (neg[const0_idx])
2996                const0_f = -const0_f;
2997             if (neg[const1_idx])
2998                const1_f = -const1_f;
2999             lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
3000             break;
3001          }
3002          case aco_opcode::v_min_u32: {
3003             lower_idx = const0 < const1 ? const0_idx : const1_idx;
3004             break;
3005          }
3006          case aco_opcode::v_min_u16:
3007          case aco_opcode::v_min_u16_e64: {
3008             lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
3009             break;
3010          }
3011          case aco_opcode::v_min_i32: {
3012             int32_t const0_i =
3013                const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
3014             int32_t const1_i =
3015                const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
3016             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
3017             break;
3018          }
3019          case aco_opcode::v_min_i16:
3020          case aco_opcode::v_min_i16_e64: {
3021             int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
3022             int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
3023             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
3024             break;
3025          }
3026          default: break;
3027          }
3028          int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
3029 
3030          if (instr->opcode == min) {
3031             if (upper_idx != 0 || lower_idx == 0)
3032                return false;
3033          } else {
3034             if (upper_idx == 0 || lower_idx != 0)
3035                return false;
3036          }
3037 
3038          ctx.uses[instr->operands[swap].tempId()]--;
3039          create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
3040 
3041          return true;
3042       }
3043    }
3044 
3045    return false;
3046 }
3047 
3048 void
apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)3049 apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3050 {
3051    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
3052                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
3053                      instr->opcode == aco_opcode::v_ashrrev_i64;
3054 
3055    /* find candidates and create the set of sgprs already read */
3056    unsigned sgpr_ids[2] = {0, 0};
3057    uint32_t operand_mask = 0;
3058    bool has_literal = false;
3059    for (unsigned i = 0; i < instr->operands.size(); i++) {
3060       if (instr->operands[i].isLiteral())
3061          has_literal = true;
3062       if (!instr->operands[i].isTemp())
3063          continue;
3064       if (instr->operands[i].getTemp().type() == RegType::sgpr) {
3065          if (instr->operands[i].tempId() != sgpr_ids[0])
3066             sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
3067       }
3068       ssa_info& info = ctx.info[instr->operands[i].tempId()];
3069       if (is_copy_label(ctx, instr, info) && info.temp.type() == RegType::sgpr)
3070          operand_mask |= 1u << i;
3071       if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr)
3072          operand_mask |= 1u << i;
3073    }
3074    unsigned max_sgprs = 1;
3075    if (ctx.program->gfx_level >= GFX10 && !is_shift64)
3076       max_sgprs = 2;
3077    if (has_literal)
3078       max_sgprs--;
3079 
3080    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
3081 
3082    /* keep on applying sgprs until there is nothing left to be done */
3083    while (operand_mask) {
3084       uint32_t sgpr_idx = 0;
3085       uint32_t sgpr_info_id = 0;
3086       uint32_t mask = operand_mask;
3087       /* choose a sgpr */
3088       while (mask) {
3089          unsigned i = u_bit_scan(&mask);
3090          uint16_t uses = ctx.uses[instr->operands[i].tempId()];
3091          if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
3092             sgpr_idx = i;
3093             sgpr_info_id = instr->operands[i].tempId();
3094          }
3095       }
3096       operand_mask &= ~(1u << sgpr_idx);
3097 
3098       ssa_info& info = ctx.info[sgpr_info_id];
3099 
3100       /* Applying two sgprs require making it VOP3, so don't do it unless it's
3101        * definitively beneficial.
3102        * TODO: this is too conservative because later the use count could be reduced to 1 */
3103       if (!info.is_extract() && num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() &&
3104           !instr->isSDWA() && instr->format != Format::VOP3P)
3105          break;
3106 
3107       Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp;
3108       bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
3109       if (new_sgpr && num_sgprs >= max_sgprs)
3110          continue;
3111 
3112       if (sgpr_idx == 0)
3113          instr->format = withoutDPP(instr->format);
3114 
3115       if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
3116           info.is_extract()) {
3117          /* can_apply_extract() checks SGPR encoding restrictions */
3118          if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
3119             apply_extract(ctx, instr, sgpr_idx, info);
3120          else if (info.is_extract())
3121             continue;
3122          instr->operands[sgpr_idx] = Operand(sgpr);
3123       } else if (can_swap_operands(instr, &instr->opcode)) {
3124          instr->operands[sgpr_idx] = instr->operands[0];
3125          instr->operands[0] = Operand(sgpr);
3126          /* swap bits using a 4-entry LUT */
3127          uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
3128          operand_mask = (operand_mask & ~0x3) | swapped;
3129       } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
3130          to_VOP3(ctx, instr);
3131          instr->operands[sgpr_idx] = Operand(sgpr);
3132       } else {
3133          continue;
3134       }
3135 
3136       if (new_sgpr)
3137          sgpr_ids[num_sgprs++] = sgpr.id();
3138       ctx.uses[sgpr_info_id]--;
3139       ctx.uses[sgpr.id()]++;
3140 
3141       /* TODO: handle when it's a VGPR */
3142       if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
3143           ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
3144          operand_mask |= 1u << sgpr_idx;
3145    }
3146 }
3147 
3148 template <typename T>
3149 bool
apply_omod_clamp_helper(opt_ctx & ctx,T * instr,ssa_info & def_info)3150 apply_omod_clamp_helper(opt_ctx& ctx, T* instr, ssa_info& def_info)
3151 {
3152    if (!def_info.is_clamp() && (instr->clamp || instr->omod))
3153       return false;
3154 
3155    if (def_info.is_omod2())
3156       instr->omod = 1;
3157    else if (def_info.is_omod4())
3158       instr->omod = 2;
3159    else if (def_info.is_omod5())
3160       instr->omod = 3;
3161    else if (def_info.is_clamp())
3162       instr->clamp = true;
3163 
3164    return true;
3165 }
3166 
3167 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
3168 bool
apply_omod_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr)3169 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3170 {
3171    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
3172        !instr_info.can_use_output_modifiers[(int)instr->opcode])
3173       return false;
3174 
3175    bool can_vop3 = can_use_VOP3(ctx, instr);
3176    bool is_mad_mix =
3177       instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16;
3178    if (!instr->isSDWA() && !is_mad_mix && !can_vop3)
3179       return false;
3180 
3181    /* omod flushes -0 to +0 and has no effect if denormals are enabled. SDWA omod is GFX9+. */
3182    bool can_use_omod = (can_vop3 || ctx.program->gfx_level >= GFX9) && !instr->isVOP3P();
3183    if (instr->definitions[0].bytes() == 4)
3184       can_use_omod =
3185          can_use_omod && ctx.fp_mode.denorm32 == 0 && !ctx.fp_mode.preserve_signed_zero_inf_nan32;
3186    else
3187       can_use_omod = can_use_omod && ctx.fp_mode.denorm16_64 == 0 &&
3188                      !ctx.fp_mode.preserve_signed_zero_inf_nan16_64;
3189 
3190    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3191 
3192    uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
3193    if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
3194       return false;
3195    /* if the omod/clamp instruction is dead, then the single user of this
3196     * instruction is a different instruction */
3197    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3198       return false;
3199 
3200    if (def_info.instr->definitions[0].bytes() != instr->definitions[0].bytes())
3201       return false;
3202 
3203    /* MADs/FMAs are created later, so we don't have to update the original add */
3204    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3205 
3206    if (instr->isSDWA()) {
3207       if (!apply_omod_clamp_helper(ctx, &instr->sdwa(), def_info))
3208          return false;
3209    } else if (instr->isVOP3P()) {
3210       assert(def_info.is_clamp());
3211       instr->vop3p().clamp = true;
3212    } else {
3213       to_VOP3(ctx, instr);
3214       if (!apply_omod_clamp_helper(ctx, &instr->vop3(), def_info))
3215          return false;
3216    }
3217 
3218    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3219    ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
3220    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3221 
3222    return true;
3223 }
3224 
3225 /* Combine an p_insert (or p_extract, in some cases) instruction with instr.
3226  * p_insert(instr(...)) -> instr_insert().
3227  */
3228 bool
apply_insert(opt_ctx & ctx,aco_ptr<Instruction> & instr)3229 apply_insert(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3230 {
3231    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1)
3232       return false;
3233 
3234    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3235    if (!def_info.is_insert())
3236       return false;
3237    /* if the insert instruction is dead, then the single user of this
3238     * instruction is a different instruction */
3239    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3240       return false;
3241 
3242    /* MADs/FMAs are created later, so we don't have to update the original add */
3243    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3244 
3245    SubdwordSel sel = parse_insert(def_info.instr);
3246    assert(sel);
3247 
3248    if (!can_use_SDWA(ctx.program->gfx_level, instr, true))
3249       return false;
3250 
3251    to_SDWA(ctx, instr);
3252    if (instr->sdwa().dst_sel.size() != 4)
3253       return false;
3254    static_cast<SDWA_instruction*>(instr.get())->dst_sel = sel;
3255 
3256    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3257    ctx.info[instr->definitions[0].tempId()].label = 0;
3258    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3259 
3260    return true;
3261 }
3262 
3263 /* Remove superfluous extract after ds_read like so:
3264  * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN()
3265  */
3266 bool
apply_ds_extract(opt_ctx & ctx,aco_ptr<Instruction> & extract)3267 apply_ds_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract)
3268 {
3269    /* Check if p_extract has a usedef operand and is the only user. */
3270    if (!ctx.info[extract->operands[0].tempId()].is_usedef() ||
3271        ctx.uses[extract->operands[0].tempId()] > 1)
3272       return false;
3273 
3274    /* Check if the usedef is a DS instruction. */
3275    Instruction* ds = ctx.info[extract->operands[0].tempId()].instr;
3276    if (ds->format != Format::DS)
3277       return false;
3278 
3279    unsigned extract_idx = extract->operands[1].constantValue();
3280    unsigned bits_extracted = extract->operands[2].constantValue();
3281    unsigned sign_ext = extract->operands[3].constantValue();
3282    unsigned dst_bitsize = extract->definitions[0].bytes() * 8u;
3283 
3284    /* TODO: These are doable, but probably don't occour too often. */
3285    if (extract_idx || sign_ext || dst_bitsize != 32)
3286       return false;
3287 
3288    unsigned bits_loaded = 0;
3289    if (ds->opcode == aco_opcode::ds_read_u8 || ds->opcode == aco_opcode::ds_read_u8_d16)
3290       bits_loaded = 8;
3291    else if (ds->opcode == aco_opcode::ds_read_u16 || ds->opcode == aco_opcode::ds_read_u16_d16)
3292       bits_loaded = 16;
3293    else
3294       return false;
3295 
3296    /* Shrink the DS load if the extracted bit size is smaller. */
3297    bits_loaded = MIN2(bits_loaded, bits_extracted);
3298 
3299    /* Change the DS opcode so it writes the full register. */
3300    if (bits_loaded == 8)
3301       ds->opcode = aco_opcode::ds_read_u8;
3302    else if (bits_loaded == 16)
3303       ds->opcode = aco_opcode::ds_read_u16;
3304    else
3305       unreachable("Forgot to add DS opcode above.");
3306 
3307    /* The DS now produces the exact same thing as the extract, remove the extract. */
3308    std::swap(ds->definitions[0], extract->definitions[0]);
3309    ctx.uses[extract->definitions[0].tempId()] = 0;
3310    ctx.info[ds->definitions[0].tempId()].label = 0;
3311    return true;
3312 }
3313 
3314 /* v_and(a, v_subbrev_co(0, 0, vcc)) -> v_cndmask(0, a, vcc) */
3315 bool
combine_and_subbrev(opt_ctx & ctx,aco_ptr<Instruction> & instr)3316 combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3317 {
3318    if (instr->usesModifiers())
3319       return false;
3320 
3321    for (unsigned i = 0; i < 2; i++) {
3322       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3323       if (op_instr && op_instr->opcode == aco_opcode::v_subbrev_co_u32 &&
3324           op_instr->operands[0].constantEquals(0) && op_instr->operands[1].constantEquals(0) &&
3325           !op_instr->usesModifiers()) {
3326 
3327          aco_ptr<Instruction> new_instr;
3328          if (instr->operands[!i].isTemp() &&
3329              instr->operands[!i].getTemp().type() == RegType::vgpr) {
3330             new_instr.reset(
3331                create_instruction<VOP2_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1));
3332          } else if (ctx.program->gfx_level >= GFX10 ||
3333                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3334             new_instr.reset(create_instruction<VOP3_instruction>(aco_opcode::v_cndmask_b32,
3335                                                                  asVOP3(Format::VOP2), 3, 1));
3336          } else {
3337             return false;
3338          }
3339 
3340          ctx.uses[instr->operands[i].tempId()]--;
3341          if (ctx.uses[instr->operands[i].tempId()])
3342             ctx.uses[op_instr->operands[2].tempId()]++;
3343 
3344          new_instr->operands[0] = Operand::zero();
3345          new_instr->operands[1] = instr->operands[!i];
3346          new_instr->operands[2] = Operand(op_instr->operands[2]);
3347          new_instr->definitions[0] = instr->definitions[0];
3348          instr = std::move(new_instr);
3349          ctx.info[instr->definitions[0].tempId()].label = 0;
3350          return true;
3351       }
3352    }
3353 
3354    return false;
3355 }
3356 
3357 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
3358  * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
3359  * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
3360  * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
3361  */
3362 bool
combine_add_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr,bool is_sub)3363 combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
3364 {
3365    if (instr->usesModifiers())
3366       return false;
3367 
3368    /* Substractions: start at operand 1 to avoid mixup such as
3369     * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
3370     */
3371    unsigned start_op_idx = is_sub ? 1 : 0;
3372 
3373    /* Don't allow 24-bit operands on subtraction because
3374     * v_mad_i32_i24 applies a sign extension.
3375     */
3376    bool allow_24bit = !is_sub;
3377 
3378    for (unsigned i = start_op_idx; i < 2; i++) {
3379       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3380       if (!op_instr)
3381          continue;
3382 
3383       if (op_instr->opcode != aco_opcode::s_lshl_b32 &&
3384           op_instr->opcode != aco_opcode::v_lshlrev_b32)
3385          continue;
3386 
3387       int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
3388 
3389       if (op_instr->operands[shift_op_idx].isConstant() &&
3390           ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
3391            op_instr->operands[!shift_op_idx].is16bit())) {
3392          uint32_t multiplier = 1 << (op_instr->operands[shift_op_idx].constantValue() % 32u);
3393          if (is_sub)
3394             multiplier = -multiplier;
3395          if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
3396             continue;
3397 
3398          Operand ops[3] = {
3399             op_instr->operands[!shift_op_idx],
3400             Operand::c32(multiplier),
3401             instr->operands[!i],
3402          };
3403          if (!check_vop3_operands(ctx, 3, ops))
3404             return false;
3405 
3406          ctx.uses[instr->operands[i].tempId()]--;
3407 
3408          aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : aco_opcode::v_mad_u32_u24;
3409          aco_ptr<VOP3_instruction> new_instr{
3410             create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
3411          for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
3412             new_instr->operands[op_idx] = ops[op_idx];
3413          new_instr->definitions[0] = instr->definitions[0];
3414          instr = std::move(new_instr);
3415          ctx.info[instr->definitions[0].tempId()].label = 0;
3416          return true;
3417       }
3418    }
3419 
3420    return false;
3421 }
3422 
3423 void
propagate_swizzles(VOP3P_instruction * instr,uint8_t opsel_lo,uint8_t opsel_hi)3424 propagate_swizzles(VOP3P_instruction* instr, uint8_t opsel_lo, uint8_t opsel_hi)
3425 {
3426    /* propagate swizzles which apply to a result down to the instruction's operands:
3427     * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
3428    assert((opsel_lo & 1) == opsel_lo);
3429    assert((opsel_hi & 1) == opsel_hi);
3430    uint8_t tmp_lo = instr->opsel_lo;
3431    uint8_t tmp_hi = instr->opsel_hi;
3432    bool neg_lo[3] = {instr->neg_lo[0], instr->neg_lo[1], instr->neg_lo[2]};
3433    bool neg_hi[3] = {instr->neg_hi[0], instr->neg_hi[1], instr->neg_hi[2]};
3434    if (opsel_lo == 1) {
3435       instr->opsel_lo = tmp_hi;
3436       for (unsigned i = 0; i < 3; i++)
3437          instr->neg_lo[i] = neg_hi[i];
3438    }
3439    if (opsel_hi == 0) {
3440       instr->opsel_hi = tmp_lo;
3441       for (unsigned i = 0; i < 3; i++)
3442          instr->neg_hi[i] = neg_lo[i];
3443    }
3444 }
3445 
3446 void
combine_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr)3447 combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3448 {
3449    VOP3P_instruction* vop3p = &instr->vop3p();
3450 
3451    /* apply clamp */
3452    if (instr->opcode == aco_opcode::v_pk_mul_f16 && instr->operands[1].constantEquals(0x3C00) &&
3453        vop3p->clamp && instr->operands[0].isTemp() && ctx.uses[instr->operands[0].tempId()] == 1 &&
3454        !((vop3p->opsel_lo | vop3p->opsel_hi) & 2)) {
3455 
3456       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3457       if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
3458          VOP3P_instruction* candidate = &ctx.info[instr->operands[0].tempId()].instr->vop3p();
3459          candidate->clamp = true;
3460          propagate_swizzles(candidate, vop3p->opsel_lo, vop3p->opsel_hi);
3461          instr->definitions[0].swapTemp(candidate->definitions[0]);
3462          ctx.info[candidate->definitions[0].tempId()].instr = candidate;
3463          ctx.uses[instr->definitions[0].tempId()]--;
3464          return;
3465       }
3466    }
3467 
3468    /* check for fneg modifiers */
3469    if (instr_info.can_use_input_modifiers[(int)instr->opcode]) {
3470       for (unsigned i = 0; i < instr->operands.size(); i++) {
3471          Operand& op = instr->operands[i];
3472          if (!op.isTemp())
3473             continue;
3474 
3475          ssa_info& info = ctx.info[op.tempId()];
3476          if (info.is_vop3p() && info.instr->opcode == aco_opcode::v_pk_mul_f16 &&
3477              info.instr->operands[1].constantEquals(0x3C00)) {
3478 
3479             VOP3P_instruction* fneg = &info.instr->vop3p();
3480 
3481             if ((fneg->opsel_lo | fneg->opsel_hi) & 2)
3482                continue;
3483 
3484             Operand ops[3];
3485             for (unsigned j = 0; j < instr->operands.size(); j++)
3486                ops[j] = instr->operands[j];
3487             ops[i] = info.instr->operands[0];
3488             if (!check_vop3_operands(ctx, instr->operands.size(), ops))
3489                continue;
3490 
3491             if (fneg->clamp)
3492                continue;
3493             instr->operands[i] = fneg->operands[0];
3494 
3495             /* opsel_lo/hi is either 0 or 1:
3496              * if 0 - pick selection from fneg->lo
3497              * if 1 - pick selection from fneg->hi
3498              */
3499             bool opsel_lo = (vop3p->opsel_lo >> i) & 1;
3500             bool opsel_hi = (vop3p->opsel_hi >> i) & 1;
3501             bool neg_lo = fneg->neg_lo[0] ^ fneg->neg_lo[1];
3502             bool neg_hi = fneg->neg_hi[0] ^ fneg->neg_hi[1];
3503             vop3p->neg_lo[i] ^= opsel_lo ? neg_hi : neg_lo;
3504             vop3p->neg_hi[i] ^= opsel_hi ? neg_hi : neg_lo;
3505             vop3p->opsel_lo ^= ((opsel_lo ? ~fneg->opsel_hi : fneg->opsel_lo) & 1) << i;
3506             vop3p->opsel_hi ^= ((opsel_hi ? ~fneg->opsel_hi : fneg->opsel_lo) & 1) << i;
3507 
3508             if (--ctx.uses[fneg->definitions[0].tempId()])
3509                ctx.uses[fneg->operands[0].tempId()]++;
3510          }
3511       }
3512    }
3513 
3514    if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) {
3515       bool fadd = instr->opcode == aco_opcode::v_pk_add_f16;
3516       if (fadd && instr->definitions[0].isPrecise())
3517          return;
3518 
3519       Instruction* mul_instr = nullptr;
3520       unsigned add_op_idx = 0;
3521       uint8_t opsel_lo = 0, opsel_hi = 0;
3522       uint32_t uses = UINT32_MAX;
3523 
3524       /* find the 'best' mul instruction to combine with the add */
3525       for (unsigned i = 0; i < 2; i++) {
3526          if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_vop3p())
3527             continue;
3528          ssa_info& info = ctx.info[instr->operands[i].tempId()];
3529          if (fadd) {
3530             if (info.instr->opcode != aco_opcode::v_pk_mul_f16 ||
3531                 info.instr->definitions[0].isPrecise())
3532                continue;
3533          } else {
3534             if (info.instr->opcode != aco_opcode::v_pk_mul_lo_u16)
3535                continue;
3536          }
3537 
3538          Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
3539          if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3540             continue;
3541 
3542          /* no clamp allowed between mul and add */
3543          if (info.instr->vop3p().clamp)
3544             continue;
3545 
3546          mul_instr = info.instr;
3547          add_op_idx = 1 - i;
3548          opsel_lo = (vop3p->opsel_lo >> i) & 1;
3549          opsel_hi = (vop3p->opsel_hi >> i) & 1;
3550          uses = ctx.uses[instr->operands[i].tempId()];
3551       }
3552 
3553       if (!mul_instr)
3554          return;
3555 
3556       /* convert to mad */
3557       Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1], instr->operands[add_op_idx]};
3558       ctx.uses[mul_instr->definitions[0].tempId()]--;
3559       if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3560          if (op[0].isTemp())
3561             ctx.uses[op[0].tempId()]++;
3562          if (op[1].isTemp())
3563             ctx.uses[op[1].tempId()]++;
3564       }
3565 
3566       /* turn packed mul+add into v_pk_fma_f16 */
3567       assert(mul_instr->isVOP3P());
3568       aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
3569       aco_ptr<VOP3P_instruction> fma{
3570          create_instruction<VOP3P_instruction>(mad, Format::VOP3P, 3, 1)};
3571       VOP3P_instruction* mul = &mul_instr->vop3p();
3572       for (unsigned i = 0; i < 2; i++) {
3573          fma->operands[i] = op[i];
3574          fma->neg_lo[i] = mul->neg_lo[i];
3575          fma->neg_hi[i] = mul->neg_hi[i];
3576       }
3577       fma->operands[2] = op[2];
3578       fma->clamp = vop3p->clamp;
3579       fma->opsel_lo = mul->opsel_lo;
3580       fma->opsel_hi = mul->opsel_hi;
3581       propagate_swizzles(fma.get(), opsel_lo, opsel_hi);
3582       fma->opsel_lo |= (vop3p->opsel_lo << (2 - add_op_idx)) & 0x4;
3583       fma->opsel_hi |= (vop3p->opsel_hi << (2 - add_op_idx)) & 0x4;
3584       fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];
3585       fma->neg_hi[2] = vop3p->neg_hi[add_op_idx];
3586       fma->neg_lo[1] = fma->neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
3587       fma->neg_hi[1] = fma->neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
3588       fma->definitions[0] = instr->definitions[0];
3589       instr = std::move(fma);
3590       ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
3591       return;
3592    }
3593 }
3594 
3595 bool
can_use_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3596 can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3597 {
3598    if (ctx.program->gfx_level < GFX9)
3599       return false;
3600 
3601    /* v_mad_mix* on GFX9 always flushes denormals for 16-bit inputs/outputs */
3602    if (ctx.program->gfx_level == GFX9 && ctx.fp_mode.denorm16_64)
3603       return false;
3604 
3605    switch (instr->opcode) {
3606    case aco_opcode::v_add_f32:
3607    case aco_opcode::v_sub_f32:
3608    case aco_opcode::v_subrev_f32:
3609    case aco_opcode::v_mul_f32:
3610    case aco_opcode::v_fma_f32: break;
3611    case aco_opcode::v_fma_mix_f32:
3612    case aco_opcode::v_fma_mixlo_f16: return true;
3613    default: return false;
3614    }
3615 
3616    if (instr->opcode == aco_opcode::v_fma_f32 && !ctx.program->dev.fused_mad_mix &&
3617        instr->definitions[0].isPrecise())
3618       return false;
3619 
3620    if (instr->isVOP3())
3621       return !instr->vop3().omod && !(instr->vop3().opsel & 0x8);
3622 
3623    return instr->format == Format::VOP2;
3624 }
3625 
3626 void
to_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3627 to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3628 {
3629    bool is_add = instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
3630 
3631    aco_ptr<VOP3P_instruction> vop3p{
3632       create_instruction<VOP3P_instruction>(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
3633 
3634    vop3p->opsel_lo = instr->isVOP3() ? ((instr->vop3().opsel & 0x7) << (is_add ? 1 : 0)) : 0x0;
3635    vop3p->opsel_hi = 0x0;
3636    for (unsigned i = 0; i < instr->operands.size(); i++) {
3637       vop3p->operands[is_add + i] = instr->operands[i];
3638       vop3p->neg_lo[is_add + i] = instr->isVOP3() && instr->vop3().neg[i];
3639       vop3p->neg_lo[is_add + i] |= instr->isSDWA() && instr->sdwa().neg[i];
3640       vop3p->neg_hi[is_add + i] = instr->isVOP3() && instr->vop3().abs[i];
3641       vop3p->neg_hi[is_add + i] |= instr->isSDWA() && instr->sdwa().abs[i];
3642       vop3p->opsel_lo |= (instr->isSDWA() && instr->sdwa().sel[i].offset()) << (is_add + i);
3643    }
3644    if (instr->opcode == aco_opcode::v_mul_f32) {
3645       vop3p->opsel_hi &= 0x3;
3646       vop3p->operands[2] = Operand::zero();
3647       vop3p->neg_lo[2] = true;
3648    } else if (is_add) {
3649       vop3p->opsel_hi &= 0x6;
3650       vop3p->operands[0] = Operand::c32(0x3f800000);
3651       if (instr->opcode == aco_opcode::v_sub_f32)
3652          vop3p->neg_lo[2] ^= true;
3653       else if (instr->opcode == aco_opcode::v_subrev_f32)
3654          vop3p->neg_lo[1] ^= true;
3655    }
3656    vop3p->definitions[0] = instr->definitions[0];
3657    vop3p->clamp = instr->isVOP3() && instr->vop3().clamp;
3658    instr = std::move(vop3p);
3659 
3660    ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
3661    if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
3662       ctx.info[instr->definitions[0].tempId()].instr = instr.get();
3663 }
3664 
3665 bool
combine_output_conversion(opt_ctx & ctx,aco_ptr<Instruction> & instr)3666 combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3667 {
3668    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3669    if (!def_info.is_f2f16())
3670       return false;
3671    Instruction* conv = def_info.instr;
3672 
3673    if (!can_use_mad_mix(ctx, instr) || ctx.uses[instr->definitions[0].tempId()] != 1)
3674       return false;
3675 
3676    if (!ctx.uses[conv->definitions[0].tempId()])
3677       return false;
3678 
3679    if (conv->usesModifiers())
3680       return false;
3681 
3682    if (!instr->isVOP3P())
3683       to_mad_mix(ctx, instr);
3684 
3685    instr->opcode = aco_opcode::v_fma_mixlo_f16;
3686    instr->definitions[0].swapTemp(conv->definitions[0]);
3687    if (conv->definitions[0].isPrecise())
3688       instr->definitions[0].setPrecise(true);
3689    ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
3690    ctx.uses[conv->definitions[0].tempId()]--;
3691 
3692    return true;
3693 }
3694 
3695 void
combine_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3696 combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3697 {
3698    if (!can_use_mad_mix(ctx, instr))
3699       return;
3700 
3701    for (unsigned i = 0; i < instr->operands.size(); i++) {
3702       if (!instr->operands[i].isTemp())
3703          continue;
3704       Temp tmp = instr->operands[i].getTemp();
3705       if (!ctx.info[tmp.id()].is_f2f32())
3706          continue;
3707 
3708       Instruction* conv = ctx.info[tmp.id()].instr;
3709       if (conv->isSDWA() && (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2 ||
3710                              conv->sdwa().clamp || conv->sdwa().omod)) {
3711          continue;
3712       } else if (conv->isVOP3() && (conv->vop3().clamp || conv->vop3().omod)) {
3713          continue;
3714       } else if (conv->isDPP()) {
3715          continue;
3716       }
3717 
3718       if (get_operand_size(instr, i) != 32)
3719          continue;
3720 
3721       /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect
3722        * check_vop3_operands(). */
3723       Operand op[3];
3724       for (unsigned j = 0; j < instr->operands.size(); j++)
3725          op[j] = instr->operands[j];
3726       op[i] = conv->operands[0];
3727       if (!check_vop3_operands(ctx, instr->operands.size(), op))
3728          continue;
3729 
3730       if (!instr->isVOP3P()) {
3731          bool is_add =
3732             instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
3733          to_mad_mix(ctx, instr);
3734          i += is_add;
3735       }
3736 
3737       if (--ctx.uses[tmp.id()])
3738          ctx.uses[conv->operands[0].tempId()]++;
3739       instr->operands[i].setTemp(conv->operands[0].getTemp());
3740       if (conv->definitions[0].isPrecise())
3741          instr->definitions[0].setPrecise(true);
3742       instr->vop3p().opsel_hi ^= 1u << i;
3743       if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
3744          instr->vop3p().opsel_lo |= 1u << i;
3745       bool neg = (conv->isVOP3() && conv->vop3().neg[0]) || (conv->isSDWA() && conv->sdwa().neg[0]);
3746       bool abs = (conv->isVOP3() && conv->vop3().abs[0]) || (conv->isSDWA() && conv->sdwa().abs[0]);
3747       if (!instr->vop3p().neg_hi[i]) {
3748          instr->vop3p().neg_lo[i] ^= neg;
3749          instr->vop3p().neg_hi[i] = abs;
3750       }
3751    }
3752 }
3753 
3754 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
3755 // this would mean that we'd have to fix the instruction uses while value propagation
3756 
3757 /* also returns true for inf */
3758 bool
is_pow_of_two(opt_ctx & ctx,Operand op)3759 is_pow_of_two(opt_ctx& ctx, Operand op)
3760 {
3761    if (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(op.bytes() * 8))
3762       return is_pow_of_two(ctx, get_constant_op(ctx, ctx.info[op.tempId()], op.bytes() * 8));
3763    else if (!op.isConstant())
3764       return false;
3765 
3766    uint64_t val = op.constantValue64();
3767 
3768    if (op.bytes() == 4) {
3769       uint32_t exponent = (val & 0x7f800000) >> 23;
3770       uint32_t fraction = val & 0x007fffff;
3771       return (exponent >= 127) && (fraction == 0);
3772    } else if (op.bytes() == 2) {
3773       uint32_t exponent = (val & 0x7c00) >> 10;
3774       uint32_t fraction = val & 0x03ff;
3775       return (exponent >= 15) && (fraction == 0);
3776    } else {
3777       assert(op.bytes() == 8);
3778       uint64_t exponent = (val & UINT64_C(0x7ff0000000000000)) >> 52;
3779       uint64_t fraction = val & UINT64_C(0x000fffffffffffff);
3780       return (exponent >= 1023) && (fraction == 0);
3781    }
3782 }
3783 
3784 void
combine_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)3785 combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3786 {
3787    if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
3788       return;
3789 
3790    if (instr->isVALU()) {
3791       /* Apply SDWA. Do this after label_instruction() so it can remove
3792        * label_extract if not all instructions can take SDWA. */
3793       for (unsigned i = 0; i < instr->operands.size(); i++) {
3794          Operand& op = instr->operands[i];
3795          if (!op.isTemp())
3796             continue;
3797          ssa_info& info = ctx.info[op.tempId()];
3798          if (!info.is_extract())
3799             continue;
3800          /* if there are that many uses, there are likely better combinations */
3801          // TODO: delay applying extract to a point where we know better
3802          if (ctx.uses[op.tempId()] > 4) {
3803             info.label &= ~label_extract;
3804             continue;
3805          }
3806          if (info.is_extract() &&
3807              (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
3808               instr->operands[i].getTemp().type() == RegType::sgpr) &&
3809              can_apply_extract(ctx, instr, i, info)) {
3810             /* Increase use count of the extract's operand if the extract still has uses. */
3811             apply_extract(ctx, instr, i, info);
3812             if (--ctx.uses[instr->operands[i].tempId()])
3813                ctx.uses[info.instr->operands[0].tempId()]++;
3814             instr->operands[i].setTemp(info.instr->operands[0].getTemp());
3815          }
3816       }
3817 
3818       if (can_apply_sgprs(ctx, instr))
3819          apply_sgprs(ctx, instr);
3820       combine_mad_mix(ctx, instr);
3821       while (apply_omod_clamp(ctx, instr) | combine_output_conversion(ctx, instr))
3822          ;
3823       apply_insert(ctx, instr);
3824    }
3825 
3826    if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
3827        instr->opcode != aco_opcode::v_fma_mixlo_f16)
3828       return combine_vop3p(ctx, instr);
3829 
3830    if (instr->isSDWA() || instr->isDPP())
3831       return;
3832 
3833    if (instr->opcode == aco_opcode::p_extract) {
3834       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3835       if (info.is_extract() && can_apply_extract(ctx, instr, 0, info)) {
3836          apply_extract(ctx, instr, 0, info);
3837          if (--ctx.uses[instr->operands[0].tempId()])
3838             ctx.uses[info.instr->operands[0].tempId()]++;
3839          instr->operands[0].setTemp(info.instr->operands[0].getTemp());
3840       }
3841 
3842       apply_ds_extract(ctx, instr);
3843    }
3844 
3845    /* TODO: There are still some peephole optimizations that could be done:
3846     * - abs(a - b) -> s_absdiff_i32
3847     * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
3848     * - patterns for v_alignbit_b32 and v_alignbyte_b32
3849     * These aren't probably too interesting though.
3850     * There are also patterns for v_cmp_class_f{16,32,64}. This is difficult but
3851     * probably more useful than the previously mentioned optimizations.
3852     * The various comparison optimizations also currently only work with 32-bit
3853     * floats. */
3854 
3855    /* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
3856    if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) &&
3857        ctx.uses[instr->operands[1].tempId()] == 1) {
3858       Temp val = ctx.info[instr->definitions[0].tempId()].temp;
3859 
3860       if (!ctx.info[val.id()].is_mul())
3861          return;
3862 
3863       Instruction* mul_instr = ctx.info[val.id()].instr;
3864 
3865       if (mul_instr->operands[0].isLiteral())
3866          return;
3867       if (mul_instr->isVOP3() && mul_instr->vop3().clamp)
3868          return;
3869       if (mul_instr->isSDWA() || mul_instr->isDPP() || mul_instr->isVOP3P())
3870          return;
3871       if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
3872           ctx.fp_mode.preserve_signed_zero_inf_nan32)
3873          return;
3874       if (mul_instr->definitions[0].bytes() != instr->definitions[0].bytes())
3875          return;
3876 
3877       /* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
3878       ctx.uses[mul_instr->definitions[0].tempId()]--;
3879       Definition def = instr->definitions[0];
3880       bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg();
3881       bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
3882       instr.reset(
3883          create_instruction<VOP3_instruction>(mul_instr->opcode, asVOP3(Format::VOP2), 2, 1));
3884       instr->operands[0] = mul_instr->operands[0];
3885       instr->operands[1] = mul_instr->operands[1];
3886       instr->definitions[0] = def;
3887       VOP3_instruction& new_mul = instr->vop3();
3888       if (mul_instr->isVOP3()) {
3889          VOP3_instruction& mul = mul_instr->vop3();
3890          new_mul.neg[0] = mul.neg[0];
3891          new_mul.neg[1] = mul.neg[1];
3892          new_mul.abs[0] = mul.abs[0];
3893          new_mul.abs[1] = mul.abs[1];
3894          new_mul.omod = mul.omod;
3895       }
3896       if (is_abs) {
3897          new_mul.neg[0] = new_mul.neg[1] = false;
3898          new_mul.abs[0] = new_mul.abs[1] = true;
3899       }
3900       new_mul.neg[0] ^= is_neg;
3901       new_mul.clamp = false;
3902 
3903       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
3904       return;
3905    }
3906 
3907    /* combine mul+add -> mad */
3908    bool is_add_mix =
3909       (instr->opcode == aco_opcode::v_fma_mix_f32 ||
3910        instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
3911       !instr->vop3p().neg_lo[0] &&
3912       ((instr->operands[0].constantEquals(0x3f800000) && (instr->vop3p().opsel_hi & 0x1) == 0) ||
3913        (instr->operands[0].constantEquals(0x3C00) && (instr->vop3p().opsel_hi & 0x1) &&
3914         !(instr->vop3p().opsel_lo & 0x1)));
3915    bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
3916                 instr->opcode == aco_opcode::v_subrev_f32;
3917    bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
3918                 instr->opcode == aco_opcode::v_subrev_f16;
3919    bool mad64 = instr->opcode == aco_opcode::v_add_f64;
3920    if (is_add_mix || mad16 || mad32 || mad64) {
3921       Instruction* mul_instr = nullptr;
3922       unsigned add_op_idx = 0;
3923       uint32_t uses = UINT32_MAX;
3924       bool emit_fma = false;
3925       /* find the 'best' mul instruction to combine with the add */
3926       for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
3927          if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
3928             continue;
3929          ssa_info& info = ctx.info[instr->operands[i].tempId()];
3930 
3931          /* no clamp/omod allowed between mul and add */
3932          if (info.instr->isVOP3() && (info.instr->vop3().clamp || info.instr->vop3().omod))
3933             continue;
3934          if (info.instr->isVOP3P() && info.instr->vop3p().clamp)
3935             continue;
3936          /* v_fma_mix_f32/etc can't do omod */
3937          if (info.instr->isVOP3P() && instr->isVOP3() && instr->vop3().omod)
3938             continue;
3939          /* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
3940          if (is_add_mix && info.instr->definitions[0].bytes() == 2)
3941             continue;
3942 
3943          if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
3944             continue;
3945 
3946          bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
3947          bool mad_mix = is_add_mix || info.instr->isVOP3P();
3948 
3949          /* Multiplication by power-of-two should never need rounding. 1/power-of-two also works,
3950           * but using fma removes denormal flushing (0xfffffe * 0.5 + 0x810001a2).
3951           */
3952          bool is_fma_precise = is_pow_of_two(ctx, info.instr->operands[0]) ||
3953                                is_pow_of_two(ctx, info.instr->operands[1]);
3954 
3955          bool has_fma = mad16 || mad64 || (legacy && ctx.program->gfx_level >= GFX10_3) ||
3956                         (mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
3957                         (mad_mix && ctx.program->dev.fused_mad_mix);
3958          bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
3959                                 : ((mad32 && ctx.program->gfx_level < GFX10_3) ||
3960                                    (mad16 && ctx.program->gfx_level <= GFX9));
3961          bool can_use_fma =
3962             has_fma &&
3963             (!(info.instr->definitions[0].isPrecise() || instr->definitions[0].isPrecise()) ||
3964              is_fma_precise);
3965          bool can_use_mad =
3966             has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
3967          if (mad_mix && legacy)
3968             continue;
3969          if (!can_use_fma && !can_use_mad)
3970             continue;
3971 
3972          unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
3973          Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
3974                           instr->operands[candidate_add_op_idx]};
3975          if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
3976              ctx.uses[instr->operands[i].tempId()] > uses)
3977             continue;
3978 
3979          if (ctx.uses[instr->operands[i].tempId()] == uses) {
3980             unsigned cur_idx = mul_instr->definitions[0].tempId();
3981             unsigned new_idx = info.instr->definitions[0].tempId();
3982             if (cur_idx > new_idx)
3983                continue;
3984          }
3985 
3986          mul_instr = info.instr;
3987          add_op_idx = candidate_add_op_idx;
3988          uses = ctx.uses[instr->operands[i].tempId()];
3989          emit_fma = !can_use_mad;
3990       }
3991 
3992       if (mul_instr) {
3993          /* turn mul+add into v_mad/v_fma */
3994          Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1],
3995                           instr->operands[add_op_idx]};
3996          ctx.uses[mul_instr->definitions[0].tempId()]--;
3997          if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3998             if (op[0].isTemp())
3999                ctx.uses[op[0].tempId()]++;
4000             if (op[1].isTemp())
4001                ctx.uses[op[1].tempId()]++;
4002          }
4003 
4004          bool neg[3] = {false, false, false};
4005          bool abs[3] = {false, false, false};
4006          unsigned omod = 0;
4007          bool clamp = false;
4008          uint8_t opsel_lo = 0;
4009          uint8_t opsel_hi = 0;
4010 
4011          if (mul_instr->isVOP3()) {
4012             VOP3_instruction& vop3 = mul_instr->vop3();
4013             neg[0] = vop3.neg[0];
4014             neg[1] = vop3.neg[1];
4015             abs[0] = vop3.abs[0];
4016             abs[1] = vop3.abs[1];
4017          } else if (mul_instr->isVOP3P()) {
4018             VOP3P_instruction& vop3p = mul_instr->vop3p();
4019             neg[0] = vop3p.neg_lo[0];
4020             neg[1] = vop3p.neg_lo[1];
4021             abs[0] = vop3p.neg_hi[0];
4022             abs[1] = vop3p.neg_hi[1];
4023             opsel_lo = vop3p.opsel_lo & 0x3;
4024             opsel_hi = vop3p.opsel_hi & 0x3;
4025          }
4026 
4027          if (instr->isVOP3()) {
4028             VOP3_instruction& vop3 = instr->vop3();
4029             neg[2] = vop3.neg[add_op_idx];
4030             abs[2] = vop3.abs[add_op_idx];
4031             omod = vop3.omod;
4032             clamp = vop3.clamp;
4033             /* abs of the multiplication result */
4034             if (vop3.abs[1 - add_op_idx]) {
4035                neg[0] = false;
4036                neg[1] = false;
4037                abs[0] = true;
4038                abs[1] = true;
4039             }
4040             /* neg of the multiplication result */
4041             neg[1] = neg[1] ^ vop3.neg[1 - add_op_idx];
4042          } else if (instr->isVOP3P()) {
4043             VOP3P_instruction& vop3p = instr->vop3p();
4044             neg[2] = vop3p.neg_lo[add_op_idx];
4045             abs[2] = vop3p.neg_hi[add_op_idx];
4046             opsel_lo |= vop3p.opsel_lo & (1 << add_op_idx) ? 0x4 : 0x0;
4047             opsel_hi |= vop3p.opsel_hi & (1 << add_op_idx) ? 0x4 : 0x0;
4048             clamp = vop3p.clamp;
4049             /* abs of the multiplication result */
4050             if (vop3p.neg_hi[3 - add_op_idx]) {
4051                neg[0] = false;
4052                neg[1] = false;
4053                abs[0] = true;
4054                abs[1] = true;
4055             }
4056             /* neg of the multiplication result */
4057             neg[1] = neg[1] ^ vop3p.neg_lo[3 - add_op_idx];
4058          }
4059 
4060          if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
4061             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
4062          else if (instr->opcode == aco_opcode::v_subrev_f32 ||
4063                   instr->opcode == aco_opcode::v_subrev_f16)
4064             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
4065 
4066          aco_ptr<Instruction> add_instr = std::move(instr);
4067          if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
4068             assert(!omod);
4069 
4070             aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
4071                                                                        : aco_opcode::v_fma_mix_f32;
4072             aco_ptr<VOP3P_instruction> mad{
4073                create_instruction<VOP3P_instruction>(mad_op, Format::VOP3P, 3, 1)};
4074             for (unsigned i = 0; i < 3; i++) {
4075                mad->operands[i] = op[i];
4076                mad->neg_lo[i] = neg[i];
4077                mad->neg_hi[i] = abs[i];
4078             }
4079             mad->clamp = clamp;
4080             mad->opsel_lo = opsel_lo;
4081             mad->opsel_hi = opsel_hi;
4082 
4083             instr = std::move(mad);
4084          } else {
4085             aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
4086             if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
4087                assert(emit_fma == (ctx.program->gfx_level >= GFX10_3));
4088                mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
4089             } else if (mad16) {
4090                mad_op = emit_fma ? (ctx.program->gfx_level == GFX8 ? aco_opcode::v_fma_legacy_f16
4091                                                                    : aco_opcode::v_fma_f16)
4092                                  : (ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_f16
4093                                                                    : aco_opcode::v_mad_f16);
4094             } else if (mad64) {
4095                mad_op = aco_opcode::v_fma_f64;
4096             }
4097 
4098             aco_ptr<VOP3_instruction> mad{
4099                create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
4100             for (unsigned i = 0; i < 3; i++) {
4101                mad->operands[i] = op[i];
4102                mad->neg[i] = neg[i];
4103                mad->abs[i] = abs[i];
4104             }
4105             mad->omod = omod;
4106             mad->clamp = clamp;
4107 
4108             instr = std::move(mad);
4109          }
4110          instr->definitions[0] = add_instr->definitions[0];
4111 
4112          /* mark this ssa_def to be re-checked for profitability and literals */
4113          ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
4114          ctx.info[instr->definitions[0].tempId()].set_mad(instr.get(), ctx.mad_infos.size() - 1);
4115          return;
4116       }
4117    }
4118    /* v_mul_f32(v_cndmask_b32(0, 1.0, cond), a) -> v_cndmask_b32(0, a, cond) */
4119    else if (((instr->opcode == aco_opcode::v_mul_f32 &&
4120               !ctx.fp_mode.preserve_signed_zero_inf_nan32) ||
4121              instr->opcode == aco_opcode::v_mul_legacy_f32) &&
4122             !instr->usesModifiers() && !ctx.fp_mode.must_flush_denorms32) {
4123       for (unsigned i = 0; i < 2; i++) {
4124          if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2f() &&
4125              ctx.uses[instr->operands[i].tempId()] == 1 && instr->operands[!i].isTemp() &&
4126              instr->operands[!i].getTemp().type() == RegType::vgpr) {
4127             ctx.uses[instr->operands[i].tempId()]--;
4128             ctx.uses[ctx.info[instr->operands[i].tempId()].temp.id()]++;
4129 
4130             aco_ptr<VOP2_instruction> new_instr{
4131                create_instruction<VOP2_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1)};
4132             new_instr->operands[0] = Operand::zero();
4133             new_instr->operands[1] = instr->operands[!i];
4134             new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
4135             new_instr->definitions[0] = instr->definitions[0];
4136             instr = std::move(new_instr);
4137             ctx.info[instr->definitions[0].tempId()].label = 0;
4138             return;
4139          }
4140       }
4141    } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->gfx_level >= GFX9) {
4142       if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
4143                                 1 | 2)) {
4144       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
4145                                        "012", 1 | 2)) {
4146       } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4147       }
4148    } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->gfx_level >= GFX10) {
4149       if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
4150                                 1 | 2)) {
4151       } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32,
4152                                        "012", 1 | 2)) {
4153       }
4154    } else if (instr->opcode == aco_opcode::v_add_u16) {
4155       combine_three_valu_op(
4156          ctx, instr, aco_opcode::v_mul_lo_u16,
4157          ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_u16 : aco_opcode::v_mad_u16,
4158          "120", 1 | 2);
4159    } else if (instr->opcode == aco_opcode::v_add_u16_e64) {
4160       combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16_e64, aco_opcode::v_mad_u16, "120",
4161                             1 | 2);
4162    } else if (instr->opcode == aco_opcode::v_add_u32) {
4163       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4164       } else if (combine_add_bcnt(ctx, instr)) {
4165       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4166                                        aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4167       } else if (ctx.program->gfx_level >= GFX9 && !instr->usesModifiers()) {
4168          if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120",
4169                                    1 | 2)) {
4170          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32,
4171                                           "120", 1 | 2)) {
4172          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32,
4173                                           "012", 1 | 2)) {
4174          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32,
4175                                           "012", 1 | 2)) {
4176          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32,
4177                                           "012", 1 | 2)) {
4178          } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4179          }
4180       }
4181    } else if (instr->opcode == aco_opcode::v_add_co_u32 ||
4182               instr->opcode == aco_opcode::v_add_co_u32_e64) {
4183       bool carry_out = ctx.uses[instr->definitions[1].tempId()] > 0;
4184       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4185       } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
4186       } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4187                                                      aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4188       } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
4189       }
4190    } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == aco_opcode::v_sub_co_u32 ||
4191               instr->opcode == aco_opcode::v_sub_co_u32_e64) {
4192       bool carry_out =
4193          instr->opcode != aco_opcode::v_sub_u32 && ctx.uses[instr->definitions[1].tempId()] > 0;
4194       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
4195       } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
4196       }
4197    } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
4198               instr->opcode == aco_opcode::v_subrev_co_u32 ||
4199               instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
4200       combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
4201    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->gfx_level >= GFX9) {
4202       combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120",
4203                             2);
4204    } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) &&
4205               ctx.program->gfx_level >= GFX9) {
4206       combine_salu_lshl_add(ctx, instr);
4207    } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) {
4208       combine_salu_not_bitwise(ctx, instr);
4209    } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
4210               instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
4211       if (combine_ordering_test(ctx, instr)) {
4212       } else if (combine_comparison_ordering(ctx, instr)) {
4213       } else if (combine_constant_comparison_ordering(ctx, instr)) {
4214       } else if (combine_salu_n2(ctx, instr)) {
4215       }
4216    } else if (instr->opcode == aco_opcode::v_and_b32) {
4217       combine_and_subbrev(ctx, instr);
4218    } else if (instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) {
4219       /* set existing v_fma_f32 with label_mad so we can create v_fmamk_f32/v_fmaak_f32.
4220        * since ctx.uses[mad_info::mul_temp_id] is always 0, we don't have to worry about
4221        * select_instruction() using mad_info::add_instr.
4222        */
4223       ctx.mad_infos.emplace_back(nullptr, 0);
4224       ctx.info[instr->definitions[0].tempId()].set_mad(instr.get(), ctx.mad_infos.size() - 1);
4225    } else {
4226       aco_opcode min, max, min3, max3, med3;
4227       bool some_gfx9_only;
4228       if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &some_gfx9_only) &&
4229           (!some_gfx9_only || ctx.program->gfx_level >= GFX9)) {
4230          if (combine_minmax(ctx, instr, instr->opcode == min ? max : min,
4231                             instr->opcode == min ? min3 : max3)) {
4232          } else {
4233             combine_clamp(ctx, instr, min, max, med3);
4234          }
4235       }
4236    }
4237 
4238    /* do this after combine_salu_n2() */
4239    if (instr->opcode == aco_opcode::s_andn2_b32 || instr->opcode == aco_opcode::s_andn2_b64)
4240       combine_inverse_comparison(ctx, instr);
4241 }
4242 
4243 bool
to_uniform_bool_instr(opt_ctx & ctx,aco_ptr<Instruction> & instr)4244 to_uniform_bool_instr(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4245 {
4246    /* Check every operand to make sure they are suitable. */
4247    for (Operand& op : instr->operands) {
4248       if (!op.isTemp())
4249          return false;
4250       if (!ctx.info[op.tempId()].is_uniform_bool() && !ctx.info[op.tempId()].is_uniform_bitwise())
4251          return false;
4252    }
4253 
4254    switch (instr->opcode) {
4255    case aco_opcode::s_and_b32:
4256    case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_and_b32; break;
4257    case aco_opcode::s_or_b32:
4258    case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_or_b32; break;
4259    case aco_opcode::s_xor_b32:
4260    case aco_opcode::s_xor_b64: instr->opcode = aco_opcode::s_absdiff_i32; break;
4261    default:
4262       /* Don't transform other instructions. They are very unlikely to appear here. */
4263       return false;
4264    }
4265 
4266    for (Operand& op : instr->operands) {
4267       ctx.uses[op.tempId()]--;
4268 
4269       if (ctx.info[op.tempId()].is_uniform_bool()) {
4270          /* Just use the uniform boolean temp. */
4271          op.setTemp(ctx.info[op.tempId()].temp);
4272       } else if (ctx.info[op.tempId()].is_uniform_bitwise()) {
4273          /* Use the SCC definition of the predecessor instruction.
4274           * This allows the predecessor to get picked up by the same optimization (if it has no
4275           * divergent users), and it also makes sure that the current instruction will keep working
4276           * even if the predecessor won't be transformed.
4277           */
4278          Instruction* pred_instr = ctx.info[op.tempId()].instr;
4279          assert(pred_instr->definitions.size() >= 2);
4280          assert(pred_instr->definitions[1].isFixed() &&
4281                 pred_instr->definitions[1].physReg() == scc);
4282          op.setTemp(pred_instr->definitions[1].getTemp());
4283       } else {
4284          unreachable("Invalid operand on uniform bitwise instruction.");
4285       }
4286 
4287       ctx.uses[op.tempId()]++;
4288    }
4289 
4290    instr->definitions[0].setTemp(Temp(instr->definitions[0].tempId(), s1));
4291    assert(instr->operands[0].regClass() == s1);
4292    assert(instr->operands[1].regClass() == s1);
4293    return true;
4294 }
4295 
4296 void
select_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4297 select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4298 {
4299    const uint32_t threshold = 4;
4300 
4301    if (is_dead(ctx.uses, instr.get())) {
4302       instr.reset();
4303       return;
4304    }
4305 
4306    /* convert split_vector into a copy or extract_vector if only one definition is ever used */
4307    if (instr->opcode == aco_opcode::p_split_vector) {
4308       unsigned num_used = 0;
4309       unsigned idx = 0;
4310       unsigned split_offset = 0;
4311       for (unsigned i = 0, offset = 0; i < instr->definitions.size();
4312            offset += instr->definitions[i++].bytes()) {
4313          if (ctx.uses[instr->definitions[i].tempId()]) {
4314             num_used++;
4315             idx = i;
4316             split_offset = offset;
4317          }
4318       }
4319       bool done = false;
4320       if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
4321           ctx.uses[instr->operands[0].tempId()] == 1) {
4322          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
4323 
4324          unsigned off = 0;
4325          Operand op;
4326          for (Operand& vec_op : vec->operands) {
4327             if (off == split_offset) {
4328                op = vec_op;
4329                break;
4330             }
4331             off += vec_op.bytes();
4332          }
4333          if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
4334             ctx.uses[instr->operands[0].tempId()]--;
4335             for (Operand& vec_op : vec->operands) {
4336                if (vec_op.isTemp())
4337                   ctx.uses[vec_op.tempId()]--;
4338             }
4339             if (op.isTemp())
4340                ctx.uses[op.tempId()]++;
4341 
4342             aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4343                aco_opcode::p_create_vector, Format::PSEUDO, 1, 1)};
4344             extract->operands[0] = op;
4345             extract->definitions[0] = instr->definitions[idx];
4346             instr = std::move(extract);
4347 
4348             done = true;
4349          }
4350       }
4351 
4352       if (!done && num_used == 1 &&
4353           instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
4354           split_offset % instr->definitions[idx].bytes() == 0) {
4355          aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4356             aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
4357          extract->operands[0] = instr->operands[0];
4358          extract->operands[1] =
4359             Operand::c32((uint32_t)split_offset / instr->definitions[idx].bytes());
4360          extract->definitions[0] = instr->definitions[idx];
4361          instr = std::move(extract);
4362       }
4363    }
4364 
4365    mad_info* mad_info = NULL;
4366    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4367       mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags];
4368       /* re-check mad instructions */
4369       if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) {
4370          ctx.uses[mad_info->mul_temp_id]++;
4371          if (instr->operands[0].isTemp())
4372             ctx.uses[instr->operands[0].tempId()]--;
4373          if (instr->operands[1].isTemp())
4374             ctx.uses[instr->operands[1].tempId()]--;
4375          instr.swap(mad_info->add_instr);
4376          mad_info = NULL;
4377       }
4378       /* check literals */
4379       else if (!instr->usesModifiers() && !instr->isVOP3P() &&
4380                instr->opcode != aco_opcode::v_fma_f64 &&
4381                instr->opcode != aco_opcode::v_mad_legacy_f32 &&
4382                instr->opcode != aco_opcode::v_fma_legacy_f32) {
4383          /* FMA can only take literals on GFX10+ */
4384          if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
4385              ctx.program->gfx_level < GFX10)
4386             return;
4387          /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take
4388           * literals (GFX10+), these instructions don't exist.
4389           */
4390          if (instr->opcode == aco_opcode::v_fma_legacy_f16)
4391             return;
4392 
4393          uint32_t literal_idx = 0;
4394          uint32_t literal_uses = UINT32_MAX;
4395 
4396          /* Try using v_madak/v_fmaak */
4397          if (instr->operands[2].isTemp() &&
4398              ctx.info[instr->operands[2].tempId()].is_literal(get_operand_size(instr, 2))) {
4399             bool has_sgpr = false;
4400             bool has_vgpr = false;
4401             for (unsigned i = 0; i < 2; i++) {
4402                if (!instr->operands[i].isTemp())
4403                   continue;
4404                has_sgpr |= instr->operands[i].getTemp().type() == RegType::sgpr;
4405                has_vgpr |= instr->operands[i].getTemp().type() == RegType::vgpr;
4406             }
4407             /* Encoding limitations requires a VGPR operand. The constant bus limitations before
4408              * GFX10 disallows SGPRs.
4409              */
4410             if ((!has_sgpr || ctx.program->gfx_level >= GFX10) && has_vgpr) {
4411                literal_idx = 2;
4412                literal_uses = ctx.uses[instr->operands[2].tempId()];
4413             }
4414          }
4415 
4416          /* Try using v_madmk/v_fmamk */
4417          /* Encoding limitations requires a VGPR operand. */
4418          if (instr->operands[2].isTemp() && instr->operands[2].getTemp().type() == RegType::vgpr) {
4419             for (unsigned i = 0; i < 2; i++) {
4420                if (!instr->operands[i].isTemp())
4421                   continue;
4422 
4423                /* The constant bus limitations before GFX10 disallows SGPRs. */
4424                if (ctx.program->gfx_level < GFX10 && instr->operands[!i].isTemp() &&
4425                    instr->operands[!i].getTemp().type() == RegType::sgpr)
4426                   continue;
4427 
4428                if (ctx.info[instr->operands[i].tempId()].is_literal(get_operand_size(instr, i)) &&
4429                    ctx.uses[instr->operands[i].tempId()] < literal_uses) {
4430                   literal_idx = i;
4431                   literal_uses = ctx.uses[instr->operands[i].tempId()];
4432                }
4433             }
4434          }
4435 
4436          /* Limit the number of literals to apply to not increase the code
4437           * size too much, but always apply literals for v_mad->v_madak
4438           * because both instructions are 64-bit and this doesn't increase
4439           * code size.
4440           * TODO: try to apply the literals earlier to lower the number of
4441           * uses below threshold
4442           */
4443          if (literal_uses < threshold || literal_idx == 2) {
4444             ctx.uses[instr->operands[literal_idx].tempId()]--;
4445             mad_info->check_literal = true;
4446             mad_info->literal_idx = literal_idx;
4447             return;
4448          }
4449       }
4450    }
4451 
4452    /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions
4453     * when it isn't beneficial */
4454    if (instr->isBranch() && instr->operands.size() && instr->operands[0].isTemp() &&
4455        instr->operands[0].isFixed() && instr->operands[0].physReg() == scc) {
4456       ctx.info[instr->operands[0].tempId()].set_scc_needed();
4457       return;
4458    } else if ((instr->opcode == aco_opcode::s_cselect_b64 ||
4459                instr->opcode == aco_opcode::s_cselect_b32) &&
4460               instr->operands[2].isTemp()) {
4461       ctx.info[instr->operands[2].tempId()].set_scc_needed();
4462    } else if (instr->opcode == aco_opcode::p_wqm && instr->operands[0].isTemp() &&
4463               ctx.info[instr->definitions[0].tempId()].is_scc_needed()) {
4464       /* Propagate label so it is correctly detected by the uniform bool transform */
4465       ctx.info[instr->operands[0].tempId()].set_scc_needed();
4466 
4467       /* Fix definition to SCC, this will prevent RA from adding superfluous moves */
4468       instr->definitions[0].setFixed(scc);
4469    }
4470 
4471    /* check for literals */
4472    if (!instr->isSALU() && !instr->isVALU())
4473       return;
4474 
4475    /* Transform uniform bitwise boolean operations to 32-bit when there are no divergent uses. */
4476    if (instr->definitions.size() && ctx.uses[instr->definitions[0].tempId()] == 0 &&
4477        ctx.info[instr->definitions[0].tempId()].is_uniform_bitwise()) {
4478       bool transform_done = to_uniform_bool_instr(ctx, instr);
4479 
4480       if (transform_done && !ctx.info[instr->definitions[1].tempId()].is_scc_needed()) {
4481          /* Swap the two definition IDs in order to avoid overusing the SCC.
4482           * This reduces extra moves generated by RA. */
4483          uint32_t def0_id = instr->definitions[0].getTemp().id();
4484          uint32_t def1_id = instr->definitions[1].getTemp().id();
4485          instr->definitions[0].setTemp(Temp(def1_id, s1));
4486          instr->definitions[1].setTemp(Temp(def0_id, s1));
4487       }
4488 
4489       return;
4490    }
4491 
4492    /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
4493    if (instr->isVALU()) {
4494       for (unsigned i = 0; i < instr->operands.size(); i++) {
4495          if (!instr->operands[i].isTemp())
4496             continue;
4497          ssa_info info = ctx.info[instr->operands[i].tempId()];
4498 
4499          aco_opcode swapped_op;
4500          if (info.is_dpp() && info.instr->pass_flags == instr->pass_flags &&
4501              (i == 0 || can_swap_operands(instr, &swapped_op)) &&
4502              can_use_DPP(instr, true, info.is_dpp8()) && !instr->isDPP()) {
4503             bool dpp8 = info.is_dpp8();
4504             convert_to_DPP(instr, dpp8);
4505             if (dpp8) {
4506                DPP8_instruction* dpp = &instr->dpp8();
4507                for (unsigned j = 0; j < 8; ++j)
4508                   dpp->lane_sel[j] = info.instr->dpp8().lane_sel[j];
4509                if (i) {
4510                   instr->opcode = swapped_op;
4511                   std::swap(instr->operands[0], instr->operands[1]);
4512                }
4513             } else {
4514                DPP16_instruction* dpp = &instr->dpp16();
4515                if (i) {
4516                   instr->opcode = swapped_op;
4517                   std::swap(instr->operands[0], instr->operands[1]);
4518                   std::swap(dpp->neg[0], dpp->neg[1]);
4519                   std::swap(dpp->abs[0], dpp->abs[1]);
4520                }
4521                dpp->dpp_ctrl = info.instr->dpp16().dpp_ctrl;
4522                dpp->bound_ctrl = info.instr->dpp16().bound_ctrl;
4523                dpp->neg[0] ^= info.instr->dpp16().neg[0] && !dpp->abs[0];
4524                dpp->abs[0] |= info.instr->dpp16().abs[0];
4525             }
4526             if (--ctx.uses[info.instr->definitions[0].tempId()])
4527                ctx.uses[info.instr->operands[0].tempId()]++;
4528             instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4529             break;
4530          }
4531       }
4532    }
4533 
4534    if (instr->isSDWA() || (instr->isVOP3() && ctx.program->gfx_level < GFX10) ||
4535        (instr->isVOP3P() && ctx.program->gfx_level < GFX10))
4536       return; /* some encodings can't ever take literals */
4537 
4538    /* we do not apply the literals yet as we don't know if it is profitable */
4539    Operand current_literal(s1);
4540 
4541    unsigned literal_id = 0;
4542    unsigned literal_uses = UINT32_MAX;
4543    Operand literal(s1);
4544    unsigned num_operands = 1;
4545    if (instr->isSALU() ||
4546        (ctx.program->gfx_level >= GFX10 && (can_use_VOP3(ctx, instr) || instr->isVOP3P())))
4547       num_operands = instr->operands.size();
4548    /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
4549    else if (instr->isVALU() && instr->operands.size() >= 3)
4550       return;
4551 
4552    unsigned sgpr_ids[2] = {0, 0};
4553    bool is_literal_sgpr = false;
4554    uint32_t mask = 0;
4555 
4556    /* choose a literal to apply */
4557    for (unsigned i = 0; i < num_operands; i++) {
4558       Operand op = instr->operands[i];
4559       unsigned bits = get_operand_size(instr, i);
4560 
4561       if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
4562           op.tempId() != sgpr_ids[0])
4563          sgpr_ids[!!sgpr_ids[0]] = op.tempId();
4564 
4565       if (op.isLiteral()) {
4566          current_literal = op;
4567          continue;
4568       } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
4569          continue;
4570       }
4571 
4572       if (!alu_can_accept_constant(instr->opcode, i))
4573          continue;
4574 
4575       if (ctx.uses[op.tempId()] < literal_uses) {
4576          is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
4577          mask = 0;
4578          literal = Operand::c32(ctx.info[op.tempId()].val);
4579          literal_uses = ctx.uses[op.tempId()];
4580          literal_id = op.tempId();
4581       }
4582 
4583       mask |= (op.tempId() == literal_id) << i;
4584    }
4585 
4586    /* don't go over the constant bus limit */
4587    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
4588                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
4589                      instr->opcode == aco_opcode::v_ashrrev_i64;
4590    unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
4591    if (ctx.program->gfx_level >= GFX10 && !is_shift64)
4592       const_bus_limit = 2;
4593 
4594    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
4595    if (num_sgprs == const_bus_limit && !is_literal_sgpr)
4596       return;
4597 
4598    if (literal_id && literal_uses < threshold &&
4599        (current_literal.isUndefined() ||
4600         (current_literal.size() == literal.size() &&
4601          current_literal.constantValue() == literal.constantValue()))) {
4602       /* mark the literal to be applied */
4603       while (mask) {
4604          unsigned i = u_bit_scan(&mask);
4605          if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
4606             ctx.uses[instr->operands[i].tempId()]--;
4607       }
4608    }
4609 }
4610 
4611 static aco_opcode
sopk_opcode_for_sopc(aco_opcode opcode)4612 sopk_opcode_for_sopc(aco_opcode opcode)
4613 {
4614 #define CTOK(op)                                                                                   \
4615    case aco_opcode::s_cmp_##op##_i32: return aco_opcode::s_cmpk_##op##_i32;                        \
4616    case aco_opcode::s_cmp_##op##_u32: return aco_opcode::s_cmpk_##op##_u32;
4617    switch (opcode) {
4618       CTOK(eq)
4619       CTOK(lg)
4620       CTOK(gt)
4621       CTOK(ge)
4622       CTOK(lt)
4623       CTOK(le)
4624    default: return aco_opcode::num_opcodes;
4625    }
4626 #undef CTOK
4627 }
4628 
4629 static bool
sopc_is_signed(aco_opcode opcode)4630 sopc_is_signed(aco_opcode opcode)
4631 {
4632 #define SOPC(op)                                                                                   \
4633    case aco_opcode::s_cmp_##op##_i32: return true;                                                 \
4634    case aco_opcode::s_cmp_##op##_u32: return false;
4635    switch (opcode) {
4636       SOPC(eq)
4637       SOPC(lg)
4638       SOPC(gt)
4639       SOPC(ge)
4640       SOPC(lt)
4641       SOPC(le)
4642    default: unreachable("Not a valid SOPC instruction.");
4643    }
4644 #undef SOPC
4645 }
4646 
4647 static aco_opcode
sopc_32_swapped(aco_opcode opcode)4648 sopc_32_swapped(aco_opcode opcode)
4649 {
4650 #define SOPC(op1, op2)                                                                             \
4651    case aco_opcode::s_cmp_##op1##_i32: return aco_opcode::s_cmp_##op2##_i32;                       \
4652    case aco_opcode::s_cmp_##op1##_u32: return aco_opcode::s_cmp_##op2##_u32;
4653    switch (opcode) {
4654       SOPC(eq, eq)
4655       SOPC(lg, lg)
4656       SOPC(gt, lt)
4657       SOPC(ge, le)
4658       SOPC(lt, gt)
4659       SOPC(le, ge)
4660    default: return aco_opcode::num_opcodes;
4661    }
4662 #undef SOPC
4663 }
4664 
4665 static void
try_convert_sopc_to_sopk(aco_ptr<Instruction> & instr)4666 try_convert_sopc_to_sopk(aco_ptr<Instruction>& instr)
4667 {
4668    if (sopk_opcode_for_sopc(instr->opcode) == aco_opcode::num_opcodes)
4669       return;
4670 
4671    if (instr->operands[0].isLiteral()) {
4672       std::swap(instr->operands[0], instr->operands[1]);
4673       instr->opcode = sopc_32_swapped(instr->opcode);
4674    }
4675 
4676    if (!instr->operands[1].isLiteral())
4677       return;
4678 
4679    if (instr->operands[0].isFixed() && instr->operands[0].physReg() >= 128)
4680       return;
4681 
4682    uint32_t value = instr->operands[1].constantValue();
4683 
4684    const uint32_t i16_mask = 0xffff8000u;
4685 
4686    bool value_is_i16 = (value & i16_mask) == 0 || (value & i16_mask) == i16_mask;
4687    bool value_is_u16 = !(value & 0xffff0000u);
4688 
4689    if (!value_is_i16 && !value_is_u16)
4690       return;
4691 
4692    if (!value_is_i16 && sopc_is_signed(instr->opcode)) {
4693       if (instr->opcode == aco_opcode::s_cmp_lg_i32)
4694          instr->opcode = aco_opcode::s_cmp_lg_u32;
4695       else if (instr->opcode == aco_opcode::s_cmp_eq_i32)
4696          instr->opcode = aco_opcode::s_cmp_eq_u32;
4697       else
4698          return;
4699    } else if (!value_is_u16 && !sopc_is_signed(instr->opcode)) {
4700       if (instr->opcode == aco_opcode::s_cmp_lg_u32)
4701          instr->opcode = aco_opcode::s_cmp_lg_i32;
4702       else if (instr->opcode == aco_opcode::s_cmp_eq_u32)
4703          instr->opcode = aco_opcode::s_cmp_eq_i32;
4704       else
4705          return;
4706    }
4707 
4708    static_assert(sizeof(SOPK_instruction) <= sizeof(SOPC_instruction),
4709                  "Invalid direct instruction cast.");
4710    instr->format = Format::SOPK;
4711    SOPK_instruction* instr_sopk = &instr->sopk();
4712 
4713    instr_sopk->imm = instr_sopk->operands[1].constantValue() & 0xffff;
4714    instr_sopk->opcode = sopk_opcode_for_sopc(instr_sopk->opcode);
4715    instr_sopk->operands.pop_back();
4716 }
4717 
4718 void
apply_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)4719 apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4720 {
4721    /* Cleanup Dead Instructions */
4722    if (!instr)
4723       return;
4724 
4725    /* apply literals on MAD */
4726    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4727       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags];
4728       if (info->check_literal &&
4729           (ctx.uses[instr->operands[info->literal_idx].tempId()] == 0 || info->literal_idx == 2)) {
4730          aco_ptr<Instruction> new_mad;
4731 
4732          aco_opcode new_op =
4733             info->literal_idx == 2 ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
4734          if (instr->opcode == aco_opcode::v_fma_f32)
4735             new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
4736          else if (instr->opcode == aco_opcode::v_mad_f16 ||
4737                   instr->opcode == aco_opcode::v_mad_legacy_f16)
4738             new_op = info->literal_idx == 2 ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
4739          else if (instr->opcode == aco_opcode::v_fma_f16)
4740             new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
4741 
4742          new_mad.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 1));
4743          if (info->literal_idx == 2) { /* add literal -> madak */
4744             new_mad->operands[0] = instr->operands[0];
4745             new_mad->operands[1] = instr->operands[1];
4746             if (!new_mad->operands[1].isTemp() ||
4747                 new_mad->operands[1].getTemp().type() == RegType::sgpr)
4748                std::swap(new_mad->operands[0], new_mad->operands[1]);
4749          } else { /* mul literal -> madmk */
4750             new_mad->operands[0] = instr->operands[1 - info->literal_idx];
4751             new_mad->operands[1] = instr->operands[2];
4752          }
4753          new_mad->operands[2] =
4754             Operand::c32(ctx.info[instr->operands[info->literal_idx].tempId()].val);
4755          new_mad->definitions[0] = instr->definitions[0];
4756          ctx.instructions.emplace_back(std::move(new_mad));
4757          return;
4758       }
4759    }
4760 
4761    /* apply literals on other SALU/VALU */
4762    if (instr->isSALU() || instr->isVALU()) {
4763       for (unsigned i = 0; i < instr->operands.size(); i++) {
4764          Operand op = instr->operands[i];
4765          unsigned bits = get_operand_size(instr, i);
4766          if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
4767             Operand literal = Operand::literal32(ctx.info[op.tempId()].val);
4768             instr->format = withoutDPP(instr->format);
4769             if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P)
4770                to_VOP3(ctx, instr);
4771             instr->operands[i] = literal;
4772          }
4773       }
4774    }
4775 
4776    if (instr->isSOPC())
4777       try_convert_sopc_to_sopk(instr);
4778 
4779    /* allow more s_addk_i32 optimizations if carry isn't used */
4780    if (instr->opcode == aco_opcode::s_add_u32 && ctx.uses[instr->definitions[1].tempId()] == 0 &&
4781        (instr->operands[0].isLiteral() || instr->operands[1].isLiteral()))
4782       instr->opcode = aco_opcode::s_add_i32;
4783 
4784    ctx.instructions.emplace_back(std::move(instr));
4785 }
4786 
4787 void
optimize(Program * program)4788 optimize(Program* program)
4789 {
4790    opt_ctx ctx;
4791    ctx.program = program;
4792    std::vector<ssa_info> info(program->peekAllocationId());
4793    ctx.info = info.data();
4794 
4795    /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
4796    for (Block& block : program->blocks) {
4797       ctx.fp_mode = block.fp_mode;
4798       for (aco_ptr<Instruction>& instr : block.instructions)
4799          label_instruction(ctx, instr);
4800    }
4801 
4802    ctx.uses = dead_code_analysis(program);
4803 
4804    /* 2. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
4805    for (Block& block : program->blocks) {
4806       ctx.fp_mode = block.fp_mode;
4807       for (aco_ptr<Instruction>& instr : block.instructions)
4808          combine_instruction(ctx, instr);
4809    }
4810 
4811    /* 3. Top-Down DAG pass (backward) to select instructions (includes DCE) */
4812    for (auto block_rit = program->blocks.rbegin(); block_rit != program->blocks.rend();
4813         ++block_rit) {
4814       Block* block = &(*block_rit);
4815       ctx.fp_mode = block->fp_mode;
4816       for (auto instr_rit = block->instructions.rbegin(); instr_rit != block->instructions.rend();
4817            ++instr_rit)
4818          select_instruction(ctx, *instr_rit);
4819    }
4820 
4821    /* 4. Add literals to instructions */
4822    for (Block& block : program->blocks) {
4823       ctx.instructions.clear();
4824       ctx.fp_mode = block.fp_mode;
4825       for (aco_ptr<Instruction>& instr : block.instructions)
4826          apply_literals(ctx, instr);
4827       block.instructions.swap(ctx.instructions);
4828    }
4829 }
4830 
4831 } // namespace aco
4832