1 /*
2 * Copyright © 2018 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 */
24
25 #include "aco_builder.h"
26 #include "aco_ir.h"
27
28 #include "util/half_float.h"
29 #include "util/memstream.h"
30
31 #include <algorithm>
32 #include <array>
33 #include <vector>
34
35 namespace aco {
36
37 #ifndef NDEBUG
38 void
perfwarn(Program * program,bool cond,const char * msg,Instruction * instr)39 perfwarn(Program* program, bool cond, const char* msg, Instruction* instr)
40 {
41 if (cond) {
42 char* out;
43 size_t outsize;
44 struct u_memstream mem;
45 u_memstream_open(&mem, &out, &outsize);
46 FILE* const memf = u_memstream_get(&mem);
47
48 fprintf(memf, "%s: ", msg);
49 aco_print_instr(program->gfx_level, instr, memf);
50 u_memstream_close(&mem);
51
52 aco_perfwarn(program, out);
53 free(out);
54
55 if (debug_flags & DEBUG_PERFWARN)
56 exit(1);
57 }
58 }
59 #endif
60
61 /**
62 * The optimizer works in 4 phases:
63 * (1) The first pass collects information for each ssa-def,
64 * propagates reg->reg operands of the same type, inline constants
65 * and neg/abs input modifiers.
66 * (2) The second pass combines instructions like mad, omod, clamp and
67 * propagates sgpr's on VALU instructions.
68 * This pass depends on information collected in the first pass.
69 * (3) The third pass goes backwards, and selects instructions,
70 * i.e. decides if a mad instruction is profitable and eliminates dead code.
71 * (4) The fourth pass cleans up the sequence: literals get applied and dead
72 * instructions are removed from the sequence.
73 */
74
75 struct mad_info {
76 aco_ptr<Instruction> add_instr;
77 uint32_t mul_temp_id;
78 uint16_t literal_mask;
79 uint16_t fp16_mask;
80
mad_infoaco::mad_info81 mad_info(aco_ptr<Instruction> instr, uint32_t id)
82 : add_instr(std::move(instr)), mul_temp_id(id), literal_mask(0), fp16_mask(0)
83 {}
84 };
85
86 enum Label {
87 label_vec = 1 << 0,
88 label_constant_32bit = 1 << 1,
89 /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
90 * 32-bit operations but this shouldn't cause any issues because we don't
91 * look through any conversions */
92 label_abs = 1 << 2,
93 label_neg = 1 << 3,
94 label_mul = 1 << 4,
95 label_temp = 1 << 5,
96 label_literal = 1 << 6,
97 label_mad = 1 << 7,
98 label_omod2 = 1 << 8,
99 label_omod4 = 1 << 9,
100 label_omod5 = 1 << 10,
101 label_clamp = 1 << 12,
102 label_undefined = 1 << 14,
103 label_vcc = 1 << 15,
104 label_b2f = 1 << 16,
105 label_add_sub = 1 << 17,
106 label_bitwise = 1 << 18,
107 label_minmax = 1 << 19,
108 label_vopc = 1 << 20,
109 label_uniform_bool = 1 << 21,
110 label_constant_64bit = 1 << 22,
111 label_uniform_bitwise = 1 << 23,
112 label_scc_invert = 1 << 24,
113 label_scc_needed = 1 << 26,
114 label_b2i = 1 << 27,
115 label_fcanonicalize = 1 << 28,
116 label_constant_16bit = 1 << 29,
117 label_usedef = 1 << 30, /* generic label */
118 label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */
119 label_canonicalized = 1ull << 32,
120 label_extract = 1ull << 33,
121 label_insert = 1ull << 34,
122 label_dpp16 = 1ull << 35,
123 label_dpp8 = 1ull << 36,
124 label_f2f32 = 1ull << 37,
125 label_f2f16 = 1ull << 38,
126 label_split = 1ull << 39,
127 label_subgroup_invocation = 1ull << 40,
128 };
129
130 static constexpr uint64_t instr_usedef_labels =
131 label_vec | label_mul | label_add_sub | label_vop3p | label_bitwise | label_uniform_bitwise |
132 label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 | label_dpp8 |
133 label_f2f32 | label_subgroup_invocation;
134 static constexpr uint64_t instr_mod_labels =
135 label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
136
137 static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels | label_split;
138 static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f |
139 label_uniform_bool | label_scc_invert | label_b2i |
140 label_fcanonicalize;
141 static constexpr uint32_t val_labels =
142 label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal | label_mad;
143
144 static_assert((instr_labels & temp_labels) == 0, "labels cannot intersect");
145 static_assert((instr_labels & val_labels) == 0, "labels cannot intersect");
146 static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
147
148 struct ssa_info {
149 uint64_t label;
150 union {
151 uint32_t val;
152 Temp temp;
153 Instruction* instr;
154 };
155
ssa_infoaco::ssa_info156 ssa_info() : label(0) {}
157
add_labelaco::ssa_info158 void add_label(Label new_label)
159 {
160 /* Since all the instr_usedef_labels use instr for the same thing
161 * (indicating the defining instruction), there is usually no need to
162 * clear any other instr labels. */
163 if (new_label & instr_usedef_labels)
164 label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */
165
166 if (new_label & instr_mod_labels) {
167 label &= ~instr_labels;
168 label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
169 }
170
171 if (new_label & temp_labels) {
172 label &= ~temp_labels;
173 label &= ~(instr_labels | val_labels); /* instr, temp and val alias */
174 }
175
176 uint32_t const_labels =
177 label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
178 if (new_label & const_labels) {
179 label &= ~val_labels | const_labels;
180 label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
181 } else if (new_label & val_labels) {
182 label &= ~val_labels;
183 label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
184 }
185
186 label |= new_label;
187 }
188
set_vecaco::ssa_info189 void set_vec(Instruction* vec)
190 {
191 add_label(label_vec);
192 instr = vec;
193 }
194
is_vecaco::ssa_info195 bool is_vec() { return label & label_vec; }
196
set_constantaco::ssa_info197 void set_constant(amd_gfx_level gfx_level, uint64_t constant)
198 {
199 Operand op16 = Operand::c16(constant);
200 Operand op32 = Operand::get_const(gfx_level, constant, 4);
201 add_label(label_literal);
202 val = constant;
203
204 /* check that no upper bits are lost in case of packed 16bit constants */
205 if (gfx_level >= GFX8 && !op16.isLiteral() &&
206 op16.constantValue16(true) == ((constant >> 16) & 0xffff))
207 add_label(label_constant_16bit);
208
209 if (!op32.isLiteral())
210 add_label(label_constant_32bit);
211
212 if (Operand::is_constant_representable(constant, 8))
213 add_label(label_constant_64bit);
214
215 if (label & label_constant_64bit) {
216 val = Operand::c64(constant).constantValue();
217 if (val != constant)
218 label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
219 }
220 }
221
is_constantaco::ssa_info222 bool is_constant(unsigned bits)
223 {
224 switch (bits) {
225 case 8: return label & label_literal;
226 case 16: return label & label_constant_16bit;
227 case 32: return label & label_constant_32bit;
228 case 64: return label & label_constant_64bit;
229 }
230 return false;
231 }
232
is_literalaco::ssa_info233 bool is_literal(unsigned bits)
234 {
235 bool is_lit = label & label_literal;
236 switch (bits) {
237 case 8: return false;
238 case 16: return is_lit && ~(label & label_constant_16bit);
239 case 32: return is_lit && ~(label & label_constant_32bit);
240 case 64: return false;
241 }
242 return false;
243 }
244
is_constant_or_literalaco::ssa_info245 bool is_constant_or_literal(unsigned bits)
246 {
247 if (bits == 64)
248 return label & label_constant_64bit;
249 else
250 return label & label_literal;
251 }
252
set_absaco::ssa_info253 void set_abs(Temp abs_temp)
254 {
255 add_label(label_abs);
256 temp = abs_temp;
257 }
258
is_absaco::ssa_info259 bool is_abs() { return label & label_abs; }
260
set_negaco::ssa_info261 void set_neg(Temp neg_temp)
262 {
263 add_label(label_neg);
264 temp = neg_temp;
265 }
266
is_negaco::ssa_info267 bool is_neg() { return label & label_neg; }
268
set_neg_absaco::ssa_info269 void set_neg_abs(Temp neg_abs_temp)
270 {
271 add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
272 temp = neg_abs_temp;
273 }
274
set_mulaco::ssa_info275 void set_mul(Instruction* mul)
276 {
277 add_label(label_mul);
278 instr = mul;
279 }
280
is_mulaco::ssa_info281 bool is_mul() { return label & label_mul; }
282
set_tempaco::ssa_info283 void set_temp(Temp tmp)
284 {
285 add_label(label_temp);
286 temp = tmp;
287 }
288
is_tempaco::ssa_info289 bool is_temp() { return label & label_temp; }
290
set_madaco::ssa_info291 void set_mad(uint32_t mad_info_idx)
292 {
293 add_label(label_mad);
294 val = mad_info_idx;
295 }
296
is_madaco::ssa_info297 bool is_mad() { return label & label_mad; }
298
set_omod2aco::ssa_info299 void set_omod2(Instruction* mul)
300 {
301 if (label & temp_labels)
302 return;
303 add_label(label_omod2);
304 instr = mul;
305 }
306
is_omod2aco::ssa_info307 bool is_omod2() { return label & label_omod2; }
308
set_omod4aco::ssa_info309 void set_omod4(Instruction* mul)
310 {
311 if (label & temp_labels)
312 return;
313 add_label(label_omod4);
314 instr = mul;
315 }
316
is_omod4aco::ssa_info317 bool is_omod4() { return label & label_omod4; }
318
set_omod5aco::ssa_info319 void set_omod5(Instruction* mul)
320 {
321 if (label & temp_labels)
322 return;
323 add_label(label_omod5);
324 instr = mul;
325 }
326
is_omod5aco::ssa_info327 bool is_omod5() { return label & label_omod5; }
328
set_clampaco::ssa_info329 void set_clamp(Instruction* med3)
330 {
331 if (label & temp_labels)
332 return;
333 add_label(label_clamp);
334 instr = med3;
335 }
336
is_clampaco::ssa_info337 bool is_clamp() { return label & label_clamp; }
338
set_f2f16aco::ssa_info339 void set_f2f16(Instruction* conv)
340 {
341 if (label & temp_labels)
342 return;
343 add_label(label_f2f16);
344 instr = conv;
345 }
346
is_f2f16aco::ssa_info347 bool is_f2f16() { return label & label_f2f16; }
348
set_undefinedaco::ssa_info349 void set_undefined() { add_label(label_undefined); }
350
is_undefinedaco::ssa_info351 bool is_undefined() { return label & label_undefined; }
352
set_vccaco::ssa_info353 void set_vcc(Temp vcc_val)
354 {
355 add_label(label_vcc);
356 temp = vcc_val;
357 }
358
is_vccaco::ssa_info359 bool is_vcc() { return label & label_vcc; }
360
set_b2faco::ssa_info361 void set_b2f(Temp b2f_val)
362 {
363 add_label(label_b2f);
364 temp = b2f_val;
365 }
366
is_b2faco::ssa_info367 bool is_b2f() { return label & label_b2f; }
368
set_add_subaco::ssa_info369 void set_add_sub(Instruction* add_sub_instr)
370 {
371 add_label(label_add_sub);
372 instr = add_sub_instr;
373 }
374
is_add_subaco::ssa_info375 bool is_add_sub() { return label & label_add_sub; }
376
set_bitwiseaco::ssa_info377 void set_bitwise(Instruction* bitwise_instr)
378 {
379 add_label(label_bitwise);
380 instr = bitwise_instr;
381 }
382
is_bitwiseaco::ssa_info383 bool is_bitwise() { return label & label_bitwise; }
384
set_uniform_bitwiseaco::ssa_info385 void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
386
is_uniform_bitwiseaco::ssa_info387 bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
388
set_minmaxaco::ssa_info389 void set_minmax(Instruction* minmax_instr)
390 {
391 add_label(label_minmax);
392 instr = minmax_instr;
393 }
394
is_minmaxaco::ssa_info395 bool is_minmax() { return label & label_minmax; }
396
set_vopcaco::ssa_info397 void set_vopc(Instruction* vopc_instr)
398 {
399 add_label(label_vopc);
400 instr = vopc_instr;
401 }
402
is_vopcaco::ssa_info403 bool is_vopc() { return label & label_vopc; }
404
set_scc_neededaco::ssa_info405 void set_scc_needed() { add_label(label_scc_needed); }
406
is_scc_neededaco::ssa_info407 bool is_scc_needed() { return label & label_scc_needed; }
408
set_scc_invertaco::ssa_info409 void set_scc_invert(Temp scc_inv)
410 {
411 add_label(label_scc_invert);
412 temp = scc_inv;
413 }
414
is_scc_invertaco::ssa_info415 bool is_scc_invert() { return label & label_scc_invert; }
416
set_uniform_boolaco::ssa_info417 void set_uniform_bool(Temp uniform_bool)
418 {
419 add_label(label_uniform_bool);
420 temp = uniform_bool;
421 }
422
is_uniform_boolaco::ssa_info423 bool is_uniform_bool() { return label & label_uniform_bool; }
424
set_b2iaco::ssa_info425 void set_b2i(Temp b2i_val)
426 {
427 add_label(label_b2i);
428 temp = b2i_val;
429 }
430
is_b2iaco::ssa_info431 bool is_b2i() { return label & label_b2i; }
432
set_usedefaco::ssa_info433 void set_usedef(Instruction* label_instr)
434 {
435 add_label(label_usedef);
436 instr = label_instr;
437 }
438
is_usedefaco::ssa_info439 bool is_usedef() { return label & label_usedef; }
440
set_vop3paco::ssa_info441 void set_vop3p(Instruction* vop3p_instr)
442 {
443 add_label(label_vop3p);
444 instr = vop3p_instr;
445 }
446
is_vop3paco::ssa_info447 bool is_vop3p() { return label & label_vop3p; }
448
set_fcanonicalizeaco::ssa_info449 void set_fcanonicalize(Temp tmp)
450 {
451 add_label(label_fcanonicalize);
452 temp = tmp;
453 }
454
is_fcanonicalizeaco::ssa_info455 bool is_fcanonicalize() { return label & label_fcanonicalize; }
456
set_canonicalizedaco::ssa_info457 void set_canonicalized() { add_label(label_canonicalized); }
458
is_canonicalizedaco::ssa_info459 bool is_canonicalized() { return label & label_canonicalized; }
460
set_f2f32aco::ssa_info461 void set_f2f32(Instruction* cvt)
462 {
463 add_label(label_f2f32);
464 instr = cvt;
465 }
466
is_f2f32aco::ssa_info467 bool is_f2f32() { return label & label_f2f32; }
468
set_extractaco::ssa_info469 void set_extract(Instruction* extract)
470 {
471 add_label(label_extract);
472 instr = extract;
473 }
474
is_extractaco::ssa_info475 bool is_extract() { return label & label_extract; }
476
set_insertaco::ssa_info477 void set_insert(Instruction* insert)
478 {
479 if (label & temp_labels)
480 return;
481 add_label(label_insert);
482 instr = insert;
483 }
484
is_insertaco::ssa_info485 bool is_insert() { return label & label_insert; }
486
set_dpp16aco::ssa_info487 void set_dpp16(Instruction* mov)
488 {
489 add_label(label_dpp16);
490 instr = mov;
491 }
492
set_dpp8aco::ssa_info493 void set_dpp8(Instruction* mov)
494 {
495 add_label(label_dpp8);
496 instr = mov;
497 }
498
is_dppaco::ssa_info499 bool is_dpp() { return label & (label_dpp16 | label_dpp8); }
is_dpp16aco::ssa_info500 bool is_dpp16() { return label & label_dpp16; }
is_dpp8aco::ssa_info501 bool is_dpp8() { return label & label_dpp8; }
502
set_splitaco::ssa_info503 void set_split(Instruction* split)
504 {
505 add_label(label_split);
506 instr = split;
507 }
508
is_splitaco::ssa_info509 bool is_split() { return label & label_split; }
510
set_subgroup_invocationaco::ssa_info511 void set_subgroup_invocation(Instruction* label_instr)
512 {
513 add_label(label_subgroup_invocation);
514 instr = label_instr;
515 }
516
is_subgroup_invocationaco::ssa_info517 bool is_subgroup_invocation() { return label & label_subgroup_invocation; }
518 };
519
520 struct opt_ctx {
521 Program* program;
522 float_mode fp_mode;
523 std::vector<aco_ptr<Instruction>> instructions;
524 ssa_info* info;
525 std::pair<uint32_t, Temp> last_literal;
526 std::vector<mad_info> mad_infos;
527 std::vector<uint16_t> uses;
528 };
529
530 bool
can_use_VOP3(opt_ctx & ctx,const aco_ptr<Instruction> & instr)531 can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
532 {
533 if (instr->isVOP3())
534 return true;
535
536 if (instr->isVOP3P())
537 return false;
538
539 if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->gfx_level < GFX10)
540 return false;
541
542 if (instr->isSDWA())
543 return false;
544
545 if (instr->isDPP() && ctx.program->gfx_level < GFX11)
546 return false;
547
548 return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
549 instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
550 instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
551 instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
552 instr->opcode != aco_opcode::v_permlane64_b32 &&
553 instr->opcode != aco_opcode::v_readlane_b32 &&
554 instr->opcode != aco_opcode::v_writelane_b32 &&
555 instr->opcode != aco_opcode::v_readfirstlane_b32;
556 }
557
558 bool
pseudo_propagate_temp(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp temp,unsigned index)559 pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
560 {
561 if (instr->definitions.empty())
562 return false;
563
564 const bool vgpr =
565 instr->opcode == aco_opcode::p_as_uniform ||
566 std::all_of(instr->definitions.begin(), instr->definitions.end(),
567 [](const Definition& def) { return def.regClass().type() == RegType::vgpr; });
568
569 /* don't propagate VGPRs into SGPR instructions */
570 if (temp.type() == RegType::vgpr && !vgpr)
571 return false;
572
573 bool can_accept_sgpr =
574 ctx.program->gfx_level >= GFX9 ||
575 std::none_of(instr->definitions.begin(), instr->definitions.end(),
576 [](const Definition& def) { return def.regClass().is_subdword(); });
577
578 switch (instr->opcode) {
579 case aco_opcode::p_phi:
580 case aco_opcode::p_linear_phi:
581 case aco_opcode::p_parallelcopy:
582 case aco_opcode::p_create_vector:
583 if (temp.bytes() != instr->operands[index].bytes())
584 return false;
585 break;
586 case aco_opcode::p_extract_vector:
587 case aco_opcode::p_extract:
588 if (temp.type() == RegType::sgpr && !can_accept_sgpr)
589 return false;
590 break;
591 case aco_opcode::p_split_vector: {
592 if (temp.type() == RegType::sgpr && !can_accept_sgpr)
593 return false;
594 /* don't increase the vector size */
595 if (temp.bytes() > instr->operands[index].bytes())
596 return false;
597 /* We can decrease the vector size as smaller temporaries are only
598 * propagated by p_as_uniform instructions.
599 * If this propagation leads to invalid IR or hits the assertion below,
600 * it means that some undefined bytes within a dword are begin accessed
601 * and a bug in instruction_selection is likely. */
602 int decrease = instr->operands[index].bytes() - temp.bytes();
603 while (decrease > 0) {
604 decrease -= instr->definitions.back().bytes();
605 instr->definitions.pop_back();
606 }
607 assert(decrease == 0);
608 break;
609 }
610 case aco_opcode::p_as_uniform:
611 if (temp.regClass() == instr->definitions[0].regClass())
612 instr->opcode = aco_opcode::p_parallelcopy;
613 break;
614 default: return false;
615 }
616
617 instr->operands[index].setTemp(temp);
618 return true;
619 }
620
621 /* This expects the DPP modifier to be removed. */
622 bool
can_apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)623 can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
624 {
625 assert(instr->isVALU());
626 if (instr->isSDWA() && ctx.program->gfx_level < GFX9)
627 return false;
628 return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
629 instr->opcode != aco_opcode::v_readlane_b32 &&
630 instr->opcode != aco_opcode::v_readlane_b32_e64 &&
631 instr->opcode != aco_opcode::v_writelane_b32 &&
632 instr->opcode != aco_opcode::v_writelane_b32_e64 &&
633 instr->opcode != aco_opcode::v_permlane16_b32 &&
634 instr->opcode != aco_opcode::v_permlanex16_b32 &&
635 instr->opcode != aco_opcode::v_permlane64_b32 &&
636 instr->opcode != aco_opcode::v_interp_p1_f32 &&
637 instr->opcode != aco_opcode::v_interp_p2_f32 &&
638 instr->opcode != aco_opcode::v_interp_mov_f32 &&
639 instr->opcode != aco_opcode::v_interp_p1ll_f16 &&
640 instr->opcode != aco_opcode::v_interp_p1lv_f16 &&
641 instr->opcode != aco_opcode::v_interp_p2_legacy_f16 &&
642 instr->opcode != aco_opcode::v_interp_p2_f16 &&
643 instr->opcode != aco_opcode::v_interp_p10_f32_inreg &&
644 instr->opcode != aco_opcode::v_interp_p2_f32_inreg &&
645 instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg &&
646 instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg &&
647 instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg &&
648 instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg &&
649 instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 &&
650 instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 &&
651 instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
652 instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
653 instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
654 instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4;
655 }
656
657 bool
is_operand_vgpr(Operand op)658 is_operand_vgpr(Operand op)
659 {
660 return op.isTemp() && op.getTemp().type() == RegType::vgpr;
661 }
662
663 /* only covers special cases */
664 bool
alu_can_accept_constant(const aco_ptr<Instruction> & instr,unsigned operand)665 alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand)
666 {
667 /* Fixed operands can't accept constants because we need them
668 * to be in their fixed register.
669 */
670 assert(instr->operands.size() > operand);
671 if (instr->operands[operand].isFixed())
672 return false;
673
674 /* SOPP instructions can't use constants. */
675 if (instr->isSOPP())
676 return false;
677
678 switch (instr->opcode) {
679 case aco_opcode::v_mac_f32:
680 case aco_opcode::v_writelane_b32:
681 case aco_opcode::v_writelane_b32_e64:
682 case aco_opcode::v_cndmask_b32: return operand != 2;
683 case aco_opcode::s_addk_i32:
684 case aco_opcode::s_mulk_i32:
685 case aco_opcode::p_extract_vector:
686 case aco_opcode::p_split_vector:
687 case aco_opcode::v_readlane_b32:
688 case aco_opcode::v_readlane_b32_e64:
689 case aco_opcode::v_readfirstlane_b32:
690 case aco_opcode::p_extract:
691 case aco_opcode::p_insert: return operand != 0;
692 case aco_opcode::p_bpermute_readlane:
693 case aco_opcode::p_bpermute_shared_vgpr:
694 case aco_opcode::p_bpermute_permlane:
695 case aco_opcode::p_interp_gfx11:
696 case aco_opcode::p_dual_src_export_gfx11:
697 case aco_opcode::v_interp_p1_f32:
698 case aco_opcode::v_interp_p2_f32:
699 case aco_opcode::v_interp_mov_f32:
700 case aco_opcode::v_interp_p1ll_f16:
701 case aco_opcode::v_interp_p1lv_f16:
702 case aco_opcode::v_interp_p2_legacy_f16:
703 case aco_opcode::v_interp_p10_f32_inreg:
704 case aco_opcode::v_interp_p2_f32_inreg:
705 case aco_opcode::v_interp_p10_f16_f32_inreg:
706 case aco_opcode::v_interp_p2_f16_f32_inreg:
707 case aco_opcode::v_interp_p10_rtz_f16_f32_inreg:
708 case aco_opcode::v_interp_p2_rtz_f16_f32_inreg:
709 case aco_opcode::v_wmma_f32_16x16x16_f16:
710 case aco_opcode::v_wmma_f32_16x16x16_bf16:
711 case aco_opcode::v_wmma_f16_16x16x16_f16:
712 case aco_opcode::v_wmma_bf16_16x16x16_bf16:
713 case aco_opcode::v_wmma_i32_16x16x16_iu8:
714 case aco_opcode::v_wmma_i32_16x16x16_iu4: return false;
715 default: return true;
716 }
717 }
718
719 bool
valu_can_accept_vgpr(aco_ptr<Instruction> & instr,unsigned operand)720 valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
721 {
722 if (instr->opcode == aco_opcode::v_readlane_b32 ||
723 instr->opcode == aco_opcode::v_readlane_b32_e64 ||
724 instr->opcode == aco_opcode::v_writelane_b32 ||
725 instr->opcode == aco_opcode::v_writelane_b32_e64)
726 return operand != 1;
727 if (instr->opcode == aco_opcode::v_permlane16_b32 ||
728 instr->opcode == aco_opcode::v_permlanex16_b32)
729 return operand == 0;
730 return true;
731 }
732
733 /* check constant bus and literal limitations */
734 bool
check_vop3_operands(opt_ctx & ctx,unsigned num_operands,Operand * operands)735 check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand* operands)
736 {
737 int limit = ctx.program->gfx_level >= GFX10 ? 2 : 1;
738 Operand literal32(s1);
739 Operand literal64(s2);
740 unsigned num_sgprs = 0;
741 unsigned sgpr[] = {0, 0};
742
743 for (unsigned i = 0; i < num_operands; i++) {
744 Operand op = operands[i];
745
746 if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
747 /* two reads of the same SGPR count as 1 to the limit */
748 if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
749 if (num_sgprs < 2)
750 sgpr[num_sgprs++] = op.tempId();
751 limit--;
752 if (limit < 0)
753 return false;
754 }
755 } else if (op.isLiteral()) {
756 if (ctx.program->gfx_level < GFX10)
757 return false;
758
759 if (!literal32.isUndefined() && literal32.constantValue() != op.constantValue())
760 return false;
761 if (!literal64.isUndefined() && literal64.constantValue() != op.constantValue())
762 return false;
763
764 /* Any number of 32-bit literals counts as only 1 to the limit. Same
765 * (but separately) for 64-bit literals. */
766 if (op.size() == 1 && literal32.isUndefined()) {
767 limit--;
768 literal32 = op;
769 } else if (op.size() == 2 && literal64.isUndefined()) {
770 limit--;
771 literal64 = op;
772 }
773
774 if (limit < 0)
775 return false;
776 }
777 }
778
779 return true;
780 }
781
782 bool
parse_base_offset(opt_ctx & ctx,Instruction * instr,unsigned op_index,Temp * base,uint32_t * offset,bool prevent_overflow)783 parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* base, uint32_t* offset,
784 bool prevent_overflow)
785 {
786 Operand op = instr->operands[op_index];
787
788 if (!op.isTemp())
789 return false;
790 Temp tmp = op.getTemp();
791 if (!ctx.info[tmp.id()].is_add_sub())
792 return false;
793
794 Instruction* add_instr = ctx.info[tmp.id()].instr;
795
796 unsigned mask = 0x3;
797 bool is_sub = false;
798 switch (add_instr->opcode) {
799 case aco_opcode::v_add_u32:
800 case aco_opcode::v_add_co_u32:
801 case aco_opcode::v_add_co_u32_e64:
802 case aco_opcode::s_add_i32:
803 case aco_opcode::s_add_u32: break;
804 case aco_opcode::v_sub_u32:
805 case aco_opcode::v_sub_i32:
806 case aco_opcode::v_sub_co_u32:
807 case aco_opcode::v_sub_co_u32_e64:
808 case aco_opcode::s_sub_u32:
809 case aco_opcode::s_sub_i32:
810 mask = 0x2;
811 is_sub = true;
812 break;
813 case aco_opcode::v_subrev_u32:
814 case aco_opcode::v_subrev_co_u32:
815 case aco_opcode::v_subrev_co_u32_e64:
816 mask = 0x1;
817 is_sub = true;
818 break;
819 default: return false;
820 }
821 if (prevent_overflow && !add_instr->definitions[0].isNUW())
822 return false;
823
824 if (add_instr->usesModifiers())
825 return false;
826
827 u_foreach_bit (i, mask) {
828 if (add_instr->operands[i].isConstant()) {
829 *offset = add_instr->operands[i].constantValue() * (uint32_t)(is_sub ? -1 : 1);
830 } else if (add_instr->operands[i].isTemp() &&
831 ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
832 *offset = ctx.info[add_instr->operands[i].tempId()].val * (uint32_t)(is_sub ? -1 : 1);
833 } else {
834 continue;
835 }
836 if (!add_instr->operands[!i].isTemp())
837 continue;
838
839 uint32_t offset2 = 0;
840 if (parse_base_offset(ctx, add_instr, !i, base, &offset2, prevent_overflow)) {
841 *offset += offset2;
842 } else {
843 *base = add_instr->operands[!i].getTemp();
844 }
845 return true;
846 }
847
848 return false;
849 }
850
851 void
skip_smem_offset_align(opt_ctx & ctx,SMEM_instruction * smem)852 skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem)
853 {
854 bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
855 if (soe && !smem->operands[1].isConstant())
856 return;
857 /* We don't need to check the constant offset because the address seems to be calculated with
858 * (offset&-4 + const_offset&-4), not (offset+const_offset)&-4.
859 */
860
861 Operand& op = smem->operands[soe ? smem->operands.size() - 1 : 1];
862 if (!op.isTemp() || !ctx.info[op.tempId()].is_bitwise())
863 return;
864
865 Instruction* bitwise_instr = ctx.info[op.tempId()].instr;
866 if (bitwise_instr->opcode != aco_opcode::s_and_b32)
867 return;
868
869 if (bitwise_instr->operands[0].constantEquals(-4) &&
870 bitwise_instr->operands[1].isOfType(op.regClass().type()))
871 op.setTemp(bitwise_instr->operands[1].getTemp());
872 else if (bitwise_instr->operands[1].constantEquals(-4) &&
873 bitwise_instr->operands[0].isOfType(op.regClass().type()))
874 op.setTemp(bitwise_instr->operands[0].getTemp());
875 }
876
877 void
smem_combine(opt_ctx & ctx,aco_ptr<Instruction> & instr)878 smem_combine(opt_ctx& ctx, aco_ptr<Instruction>& instr)
879 {
880 /* skip &-4 before offset additions: load((a + 16) & -4, 0) */
881 if (!instr->operands.empty())
882 skip_smem_offset_align(ctx, &instr->smem());
883
884 /* propagate constants and combine additions */
885 if (!instr->operands.empty() && instr->operands[1].isTemp()) {
886 SMEM_instruction& smem = instr->smem();
887 ssa_info info = ctx.info[instr->operands[1].tempId()];
888
889 Temp base;
890 uint32_t offset;
891 if (info.is_constant_or_literal(32) &&
892 ((ctx.program->gfx_level == GFX6 && info.val <= 0x3FF) ||
893 (ctx.program->gfx_level == GFX7 && info.val <= 0xFFFFFFFF) ||
894 (ctx.program->gfx_level >= GFX8 && info.val <= 0xFFFFF))) {
895 instr->operands[1] = Operand::c32(info.val);
896 } else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, true) &&
897 base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->gfx_level >= GFX9 &&
898 offset % 4u == 0) {
899 bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
900 if (soe) {
901 if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) &&
902 ctx.info[smem.operands.back().tempId()].val == 0) {
903 smem.operands[1] = Operand::c32(offset);
904 smem.operands.back() = Operand(base);
905 }
906 } else {
907 SMEM_instruction* new_instr = create_instruction<SMEM_instruction>(
908 smem.opcode, Format::SMEM, smem.operands.size() + 1, smem.definitions.size());
909 new_instr->operands[0] = smem.operands[0];
910 new_instr->operands[1] = Operand::c32(offset);
911 if (smem.definitions.empty())
912 new_instr->operands[2] = smem.operands[2];
913 new_instr->operands.back() = Operand(base);
914 if (!smem.definitions.empty())
915 new_instr->definitions[0] = smem.definitions[0];
916 new_instr->sync = smem.sync;
917 new_instr->glc = smem.glc;
918 new_instr->dlc = smem.dlc;
919 new_instr->nv = smem.nv;
920 new_instr->disable_wqm = smem.disable_wqm;
921 instr.reset(new_instr);
922 }
923 }
924 }
925
926 /* skip &-4 after offset additions: load(a & -4, 16) */
927 if (!instr->operands.empty())
928 skip_smem_offset_align(ctx, &instr->smem());
929 }
930
931 Operand
get_constant_op(opt_ctx & ctx,ssa_info info,uint32_t bits)932 get_constant_op(opt_ctx& ctx, ssa_info info, uint32_t bits)
933 {
934 if (bits == 64)
935 return Operand::c32_or_c64(info.val, true);
936 return Operand::get_const(ctx.program->gfx_level, info.val, bits / 8u);
937 }
938
939 void
propagate_constants_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned i)940 propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned i)
941 {
942 if (!info.is_constant_or_literal(32))
943 return;
944
945 assert(instr->operands[i].isTemp());
946 unsigned bits = get_operand_size(instr, i);
947 if (info.is_constant(bits)) {
948 instr->operands[i] = get_constant_op(ctx, info, bits);
949 return;
950 }
951
952 /* The accumulation operand of dot product instructions ignores opsel. */
953 bool cannot_use_opsel =
954 (instr->opcode == aco_opcode::v_dot4_i32_i8 || instr->opcode == aco_opcode::v_dot2_i32_i16 ||
955 instr->opcode == aco_opcode::v_dot4_i32_iu8 || instr->opcode == aco_opcode::v_dot4_u32_u8 ||
956 instr->opcode == aco_opcode::v_dot2_u32_u16) &&
957 i == 2;
958 if (cannot_use_opsel)
959 return;
960
961 /* try to fold inline constants */
962 VALU_instruction* vop3p = &instr->valu();
963 bool opsel_lo = vop3p->opsel_lo[i];
964 bool opsel_hi = vop3p->opsel_hi[i];
965
966 Operand const_op[2];
967 bool const_opsel[2] = {false, false};
968 for (unsigned j = 0; j < 2; j++) {
969 if ((unsigned)opsel_lo != j && (unsigned)opsel_hi != j)
970 continue; /* this half is unused */
971
972 uint16_t val = info.val >> (j ? 16 : 0);
973 Operand op = Operand::get_const(ctx.program->gfx_level, val, bits / 8u);
974 if (bits == 32 && op.isLiteral()) /* try sign extension */
975 op = Operand::get_const(ctx.program->gfx_level, val | 0xffff0000, 4);
976 if (bits == 32 && op.isLiteral()) { /* try shifting left */
977 op = Operand::get_const(ctx.program->gfx_level, val << 16, 4);
978 const_opsel[j] = true;
979 }
980 if (op.isLiteral())
981 return;
982 const_op[j] = op;
983 }
984
985 Operand const_lo = const_op[0];
986 Operand const_hi = const_op[1];
987 bool const_lo_opsel = const_opsel[0];
988 bool const_hi_opsel = const_opsel[1];
989
990 if (opsel_lo == opsel_hi) {
991 /* use the single 16bit value */
992 instr->operands[i] = opsel_lo ? const_hi : const_lo;
993
994 /* opsel must point the same for both halves */
995 opsel_lo = opsel_lo ? const_hi_opsel : const_lo_opsel;
996 opsel_hi = opsel_lo;
997 } else if (const_lo == const_hi) {
998 /* both constants are the same */
999 instr->operands[i] = const_lo;
1000
1001 /* opsel must point the same for both halves */
1002 opsel_lo = const_lo_opsel;
1003 opsel_hi = const_lo_opsel;
1004 } else if (const_lo.constantValue16(const_lo_opsel) ==
1005 const_hi.constantValue16(!const_hi_opsel)) {
1006 instr->operands[i] = const_hi;
1007
1008 /* redirect opsel selection */
1009 opsel_lo = opsel_lo ? const_hi_opsel : !const_hi_opsel;
1010 opsel_hi = opsel_hi ? const_hi_opsel : !const_hi_opsel;
1011 } else if (const_hi.constantValue16(const_hi_opsel) ==
1012 const_lo.constantValue16(!const_lo_opsel)) {
1013 instr->operands[i] = const_lo;
1014
1015 /* redirect opsel selection */
1016 opsel_lo = opsel_lo ? !const_lo_opsel : const_lo_opsel;
1017 opsel_hi = opsel_hi ? !const_lo_opsel : const_lo_opsel;
1018 } else if (bits == 16 && const_lo.constantValue() == (const_hi.constantValue() ^ (1 << 15))) {
1019 assert(const_lo_opsel == false && const_hi_opsel == false);
1020
1021 /* const_lo == -const_hi */
1022 if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
1023 return;
1024
1025 instr->operands[i] = Operand::c16(const_lo.constantValue() & 0x7FFF);
1026 bool neg_lo = const_lo.constantValue() & (1 << 15);
1027 vop3p->neg_lo[i] ^= opsel_lo ^ neg_lo;
1028 vop3p->neg_hi[i] ^= opsel_hi ^ neg_lo;
1029
1030 /* opsel must point to lo for both operands */
1031 opsel_lo = false;
1032 opsel_hi = false;
1033 }
1034
1035 vop3p->opsel_lo[i] = opsel_lo;
1036 vop3p->opsel_hi[i] = opsel_hi;
1037 }
1038
1039 bool
fixed_to_exec(Operand op)1040 fixed_to_exec(Operand op)
1041 {
1042 return op.isFixed() && op.physReg() == exec;
1043 }
1044
1045 SubdwordSel
parse_extract(Instruction * instr)1046 parse_extract(Instruction* instr)
1047 {
1048 if (instr->opcode == aco_opcode::p_extract) {
1049 unsigned size = instr->operands[2].constantValue() / 8;
1050 unsigned offset = instr->operands[1].constantValue() * size;
1051 bool sext = instr->operands[3].constantEquals(1);
1052 return SubdwordSel(size, offset, sext);
1053 } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) {
1054 return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
1055 } else if (instr->opcode == aco_opcode::p_extract_vector) {
1056 unsigned size = instr->definitions[0].bytes();
1057 unsigned offset = instr->operands[1].constantValue() * size;
1058 if (size <= 2)
1059 return SubdwordSel(size, offset, false);
1060 } else if (instr->opcode == aco_opcode::p_split_vector) {
1061 assert(instr->operands[0].bytes() == 4 && instr->definitions[1].bytes() == 2);
1062 return SubdwordSel(2, 2, false);
1063 }
1064
1065 return SubdwordSel();
1066 }
1067
1068 SubdwordSel
parse_insert(Instruction * instr)1069 parse_insert(Instruction* instr)
1070 {
1071 if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) &&
1072 instr->operands[1].constantEquals(0)) {
1073 return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
1074 } else if (instr->opcode == aco_opcode::p_insert) {
1075 unsigned size = instr->operands[2].constantValue() / 8;
1076 unsigned offset = instr->operands[1].constantValue() * size;
1077 return SubdwordSel(size, offset, false);
1078 } else {
1079 return SubdwordSel();
1080 }
1081 }
1082
1083 bool
can_apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1084 can_apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1085 {
1086 Temp tmp = info.instr->operands[0].getTemp();
1087 SubdwordSel sel = parse_extract(info.instr);
1088
1089 if (!sel) {
1090 return false;
1091 } else if (sel.size() == 4) {
1092 return true;
1093 } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1094 instr->opcode == aco_opcode::v_cvt_f32_i32) &&
1095 sel.size() == 1 && !sel.sign_extend()) {
1096 return true;
1097 } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1098 sel.offset() == 0 &&
1099 ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1100 (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1101 return true;
1102 } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1103 !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1104 (instr->operands[!idx].is16bit() ||
1105 instr->operands[!idx].constantValue() <= UINT16_MAX)) {
1106 return true;
1107 } else if (idx < 2 && can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1108 (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1109 if (instr->isSDWA() && instr->sdwa().sel[idx] != SubdwordSel::dword)
1110 return false;
1111 return true;
1112 } else if (instr->isVALU() && sel.size() == 2 && !instr->valu().opsel[idx] &&
1113 can_use_opsel(ctx.program->gfx_level, instr->opcode, idx)) {
1114 return true;
1115 } else if (instr->opcode == aco_opcode::p_extract) {
1116 SubdwordSel instrSel = parse_extract(instr.get());
1117
1118 /* the outer offset must be within extracted range */
1119 if (instrSel.offset() >= sel.size())
1120 return false;
1121
1122 /* don't remove the sign-extension when increasing the size further */
1123 if (instrSel.size() > sel.size() && !instrSel.sign_extend() && sel.sign_extend())
1124 return false;
1125
1126 return true;
1127 }
1128
1129 return false;
1130 }
1131
1132 /* Combine an p_extract (or p_insert, in some cases) instruction with instr.
1133 * instr(p_extract(...)) -> instr()
1134 */
1135 void
apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1136 apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1137 {
1138 Temp tmp = info.instr->operands[0].getTemp();
1139 SubdwordSel sel = parse_extract(info.instr);
1140 assert(sel);
1141
1142 instr->operands[idx].set16bit(false);
1143 instr->operands[idx].set24bit(false);
1144
1145 ctx.info[tmp.id()].label &= ~label_insert;
1146
1147 if (sel.size() == 4) {
1148 /* full dword selection */
1149 } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1150 instr->opcode == aco_opcode::v_cvt_f32_i32) &&
1151 sel.size() == 1 && !sel.sign_extend()) {
1152 switch (sel.offset()) {
1153 case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break;
1154 case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break;
1155 case 2: instr->opcode = aco_opcode::v_cvt_f32_ubyte2; break;
1156 case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break;
1157 }
1158 } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1159 sel.offset() == 0 &&
1160 ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1161 (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1162 /* The undesirable upper bits are already shifted out. */
1163 return;
1164 } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1165 !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1166 (instr->operands[!idx].is16bit() ||
1167 instr->operands[!idx].constantValue() <= UINT16_MAX)) {
1168 Instruction* mad =
1169 create_instruction<VALU_instruction>(aco_opcode::v_mad_u32_u16, Format::VOP3, 3, 1);
1170 mad->definitions[0] = instr->definitions[0];
1171 mad->operands[0] = instr->operands[0];
1172 mad->operands[1] = instr->operands[1];
1173 mad->operands[2] = Operand::zero();
1174 mad->valu().opsel[idx] = sel.offset();
1175 mad->pass_flags = instr->pass_flags;
1176 instr.reset(mad);
1177 } else if (can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1178 (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1179 convert_to_SDWA(ctx.program->gfx_level, instr);
1180 instr->sdwa().sel[idx] = sel;
1181 } else if (instr->isVALU()) {
1182 if (sel.offset()) {
1183 instr->valu().opsel[idx] = true;
1184
1185 /* VOP12C cannot use opsel with SGPRs. */
1186 if (!instr->isVOP3() && !instr->isVINTERP_INREG() &&
1187 !info.instr->operands[0].isOfType(RegType::vgpr))
1188 instr->format = asVOP3(instr->format);
1189 }
1190 } else if (instr->opcode == aco_opcode::p_extract) {
1191 SubdwordSel instrSel = parse_extract(instr.get());
1192
1193 unsigned size = std::min(sel.size(), instrSel.size());
1194 unsigned offset = sel.offset() + instrSel.offset();
1195 unsigned sign_extend =
1196 instrSel.sign_extend() && (sel.sign_extend() || instrSel.size() <= sel.size());
1197
1198 instr->operands[1] = Operand::c32(offset / size);
1199 instr->operands[2] = Operand::c32(size * 8u);
1200 instr->operands[3] = Operand::c32(sign_extend);
1201 return;
1202 }
1203
1204 /* These are the only labels worth keeping at the moment. */
1205 for (Definition& def : instr->definitions) {
1206 ctx.info[def.tempId()].label &=
1207 (label_mul | label_minmax | label_usedef | label_vopc | label_f2f32 | instr_mod_labels);
1208 if (ctx.info[def.tempId()].label & instr_usedef_labels)
1209 ctx.info[def.tempId()].instr = instr.get();
1210 }
1211 }
1212
1213 void
check_sdwa_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr)1214 check_sdwa_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1215 {
1216 for (unsigned i = 0; i < instr->operands.size(); i++) {
1217 Operand op = instr->operands[i];
1218 if (!op.isTemp())
1219 continue;
1220 ssa_info& info = ctx.info[op.tempId()];
1221 if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
1222 op.getTemp().type() == RegType::sgpr)) {
1223 if (!can_apply_extract(ctx, instr, i, info))
1224 info.label &= ~label_extract;
1225 }
1226 }
1227 }
1228
1229 bool
does_fp_op_flush_denorms(opt_ctx & ctx,aco_opcode op)1230 does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
1231 {
1232 switch (op) {
1233 case aco_opcode::v_min_f32:
1234 case aco_opcode::v_max_f32:
1235 case aco_opcode::v_med3_f32:
1236 case aco_opcode::v_min3_f32:
1237 case aco_opcode::v_max3_f32:
1238 case aco_opcode::v_min_f16:
1239 case aco_opcode::v_max_f16: return ctx.program->gfx_level > GFX8;
1240 case aco_opcode::v_cndmask_b32:
1241 case aco_opcode::v_cndmask_b16:
1242 case aco_opcode::v_mov_b32:
1243 case aco_opcode::v_mov_b16: return false;
1244 default: return true;
1245 }
1246 }
1247
1248 bool
can_eliminate_fcanonicalize(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp tmp,unsigned idx)1249 can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp, unsigned idx)
1250 {
1251 float_mode* fp = &ctx.fp_mode;
1252 if (ctx.info[tmp.id()].is_canonicalized() ||
1253 (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1254 return true;
1255
1256 aco_opcode op = instr->opcode;
1257 return can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, idx) &&
1258 does_fp_op_flush_denorms(ctx, op);
1259 }
1260
1261 bool
can_eliminate_and_exec(opt_ctx & ctx,Temp tmp,unsigned pass_flags)1262 can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags)
1263 {
1264 if (ctx.info[tmp.id()].is_vopc()) {
1265 Instruction* vopc_instr = ctx.info[tmp.id()].instr;
1266 /* Remove superfluous s_and when the VOPC instruction uses the same exec and thus
1267 * already produces the same result */
1268 return vopc_instr->pass_flags == pass_flags;
1269 }
1270 if (ctx.info[tmp.id()].is_bitwise()) {
1271 Instruction* instr = ctx.info[tmp.id()].instr;
1272 if (instr->operands.size() != 2 || instr->pass_flags != pass_flags)
1273 return false;
1274 if (!(instr->operands[0].isTemp() && instr->operands[1].isTemp()))
1275 return false;
1276 if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
1277 return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) ||
1278 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1279 } else {
1280 return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) &&
1281 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1282 }
1283 }
1284 return false;
1285 }
1286
1287 bool
is_copy_label(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned idx)1288 is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned idx)
1289 {
1290 return info.is_temp() ||
1291 (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp, idx));
1292 }
1293
1294 bool
is_op_canonicalized(opt_ctx & ctx,Operand op)1295 is_op_canonicalized(opt_ctx& ctx, Operand op)
1296 {
1297 float_mode* fp = &ctx.fp_mode;
1298 if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) ||
1299 (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1300 return true;
1301
1302 if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
1303 uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
1304 if (op.bytes() == 2)
1305 return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
1306 else if (op.bytes() == 4)
1307 return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff;
1308 }
1309 return false;
1310 }
1311
1312 bool
is_scratch_offset_valid(opt_ctx & ctx,Instruction * instr,int64_t offset0,int64_t offset1)1313 is_scratch_offset_valid(opt_ctx& ctx, Instruction* instr, int64_t offset0, int64_t offset1)
1314 {
1315 bool negative_unaligned_scratch_offset_bug = ctx.program->gfx_level == GFX10;
1316 int32_t min = ctx.program->dev.scratch_global_offset_min;
1317 int32_t max = ctx.program->dev.scratch_global_offset_max;
1318
1319 int64_t offset = offset0 + offset1;
1320
1321 bool has_vgpr_offset = instr && !instr->operands[0].isUndefined();
1322 if (negative_unaligned_scratch_offset_bug && has_vgpr_offset && offset < 0 && offset % 4)
1323 return false;
1324
1325 return offset >= min && offset <= max;
1326 }
1327
1328 bool
detect_clamp(Instruction * instr,unsigned * clamped_idx)1329 detect_clamp(Instruction* instr, unsigned* clamped_idx)
1330 {
1331 VALU_instruction& valu = instr->valu();
1332 if (valu.omod != 0 || valu.opsel != 0)
1333 return false;
1334
1335 unsigned idx = 0;
1336 bool found_zero = false, found_one = false;
1337 bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16;
1338 for (unsigned i = 0; i < 3; i++) {
1339 if (!valu.neg[i] && instr->operands[i].constantEquals(0))
1340 found_zero = true;
1341 else if (!valu.neg[i] &&
1342 instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */
1343 found_one = true;
1344 else
1345 idx = i;
1346 }
1347 if (found_zero && found_one && instr->operands[idx].isTemp()) {
1348 *clamped_idx = idx;
1349 return true;
1350 } else {
1351 return false;
1352 }
1353 }
1354
1355 void
label_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)1356 label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1357 {
1358 if (instr->isSALU() || instr->isVALU() || instr->isPseudo()) {
1359 ASSERTED bool all_const = false;
1360 for (Operand& op : instr->operands)
1361 all_const =
1362 all_const && (!op.isTemp() || ctx.info[op.tempId()].is_constant_or_literal(32));
1363 perfwarn(ctx.program, all_const, "All instruction operands are constant", instr.get());
1364
1365 ASSERTED bool is_copy = instr->opcode == aco_opcode::s_mov_b32 ||
1366 instr->opcode == aco_opcode::s_mov_b64 ||
1367 instr->opcode == aco_opcode::v_mov_b32;
1368 perfwarn(ctx.program, is_copy && !instr->usesModifiers(), "Use p_parallelcopy instead",
1369 instr.get());
1370 }
1371
1372 if (instr->isSMEM())
1373 smem_combine(ctx, instr);
1374
1375 for (unsigned i = 0; i < instr->operands.size(); i++) {
1376 if (!instr->operands[i].isTemp())
1377 continue;
1378
1379 ssa_info info = ctx.info[instr->operands[i].tempId()];
1380 /* propagate undef */
1381 if (info.is_undefined() && is_phi(instr))
1382 instr->operands[i] = Operand(instr->operands[i].regClass());
1383 /* propagate reg->reg of same type */
1384 while (info.is_temp() && info.temp.regClass() == instr->operands[i].getTemp().regClass()) {
1385 instr->operands[i].setTemp(ctx.info[instr->operands[i].tempId()].temp);
1386 info = ctx.info[info.temp.id()];
1387 }
1388
1389 /* PSEUDO: propagate temporaries */
1390 if (instr->isPseudo()) {
1391 while (info.is_temp()) {
1392 pseudo_propagate_temp(ctx, instr, info.temp, i);
1393 info = ctx.info[info.temp.id()];
1394 }
1395 }
1396
1397 /* SALU / PSEUDO: propagate inline constants */
1398 if (instr->isSALU() || instr->isPseudo()) {
1399 unsigned bits = get_operand_size(instr, i);
1400 if ((info.is_constant(bits) || (info.is_literal(bits) && instr->isPseudo())) &&
1401 alu_can_accept_constant(instr, i)) {
1402 instr->operands[i] = get_constant_op(ctx, info, bits);
1403 continue;
1404 }
1405 }
1406
1407 /* VALU: propagate neg, abs & inline constants */
1408 else if (instr->isVALU()) {
1409 if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::vgpr &&
1410 valu_can_accept_vgpr(instr, i)) {
1411 instr->operands[i].setTemp(info.temp);
1412 info = ctx.info[info.temp.id()];
1413 }
1414 /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
1415 if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) &&
1416 instr->operands.size() == 1) {
1417 instr->format = withoutDPP(instr->format);
1418 instr->operands[i].setTemp(info.temp);
1419 info = ctx.info[info.temp.id()];
1420 }
1421
1422 /* for instructions other than v_cndmask_b32, the size of the instruction should match the
1423 * operand size */
1424 bool can_use_mod =
1425 instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
1426 can_use_mod &= can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i);
1427
1428 bool packed_math = instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
1429 instr->opcode != aco_opcode::v_fma_mixlo_f16 &&
1430 instr->opcode != aco_opcode::v_fma_mixhi_f16;
1431
1432 if (instr->isSDWA())
1433 can_use_mod &= instr->sdwa().sel[i].size() == 4;
1434 else if (instr->isVOP3P())
1435 can_use_mod &= !packed_math || !info.is_abs();
1436 else
1437 can_use_mod &= instr->isDPP16() || can_use_VOP3(ctx, instr);
1438
1439 unsigned bits = get_operand_size(instr, i);
1440 can_use_mod &= instr->operands[i].bytes() * 8 == bits;
1441
1442 if (info.is_neg() && can_use_mod &&
1443 can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1444 instr->operands[i].setTemp(info.temp);
1445 if (!packed_math && instr->valu().abs[i]) {
1446 /* fabs(fneg(a)) -> fabs(a) */
1447 } else if (instr->opcode == aco_opcode::v_add_f32) {
1448 instr->opcode = i ? aco_opcode::v_sub_f32 : aco_opcode::v_subrev_f32;
1449 } else if (instr->opcode == aco_opcode::v_add_f16) {
1450 instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
1451 } else if (packed_math) {
1452 /* Bit size compat should ensure this. */
1453 assert(!instr->valu().opsel_lo[i] && !instr->valu().opsel_hi[i]);
1454 instr->valu().neg_lo[i] ^= true;
1455 instr->valu().neg_hi[i] ^= true;
1456 } else {
1457 if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1458 instr->format = asVOP3(instr->format);
1459 instr->valu().neg[i] ^= true;
1460 }
1461 }
1462 if (info.is_abs() && can_use_mod &&
1463 can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1464 if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1465 instr->format = asVOP3(instr->format);
1466 instr->operands[i] = Operand(info.temp);
1467 instr->valu().abs[i] = true;
1468 continue;
1469 }
1470
1471 if (instr->isVOP3P()) {
1472 propagate_constants_vop3p(ctx, instr, info, i);
1473 continue;
1474 }
1475
1476 if (info.is_constant(bits) && alu_can_accept_constant(instr, i) &&
1477 (!instr->isSDWA() || ctx.program->gfx_level >= GFX9) && (!instr->isDPP() || i != 1)) {
1478 Operand op = get_constant_op(ctx, info, bits);
1479 perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2,
1480 "v_cndmask_b32 with a constant selector", instr.get());
1481 if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
1482 instr->opcode == aco_opcode::v_writelane_b32) {
1483 instr->format = withoutDPP(instr->format);
1484 instr->operands[i] = op;
1485 continue;
1486 } else if (!instr->isVOP3() && can_swap_operands(instr, &instr->opcode)) {
1487 instr->operands[i] = op;
1488 instr->valu().swapOperands(0, i);
1489 continue;
1490 } else if (can_use_VOP3(ctx, instr)) {
1491 instr->format = asVOP3(instr->format);
1492 instr->operands[i] = op;
1493 continue;
1494 }
1495 }
1496 }
1497
1498 /* MUBUF: propagate constants and combine additions */
1499 else if (instr->isMUBUF()) {
1500 MUBUF_instruction& mubuf = instr->mubuf();
1501 Temp base;
1502 uint32_t offset;
1503 while (info.is_temp())
1504 info = ctx.info[info.temp.id()];
1505
1506 /* According to AMDGPUDAGToDAGISel::SelectMUBUFScratchOffen(), vaddr
1507 * overflow for scratch accesses works only on GFX9+ and saddr overflow
1508 * never works. Since swizzling is the only thing that separates
1509 * scratch accesses and other accesses and swizzling changing how
1510 * addressing works significantly, this probably applies to swizzled
1511 * MUBUF accesses. */
1512 bool vaddr_prevent_overflow = mubuf.swizzled && ctx.program->gfx_level < GFX9;
1513
1514 if (mubuf.offen && mubuf.idxen && i == 1 && info.is_vec() &&
1515 info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1516 info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1517 mubuf.offset + info.instr->operands[1].constantValue() < 4096) {
1518 instr->operands[1] = info.instr->operands[0];
1519 mubuf.offset += info.instr->operands[1].constantValue();
1520 mubuf.offen = false;
1521 continue;
1522 } else if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
1523 mubuf.offset + info.val < 4096) {
1524 assert(!mubuf.idxen);
1525 instr->operands[1] = Operand(v1);
1526 mubuf.offset += info.val;
1527 mubuf.offen = false;
1528 continue;
1529 } else if (i == 2 && info.is_constant_or_literal(32) && mubuf.offset + info.val < 4096) {
1530 instr->operands[2] = Operand::c32(0);
1531 mubuf.offset += info.val;
1532 continue;
1533 } else if (mubuf.offen && i == 1 &&
1534 parse_base_offset(ctx, instr.get(), i, &base, &offset,
1535 vaddr_prevent_overflow) &&
1536 base.regClass() == v1 && mubuf.offset + offset < 4096) {
1537 assert(!mubuf.idxen);
1538 instr->operands[1].setTemp(base);
1539 mubuf.offset += offset;
1540 continue;
1541 } else if (i == 2 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1542 base.regClass() == s1 && mubuf.offset + offset < 4096 && !mubuf.swizzled) {
1543 instr->operands[i].setTemp(base);
1544 mubuf.offset += offset;
1545 continue;
1546 }
1547 }
1548
1549 else if (instr->isMTBUF()) {
1550 MTBUF_instruction& mtbuf = instr->mtbuf();
1551 while (info.is_temp())
1552 info = ctx.info[info.temp.id()];
1553
1554 if (mtbuf.offen && mtbuf.idxen && i == 1 && info.is_vec() &&
1555 info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1556 info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1557 mtbuf.offset + info.instr->operands[1].constantValue() < 4096) {
1558 instr->operands[1] = info.instr->operands[0];
1559 mtbuf.offset += info.instr->operands[1].constantValue();
1560 mtbuf.offen = false;
1561 continue;
1562 }
1563 }
1564
1565 /* SCRATCH: propagate constants and combine additions */
1566 else if (instr->isScratch()) {
1567 FLAT_instruction& scratch = instr->scratch();
1568 Temp base;
1569 uint32_t offset;
1570 while (info.is_temp())
1571 info = ctx.info[info.temp.id()];
1572
1573 /* The hardware probably does: 'scratch_base + u2u64(saddr) + i2i64(offset)'. This means
1574 * we can't combine the addition if the unsigned addition overflows and offset is
1575 * positive. In theory, there is also issues if
1576 * 'ilt(offset, 0) && ige(saddr, 0) && ilt(saddr + offset, 0)', but that just
1577 * replaces an already out-of-bounds access with a larger one since 'saddr + offset'
1578 * would be larger than INT32_MAX.
1579 */
1580 if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1581 base.regClass() == instr->operands[i].regClass() &&
1582 is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1583 instr->operands[i].setTemp(base);
1584 scratch.offset += (int32_t)offset;
1585 continue;
1586 } else if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1587 base.regClass() == instr->operands[i].regClass() && (int32_t)offset < 0 &&
1588 is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1589 instr->operands[i].setTemp(base);
1590 scratch.offset += (int32_t)offset;
1591 continue;
1592 } else if (i <= 1 && info.is_constant_or_literal(32) &&
1593 ctx.program->gfx_level >= GFX10_3 &&
1594 is_scratch_offset_valid(ctx, NULL, scratch.offset, (int32_t)info.val)) {
1595 /* GFX10.3+ can disable both SADDR and ADDR. */
1596 instr->operands[i] = Operand(instr->operands[i].regClass());
1597 scratch.offset += (int32_t)info.val;
1598 continue;
1599 }
1600 }
1601
1602 /* DS: combine additions */
1603 else if (instr->isDS()) {
1604
1605 DS_instruction& ds = instr->ds();
1606 Temp base;
1607 uint32_t offset;
1608 bool has_usable_ds_offset = ctx.program->gfx_level >= GFX7;
1609 if (has_usable_ds_offset && i == 0 &&
1610 parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1611 base.regClass() == instr->operands[i].regClass() &&
1612 instr->opcode != aco_opcode::ds_swizzle_b32) {
1613 if (instr->opcode == aco_opcode::ds_write2_b32 ||
1614 instr->opcode == aco_opcode::ds_read2_b32 ||
1615 instr->opcode == aco_opcode::ds_write2_b64 ||
1616 instr->opcode == aco_opcode::ds_read2_b64 ||
1617 instr->opcode == aco_opcode::ds_write2st64_b32 ||
1618 instr->opcode == aco_opcode::ds_read2st64_b32 ||
1619 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1620 instr->opcode == aco_opcode::ds_read2st64_b64) {
1621 bool is64bit = instr->opcode == aco_opcode::ds_write2_b64 ||
1622 instr->opcode == aco_opcode::ds_read2_b64 ||
1623 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1624 instr->opcode == aco_opcode::ds_read2st64_b64;
1625 bool st64 = instr->opcode == aco_opcode::ds_write2st64_b32 ||
1626 instr->opcode == aco_opcode::ds_read2st64_b32 ||
1627 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1628 instr->opcode == aco_opcode::ds_read2st64_b64;
1629 unsigned shifts = (is64bit ? 3 : 2) + (st64 ? 6 : 0);
1630 unsigned mask = BITFIELD_MASK(shifts);
1631
1632 if ((offset & mask) == 0 && ds.offset0 + (offset >> shifts) <= 255 &&
1633 ds.offset1 + (offset >> shifts) <= 255) {
1634 instr->operands[i].setTemp(base);
1635 ds.offset0 += offset >> shifts;
1636 ds.offset1 += offset >> shifts;
1637 }
1638 } else {
1639 if (ds.offset0 + offset <= 65535) {
1640 instr->operands[i].setTemp(base);
1641 ds.offset0 += offset;
1642 }
1643 }
1644 }
1645 }
1646
1647 else if (instr->isBranch()) {
1648 if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1649 /* Flip the branch instruction to get rid of the scc_invert instruction */
1650 instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
1651 : aco_opcode::p_cbranch_z;
1652 instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
1653 }
1654 }
1655 }
1656
1657 /* if this instruction doesn't define anything, return */
1658 if (instr->definitions.empty()) {
1659 check_sdwa_extract(ctx, instr);
1660 return;
1661 }
1662
1663 if (instr->isVALU() || instr->isVINTRP()) {
1664 if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() ||
1665 instr->opcode == aco_opcode::v_cndmask_b32) {
1666 bool canonicalized = true;
1667 if (!does_fp_op_flush_denorms(ctx, instr->opcode)) {
1668 unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size();
1669 for (unsigned i = 0; canonicalized && (i < ops); i++)
1670 canonicalized = is_op_canonicalized(ctx, instr->operands[i]);
1671 }
1672 if (canonicalized)
1673 ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1674 }
1675
1676 if (instr->isVOPC()) {
1677 ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
1678 check_sdwa_extract(ctx, instr);
1679 return;
1680 }
1681 if (instr->isVOP3P()) {
1682 ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
1683 return;
1684 }
1685 }
1686
1687 switch (instr->opcode) {
1688 case aco_opcode::p_create_vector: {
1689 bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
1690 instr->operands[0].regClass() == instr->definitions[0].regClass();
1691 if (copy_prop) {
1692 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1693 break;
1694 }
1695
1696 /* expand vector operands */
1697 std::vector<Operand> ops;
1698 unsigned offset = 0;
1699 for (const Operand& op : instr->operands) {
1700 /* ensure that any expanded operands are properly aligned */
1701 bool aligned = offset % 4 == 0 || op.bytes() < 4;
1702 offset += op.bytes();
1703 if (aligned && op.isTemp() && ctx.info[op.tempId()].is_vec()) {
1704 Instruction* vec = ctx.info[op.tempId()].instr;
1705 for (const Operand& vec_op : vec->operands)
1706 ops.emplace_back(vec_op);
1707 } else {
1708 ops.emplace_back(op);
1709 }
1710 }
1711
1712 /* combine expanded operands to new vector */
1713 if (ops.size() != instr->operands.size()) {
1714 assert(ops.size() > instr->operands.size());
1715 Definition def = instr->definitions[0];
1716 instr.reset(create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector,
1717 Format::PSEUDO, ops.size(), 1));
1718 for (unsigned i = 0; i < ops.size(); i++) {
1719 if (ops[i].isTemp() && ctx.info[ops[i].tempId()].is_temp() &&
1720 ops[i].regClass() == ctx.info[ops[i].tempId()].temp.regClass())
1721 ops[i].setTemp(ctx.info[ops[i].tempId()].temp);
1722 instr->operands[i] = ops[i];
1723 }
1724 instr->definitions[0] = def;
1725 } else {
1726 for (unsigned i = 0; i < ops.size(); i++) {
1727 assert(instr->operands[i] == ops[i]);
1728 }
1729 }
1730 ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1731
1732 if (instr->operands.size() == 2) {
1733 /* check if this is created from split_vector */
1734 if (instr->operands[1].isTemp() && ctx.info[instr->operands[1].tempId()].is_split()) {
1735 Instruction* split = ctx.info[instr->operands[1].tempId()].instr;
1736 if (instr->operands[0].isTemp() &&
1737 instr->operands[0].getTemp() == split->definitions[0].getTemp())
1738 ctx.info[instr->definitions[0].tempId()].set_temp(split->operands[0].getTemp());
1739 }
1740 }
1741 break;
1742 }
1743 case aco_opcode::p_split_vector: {
1744 ssa_info& info = ctx.info[instr->operands[0].tempId()];
1745
1746 if (info.is_constant_or_literal(32)) {
1747 uint64_t val = info.val;
1748 for (Definition def : instr->definitions) {
1749 uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
1750 ctx.info[def.tempId()].set_constant(ctx.program->gfx_level, val & mask);
1751 val >>= def.bytes() * 8u;
1752 }
1753 break;
1754 } else if (!info.is_vec()) {
1755 if (instr->definitions.size() == 2 && instr->operands[0].isTemp() &&
1756 instr->definitions[0].bytes() == instr->definitions[1].bytes()) {
1757 ctx.info[instr->definitions[1].tempId()].set_split(instr.get());
1758 if (instr->operands[0].bytes() == 4) {
1759 /* D16 subdword split */
1760 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1761 ctx.info[instr->definitions[1].tempId()].set_extract(instr.get());
1762 }
1763 }
1764 break;
1765 }
1766
1767 Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1768 unsigned split_offset = 0;
1769 unsigned vec_offset = 0;
1770 unsigned vec_index = 0;
1771 for (unsigned i = 0; i < instr->definitions.size();
1772 split_offset += instr->definitions[i++].bytes()) {
1773 while (vec_offset < split_offset && vec_index < vec->operands.size())
1774 vec_offset += vec->operands[vec_index++].bytes();
1775
1776 if (vec_offset != split_offset ||
1777 vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
1778 continue;
1779
1780 Operand vec_op = vec->operands[vec_index];
1781 if (vec_op.isConstant()) {
1782 ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->gfx_level,
1783 vec_op.constantValue64());
1784 } else if (vec_op.isUndefined()) {
1785 ctx.info[instr->definitions[i].tempId()].set_undefined();
1786 } else {
1787 assert(vec_op.isTemp());
1788 ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
1789 }
1790 }
1791 break;
1792 }
1793 case aco_opcode::p_extract_vector: { /* mov */
1794 ssa_info& info = ctx.info[instr->operands[0].tempId()];
1795 const unsigned index = instr->operands[1].constantValue();
1796 const unsigned dst_offset = index * instr->definitions[0].bytes();
1797
1798 if (info.is_vec()) {
1799 /* check if we index directly into a vector element */
1800 Instruction* vec = info.instr;
1801 unsigned offset = 0;
1802
1803 for (const Operand& op : vec->operands) {
1804 if (offset < dst_offset) {
1805 offset += op.bytes();
1806 continue;
1807 } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
1808 break;
1809 }
1810 instr->operands[0] = op;
1811 break;
1812 }
1813 } else if (info.is_constant_or_literal(32)) {
1814 /* propagate constants */
1815 uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
1816 uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
1817 instr->operands[0] =
1818 Operand::get_const(ctx.program->gfx_level, val, instr->definitions[0].bytes());
1819 ;
1820 }
1821
1822 if (instr->operands[0].bytes() != instr->definitions[0].bytes()) {
1823 if (instr->operands[0].size() != 1)
1824 break;
1825
1826 if (index == 0)
1827 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1828 else
1829 ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1830 break;
1831 }
1832
1833 /* convert this extract into a copy instruction */
1834 instr->opcode = aco_opcode::p_parallelcopy;
1835 instr->operands.pop_back();
1836 FALLTHROUGH;
1837 }
1838 case aco_opcode::p_parallelcopy: /* propagate */
1839 if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_vec() &&
1840 instr->operands[0].regClass() != instr->definitions[0].regClass()) {
1841 /* We might not be able to copy-propagate if it's a SGPR->VGPR copy, so
1842 * duplicate the vector instead.
1843 */
1844 Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1845 aco_ptr<Instruction> old_copy = std::move(instr);
1846
1847 instr.reset(create_instruction<Pseudo_instruction>(
1848 aco_opcode::p_create_vector, Format::PSEUDO, vec->operands.size(), 1));
1849 instr->definitions[0] = old_copy->definitions[0];
1850 std::copy(vec->operands.begin(), vec->operands.end(), instr->operands.begin());
1851 for (unsigned i = 0; i < vec->operands.size(); i++) {
1852 Operand& op = instr->operands[i];
1853 if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
1854 ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
1855 op.setTemp(ctx.info[op.tempId()].temp);
1856 }
1857 ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1858 break;
1859 }
1860 FALLTHROUGH;
1861 case aco_opcode::p_as_uniform:
1862 if (instr->definitions[0].isFixed()) {
1863 /* don't copy-propagate copies into fixed registers */
1864 } else if (instr->operands[0].isConstant()) {
1865 ctx.info[instr->definitions[0].tempId()].set_constant(
1866 ctx.program->gfx_level, instr->operands[0].constantValue64());
1867 } else if (instr->operands[0].isTemp()) {
1868 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1869 if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
1870 ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1871 } else {
1872 assert(instr->operands[0].isFixed());
1873 }
1874 break;
1875 case aco_opcode::v_mov_b32:
1876 if (instr->isDPP16()) {
1877 /* anything else doesn't make sense in SSA */
1878 assert(instr->dpp16().row_mask == 0xf && instr->dpp16().bank_mask == 0xf);
1879 ctx.info[instr->definitions[0].tempId()].set_dpp16(instr.get());
1880 } else if (instr->isDPP8()) {
1881 ctx.info[instr->definitions[0].tempId()].set_dpp8(instr.get());
1882 }
1883 break;
1884 case aco_opcode::p_is_helper:
1885 if (!ctx.program->needs_wqm)
1886 ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1887 break;
1888 case aco_opcode::v_mul_f64: ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); break;
1889 case aco_opcode::v_mul_f16:
1890 case aco_opcode::v_mul_f32:
1891 case aco_opcode::v_mul_legacy_f32: { /* omod */
1892 ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
1893
1894 /* TODO: try to move the negate/abs modifier to the consumer instead */
1895 bool uses_mods = instr->usesModifiers();
1896 bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
1897
1898 for (unsigned i = 0; i < 2; i++) {
1899 if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
1900 if (!instr->isDPP() && !instr->isSDWA() && !instr->valu().opsel &&
1901 (instr->operands[!i].constantEquals(fp16 ? 0x3c00 : 0x3f800000) || /* 1.0 */
1902 instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u))) { /* -1.0 */
1903 bool neg1 = instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u);
1904
1905 VALU_instruction* valu = &instr->valu();
1906 if (valu->abs[!i] || valu->neg[!i] || valu->omod)
1907 continue;
1908
1909 bool abs = valu->abs[i];
1910 bool neg = neg1 ^ valu->neg[i];
1911 Temp other = instr->operands[i].getTemp();
1912
1913 if (valu->clamp) {
1914 if (!abs && !neg && other.type() == RegType::vgpr)
1915 ctx.info[other.id()].set_clamp(instr.get());
1916 continue;
1917 }
1918
1919 if (abs && neg && other.type() == RegType::vgpr)
1920 ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
1921 else if (abs && !neg && other.type() == RegType::vgpr)
1922 ctx.info[instr->definitions[0].tempId()].set_abs(other);
1923 else if (!abs && neg && other.type() == RegType::vgpr)
1924 ctx.info[instr->definitions[0].tempId()].set_neg(other);
1925 else if (!abs && !neg)
1926 ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
1927 } else if (uses_mods || ((fp16 ? ctx.fp_mode.preserve_signed_zero_inf_nan16_64
1928 : ctx.fp_mode.preserve_signed_zero_inf_nan32) &&
1929 instr->opcode != aco_opcode::v_mul_legacy_f32)) {
1930 continue; /* omod uses a legacy multiplication. */
1931 } else if (instr->operands[!i].constantValue() == 0u) { /* 0.0 */
1932 ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1933 } else if ((fp16 ? ctx.fp_mode.denorm16_64 : ctx.fp_mode.denorm32) != fp_denorm_flush) {
1934 /* omod has no effect if denormals are enabled. */
1935 continue;
1936 } else if (instr->operands[!i].constantValue() ==
1937 (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */
1938 ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
1939 } else if (instr->operands[!i].constantValue() ==
1940 (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */
1941 ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
1942 } else if (instr->operands[!i].constantValue() ==
1943 (fp16 ? 0x3800 : 0x3f000000)) { /* 0.5 */
1944 ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
1945 } else {
1946 continue;
1947 }
1948 break;
1949 }
1950 }
1951 break;
1952 }
1953 case aco_opcode::v_mul_lo_u16:
1954 case aco_opcode::v_mul_lo_u16_e64:
1955 case aco_opcode::v_mul_u32_u24:
1956 ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1957 break;
1958 case aco_opcode::v_med3_f16:
1959 case aco_opcode::v_med3_f32: { /* clamp */
1960 unsigned idx;
1961 if (detect_clamp(instr.get(), &idx) && !instr->valu().abs && !instr->valu().neg)
1962 ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
1963 break;
1964 }
1965 case aco_opcode::v_cndmask_b32:
1966 if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(0xFFFFFFFF))
1967 ctx.info[instr->definitions[0].tempId()].set_vcc(instr->operands[2].getTemp());
1968 else if (instr->operands[0].constantEquals(0) &&
1969 instr->operands[1].constantEquals(0x3f800000u))
1970 ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
1971 else if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(1))
1972 ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
1973
1974 break;
1975 case aco_opcode::v_cmp_lg_u32:
1976 if (instr->format == Format::VOPC && /* don't optimize VOP3 / SDWA / DPP */
1977 instr->operands[0].constantEquals(0) && instr->operands[1].isTemp() &&
1978 ctx.info[instr->operands[1].tempId()].is_vcc())
1979 ctx.info[instr->definitions[0].tempId()].set_temp(
1980 ctx.info[instr->operands[1].tempId()].temp);
1981 break;
1982 case aco_opcode::p_linear_phi: {
1983 /* lower_bool_phis() can create phis like this */
1984 bool all_same_temp = instr->operands[0].isTemp();
1985 /* this check is needed when moving uniform loop counters out of a divergent loop */
1986 if (all_same_temp)
1987 all_same_temp = instr->definitions[0].regClass() == instr->operands[0].regClass();
1988 for (unsigned i = 1; all_same_temp && (i < instr->operands.size()); i++) {
1989 if (!instr->operands[i].isTemp() ||
1990 instr->operands[i].tempId() != instr->operands[0].tempId())
1991 all_same_temp = false;
1992 }
1993 if (all_same_temp) {
1994 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1995 } else {
1996 bool all_undef = instr->operands[0].isUndefined();
1997 for (unsigned i = 1; all_undef && (i < instr->operands.size()); i++) {
1998 if (!instr->operands[i].isUndefined())
1999 all_undef = false;
2000 }
2001 if (all_undef)
2002 ctx.info[instr->definitions[0].tempId()].set_undefined();
2003 }
2004 break;
2005 }
2006 case aco_opcode::v_add_u32:
2007 case aco_opcode::v_add_co_u32:
2008 case aco_opcode::v_add_co_u32_e64:
2009 case aco_opcode::s_add_i32:
2010 case aco_opcode::s_add_u32:
2011 case aco_opcode::v_subbrev_co_u32:
2012 case aco_opcode::v_sub_u32:
2013 case aco_opcode::v_sub_i32:
2014 case aco_opcode::v_sub_co_u32:
2015 case aco_opcode::v_sub_co_u32_e64:
2016 case aco_opcode::s_sub_u32:
2017 case aco_opcode::s_sub_i32:
2018 case aco_opcode::v_subrev_u32:
2019 case aco_opcode::v_subrev_co_u32:
2020 case aco_opcode::v_subrev_co_u32_e64:
2021 ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
2022 break;
2023 case aco_opcode::s_not_b32:
2024 case aco_opcode::s_not_b64:
2025 if (!instr->operands[0].isTemp()) {
2026 } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
2027 ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
2028 ctx.info[instr->definitions[1].tempId()].set_scc_invert(
2029 ctx.info[instr->operands[0].tempId()].temp);
2030 } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
2031 ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
2032 ctx.info[instr->definitions[1].tempId()].set_scc_invert(
2033 ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
2034 }
2035 ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2036 break;
2037 case aco_opcode::s_and_b32:
2038 case aco_opcode::s_and_b64:
2039 if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
2040 if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
2041 /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a
2042 * uniform bool into divergent */
2043 ctx.info[instr->definitions[1].tempId()].set_temp(
2044 ctx.info[instr->operands[0].tempId()].temp);
2045 ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
2046 ctx.info[instr->operands[0].tempId()].temp);
2047 break;
2048 } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
2049 /* Try to get rid of the superfluous s_and_b64, since the uniform bitwise instruction
2050 * already produces the same SCC */
2051 ctx.info[instr->definitions[1].tempId()].set_temp(
2052 ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
2053 ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
2054 ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
2055 break;
2056 } else if ((ctx.program->stage.num_sw_stages() > 1 ||
2057 ctx.program->stage.hw == AC_HW_NEXT_GEN_GEOMETRY_SHADER) &&
2058 instr->pass_flags == 1) {
2059 /* In case of merged shaders, pass_flags=1 means that all lanes are active (exec=-1), so
2060 * s_and is unnecessary. */
2061 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
2062 break;
2063 }
2064 }
2065 FALLTHROUGH;
2066 case aco_opcode::s_or_b32:
2067 case aco_opcode::s_or_b64:
2068 case aco_opcode::s_xor_b32:
2069 case aco_opcode::s_xor_b64:
2070 if (std::all_of(instr->operands.begin(), instr->operands.end(),
2071 [&ctx](const Operand& op)
2072 {
2073 return op.isTemp() && (ctx.info[op.tempId()].is_uniform_bool() ||
2074 ctx.info[op.tempId()].is_uniform_bitwise());
2075 })) {
2076 ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
2077 }
2078 ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2079 break;
2080 case aco_opcode::s_lshl_b32:
2081 case aco_opcode::v_or_b32:
2082 case aco_opcode::v_lshlrev_b32:
2083 case aco_opcode::v_bcnt_u32_b32:
2084 case aco_opcode::v_and_b32:
2085 case aco_opcode::v_xor_b32:
2086 case aco_opcode::v_not_b32:
2087 ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2088 break;
2089 case aco_opcode::v_min_f32:
2090 case aco_opcode::v_min_f16:
2091 case aco_opcode::v_min_u32:
2092 case aco_opcode::v_min_i32:
2093 case aco_opcode::v_min_u16:
2094 case aco_opcode::v_min_i16:
2095 case aco_opcode::v_min_u16_e64:
2096 case aco_opcode::v_min_i16_e64:
2097 case aco_opcode::v_max_f32:
2098 case aco_opcode::v_max_f16:
2099 case aco_opcode::v_max_u32:
2100 case aco_opcode::v_max_i32:
2101 case aco_opcode::v_max_u16:
2102 case aco_opcode::v_max_i16:
2103 case aco_opcode::v_max_u16_e64:
2104 case aco_opcode::v_max_i16_e64:
2105 ctx.info[instr->definitions[0].tempId()].set_minmax(instr.get());
2106 break;
2107 case aco_opcode::s_cselect_b64:
2108 case aco_opcode::s_cselect_b32:
2109 if (instr->operands[0].constantEquals((unsigned)-1) && instr->operands[1].constantEquals(0)) {
2110 /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
2111 ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
2112 }
2113 if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
2114 /* Flip the operands to get rid of the scc_invert instruction */
2115 std::swap(instr->operands[0], instr->operands[1]);
2116 instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
2117 }
2118 break;
2119 case aco_opcode::s_mul_i32:
2120 /* Testing every uint32_t shows that 0x3f800000*n is never a denormal.
2121 * This pattern is created from a uniform nir_op_b2f. */
2122 if (instr->operands[0].constantEquals(0x3f800000u))
2123 ctx.info[instr->definitions[0].tempId()].set_canonicalized();
2124 break;
2125 case aco_opcode::p_extract: {
2126 if (instr->definitions[0].bytes() == 4) {
2127 ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2128 if (instr->operands[0].regClass() == v1 && parse_insert(instr.get()))
2129 ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2130 }
2131 break;
2132 }
2133 case aco_opcode::p_insert: {
2134 if (instr->operands[0].bytes() == 4) {
2135 if (instr->operands[0].regClass() == v1)
2136 ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2137 if (parse_extract(instr.get()))
2138 ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2139 ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2140 }
2141 break;
2142 }
2143 case aco_opcode::ds_read_u8:
2144 case aco_opcode::ds_read_u8_d16:
2145 case aco_opcode::ds_read_u16:
2146 case aco_opcode::ds_read_u16_d16: {
2147 ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2148 break;
2149 }
2150 case aco_opcode::v_mbcnt_lo_u32_b32: {
2151 if (instr->operands[0].constantEquals(-1) && instr->operands[1].constantEquals(0)) {
2152 if (ctx.program->wave_size == 32)
2153 ctx.info[instr->definitions[0].tempId()].set_subgroup_invocation(instr.get());
2154 else
2155 ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2156 }
2157 break;
2158 }
2159 case aco_opcode::v_mbcnt_hi_u32_b32:
2160 case aco_opcode::v_mbcnt_hi_u32_b32_e64: {
2161 if (instr->operands[0].constantEquals(-1) && instr->operands[1].isTemp() &&
2162 ctx.info[instr->operands[1].tempId()].is_usedef()) {
2163 Instruction* usedef_instr = ctx.info[instr->operands[1].tempId()].instr;
2164 if (usedef_instr->opcode == aco_opcode::v_mbcnt_lo_u32_b32 &&
2165 usedef_instr->operands[0].constantEquals(-1) &&
2166 usedef_instr->operands[1].constantEquals(0))
2167 ctx.info[instr->definitions[0].tempId()].set_subgroup_invocation(instr.get());
2168 }
2169 break;
2170 }
2171 case aco_opcode::v_cvt_f16_f32: {
2172 if (instr->operands[0].isTemp())
2173 ctx.info[instr->operands[0].tempId()].set_f2f16(instr.get());
2174 break;
2175 }
2176 case aco_opcode::v_cvt_f32_f16: {
2177 if (instr->operands[0].isTemp())
2178 ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
2179 break;
2180 }
2181 default: break;
2182 }
2183
2184 /* Don't remove label_extract if we can't apply the extract to
2185 * neg/abs instructions because we'll likely combine it into another valu. */
2186 if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)))
2187 check_sdwa_extract(ctx, instr);
2188 }
2189
2190 unsigned
original_temp_id(opt_ctx & ctx,Temp tmp)2191 original_temp_id(opt_ctx& ctx, Temp tmp)
2192 {
2193 if (ctx.info[tmp.id()].is_temp())
2194 return ctx.info[tmp.id()].temp.id();
2195 else
2196 return tmp.id();
2197 }
2198
2199 void
decrease_op_uses_if_dead(opt_ctx & ctx,Instruction * instr)2200 decrease_op_uses_if_dead(opt_ctx& ctx, Instruction* instr)
2201 {
2202 if (is_dead(ctx.uses, instr)) {
2203 for (const Operand& op : instr->operands) {
2204 if (op.isTemp())
2205 ctx.uses[op.tempId()]--;
2206 }
2207 }
2208 }
2209
2210 void
decrease_uses(opt_ctx & ctx,Instruction * instr)2211 decrease_uses(opt_ctx& ctx, Instruction* instr)
2212 {
2213 ctx.uses[instr->definitions[0].tempId()]--;
2214 decrease_op_uses_if_dead(ctx, instr);
2215 }
2216
2217 Operand
copy_operand(opt_ctx & ctx,Operand op)2218 copy_operand(opt_ctx& ctx, Operand op)
2219 {
2220 if (op.isTemp())
2221 ctx.uses[op.tempId()]++;
2222 return op;
2223 }
2224
2225 Instruction*
follow_operand(opt_ctx & ctx,Operand op,bool ignore_uses=false)2226 follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false)
2227 {
2228 if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels))
2229 return nullptr;
2230 if (!ignore_uses && ctx.uses[op.tempId()] > 1)
2231 return nullptr;
2232
2233 Instruction* instr = ctx.info[op.tempId()].instr;
2234
2235 if (instr->definitions.size() == 2) {
2236 assert(instr->definitions[0].isTemp() && instr->definitions[0].tempId() == op.tempId());
2237 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2238 return nullptr;
2239 }
2240
2241 for (Operand& operand : instr->operands) {
2242 if (fixed_to_exec(operand))
2243 return nullptr;
2244 }
2245
2246 return instr;
2247 }
2248
2249 /* s_or_b64(neq(a, a), neq(b, b)) -> v_cmp_u_f32(a, b)
2250 * s_and_b64(eq(a, a), eq(b, b)) -> v_cmp_o_f32(a, b) */
2251 bool
combine_ordering_test(opt_ctx & ctx,aco_ptr<Instruction> & instr)2252 combine_ordering_test(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2253 {
2254 if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2255 return false;
2256 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2257 return false;
2258
2259 bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2260
2261 bitarray8 opsel = 0;
2262 Instruction* op_instr[2];
2263 Temp op[2];
2264
2265 unsigned bitsize = 0;
2266 for (unsigned i = 0; i < 2; i++) {
2267 op_instr[i] = follow_operand(ctx, instr->operands[i], true);
2268 if (!op_instr[i])
2269 return false;
2270
2271 aco_opcode expected_cmp = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
2272 unsigned op_bitsize = get_cmp_bitsize(op_instr[i]->opcode);
2273
2274 if (get_f32_cmp(op_instr[i]->opcode) != expected_cmp)
2275 return false;
2276 if (bitsize && op_bitsize != bitsize)
2277 return false;
2278 if (!op_instr[i]->operands[0].isTemp() || !op_instr[i]->operands[1].isTemp())
2279 return false;
2280
2281 if (op_instr[i]->isSDWA() || op_instr[i]->isDPP())
2282 return false;
2283
2284 VALU_instruction& valu = op_instr[i]->valu();
2285 if (valu.neg[0] != valu.neg[1] || valu.abs[0] != valu.abs[1] ||
2286 valu.opsel[0] != valu.opsel[1])
2287 return false;
2288 opsel[i] = valu.opsel[0];
2289
2290 Temp op0 = op_instr[i]->operands[0].getTemp();
2291 Temp op1 = op_instr[i]->operands[1].getTemp();
2292 if (original_temp_id(ctx, op0) != original_temp_id(ctx, op1))
2293 return false;
2294
2295 op[i] = op1;
2296 bitsize = op_bitsize;
2297 }
2298
2299 if (op[1].type() == RegType::sgpr) {
2300 std::swap(op[0], op[1]);
2301 opsel[0].swap(opsel[1]);
2302 }
2303 unsigned num_sgprs = (op[0].type() == RegType::sgpr) + (op[1].type() == RegType::sgpr);
2304 if (num_sgprs > (ctx.program->gfx_level >= GFX10 ? 2 : 1))
2305 return false;
2306
2307 aco_opcode new_op = aco_opcode::num_opcodes;
2308 switch (bitsize) {
2309 case 16: new_op = is_or ? aco_opcode::v_cmp_u_f16 : aco_opcode::v_cmp_o_f16; break;
2310 case 32: new_op = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32; break;
2311 case 64: new_op = is_or ? aco_opcode::v_cmp_u_f64 : aco_opcode::v_cmp_o_f64; break;
2312 }
2313 bool needs_vop3 = num_sgprs > 1 || (opsel[0] && op[0].type() != RegType::vgpr);
2314 VALU_instruction* new_instr = create_instruction<VALU_instruction>(
2315 new_op, needs_vop3 ? asVOP3(Format::VOPC) : Format::VOPC, 2, 1);
2316
2317 new_instr->opsel = opsel;
2318 new_instr->operands[0] = copy_operand(ctx, Operand(op[0]));
2319 new_instr->operands[1] = copy_operand(ctx, Operand(op[1]));
2320 new_instr->definitions[0] = instr->definitions[0];
2321 new_instr->pass_flags = instr->pass_flags;
2322
2323 decrease_uses(ctx, op_instr[0]);
2324 decrease_uses(ctx, op_instr[1]);
2325
2326 ctx.info[instr->definitions[0].tempId()].label = 0;
2327 ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2328
2329 instr.reset(new_instr);
2330
2331 return true;
2332 }
2333
2334 /* s_or_b64(v_cmp_u_f32(a, b), cmp(a, b)) -> get_unordered(cmp)(a, b)
2335 * s_and_b64(v_cmp_o_f32(a, b), cmp(a, b)) -> get_ordered(cmp)(a, b) */
2336 bool
combine_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2337 combine_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2338 {
2339 if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2340 return false;
2341 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2342 return false;
2343
2344 bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2345 aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
2346
2347 Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2348 Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2349 if (!nan_test || !cmp)
2350 return false;
2351 if (nan_test->isSDWA() || cmp->isSDWA())
2352 return false;
2353
2354 if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2355 std::swap(nan_test, cmp);
2356 else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2357 return false;
2358
2359 if (!is_fp_cmp(cmp->opcode) || get_cmp_bitsize(cmp->opcode) != get_cmp_bitsize(nan_test->opcode))
2360 return false;
2361
2362 if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2363 return false;
2364 if (!cmp->operands[0].isTemp() || !cmp->operands[1].isTemp())
2365 return false;
2366
2367 unsigned prop_cmp0 = original_temp_id(ctx, cmp->operands[0].getTemp());
2368 unsigned prop_cmp1 = original_temp_id(ctx, cmp->operands[1].getTemp());
2369 unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2370 unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2371 VALU_instruction& cmp_valu = cmp->valu();
2372 VALU_instruction& nan_valu = nan_test->valu();
2373 if ((prop_cmp0 != prop_nan0 || cmp_valu.opsel[0] != nan_valu.opsel[0]) &&
2374 (prop_cmp0 != prop_nan1 || cmp_valu.opsel[0] != nan_valu.opsel[1]))
2375 return false;
2376 if ((prop_cmp1 != prop_nan0 || cmp_valu.opsel[1] != nan_valu.opsel[0]) &&
2377 (prop_cmp1 != prop_nan1 || cmp_valu.opsel[1] != nan_valu.opsel[1]))
2378 return false;
2379 if (prop_cmp0 == prop_cmp1 && cmp_valu.opsel[0] == cmp_valu.opsel[1])
2380 return false;
2381
2382 aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2383 VALU_instruction* new_instr = create_instruction<VALU_instruction>(
2384 new_op, cmp->isVOP3() ? asVOP3(Format::VOPC) : Format::VOPC, 2, 1);
2385 new_instr->neg = cmp_valu.neg;
2386 new_instr->abs = cmp_valu.abs;
2387 new_instr->clamp = cmp_valu.clamp;
2388 new_instr->omod = cmp_valu.omod;
2389 new_instr->opsel = cmp_valu.opsel;
2390 new_instr->operands[0] = copy_operand(ctx, cmp->operands[0]);
2391 new_instr->operands[1] = copy_operand(ctx, cmp->operands[1]);
2392 new_instr->definitions[0] = instr->definitions[0];
2393 new_instr->pass_flags = instr->pass_flags;
2394
2395 decrease_uses(ctx, nan_test);
2396 decrease_uses(ctx, cmp);
2397
2398 ctx.info[instr->definitions[0].tempId()].label = 0;
2399 ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2400
2401 instr.reset(new_instr);
2402
2403 return true;
2404 }
2405
2406 /* Optimize v_cmp of constant with subgroup invocation to a constant mask.
2407 * Ideally, we can trade v_cmp for a constant (or literal).
2408 * In a less ideal case, we trade v_cmp for a SALU instruction, which is still a win.
2409 */
2410 bool
optimize_cmp_subgroup_invocation(opt_ctx & ctx,aco_ptr<Instruction> & instr)2411 optimize_cmp_subgroup_invocation(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2412 {
2413 /* This optimization only applies to VOPC with 2 operands. */
2414 if (instr->operands.size() != 2)
2415 return false;
2416
2417 /* Find the constant operand or return early if there isn't one. */
2418 const int const_op_idx = instr->operands[0].isConstant() ? 0
2419 : instr->operands[1].isConstant() ? 1
2420 : -1;
2421 if (const_op_idx == -1)
2422 return false;
2423
2424 /* Find the operand that has the subgroup invocation. */
2425 const int mbcnt_op_idx = 1 - const_op_idx;
2426 const Operand mbcnt_op = instr->operands[mbcnt_op_idx];
2427 if (!mbcnt_op.isTemp() || !ctx.info[mbcnt_op.tempId()].is_subgroup_invocation())
2428 return false;
2429
2430 /* Adjust opcode so we don't have to care about const_op_idx below. */
2431 const aco_opcode op = const_op_idx == 0 ? get_swapped(instr->opcode) : instr->opcode;
2432 const unsigned wave_size = ctx.program->wave_size;
2433 const unsigned val = instr->operands[const_op_idx].constantValue();
2434
2435 /* Find suitable constant bitmask corresponding to the value. */
2436 unsigned first_bit = 0, num_bits = 0;
2437 switch (op) {
2438 case aco_opcode::v_cmp_eq_u32:
2439 case aco_opcode::v_cmp_eq_i32:
2440 first_bit = val;
2441 num_bits = val >= wave_size ? 0 : 1;
2442 break;
2443 case aco_opcode::v_cmp_le_u32:
2444 case aco_opcode::v_cmp_le_i32:
2445 first_bit = 0;
2446 num_bits = val >= wave_size ? wave_size : (val + 1);
2447 break;
2448 case aco_opcode::v_cmp_lt_u32:
2449 case aco_opcode::v_cmp_lt_i32:
2450 first_bit = 0;
2451 num_bits = val >= wave_size ? wave_size : val;
2452 break;
2453 case aco_opcode::v_cmp_ge_u32:
2454 case aco_opcode::v_cmp_ge_i32:
2455 first_bit = val;
2456 num_bits = val >= wave_size ? 0 : (wave_size - val);
2457 break;
2458 case aco_opcode::v_cmp_gt_u32:
2459 case aco_opcode::v_cmp_gt_i32:
2460 first_bit = val + 1;
2461 num_bits = val >= wave_size ? 0 : (wave_size - val - 1);
2462 break;
2463 default: return false;
2464 }
2465
2466 Instruction* cpy = NULL;
2467 const uint64_t mask = BITFIELD64_RANGE(first_bit, num_bits);
2468 if (wave_size == 64 && mask > 0x7fffffff && mask != -1ull) {
2469 /* Mask can't be represented as a 64-bit constant or literal, use s_bfm_b64. */
2470 cpy = create_instruction<SOP2_instruction>(aco_opcode::s_bfm_b64, Format::SOP2, 2, 1);
2471 cpy->operands[0] = Operand::c32(num_bits);
2472 cpy->operands[1] = Operand::c32(first_bit);
2473 } else {
2474 /* Copy mask as a literal constant. */
2475 cpy =
2476 create_instruction<Pseudo_instruction>(aco_opcode::p_parallelcopy, Format::PSEUDO, 1, 1);
2477 cpy->operands[0] = wave_size == 32 ? Operand::c32((uint32_t)mask) : Operand::c64(mask);
2478 }
2479
2480 cpy->definitions[0] = instr->definitions[0];
2481 ctx.info[instr->definitions[0].tempId()].label = 0;
2482 decrease_uses(ctx, ctx.info[mbcnt_op.tempId()].instr);
2483 instr.reset(cpy);
2484
2485 return true;
2486 }
2487
2488 bool
is_operand_constant(opt_ctx & ctx,Operand op,unsigned bit_size,uint64_t * value)2489 is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value)
2490 {
2491 if (op.isConstant()) {
2492 *value = op.constantValue64();
2493 return true;
2494 } else if (op.isTemp()) {
2495 unsigned id = original_temp_id(ctx, op.getTemp());
2496 if (!ctx.info[id].is_constant_or_literal(bit_size))
2497 return false;
2498 *value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
2499 return true;
2500 }
2501 return false;
2502 }
2503
2504 bool
is_constant_nan(uint64_t value,unsigned bit_size)2505 is_constant_nan(uint64_t value, unsigned bit_size)
2506 {
2507 if (bit_size == 16)
2508 return ((value >> 10) & 0x1f) == 0x1f && (value & 0x3ff);
2509 else if (bit_size == 32)
2510 return ((value >> 23) & 0xff) == 0xff && (value & 0x7fffff);
2511 else
2512 return ((value >> 52) & 0x7ff) == 0x7ff && (value & 0xfffffffffffff);
2513 }
2514
2515 /* s_or_b64(v_cmp_neq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_unordered(cmp)(a, b)
2516 * s_and_b64(v_cmp_eq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_ordered(cmp)(a, b) */
2517 bool
combine_constant_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2518 combine_constant_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2519 {
2520 if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2521 return false;
2522 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2523 return false;
2524
2525 bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2526
2527 Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2528 Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2529
2530 if (!nan_test || !cmp || nan_test->isSDWA() || cmp->isSDWA() || nan_test->isDPP() ||
2531 cmp->isDPP())
2532 return false;
2533
2534 aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
2535 if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2536 std::swap(nan_test, cmp);
2537 else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2538 return false;
2539
2540 unsigned bit_size = get_cmp_bitsize(cmp->opcode);
2541 if (!is_fp_cmp(cmp->opcode) || get_cmp_bitsize(nan_test->opcode) != bit_size)
2542 return false;
2543
2544 if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2545 return false;
2546 if (!cmp->operands[0].isTemp() && !cmp->operands[1].isTemp())
2547 return false;
2548
2549 unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2550 unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2551 if (prop_nan0 != prop_nan1)
2552 return false;
2553
2554 VALU_instruction& vop3 = nan_test->valu();
2555 if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel[0] != vop3.opsel[1])
2556 return false;
2557
2558 int constant_operand = -1;
2559 for (unsigned i = 0; i < 2; i++) {
2560 if (cmp->operands[i].isTemp() &&
2561 original_temp_id(ctx, cmp->operands[i].getTemp()) == prop_nan0 &&
2562 cmp->valu().opsel[i] == nan_test->valu().opsel[0]) {
2563 constant_operand = !i;
2564 break;
2565 }
2566 }
2567 if (constant_operand == -1)
2568 return false;
2569
2570 uint64_t constant_value;
2571 if (!is_operand_constant(ctx, cmp->operands[constant_operand], bit_size, &constant_value))
2572 return false;
2573 if (is_constant_nan(constant_value >> (cmp->valu().opsel[constant_operand] * 16), bit_size))
2574 return false;
2575
2576 aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2577 Instruction* new_instr = create_instruction<VALU_instruction>(new_op, cmp->format, 2, 1);
2578 new_instr->valu().neg = cmp->valu().neg;
2579 new_instr->valu().abs = cmp->valu().abs;
2580 new_instr->valu().clamp = cmp->valu().clamp;
2581 new_instr->valu().omod = cmp->valu().omod;
2582 new_instr->valu().opsel = cmp->valu().opsel;
2583 new_instr->operands[0] = copy_operand(ctx, cmp->operands[0]);
2584 new_instr->operands[1] = copy_operand(ctx, cmp->operands[1]);
2585 new_instr->definitions[0] = instr->definitions[0];
2586 new_instr->pass_flags = instr->pass_flags;
2587
2588 decrease_uses(ctx, nan_test);
2589 decrease_uses(ctx, cmp);
2590
2591 ctx.info[instr->definitions[0].tempId()].label = 0;
2592 ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2593
2594 instr.reset(new_instr);
2595
2596 return true;
2597 }
2598
2599 /* s_not(cmp(a, b)) -> get_inverse(cmp)(a, b) */
2600 bool
combine_inverse_comparison(opt_ctx & ctx,aco_ptr<Instruction> & instr)2601 combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2602 {
2603 if (ctx.uses[instr->definitions[1].tempId()])
2604 return false;
2605 if (!instr->operands[0].isTemp() || ctx.uses[instr->operands[0].tempId()] != 1)
2606 return false;
2607
2608 Instruction* cmp = follow_operand(ctx, instr->operands[0]);
2609 if (!cmp)
2610 return false;
2611
2612 aco_opcode new_opcode = get_inverse(cmp->opcode);
2613 if (new_opcode == aco_opcode::num_opcodes)
2614 return false;
2615
2616 /* Invert compare instruction and assign this instruction's definition */
2617 cmp->opcode = new_opcode;
2618 ctx.info[instr->definitions[0].tempId()] = ctx.info[cmp->definitions[0].tempId()];
2619 std::swap(instr->definitions[0], cmp->definitions[0]);
2620
2621 ctx.uses[instr->operands[0].tempId()]--;
2622 return true;
2623 }
2624
2625 /* op1(op2(1, 2), 0) if swap = false
2626 * op1(0, op2(1, 2)) if swap = true */
2627 bool
match_op3_for_vop3(opt_ctx & ctx,aco_opcode op1,aco_opcode op2,Instruction * op1_instr,bool swap,const char * shuffle_str,Operand operands[3],bitarray8 & neg,bitarray8 & abs,bitarray8 & opsel,bool * op1_clamp,uint8_t * op1_omod,bool * inbetween_neg,bool * inbetween_abs,bool * inbetween_opsel,bool * precise)2628 match_op3_for_vop3(opt_ctx& ctx, aco_opcode op1, aco_opcode op2, Instruction* op1_instr, bool swap,
2629 const char* shuffle_str, Operand operands[3], bitarray8& neg, bitarray8& abs,
2630 bitarray8& opsel, bool* op1_clamp, uint8_t* op1_omod, bool* inbetween_neg,
2631 bool* inbetween_abs, bool* inbetween_opsel, bool* precise)
2632 {
2633 /* checks */
2634 if (op1_instr->opcode != op1)
2635 return false;
2636
2637 Instruction* op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
2638 if (!op2_instr || op2_instr->opcode != op2)
2639 return false;
2640
2641 VALU_instruction* op1_valu = op1_instr->isVALU() ? &op1_instr->valu() : NULL;
2642 VALU_instruction* op2_valu = op2_instr->isVALU() ? &op2_instr->valu() : NULL;
2643
2644 if (op1_instr->isSDWA() || op2_instr->isSDWA())
2645 return false;
2646 if (op1_instr->isDPP() || op2_instr->isDPP())
2647 return false;
2648
2649 /* don't support inbetween clamp/omod */
2650 if (op2_valu && (op2_valu->clamp || op2_valu->omod))
2651 return false;
2652
2653 /* get operands and modifiers and check inbetween modifiers */
2654 *op1_clamp = op1_valu ? (bool)op1_valu->clamp : false;
2655 *op1_omod = op1_valu ? (unsigned)op1_valu->omod : 0u;
2656
2657 if (inbetween_neg)
2658 *inbetween_neg = op1_valu ? op1_valu->neg[swap] : false;
2659 else if (op1_valu && op1_valu->neg[swap])
2660 return false;
2661
2662 if (inbetween_abs)
2663 *inbetween_abs = op1_valu ? op1_valu->abs[swap] : false;
2664 else if (op1_valu && op1_valu->abs[swap])
2665 return false;
2666
2667 if (inbetween_opsel)
2668 *inbetween_opsel = op1_valu ? op1_valu->opsel[swap] : false;
2669 else if (op1_valu && op1_valu->opsel[swap])
2670 return false;
2671
2672 *precise = op1_instr->definitions[0].isPrecise() || op2_instr->definitions[0].isPrecise();
2673
2674 int shuffle[3];
2675 shuffle[shuffle_str[0] - '0'] = 0;
2676 shuffle[shuffle_str[1] - '0'] = 1;
2677 shuffle[shuffle_str[2] - '0'] = 2;
2678
2679 operands[shuffle[0]] = op1_instr->operands[!swap];
2680 neg[shuffle[0]] = op1_valu ? op1_valu->neg[!swap] : false;
2681 abs[shuffle[0]] = op1_valu ? op1_valu->abs[!swap] : false;
2682 opsel[shuffle[0]] = op1_valu ? op1_valu->opsel[!swap] : false;
2683
2684 for (unsigned i = 0; i < 2; i++) {
2685 operands[shuffle[i + 1]] = op2_instr->operands[i];
2686 neg[shuffle[i + 1]] = op2_valu ? op2_valu->neg[i] : false;
2687 abs[shuffle[i + 1]] = op2_valu ? op2_valu->abs[i] : false;
2688 opsel[shuffle[i + 1]] = op2_valu ? op2_valu->opsel[i] : false;
2689 }
2690
2691 /* check operands */
2692 if (!check_vop3_operands(ctx, 3, operands))
2693 return false;
2694
2695 return true;
2696 }
2697
2698 void
create_vop3_for_op3(opt_ctx & ctx,aco_opcode opcode,aco_ptr<Instruction> & instr,Operand operands[3],uint8_t neg,uint8_t abs,uint8_t opsel,bool clamp,unsigned omod)2699 create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>& instr,
2700 Operand operands[3], uint8_t neg, uint8_t abs, uint8_t opsel, bool clamp,
2701 unsigned omod)
2702 {
2703 VALU_instruction* new_instr = create_instruction<VALU_instruction>(opcode, Format::VOP3, 3, 1);
2704 new_instr->neg = neg;
2705 new_instr->abs = abs;
2706 new_instr->clamp = clamp;
2707 new_instr->omod = omod;
2708 new_instr->opsel = opsel;
2709 new_instr->operands[0] = operands[0];
2710 new_instr->operands[1] = operands[1];
2711 new_instr->operands[2] = operands[2];
2712 new_instr->definitions[0] = instr->definitions[0];
2713 new_instr->pass_flags = instr->pass_flags;
2714 ctx.info[instr->definitions[0].tempId()].label = 0;
2715
2716 instr.reset(new_instr);
2717 }
2718
2719 bool
combine_three_valu_op(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode op2,aco_opcode new_op,const char * shuffle,uint8_t ops)2720 combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op,
2721 const char* shuffle, uint8_t ops)
2722 {
2723 for (unsigned swap = 0; swap < 2; swap++) {
2724 if (!((1 << swap) & ops))
2725 continue;
2726
2727 Operand operands[3];
2728 bool clamp, precise;
2729 bitarray8 neg = 0, abs = 0, opsel = 0;
2730 uint8_t omod = 0;
2731 if (match_op3_for_vop3(ctx, instr->opcode, op2, instr.get(), swap, shuffle, operands, neg,
2732 abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2733 ctx.uses[instr->operands[swap].tempId()]--;
2734 create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
2735 return true;
2736 }
2737 }
2738 return false;
2739 }
2740
2741 /* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
2742 bool
combine_add_or_then_and_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr)2743 combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2744 {
2745 bool is_or = instr->opcode == aco_opcode::v_or_b32;
2746 aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
2747
2748 if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32,
2749 "120", 1 | 2))
2750 return true;
2751 if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32,
2752 "120", 1 | 2))
2753 return true;
2754 if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
2755 return true;
2756 if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
2757 return true;
2758
2759 if (instr->isSDWA() || instr->isDPP())
2760 return false;
2761
2762 /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2763 * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2764 * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
2765 * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
2766 */
2767 for (unsigned i = 0; i < 2; i++) {
2768 Instruction* extins = follow_operand(ctx, instr->operands[i]);
2769 if (!extins)
2770 continue;
2771
2772 aco_opcode op;
2773 Operand operands[3];
2774
2775 if (extins->opcode == aco_opcode::p_insert &&
2776 (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
2777 op = new_op_lshl;
2778 operands[1] =
2779 Operand::c32(extins->operands[1].constantValue() * extins->operands[2].constantValue());
2780 } else if (is_or &&
2781 (extins->opcode == aco_opcode::p_insert ||
2782 (extins->opcode == aco_opcode::p_extract &&
2783 extins->operands[3].constantEquals(0))) &&
2784 extins->operands[1].constantEquals(0)) {
2785 op = aco_opcode::v_and_or_b32;
2786 operands[1] = Operand::c32(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
2787 } else {
2788 continue;
2789 }
2790
2791 operands[0] = extins->operands[0];
2792 operands[2] = instr->operands[!i];
2793
2794 if (!check_vop3_operands(ctx, 3, operands))
2795 continue;
2796
2797 uint8_t neg = 0, abs = 0, opsel = 0, omod = 0;
2798 bool clamp = false;
2799 if (instr->isVOP3())
2800 clamp = instr->valu().clamp;
2801
2802 ctx.uses[instr->operands[i].tempId()]--;
2803 create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
2804 return true;
2805 }
2806
2807 return false;
2808 }
2809
2810 /* v_xor(a, s_not(b)) -> v_xnor(a, b)
2811 * v_xor(a, v_not(b)) -> v_xnor(a, b)
2812 */
2813 bool
combine_xor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)2814 combine_xor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2815 {
2816 if (instr->usesModifiers())
2817 return false;
2818
2819 for (unsigned i = 0; i < 2; i++) {
2820 Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
2821 if (!op_instr ||
2822 (op_instr->opcode != aco_opcode::v_not_b32 &&
2823 op_instr->opcode != aco_opcode::s_not_b32) ||
2824 op_instr->usesModifiers() || op_instr->operands[0].isLiteral())
2825 continue;
2826
2827 instr->opcode = aco_opcode::v_xnor_b32;
2828 instr->operands[i] = copy_operand(ctx, op_instr->operands[0]);
2829 decrease_uses(ctx, op_instr);
2830 if (instr->operands[0].isOfType(RegType::vgpr))
2831 std::swap(instr->operands[0], instr->operands[1]);
2832 if (!instr->operands[1].isOfType(RegType::vgpr))
2833 instr->format = asVOP3(instr->format);
2834
2835 return true;
2836 }
2837
2838 return false;
2839 }
2840
2841 /* v_not(v_xor(a, b)) -> v_xnor(a, b) */
2842 bool
combine_not_xor(opt_ctx & ctx,aco_ptr<Instruction> & instr)2843 combine_not_xor(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2844 {
2845 if (instr->usesModifiers())
2846 return false;
2847
2848 Instruction* op_instr = follow_operand(ctx, instr->operands[0]);
2849 if (!op_instr || op_instr->opcode != aco_opcode::v_xor_b32 || op_instr->isSDWA())
2850 return false;
2851
2852 ctx.uses[instr->operands[0].tempId()]--;
2853 std::swap(instr->definitions[0], op_instr->definitions[0]);
2854 op_instr->opcode = aco_opcode::v_xnor_b32;
2855
2856 return true;
2857 }
2858
2859 bool
combine_minmax(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode opposite,aco_opcode op3src,aco_opcode minmax)2860 combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode op3src,
2861 aco_opcode minmax)
2862 {
2863 /* TODO: this can handle SDWA min/max instructions by using opsel */
2864
2865 /* min(min(a, b), c) -> min3(a, b, c)
2866 * max(max(a, b), c) -> max3(a, b, c)
2867 * gfx11: min(-min(a, b), c) -> maxmin(-a, -b, c)
2868 * gfx11: max(-max(a, b), c) -> minmax(-a, -b, c)
2869 */
2870 for (unsigned swap = 0; swap < 2; swap++) {
2871 Operand operands[3];
2872 bool clamp, precise;
2873 bitarray8 opsel = 0, neg = 0, abs = 0;
2874 uint8_t omod = 0;
2875 bool inbetween_neg;
2876 if (match_op3_for_vop3(ctx, instr->opcode, instr->opcode, instr.get(), swap, "120", operands,
2877 neg, abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL,
2878 &precise) &&
2879 (!inbetween_neg ||
2880 (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2881 ctx.uses[instr->operands[swap].tempId()]--;
2882 if (inbetween_neg) {
2883 neg[0] = !neg[0];
2884 neg[1] = !neg[1];
2885 create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2886 } else {
2887 create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2888 }
2889 return true;
2890 }
2891 }
2892
2893 /* min(-max(a, b), c) -> min3(-a, -b, c)
2894 * max(-min(a, b), c) -> max3(-a, -b, c)
2895 * gfx11: min(max(a, b), c) -> maxmin(a, b, c)
2896 * gfx11: max(min(a, b), c) -> minmax(a, b, c)
2897 */
2898 for (unsigned swap = 0; swap < 2; swap++) {
2899 Operand operands[3];
2900 bool clamp, precise;
2901 bitarray8 opsel = 0, neg = 0, abs = 0;
2902 uint8_t omod = 0;
2903 bool inbetween_neg;
2904 if (match_op3_for_vop3(ctx, instr->opcode, opposite, instr.get(), swap, "120", operands, neg,
2905 abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL, &precise) &&
2906 (inbetween_neg ||
2907 (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2908 ctx.uses[instr->operands[swap].tempId()]--;
2909 if (inbetween_neg) {
2910 neg[0] = !neg[0];
2911 neg[1] = !neg[1];
2912 create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2913 } else {
2914 create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2915 }
2916 return true;
2917 }
2918 }
2919 return false;
2920 }
2921
2922 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
2923 * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
2924 * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
2925 * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b)
2926 * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b)
2927 * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */
2928 bool
combine_salu_not_bitwise(opt_ctx & ctx,aco_ptr<Instruction> & instr)2929 combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2930 {
2931 /* checks */
2932 if (!instr->operands[0].isTemp())
2933 return false;
2934 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2935 return false;
2936
2937 Instruction* op2_instr = follow_operand(ctx, instr->operands[0]);
2938 if (!op2_instr)
2939 return false;
2940 switch (op2_instr->opcode) {
2941 case aco_opcode::s_and_b32:
2942 case aco_opcode::s_or_b32:
2943 case aco_opcode::s_xor_b32:
2944 case aco_opcode::s_and_b64:
2945 case aco_opcode::s_or_b64:
2946 case aco_opcode::s_xor_b64: break;
2947 default: return false;
2948 }
2949
2950 /* create instruction */
2951 std::swap(instr->definitions[0], op2_instr->definitions[0]);
2952 std::swap(instr->definitions[1], op2_instr->definitions[1]);
2953 ctx.uses[instr->operands[0].tempId()]--;
2954 ctx.info[op2_instr->definitions[0].tempId()].label = 0;
2955
2956 switch (op2_instr->opcode) {
2957 case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break;
2958 case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break;
2959 case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break;
2960 case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break;
2961 case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break;
2962 case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break;
2963 default: break;
2964 }
2965
2966 return true;
2967 }
2968
2969 /* s_and_b32(a, s_not_b32(b)) -> s_andn2_b32(a, b)
2970 * s_or_b32(a, s_not_b32(b)) -> s_orn2_b32(a, b)
2971 * s_and_b64(a, s_not_b64(b)) -> s_andn2_b64(a, b)
2972 * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
2973 bool
combine_salu_n2(opt_ctx & ctx,aco_ptr<Instruction> & instr)2974 combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2975 {
2976 if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
2977 return false;
2978
2979 for (unsigned i = 0; i < 2; i++) {
2980 Instruction* op2_instr = follow_operand(ctx, instr->operands[i]);
2981 if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 &&
2982 op2_instr->opcode != aco_opcode::s_not_b64))
2983 continue;
2984 if (ctx.uses[op2_instr->definitions[1].tempId()])
2985 continue;
2986
2987 if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2988 instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2989 continue;
2990
2991 ctx.uses[instr->operands[i].tempId()]--;
2992 instr->operands[0] = instr->operands[!i];
2993 instr->operands[1] = op2_instr->operands[0];
2994 ctx.info[instr->definitions[0].tempId()].label = 0;
2995
2996 switch (instr->opcode) {
2997 case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_andn2_b32; break;
2998 case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_orn2_b32; break;
2999 case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_andn2_b64; break;
3000 case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_orn2_b64; break;
3001 default: break;
3002 }
3003
3004 return true;
3005 }
3006 return false;
3007 }
3008
3009 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
3010 bool
combine_salu_lshl_add(opt_ctx & ctx,aco_ptr<Instruction> & instr)3011 combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3012 {
3013 if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
3014 return false;
3015
3016 for (unsigned i = 0; i < 2; i++) {
3017 Instruction* op2_instr = follow_operand(ctx, instr->operands[i], true);
3018 if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
3019 ctx.uses[op2_instr->definitions[1].tempId()])
3020 continue;
3021 if (!op2_instr->operands[1].isConstant())
3022 continue;
3023
3024 uint32_t shift = op2_instr->operands[1].constantValue();
3025 if (shift < 1 || shift > 4)
3026 continue;
3027
3028 if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
3029 instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
3030 continue;
3031
3032 instr->operands[1] = instr->operands[!i];
3033 instr->operands[0] = copy_operand(ctx, op2_instr->operands[0]);
3034 decrease_uses(ctx, op2_instr);
3035 ctx.info[instr->definitions[0].tempId()].label = 0;
3036
3037 instr->opcode = std::array<aco_opcode, 4>{
3038 aco_opcode::s_lshl1_add_u32, aco_opcode::s_lshl2_add_u32, aco_opcode::s_lshl3_add_u32,
3039 aco_opcode::s_lshl4_add_u32}[shift - 1];
3040
3041 return true;
3042 }
3043 return false;
3044 }
3045
3046 /* s_abs_i32(s_sub_[iu]32(a, b)) -> s_absdiff_i32(a, b)
3047 * s_abs_i32(s_add_[iu]32(a, #b)) -> s_absdiff_i32(a, -b)
3048 */
3049 bool
combine_sabsdiff(opt_ctx & ctx,aco_ptr<Instruction> & instr)3050 combine_sabsdiff(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3051 {
3052 if (!instr->operands[0].isTemp() || !ctx.info[instr->operands[0].tempId()].is_add_sub())
3053 return false;
3054
3055 Instruction* op_instr = follow_operand(ctx, instr->operands[0], false);
3056 if (!op_instr)
3057 return false;
3058
3059 if (op_instr->opcode == aco_opcode::s_add_i32 || op_instr->opcode == aco_opcode::s_add_u32) {
3060 for (unsigned i = 0; i < 2; i++) {
3061 uint64_t constant;
3062 if (op_instr->operands[!i].isLiteral() ||
3063 !is_operand_constant(ctx, op_instr->operands[i], 32, &constant))
3064 continue;
3065
3066 if (op_instr->operands[i].isTemp())
3067 ctx.uses[op_instr->operands[i].tempId()]--;
3068 op_instr->operands[0] = op_instr->operands[!i];
3069 op_instr->operands[1] = Operand::c32(-int32_t(constant));
3070 goto use_absdiff;
3071 }
3072 return false;
3073 }
3074
3075 use_absdiff:
3076 op_instr->opcode = aco_opcode::s_absdiff_i32;
3077 std::swap(instr->definitions[0], op_instr->definitions[0]);
3078 std::swap(instr->definitions[1], op_instr->definitions[1]);
3079 ctx.uses[instr->operands[0].tempId()]--;
3080
3081 return true;
3082 }
3083
3084 bool
combine_add_sub_b2i(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode new_op,uint8_t ops)3085 combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
3086 {
3087 if (instr->usesModifiers())
3088 return false;
3089
3090 for (unsigned i = 0; i < 2; i++) {
3091 if (!((1 << i) & ops))
3092 continue;
3093 if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2i() &&
3094 ctx.uses[instr->operands[i].tempId()] == 1) {
3095
3096 aco_ptr<Instruction> new_instr;
3097 if (instr->operands[!i].isTemp() &&
3098 instr->operands[!i].getTemp().type() == RegType::vgpr) {
3099 new_instr.reset(create_instruction<VALU_instruction>(new_op, Format::VOP2, 3, 2));
3100 } else if (ctx.program->gfx_level >= GFX10 ||
3101 (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3102 new_instr.reset(
3103 create_instruction<VALU_instruction>(new_op, asVOP3(Format::VOP2), 3, 2));
3104 } else {
3105 return false;
3106 }
3107 ctx.uses[instr->operands[i].tempId()]--;
3108 new_instr->definitions[0] = instr->definitions[0];
3109 if (instr->definitions.size() == 2) {
3110 new_instr->definitions[1] = instr->definitions[1];
3111 } else {
3112 new_instr->definitions[1] =
3113 Definition(ctx.program->allocateTmp(ctx.program->lane_mask));
3114 /* Make sure the uses vector is large enough and the number of
3115 * uses properly initialized to 0.
3116 */
3117 ctx.uses.push_back(0);
3118 }
3119 new_instr->operands[0] = Operand::zero();
3120 new_instr->operands[1] = instr->operands[!i];
3121 new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
3122 new_instr->pass_flags = instr->pass_flags;
3123 instr = std::move(new_instr);
3124 ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
3125 return true;
3126 }
3127 }
3128
3129 return false;
3130 }
3131
3132 bool
combine_add_bcnt(opt_ctx & ctx,aco_ptr<Instruction> & instr)3133 combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3134 {
3135 if (instr->usesModifiers())
3136 return false;
3137
3138 for (unsigned i = 0; i < 2; i++) {
3139 Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3140 if (op_instr && op_instr->opcode == aco_opcode::v_bcnt_u32_b32 &&
3141 !op_instr->usesModifiers() && op_instr->operands[0].isTemp() &&
3142 op_instr->operands[0].getTemp().type() == RegType::vgpr &&
3143 op_instr->operands[1].constantEquals(0)) {
3144 aco_ptr<Instruction> new_instr{
3145 create_instruction<VALU_instruction>(aco_opcode::v_bcnt_u32_b32, Format::VOP3, 2, 1)};
3146 ctx.uses[instr->operands[i].tempId()]--;
3147 new_instr->operands[0] = op_instr->operands[0];
3148 new_instr->operands[1] = instr->operands[!i];
3149 new_instr->definitions[0] = instr->definitions[0];
3150 new_instr->pass_flags = instr->pass_flags;
3151 instr = std::move(new_instr);
3152 ctx.info[instr->definitions[0].tempId()].label = 0;
3153
3154 return true;
3155 }
3156 }
3157
3158 return false;
3159 }
3160
3161 bool
get_minmax_info(aco_opcode op,aco_opcode * min,aco_opcode * max,aco_opcode * min3,aco_opcode * max3,aco_opcode * med3,aco_opcode * minmax,bool * some_gfx9_only)3162 get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
3163 aco_opcode* med3, aco_opcode* minmax, bool* some_gfx9_only)
3164 {
3165 switch (op) {
3166 #define MINMAX(type, gfx9) \
3167 case aco_opcode::v_min_##type: \
3168 case aco_opcode::v_max_##type: \
3169 *min = aco_opcode::v_min_##type; \
3170 *max = aco_opcode::v_max_##type; \
3171 *med3 = aco_opcode::v_med3_##type; \
3172 *min3 = aco_opcode::v_min3_##type; \
3173 *max3 = aco_opcode::v_max3_##type; \
3174 *minmax = op == *min ? aco_opcode::v_maxmin_##type : aco_opcode::v_minmax_##type; \
3175 *some_gfx9_only = gfx9; \
3176 return true;
3177 #define MINMAX_INT16(type, gfx9) \
3178 case aco_opcode::v_min_##type: \
3179 case aco_opcode::v_max_##type: \
3180 *min = aco_opcode::v_min_##type; \
3181 *max = aco_opcode::v_max_##type; \
3182 *med3 = aco_opcode::v_med3_##type; \
3183 *min3 = aco_opcode::v_min3_##type; \
3184 *max3 = aco_opcode::v_max3_##type; \
3185 *minmax = aco_opcode::num_opcodes; \
3186 *some_gfx9_only = gfx9; \
3187 return true;
3188 #define MINMAX_INT16_E64(type, gfx9) \
3189 case aco_opcode::v_min_##type##_e64: \
3190 case aco_opcode::v_max_##type##_e64: \
3191 *min = aco_opcode::v_min_##type##_e64; \
3192 *max = aco_opcode::v_max_##type##_e64; \
3193 *med3 = aco_opcode::v_med3_##type; \
3194 *min3 = aco_opcode::v_min3_##type; \
3195 *max3 = aco_opcode::v_max3_##type; \
3196 *minmax = aco_opcode::num_opcodes; \
3197 *some_gfx9_only = gfx9; \
3198 return true;
3199 MINMAX(f32, false)
3200 MINMAX(u32, false)
3201 MINMAX(i32, false)
3202 MINMAX(f16, true)
3203 MINMAX_INT16(u16, true)
3204 MINMAX_INT16(i16, true)
3205 MINMAX_INT16_E64(u16, true)
3206 MINMAX_INT16_E64(i16, true)
3207 #undef MINMAX_INT16_E64
3208 #undef MINMAX_INT16
3209 #undef MINMAX
3210 default: return false;
3211 }
3212 }
3213
3214 /* when ub > lb:
3215 * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
3216 * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
3217 */
3218 bool
combine_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode min,aco_opcode max,aco_opcode med)3219 combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
3220 aco_opcode med)
3221 {
3222 /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
3223 * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
3224 * minVal > maxVal, which means we can always select it to a v_med3_f32 */
3225 aco_opcode other_op;
3226 if (instr->opcode == min)
3227 other_op = max;
3228 else if (instr->opcode == max)
3229 other_op = min;
3230 else
3231 return false;
3232
3233 for (unsigned swap = 0; swap < 2; swap++) {
3234 Operand operands[3];
3235 bool clamp, precise;
3236 bitarray8 opsel = 0, neg = 0, abs = 0;
3237 uint8_t omod = 0;
3238 if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
3239 abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
3240 /* max(min(src, upper), lower) returns upper if src is NaN, but
3241 * med3(src, lower, upper) returns lower.
3242 */
3243 if (precise && instr->opcode != min &&
3244 (min == aco_opcode::v_min_f16 || min == aco_opcode::v_min_f32))
3245 continue;
3246
3247 int const0_idx = -1, const1_idx = -1;
3248 uint32_t const0 = 0, const1 = 0;
3249 for (int i = 0; i < 3; i++) {
3250 uint32_t val;
3251 bool hi16 = opsel & (1 << i);
3252 if (operands[i].isConstant()) {
3253 val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue();
3254 } else if (operands[i].isTemp() &&
3255 ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
3256 val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0);
3257 } else {
3258 continue;
3259 }
3260 if (const0_idx >= 0) {
3261 const1_idx = i;
3262 const1 = val;
3263 } else {
3264 const0_idx = i;
3265 const0 = val;
3266 }
3267 }
3268 if (const0_idx < 0 || const1_idx < 0)
3269 continue;
3270
3271 int lower_idx = const0_idx;
3272 switch (min) {
3273 case aco_opcode::v_min_f32:
3274 case aco_opcode::v_min_f16: {
3275 float const0_f, const1_f;
3276 if (min == aco_opcode::v_min_f32) {
3277 memcpy(&const0_f, &const0, 4);
3278 memcpy(&const1_f, &const1, 4);
3279 } else {
3280 const0_f = _mesa_half_to_float(const0);
3281 const1_f = _mesa_half_to_float(const1);
3282 }
3283 if (abs[const0_idx])
3284 const0_f = fabsf(const0_f);
3285 if (abs[const1_idx])
3286 const1_f = fabsf(const1_f);
3287 if (neg[const0_idx])
3288 const0_f = -const0_f;
3289 if (neg[const1_idx])
3290 const1_f = -const1_f;
3291 lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
3292 break;
3293 }
3294 case aco_opcode::v_min_u32: {
3295 lower_idx = const0 < const1 ? const0_idx : const1_idx;
3296 break;
3297 }
3298 case aco_opcode::v_min_u16:
3299 case aco_opcode::v_min_u16_e64: {
3300 lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
3301 break;
3302 }
3303 case aco_opcode::v_min_i32: {
3304 int32_t const0_i =
3305 const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
3306 int32_t const1_i =
3307 const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
3308 lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
3309 break;
3310 }
3311 case aco_opcode::v_min_i16:
3312 case aco_opcode::v_min_i16_e64: {
3313 int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
3314 int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
3315 lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
3316 break;
3317 }
3318 default: break;
3319 }
3320 int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
3321
3322 if (instr->opcode == min) {
3323 if (upper_idx != 0 || lower_idx == 0)
3324 return false;
3325 } else {
3326 if (upper_idx == 0 || lower_idx != 0)
3327 return false;
3328 }
3329
3330 ctx.uses[instr->operands[swap].tempId()]--;
3331 create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
3332
3333 return true;
3334 }
3335 }
3336
3337 return false;
3338 }
3339
3340 void
apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)3341 apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3342 {
3343 bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
3344 instr->opcode == aco_opcode::v_lshrrev_b64 ||
3345 instr->opcode == aco_opcode::v_ashrrev_i64;
3346
3347 /* find candidates and create the set of sgprs already read */
3348 unsigned sgpr_ids[2] = {0, 0};
3349 uint32_t operand_mask = 0;
3350 bool has_literal = false;
3351 for (unsigned i = 0; i < instr->operands.size(); i++) {
3352 if (instr->operands[i].isLiteral())
3353 has_literal = true;
3354 if (!instr->operands[i].isTemp())
3355 continue;
3356 if (instr->operands[i].getTemp().type() == RegType::sgpr) {
3357 if (instr->operands[i].tempId() != sgpr_ids[0])
3358 sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
3359 }
3360 ssa_info& info = ctx.info[instr->operands[i].tempId()];
3361 if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::sgpr)
3362 operand_mask |= 1u << i;
3363 if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr)
3364 operand_mask |= 1u << i;
3365 }
3366 unsigned max_sgprs = 1;
3367 if (ctx.program->gfx_level >= GFX10 && !is_shift64)
3368 max_sgprs = 2;
3369 if (has_literal)
3370 max_sgprs--;
3371
3372 unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
3373
3374 /* keep on applying sgprs until there is nothing left to be done */
3375 while (operand_mask) {
3376 uint32_t sgpr_idx = 0;
3377 uint32_t sgpr_info_id = 0;
3378 uint32_t mask = operand_mask;
3379 /* choose a sgpr */
3380 while (mask) {
3381 unsigned i = u_bit_scan(&mask);
3382 uint16_t uses = ctx.uses[instr->operands[i].tempId()];
3383 if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
3384 sgpr_idx = i;
3385 sgpr_info_id = instr->operands[i].tempId();
3386 }
3387 }
3388 operand_mask &= ~(1u << sgpr_idx);
3389
3390 ssa_info& info = ctx.info[sgpr_info_id];
3391
3392 /* Applying two sgprs require making it VOP3, so don't do it unless it's
3393 * definitively beneficial.
3394 * TODO: this is too conservative because later the use count could be reduced to 1 */
3395 if (!info.is_extract() && num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() &&
3396 !instr->isSDWA() && instr->format != Format::VOP3P)
3397 break;
3398
3399 Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp;
3400 bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
3401 if (new_sgpr && num_sgprs >= max_sgprs)
3402 continue;
3403
3404 if (sgpr_idx == 0)
3405 instr->format = withoutDPP(instr->format);
3406
3407 if (sgpr_idx == 1 && instr->isDPP())
3408 continue;
3409
3410 if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
3411 info.is_extract()) {
3412 /* can_apply_extract() checks SGPR encoding restrictions */
3413 if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
3414 apply_extract(ctx, instr, sgpr_idx, info);
3415 else if (info.is_extract())
3416 continue;
3417 instr->operands[sgpr_idx] = Operand(sgpr);
3418 } else if (can_swap_operands(instr, &instr->opcode) && !instr->valu().opsel[sgpr_idx]) {
3419 instr->operands[sgpr_idx] = instr->operands[0];
3420 instr->operands[0] = Operand(sgpr);
3421 instr->valu().opsel[0].swap(instr->valu().opsel[sgpr_idx]);
3422 /* swap bits using a 4-entry LUT */
3423 uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
3424 operand_mask = (operand_mask & ~0x3) | swapped;
3425 } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
3426 instr->format = asVOP3(instr->format);
3427 instr->operands[sgpr_idx] = Operand(sgpr);
3428 } else {
3429 continue;
3430 }
3431
3432 if (new_sgpr)
3433 sgpr_ids[num_sgprs++] = sgpr.id();
3434 ctx.uses[sgpr_info_id]--;
3435 ctx.uses[sgpr.id()]++;
3436
3437 /* TODO: handle when it's a VGPR */
3438 if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
3439 ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
3440 operand_mask |= 1u << sgpr_idx;
3441 }
3442 }
3443
3444 void
interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction> & instr)3445 interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction>& instr)
3446 {
3447 static_assert(sizeof(DPP16_instruction) == sizeof(VINTERP_inreg_instruction),
3448 "Invalid instr cast.");
3449 instr->format = asVOP3(Format::DPP16);
3450 instr->opcode = aco_opcode::v_fma_f32;
3451 instr->dpp16().dpp_ctrl = dpp_quad_perm(2, 2, 2, 2);
3452 instr->dpp16().row_mask = 0xf;
3453 instr->dpp16().bank_mask = 0xf;
3454 instr->dpp16().bound_ctrl = 0;
3455 instr->dpp16().fetch_inactive = 1;
3456 }
3457
3458 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
3459 bool
apply_omod_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr)3460 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3461 {
3462 if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
3463 !instr_info.can_use_output_modifiers[(int)instr->opcode])
3464 return false;
3465
3466 bool can_vop3 = can_use_VOP3(ctx, instr);
3467 bool is_mad_mix =
3468 instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16;
3469 bool needs_vop3 = !instr->isSDWA() && !instr->isVINTERP_INREG() && !is_mad_mix;
3470 if (needs_vop3 && !can_vop3)
3471 return false;
3472
3473 /* SDWA omod is GFX9+. */
3474 bool can_use_omod =
3475 (can_vop3 || ctx.program->gfx_level >= GFX9) && !instr->isVOP3P() &&
3476 (!instr->isVINTERP_INREG() || instr->opcode == aco_opcode::v_interp_p2_f32_inreg);
3477
3478 ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3479
3480 uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
3481 if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
3482 return false;
3483 /* if the omod/clamp instruction is dead, then the single user of this
3484 * instruction is a different instruction */
3485 if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3486 return false;
3487
3488 if (def_info.instr->definitions[0].bytes() != instr->definitions[0].bytes())
3489 return false;
3490
3491 /* MADs/FMAs are created later, so we don't have to update the original add */
3492 assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3493
3494 if (!def_info.is_clamp() && (instr->valu().clamp || instr->valu().omod))
3495 return false;
3496
3497 if (needs_vop3)
3498 instr->format = asVOP3(instr->format);
3499
3500 if (!def_info.is_clamp() && instr->opcode == aco_opcode::v_interp_p2_f32_inreg)
3501 interp_p2_f32_inreg_to_fma_dpp(instr);
3502
3503 if (def_info.is_omod2())
3504 instr->valu().omod = 1;
3505 else if (def_info.is_omod4())
3506 instr->valu().omod = 2;
3507 else if (def_info.is_omod5())
3508 instr->valu().omod = 3;
3509 else if (def_info.is_clamp())
3510 instr->valu().clamp = true;
3511
3512 instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3513 ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
3514 ctx.uses[def_info.instr->definitions[0].tempId()]--;
3515
3516 return true;
3517 }
3518
3519 /* Combine an p_insert (or p_extract, in some cases) instruction with instr.
3520 * p_insert(instr(...)) -> instr_insert().
3521 */
3522 bool
apply_insert(opt_ctx & ctx,aco_ptr<Instruction> & instr)3523 apply_insert(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3524 {
3525 if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1)
3526 return false;
3527
3528 ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3529 if (!def_info.is_insert())
3530 return false;
3531 /* if the insert instruction is dead, then the single user of this
3532 * instruction is a different instruction */
3533 if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3534 return false;
3535
3536 /* MADs/FMAs are created later, so we don't have to update the original add */
3537 assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3538
3539 SubdwordSel sel = parse_insert(def_info.instr);
3540 assert(sel);
3541
3542 if (!can_use_SDWA(ctx.program->gfx_level, instr, true))
3543 return false;
3544
3545 convert_to_SDWA(ctx.program->gfx_level, instr);
3546 if (instr->sdwa().dst_sel.size() != 4)
3547 return false;
3548 instr->sdwa().dst_sel = sel;
3549
3550 instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3551 ctx.info[instr->definitions[0].tempId()].label = 0;
3552 ctx.uses[def_info.instr->definitions[0].tempId()]--;
3553
3554 return true;
3555 }
3556
3557 /* Remove superfluous extract after ds_read like so:
3558 * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN()
3559 */
3560 bool
apply_ds_extract(opt_ctx & ctx,aco_ptr<Instruction> & extract)3561 apply_ds_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract)
3562 {
3563 /* Check if p_extract has a usedef operand and is the only user. */
3564 if (!ctx.info[extract->operands[0].tempId()].is_usedef() ||
3565 ctx.uses[extract->operands[0].tempId()] > 1)
3566 return false;
3567
3568 /* Check if the usedef is a DS instruction. */
3569 Instruction* ds = ctx.info[extract->operands[0].tempId()].instr;
3570 if (ds->format != Format::DS)
3571 return false;
3572
3573 unsigned extract_idx = extract->operands[1].constantValue();
3574 unsigned bits_extracted = extract->operands[2].constantValue();
3575 unsigned sign_ext = extract->operands[3].constantValue();
3576 unsigned dst_bitsize = extract->definitions[0].bytes() * 8u;
3577
3578 /* TODO: These are doable, but probably don't occur too often. */
3579 if (extract_idx || sign_ext || dst_bitsize != 32)
3580 return false;
3581
3582 unsigned bits_loaded = 0;
3583 if (ds->opcode == aco_opcode::ds_read_u8 || ds->opcode == aco_opcode::ds_read_u8_d16)
3584 bits_loaded = 8;
3585 else if (ds->opcode == aco_opcode::ds_read_u16 || ds->opcode == aco_opcode::ds_read_u16_d16)
3586 bits_loaded = 16;
3587 else
3588 return false;
3589
3590 /* Shrink the DS load if the extracted bit size is smaller. */
3591 bits_loaded = MIN2(bits_loaded, bits_extracted);
3592
3593 /* Change the DS opcode so it writes the full register. */
3594 if (bits_loaded == 8)
3595 ds->opcode = aco_opcode::ds_read_u8;
3596 else if (bits_loaded == 16)
3597 ds->opcode = aco_opcode::ds_read_u16;
3598 else
3599 unreachable("Forgot to add DS opcode above.");
3600
3601 /* The DS now produces the exact same thing as the extract, remove the extract. */
3602 std::swap(ds->definitions[0], extract->definitions[0]);
3603 ctx.uses[extract->definitions[0].tempId()] = 0;
3604 ctx.info[ds->definitions[0].tempId()].label = 0;
3605 return true;
3606 }
3607
3608 /* v_and(a, v_subbrev_co(0, 0, vcc)) -> v_cndmask(0, a, vcc) */
3609 bool
combine_and_subbrev(opt_ctx & ctx,aco_ptr<Instruction> & instr)3610 combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3611 {
3612 if (instr->usesModifiers())
3613 return false;
3614
3615 for (unsigned i = 0; i < 2; i++) {
3616 Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3617 if (op_instr && op_instr->opcode == aco_opcode::v_subbrev_co_u32 &&
3618 op_instr->operands[0].constantEquals(0) && op_instr->operands[1].constantEquals(0) &&
3619 !op_instr->usesModifiers()) {
3620
3621 aco_ptr<Instruction> new_instr;
3622 if (instr->operands[!i].isTemp() &&
3623 instr->operands[!i].getTemp().type() == RegType::vgpr) {
3624 new_instr.reset(
3625 create_instruction<VALU_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1));
3626 } else if (ctx.program->gfx_level >= GFX10 ||
3627 (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3628 new_instr.reset(create_instruction<VALU_instruction>(aco_opcode::v_cndmask_b32,
3629 asVOP3(Format::VOP2), 3, 1));
3630 } else {
3631 return false;
3632 }
3633
3634 new_instr->operands[0] = Operand::zero();
3635 new_instr->operands[1] = instr->operands[!i];
3636 new_instr->operands[2] = copy_operand(ctx, op_instr->operands[2]);
3637 new_instr->definitions[0] = instr->definitions[0];
3638 new_instr->pass_flags = instr->pass_flags;
3639 instr = std::move(new_instr);
3640 decrease_uses(ctx, op_instr);
3641 ctx.info[instr->definitions[0].tempId()].label = 0;
3642 return true;
3643 }
3644 }
3645
3646 return false;
3647 }
3648
3649 /* v_and(a, not(b)) -> v_bfi_b32(b, 0, a)
3650 * v_or(a, not(b)) -> v_bfi_b32(b, a, -1)
3651 */
3652 bool
combine_v_andor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)3653 combine_v_andor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3654 {
3655 if (instr->usesModifiers())
3656 return false;
3657
3658 for (unsigned i = 0; i < 2; i++) {
3659 Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3660 if (op_instr && !op_instr->usesModifiers() &&
3661 (op_instr->opcode == aco_opcode::v_not_b32 ||
3662 op_instr->opcode == aco_opcode::s_not_b32)) {
3663
3664 Operand ops[3] = {
3665 op_instr->operands[0],
3666 Operand::zero(),
3667 instr->operands[!i],
3668 };
3669 if (instr->opcode == aco_opcode::v_or_b32) {
3670 ops[1] = instr->operands[!i];
3671 ops[2] = Operand::c32(-1);
3672 }
3673 if (!check_vop3_operands(ctx, 3, ops))
3674 continue;
3675
3676 Instruction* new_instr =
3677 create_instruction<VALU_instruction>(aco_opcode::v_bfi_b32, Format::VOP3, 3, 1);
3678
3679 if (op_instr->operands[0].isTemp())
3680 ctx.uses[op_instr->operands[0].tempId()]++;
3681 for (unsigned j = 0; j < 3; j++)
3682 new_instr->operands[j] = ops[j];
3683 new_instr->definitions[0] = instr->definitions[0];
3684 new_instr->pass_flags = instr->pass_flags;
3685 instr.reset(new_instr);
3686 decrease_uses(ctx, op_instr);
3687 ctx.info[instr->definitions[0].tempId()].label = 0;
3688 return true;
3689 }
3690 }
3691
3692 return false;
3693 }
3694
3695 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
3696 * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
3697 * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
3698 * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
3699 */
3700 bool
combine_add_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr,bool is_sub)3701 combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
3702 {
3703 if (instr->usesModifiers())
3704 return false;
3705
3706 /* Substractions: start at operand 1 to avoid mixup such as
3707 * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
3708 */
3709 unsigned start_op_idx = is_sub ? 1 : 0;
3710
3711 /* Don't allow 24-bit operands on subtraction because
3712 * v_mad_i32_i24 applies a sign extension.
3713 */
3714 bool allow_24bit = !is_sub;
3715
3716 for (unsigned i = start_op_idx; i < 2; i++) {
3717 Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3718 if (!op_instr)
3719 continue;
3720
3721 if (op_instr->opcode != aco_opcode::s_lshl_b32 &&
3722 op_instr->opcode != aco_opcode::v_lshlrev_b32)
3723 continue;
3724
3725 int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
3726
3727 if (op_instr->operands[shift_op_idx].isConstant() &&
3728 ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
3729 op_instr->operands[!shift_op_idx].is16bit())) {
3730 uint32_t multiplier = 1 << (op_instr->operands[shift_op_idx].constantValue() % 32u);
3731 if (is_sub)
3732 multiplier = -multiplier;
3733 if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
3734 continue;
3735
3736 Operand ops[3] = {
3737 op_instr->operands[!shift_op_idx],
3738 Operand::c32(multiplier),
3739 instr->operands[!i],
3740 };
3741 if (!check_vop3_operands(ctx, 3, ops))
3742 return false;
3743
3744 ctx.uses[instr->operands[i].tempId()]--;
3745
3746 aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : aco_opcode::v_mad_u32_u24;
3747 aco_ptr<VALU_instruction> new_instr{
3748 create_instruction<VALU_instruction>(mad_op, Format::VOP3, 3, 1)};
3749 for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
3750 new_instr->operands[op_idx] = ops[op_idx];
3751 new_instr->definitions[0] = instr->definitions[0];
3752 new_instr->pass_flags = instr->pass_flags;
3753 instr = std::move(new_instr);
3754 ctx.info[instr->definitions[0].tempId()].label = 0;
3755 return true;
3756 }
3757 }
3758
3759 return false;
3760 }
3761
3762 void
propagate_swizzles(VALU_instruction * instr,bool opsel_lo,bool opsel_hi)3763 propagate_swizzles(VALU_instruction* instr, bool opsel_lo, bool opsel_hi)
3764 {
3765 /* propagate swizzles which apply to a result down to the instruction's operands:
3766 * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
3767 uint8_t tmp_lo = instr->opsel_lo;
3768 uint8_t tmp_hi = instr->opsel_hi;
3769 uint8_t neg_lo = instr->neg_lo;
3770 uint8_t neg_hi = instr->neg_hi;
3771 if (opsel_lo == 1) {
3772 instr->opsel_lo = tmp_hi;
3773 instr->neg_lo = neg_hi;
3774 }
3775 if (opsel_hi == 0) {
3776 instr->opsel_hi = tmp_lo;
3777 instr->neg_hi = neg_lo;
3778 }
3779 }
3780
3781 void
combine_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr)3782 combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3783 {
3784 VALU_instruction* vop3p = &instr->valu();
3785
3786 /* apply clamp */
3787 if (instr->opcode == aco_opcode::v_pk_mul_f16 && instr->operands[1].constantEquals(0x3C00) &&
3788 vop3p->clamp && instr->operands[0].isTemp() && ctx.uses[instr->operands[0].tempId()] == 1 &&
3789 !vop3p->opsel_lo[1] && !vop3p->opsel_hi[1]) {
3790
3791 ssa_info& info = ctx.info[instr->operands[0].tempId()];
3792 if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
3793 VALU_instruction* candidate = &ctx.info[instr->operands[0].tempId()].instr->valu();
3794 candidate->clamp = true;
3795 propagate_swizzles(candidate, vop3p->opsel_lo[0], vop3p->opsel_hi[0]);
3796 instr->definitions[0].swapTemp(candidate->definitions[0]);
3797 ctx.info[candidate->definitions[0].tempId()].instr = candidate;
3798 ctx.uses[instr->definitions[0].tempId()]--;
3799 return;
3800 }
3801 }
3802
3803 /* check for fneg modifiers */
3804 for (unsigned i = 0; i < instr->operands.size(); i++) {
3805 if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
3806 continue;
3807 Operand& op = instr->operands[i];
3808 if (!op.isTemp())
3809 continue;
3810
3811 ssa_info& info = ctx.info[op.tempId()];
3812 if (info.is_vop3p() && info.instr->opcode == aco_opcode::v_pk_mul_f16 &&
3813 (info.instr->operands[0].constantEquals(0x3C00) ||
3814 info.instr->operands[1].constantEquals(0x3C00))) {
3815
3816 VALU_instruction* fneg = &info.instr->valu();
3817
3818 unsigned fneg_src = fneg->operands[0].constantEquals(0x3C00);
3819
3820 if (fneg->opsel_lo[1 - fneg_src] || fneg->opsel_hi[1 - fneg_src])
3821 continue;
3822
3823 Operand ops[3];
3824 for (unsigned j = 0; j < instr->operands.size(); j++)
3825 ops[j] = instr->operands[j];
3826 ops[i] = fneg->operands[fneg_src];
3827 if (!check_vop3_operands(ctx, instr->operands.size(), ops))
3828 continue;
3829
3830 if (fneg->clamp)
3831 continue;
3832 instr->operands[i] = fneg->operands[fneg_src];
3833
3834 /* opsel_lo/hi is either 0 or 1:
3835 * if 0 - pick selection from fneg->lo
3836 * if 1 - pick selection from fneg->hi
3837 */
3838 bool opsel_lo = vop3p->opsel_lo[i];
3839 bool opsel_hi = vop3p->opsel_hi[i];
3840 bool neg_lo = fneg->neg_lo[0] ^ fneg->neg_lo[1];
3841 bool neg_hi = fneg->neg_hi[0] ^ fneg->neg_hi[1];
3842 vop3p->neg_lo[i] ^= opsel_lo ? neg_hi : neg_lo;
3843 vop3p->neg_hi[i] ^= opsel_hi ? neg_hi : neg_lo;
3844 vop3p->opsel_lo[i] ^= opsel_lo ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3845 vop3p->opsel_hi[i] ^= opsel_hi ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3846
3847 if (--ctx.uses[fneg->definitions[0].tempId()])
3848 ctx.uses[fneg->operands[fneg_src].tempId()]++;
3849 }
3850 }
3851
3852 if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) {
3853 bool fadd = instr->opcode == aco_opcode::v_pk_add_f16;
3854 if (fadd && instr->definitions[0].isPrecise())
3855 return;
3856
3857 Instruction* mul_instr = nullptr;
3858 unsigned add_op_idx = 0;
3859 bitarray8 mul_neg_lo = 0, mul_neg_hi = 0, mul_opsel_lo = 0, mul_opsel_hi = 0;
3860 uint32_t uses = UINT32_MAX;
3861
3862 /* find the 'best' mul instruction to combine with the add */
3863 for (unsigned i = 0; i < 2; i++) {
3864 Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3865 if (!op_instr)
3866 continue;
3867
3868 if (ctx.info[instr->operands[i].tempId()].is_vop3p()) {
3869 if (fadd) {
3870 if (op_instr->opcode != aco_opcode::v_pk_mul_f16 ||
3871 op_instr->definitions[0].isPrecise())
3872 continue;
3873 } else {
3874 if (op_instr->opcode != aco_opcode::v_pk_mul_lo_u16)
3875 continue;
3876 }
3877
3878 Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3879 if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3880 continue;
3881
3882 /* no clamp allowed between mul and add */
3883 if (op_instr->valu().clamp)
3884 continue;
3885
3886 mul_instr = op_instr;
3887 add_op_idx = 1 - i;
3888 uses = ctx.uses[instr->operands[i].tempId()];
3889 mul_neg_lo = mul_instr->valu().neg_lo;
3890 mul_neg_hi = mul_instr->valu().neg_hi;
3891 mul_opsel_lo = mul_instr->valu().opsel_lo;
3892 mul_opsel_hi = mul_instr->valu().opsel_hi;
3893 } else if (instr->operands[i].bytes() == 2) {
3894 if ((fadd && (op_instr->opcode != aco_opcode::v_mul_f16 ||
3895 op_instr->definitions[0].isPrecise())) ||
3896 (!fadd && op_instr->opcode != aco_opcode::v_mul_lo_u16 &&
3897 op_instr->opcode != aco_opcode::v_mul_lo_u16_e64))
3898 continue;
3899
3900 if (op_instr->valu().clamp || op_instr->valu().omod || op_instr->valu().abs)
3901 continue;
3902
3903 if (op_instr->isDPP() || (op_instr->isSDWA() && (op_instr->sdwa().sel[0].size() < 2 ||
3904 op_instr->sdwa().sel[1].size() < 2)))
3905 continue;
3906
3907 Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3908 if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3909 continue;
3910
3911 mul_instr = op_instr;
3912 add_op_idx = 1 - i;
3913 uses = ctx.uses[instr->operands[i].tempId()];
3914 mul_neg_lo = mul_instr->valu().neg;
3915 mul_neg_hi = mul_instr->valu().neg;
3916 if (mul_instr->isSDWA()) {
3917 for (unsigned j = 0; j < 2; j++)
3918 mul_opsel_lo[j] = mul_instr->sdwa().sel[j].offset();
3919 } else {
3920 mul_opsel_lo = mul_instr->valu().opsel;
3921 }
3922 mul_opsel_hi = mul_opsel_lo;
3923 }
3924 }
3925
3926 if (!mul_instr)
3927 return;
3928
3929 /* turn mul + packed add into v_pk_fma_f16 */
3930 aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
3931 aco_ptr<VALU_instruction> fma{create_instruction<VALU_instruction>(mad, Format::VOP3P, 3, 1)};
3932 fma->operands[0] = copy_operand(ctx, mul_instr->operands[0]);
3933 fma->operands[1] = copy_operand(ctx, mul_instr->operands[1]);
3934 fma->operands[2] = instr->operands[add_op_idx];
3935 fma->clamp = vop3p->clamp;
3936 fma->neg_lo = mul_neg_lo;
3937 fma->neg_hi = mul_neg_hi;
3938 fma->opsel_lo = mul_opsel_lo;
3939 fma->opsel_hi = mul_opsel_hi;
3940 propagate_swizzles(fma.get(), vop3p->opsel_lo[1 - add_op_idx],
3941 vop3p->opsel_hi[1 - add_op_idx]);
3942 fma->opsel_lo[2] = vop3p->opsel_lo[add_op_idx];
3943 fma->opsel_hi[2] = vop3p->opsel_hi[add_op_idx];
3944 fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];
3945 fma->neg_hi[2] = vop3p->neg_hi[add_op_idx];
3946 fma->neg_lo[1] = fma->neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
3947 fma->neg_hi[1] = fma->neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
3948 fma->definitions[0] = instr->definitions[0];
3949 fma->pass_flags = instr->pass_flags;
3950 instr = std::move(fma);
3951 ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
3952 decrease_uses(ctx, mul_instr);
3953 return;
3954 }
3955 }
3956
3957 bool
can_use_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3958 can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3959 {
3960 if (ctx.program->gfx_level < GFX9)
3961 return false;
3962
3963 /* v_mad_mix* on GFX9 always flushes denormals for 16-bit inputs/outputs */
3964 if (ctx.program->gfx_level == GFX9 && ctx.fp_mode.denorm16_64)
3965 return false;
3966
3967 if (instr->valu().omod)
3968 return false;
3969
3970 switch (instr->opcode) {
3971 case aco_opcode::v_add_f32:
3972 case aco_opcode::v_sub_f32:
3973 case aco_opcode::v_subrev_f32:
3974 case aco_opcode::v_mul_f32: return !instr->isSDWA() && !instr->isDPP();
3975 case aco_opcode::v_fma_f32:
3976 return ctx.program->dev.fused_mad_mix || !instr->definitions[0].isPrecise();
3977 case aco_opcode::v_fma_mix_f32:
3978 case aco_opcode::v_fma_mixlo_f16: return true;
3979 default: return false;
3980 }
3981 }
3982
3983 void
to_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3984 to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3985 {
3986 ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
3987
3988 if (instr->opcode == aco_opcode::v_fma_f32) {
3989 instr->format = (Format)((uint32_t)withoutVOP3(instr->format) | (uint32_t)(Format::VOP3P));
3990 instr->opcode = aco_opcode::v_fma_mix_f32;
3991 return;
3992 }
3993
3994 bool is_add = instr->opcode != aco_opcode::v_mul_f32;
3995
3996 aco_ptr<VALU_instruction> vop3p{
3997 create_instruction<VALU_instruction>(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
3998
3999 for (unsigned i = 0; i < instr->operands.size(); i++) {
4000 vop3p->operands[is_add + i] = instr->operands[i];
4001 vop3p->neg_lo[is_add + i] = instr->valu().neg[i];
4002 vop3p->neg_hi[is_add + i] = instr->valu().abs[i];
4003 }
4004 if (instr->opcode == aco_opcode::v_mul_f32) {
4005 vop3p->operands[2] = Operand::zero();
4006 vop3p->neg_lo[2] = true;
4007 } else if (is_add) {
4008 vop3p->operands[0] = Operand::c32(0x3f800000);
4009 if (instr->opcode == aco_opcode::v_sub_f32)
4010 vop3p->neg_lo[2] ^= true;
4011 else if (instr->opcode == aco_opcode::v_subrev_f32)
4012 vop3p->neg_lo[1] ^= true;
4013 }
4014 vop3p->definitions[0] = instr->definitions[0];
4015 vop3p->clamp = instr->valu().clamp;
4016 vop3p->pass_flags = instr->pass_flags;
4017 instr = std::move(vop3p);
4018
4019 if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
4020 ctx.info[instr->definitions[0].tempId()].instr = instr.get();
4021 }
4022
4023 bool
combine_output_conversion(opt_ctx & ctx,aco_ptr<Instruction> & instr)4024 combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4025 {
4026 ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
4027 if (!def_info.is_f2f16())
4028 return false;
4029 Instruction* conv = def_info.instr;
4030
4031 if (!ctx.uses[conv->definitions[0].tempId()] || ctx.uses[instr->definitions[0].tempId()] != 1)
4032 return false;
4033
4034 if (conv->usesModifiers())
4035 return false;
4036
4037 if (instr->opcode == aco_opcode::v_interp_p2_f32_inreg)
4038 interp_p2_f32_inreg_to_fma_dpp(instr);
4039
4040 if (!can_use_mad_mix(ctx, instr))
4041 return false;
4042
4043 if (!instr->isVOP3P())
4044 to_mad_mix(ctx, instr);
4045
4046 instr->opcode = aco_opcode::v_fma_mixlo_f16;
4047 instr->definitions[0].swapTemp(conv->definitions[0]);
4048 if (conv->definitions[0].isPrecise())
4049 instr->definitions[0].setPrecise(true);
4050 ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
4051 ctx.uses[conv->definitions[0].tempId()]--;
4052
4053 return true;
4054 }
4055
4056 void
combine_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)4057 combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4058 {
4059 if (!can_use_mad_mix(ctx, instr))
4060 return;
4061
4062 for (unsigned i = 0; i < instr->operands.size(); i++) {
4063 if (!instr->operands[i].isTemp())
4064 continue;
4065 Temp tmp = instr->operands[i].getTemp();
4066 if (!ctx.info[tmp.id()].is_f2f32())
4067 continue;
4068
4069 Instruction* conv = ctx.info[tmp.id()].instr;
4070 if (conv->valu().clamp || conv->valu().omod) {
4071 continue;
4072 } else if (conv->isSDWA() &&
4073 (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2)) {
4074 continue;
4075 } else if (conv->isDPP()) {
4076 continue;
4077 }
4078
4079 if (get_operand_size(instr, i) != 32)
4080 continue;
4081
4082 /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect
4083 * check_vop3_operands(). */
4084 Operand op[3];
4085 for (unsigned j = 0; j < instr->operands.size(); j++)
4086 op[j] = instr->operands[j];
4087 op[i] = conv->operands[0];
4088 if (!check_vop3_operands(ctx, instr->operands.size(), op))
4089 continue;
4090 if (!conv->operands[0].isOfType(RegType::vgpr) && instr->isDPP())
4091 continue;
4092
4093 if (!instr->isVOP3P()) {
4094 bool is_add =
4095 instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
4096 to_mad_mix(ctx, instr);
4097 i += is_add;
4098 }
4099
4100 if (--ctx.uses[tmp.id()])
4101 ctx.uses[conv->operands[0].tempId()]++;
4102 instr->operands[i].setTemp(conv->operands[0].getTemp());
4103 if (conv->definitions[0].isPrecise())
4104 instr->definitions[0].setPrecise(true);
4105 instr->valu().opsel_hi[i] = true;
4106 if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
4107 instr->valu().opsel_lo[i] = true;
4108 else
4109 instr->valu().opsel_lo[i] = conv->valu().opsel[0];
4110 bool neg = conv->valu().neg[0];
4111 bool abs = conv->valu().abs[0];
4112 if (!instr->valu().abs[i]) {
4113 instr->valu().neg[i] ^= neg;
4114 instr->valu().abs[i] = abs;
4115 }
4116 }
4117 }
4118
4119 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
4120 // this would mean that we'd have to fix the instruction uses while value propagation
4121
4122 /* also returns true for inf */
4123 bool
is_pow_of_two(opt_ctx & ctx,Operand op)4124 is_pow_of_two(opt_ctx& ctx, Operand op)
4125 {
4126 if (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(op.bytes() * 8))
4127 return is_pow_of_two(ctx, get_constant_op(ctx, ctx.info[op.tempId()], op.bytes() * 8));
4128 else if (!op.isConstant())
4129 return false;
4130
4131 uint64_t val = op.constantValue64();
4132
4133 if (op.bytes() == 4) {
4134 uint32_t exponent = (val & 0x7f800000) >> 23;
4135 uint32_t fraction = val & 0x007fffff;
4136 return (exponent >= 127) && (fraction == 0);
4137 } else if (op.bytes() == 2) {
4138 uint32_t exponent = (val & 0x7c00) >> 10;
4139 uint32_t fraction = val & 0x03ff;
4140 return (exponent >= 15) && (fraction == 0);
4141 } else {
4142 assert(op.bytes() == 8);
4143 uint64_t exponent = (val & UINT64_C(0x7ff0000000000000)) >> 52;
4144 uint64_t fraction = val & UINT64_C(0x000fffffffffffff);
4145 return (exponent >= 1023) && (fraction == 0);
4146 }
4147 }
4148
4149 void
combine_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4150 combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4151 {
4152 if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
4153 return;
4154
4155 if (instr->isVALU()) {
4156 /* Apply SDWA. Do this after label_instruction() so it can remove
4157 * label_extract if not all instructions can take SDWA. */
4158 for (unsigned i = 0; i < instr->operands.size(); i++) {
4159 Operand& op = instr->operands[i];
4160 if (!op.isTemp())
4161 continue;
4162 ssa_info& info = ctx.info[op.tempId()];
4163 if (!info.is_extract())
4164 continue;
4165 /* if there are that many uses, there are likely better combinations */
4166 // TODO: delay applying extract to a point where we know better
4167 if (ctx.uses[op.tempId()] > 4) {
4168 info.label &= ~label_extract;
4169 continue;
4170 }
4171 if (info.is_extract() &&
4172 (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
4173 instr->operands[i].getTemp().type() == RegType::sgpr) &&
4174 can_apply_extract(ctx, instr, i, info)) {
4175 /* Increase use count of the extract's operand if the extract still has uses. */
4176 apply_extract(ctx, instr, i, info);
4177 if (--ctx.uses[instr->operands[i].tempId()])
4178 ctx.uses[info.instr->operands[0].tempId()]++;
4179 instr->operands[i].setTemp(info.instr->operands[0].getTemp());
4180 }
4181 }
4182
4183 if (can_apply_sgprs(ctx, instr))
4184 apply_sgprs(ctx, instr);
4185 combine_mad_mix(ctx, instr);
4186 while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr))
4187 ;
4188 apply_insert(ctx, instr);
4189 }
4190
4191 if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
4192 instr->opcode != aco_opcode::v_fma_mixlo_f16)
4193 return combine_vop3p(ctx, instr);
4194
4195 if (instr->isSDWA() || instr->isDPP())
4196 return;
4197
4198 if (instr->opcode == aco_opcode::p_extract) {
4199 ssa_info& info = ctx.info[instr->operands[0].tempId()];
4200 if (info.is_extract() && can_apply_extract(ctx, instr, 0, info)) {
4201 apply_extract(ctx, instr, 0, info);
4202 if (--ctx.uses[instr->operands[0].tempId()])
4203 ctx.uses[info.instr->operands[0].tempId()]++;
4204 instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4205 }
4206
4207 apply_ds_extract(ctx, instr);
4208 }
4209
4210 if (instr->isVOPC()) {
4211 if (optimize_cmp_subgroup_invocation(ctx, instr))
4212 return;
4213 }
4214
4215 /* TODO: There are still some peephole optimizations that could be done:
4216 * - abs(a - b) -> s_absdiff_i32
4217 * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
4218 * - patterns for v_alignbit_b32 and v_alignbyte_b32
4219 * These aren't probably too interesting though.
4220 * There are also patterns for v_cmp_class_f{16,32,64}. This is difficult but
4221 * probably more useful than the previously mentioned optimizations.
4222 * The various comparison optimizations also currently only work with 32-bit
4223 * floats. */
4224
4225 /* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
4226 if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) &&
4227 ctx.uses[instr->operands[1].tempId()] == 1) {
4228 Temp val = ctx.info[instr->definitions[0].tempId()].temp;
4229
4230 if (!ctx.info[val.id()].is_mul())
4231 return;
4232
4233 Instruction* mul_instr = ctx.info[val.id()].instr;
4234
4235 if (mul_instr->operands[0].isLiteral())
4236 return;
4237 if (mul_instr->valu().clamp)
4238 return;
4239 if (mul_instr->isSDWA() || mul_instr->isDPP())
4240 return;
4241 if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
4242 ctx.fp_mode.preserve_signed_zero_inf_nan32)
4243 return;
4244 if (mul_instr->definitions[0].bytes() != instr->definitions[0].bytes())
4245 return;
4246
4247 /* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
4248 ctx.uses[mul_instr->definitions[0].tempId()]--;
4249 Definition def = instr->definitions[0];
4250 bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg();
4251 bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
4252 uint32_t pass_flags = instr->pass_flags;
4253 Format format = mul_instr->format == Format::VOP2 ? asVOP3(Format::VOP2) : mul_instr->format;
4254 instr.reset(create_instruction<VALU_instruction>(mul_instr->opcode, format,
4255 mul_instr->operands.size(), 1));
4256 std::copy(mul_instr->operands.cbegin(), mul_instr->operands.cend(), instr->operands.begin());
4257 instr->pass_flags = pass_flags;
4258 instr->definitions[0] = def;
4259 VALU_instruction& new_mul = instr->valu();
4260 VALU_instruction& mul = mul_instr->valu();
4261 new_mul.neg = mul.neg;
4262 new_mul.abs = mul.abs;
4263 new_mul.omod = mul.omod;
4264 new_mul.opsel = mul.opsel;
4265 new_mul.opsel_lo = mul.opsel_lo;
4266 new_mul.opsel_hi = mul.opsel_hi;
4267 if (is_abs) {
4268 new_mul.neg[0] = new_mul.neg[1] = false;
4269 new_mul.abs[0] = new_mul.abs[1] = true;
4270 }
4271 new_mul.neg[0] ^= is_neg;
4272 new_mul.clamp = false;
4273
4274 ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
4275 return;
4276 }
4277
4278 /* combine mul+add -> mad */
4279 bool is_add_mix =
4280 (instr->opcode == aco_opcode::v_fma_mix_f32 ||
4281 instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
4282 !instr->valu().neg_lo[0] &&
4283 ((instr->operands[0].constantEquals(0x3f800000) && !instr->valu().opsel_hi[0]) ||
4284 (instr->operands[0].constantEquals(0x3C00) && instr->valu().opsel_hi[0] &&
4285 !instr->valu().opsel_lo[0]));
4286 bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
4287 instr->opcode == aco_opcode::v_subrev_f32;
4288 bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
4289 instr->opcode == aco_opcode::v_subrev_f16;
4290 bool mad64 = instr->opcode == aco_opcode::v_add_f64;
4291 if (is_add_mix || mad16 || mad32 || mad64) {
4292 Instruction* mul_instr = nullptr;
4293 unsigned add_op_idx = 0;
4294 uint32_t uses = UINT32_MAX;
4295 bool emit_fma = false;
4296 /* find the 'best' mul instruction to combine with the add */
4297 for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
4298 if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
4299 continue;
4300 ssa_info& info = ctx.info[instr->operands[i].tempId()];
4301
4302 /* no clamp/omod allowed between mul and add */
4303 if (info.instr->isVOP3() && (info.instr->valu().clamp || info.instr->valu().omod))
4304 continue;
4305 if (info.instr->isVOP3P() && info.instr->valu().clamp)
4306 continue;
4307 /* v_fma_mix_f32/etc can't do omod */
4308 if (info.instr->isVOP3P() && instr->isVOP3() && instr->valu().omod)
4309 continue;
4310 /* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
4311 if (is_add_mix && info.instr->definitions[0].bytes() == 2)
4312 continue;
4313
4314 if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
4315 continue;
4316
4317 bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
4318 bool mad_mix = is_add_mix || info.instr->isVOP3P();
4319
4320 /* Multiplication by power-of-two should never need rounding. 1/power-of-two also works,
4321 * but using fma removes denormal flushing (0xfffffe * 0.5 + 0x810001a2).
4322 */
4323 bool is_fma_precise = is_pow_of_two(ctx, info.instr->operands[0]) ||
4324 is_pow_of_two(ctx, info.instr->operands[1]);
4325
4326 bool has_fma = mad16 || mad64 || (legacy && ctx.program->gfx_level >= GFX10_3) ||
4327 (mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
4328 (mad_mix && ctx.program->dev.fused_mad_mix);
4329 bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
4330 : ((mad32 && ctx.program->gfx_level < GFX10_3) ||
4331 (mad16 && ctx.program->gfx_level <= GFX9));
4332 bool can_use_fma =
4333 has_fma &&
4334 (!(info.instr->definitions[0].isPrecise() || instr->definitions[0].isPrecise()) ||
4335 is_fma_precise);
4336 bool can_use_mad =
4337 has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
4338 if (mad_mix && legacy)
4339 continue;
4340 if (!can_use_fma && !can_use_mad)
4341 continue;
4342
4343 unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
4344 Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
4345 instr->operands[candidate_add_op_idx]};
4346 if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
4347 ctx.uses[instr->operands[i].tempId()] > uses)
4348 continue;
4349
4350 if (ctx.uses[instr->operands[i].tempId()] == uses) {
4351 unsigned cur_idx = mul_instr->definitions[0].tempId();
4352 unsigned new_idx = info.instr->definitions[0].tempId();
4353 if (cur_idx > new_idx)
4354 continue;
4355 }
4356
4357 mul_instr = info.instr;
4358 add_op_idx = candidate_add_op_idx;
4359 uses = ctx.uses[instr->operands[i].tempId()];
4360 emit_fma = !can_use_mad;
4361 }
4362
4363 if (mul_instr) {
4364 /* turn mul+add into v_mad/v_fma */
4365 Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1],
4366 instr->operands[add_op_idx]};
4367 ctx.uses[mul_instr->definitions[0].tempId()]--;
4368 if (ctx.uses[mul_instr->definitions[0].tempId()]) {
4369 if (op[0].isTemp())
4370 ctx.uses[op[0].tempId()]++;
4371 if (op[1].isTemp())
4372 ctx.uses[op[1].tempId()]++;
4373 }
4374
4375 bool neg[3] = {false, false, false};
4376 bool abs[3] = {false, false, false};
4377 unsigned omod = 0;
4378 bool clamp = false;
4379 bitarray8 opsel_lo = 0;
4380 bitarray8 opsel_hi = 0;
4381 bitarray8 opsel = 0;
4382 unsigned mul_op_idx = (instr->isVOP3P() ? 3 : 1) - add_op_idx;
4383
4384 VALU_instruction& valu_mul = mul_instr->valu();
4385 neg[0] = valu_mul.neg[0];
4386 neg[1] = valu_mul.neg[1];
4387 abs[0] = valu_mul.abs[0];
4388 abs[1] = valu_mul.abs[1];
4389 opsel_lo = valu_mul.opsel_lo & 0x3;
4390 opsel_hi = valu_mul.opsel_hi & 0x3;
4391 opsel = valu_mul.opsel & 0x3;
4392
4393 VALU_instruction& valu = instr->valu();
4394 neg[2] = valu.neg[add_op_idx];
4395 abs[2] = valu.abs[add_op_idx];
4396 opsel_lo[2] = valu.opsel_lo[add_op_idx];
4397 opsel_hi[2] = valu.opsel_hi[add_op_idx];
4398 opsel[2] = valu.opsel[add_op_idx];
4399 opsel[3] = valu.opsel[3];
4400 omod = valu.omod;
4401 clamp = valu.clamp;
4402 /* abs of the multiplication result */
4403 if (valu.abs[mul_op_idx]) {
4404 neg[0] = false;
4405 neg[1] = false;
4406 abs[0] = true;
4407 abs[1] = true;
4408 }
4409 /* neg of the multiplication result */
4410 neg[1] ^= valu.neg[mul_op_idx];
4411
4412 if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
4413 neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
4414 else if (instr->opcode == aco_opcode::v_subrev_f32 ||
4415 instr->opcode == aco_opcode::v_subrev_f16)
4416 neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
4417
4418 aco_ptr<Instruction> add_instr = std::move(instr);
4419 aco_ptr<VALU_instruction> mad;
4420 if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
4421 assert(!omod);
4422 assert(!opsel);
4423
4424 aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
4425 : aco_opcode::v_fma_mix_f32;
4426 mad.reset(create_instruction<VALU_instruction>(mad_op, Format::VOP3P, 3, 1));
4427 } else {
4428 assert(!opsel_lo);
4429 assert(!opsel_hi);
4430
4431 aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
4432 if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
4433 assert(emit_fma == (ctx.program->gfx_level >= GFX10_3));
4434 mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
4435 } else if (mad16) {
4436 mad_op = emit_fma ? (ctx.program->gfx_level == GFX8 ? aco_opcode::v_fma_legacy_f16
4437 : aco_opcode::v_fma_f16)
4438 : (ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_f16
4439 : aco_opcode::v_mad_f16);
4440 } else if (mad64) {
4441 mad_op = aco_opcode::v_fma_f64;
4442 }
4443
4444 mad.reset(create_instruction<VALU_instruction>(mad_op, Format::VOP3, 3, 1));
4445 }
4446
4447 for (unsigned i = 0; i < 3; i++) {
4448 mad->operands[i] = op[i];
4449 mad->neg[i] = neg[i];
4450 mad->abs[i] = abs[i];
4451 }
4452 mad->omod = omod;
4453 mad->clamp = clamp;
4454 mad->opsel_lo = opsel_lo;
4455 mad->opsel_hi = opsel_hi;
4456 mad->opsel = opsel;
4457 mad->definitions[0] = add_instr->definitions[0];
4458 mad->definitions[0].setPrecise(add_instr->definitions[0].isPrecise() ||
4459 mul_instr->definitions[0].isPrecise());
4460 mad->pass_flags = add_instr->pass_flags;
4461
4462 instr = std::move(mad);
4463
4464 /* mark this ssa_def to be re-checked for profitability and literals */
4465 ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
4466 ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4467 return;
4468 }
4469 }
4470 /* v_mul_f32(v_cndmask_b32(0, 1.0, cond), a) -> v_cndmask_b32(0, a, cond) */
4471 else if (((instr->opcode == aco_opcode::v_mul_f32 &&
4472 !ctx.fp_mode.preserve_signed_zero_inf_nan32) ||
4473 instr->opcode == aco_opcode::v_mul_legacy_f32) &&
4474 !instr->usesModifiers() && !ctx.fp_mode.must_flush_denorms32) {
4475 for (unsigned i = 0; i < 2; i++) {
4476 if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2f() &&
4477 ctx.uses[instr->operands[i].tempId()] == 1 && instr->operands[!i].isTemp() &&
4478 instr->operands[!i].getTemp().type() == RegType::vgpr) {
4479 ctx.uses[instr->operands[i].tempId()]--;
4480 ctx.uses[ctx.info[instr->operands[i].tempId()].temp.id()]++;
4481
4482 aco_ptr<VALU_instruction> new_instr{
4483 create_instruction<VALU_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1)};
4484 new_instr->operands[0] = Operand::zero();
4485 new_instr->operands[1] = instr->operands[!i];
4486 new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
4487 new_instr->definitions[0] = instr->definitions[0];
4488 new_instr->pass_flags = instr->pass_flags;
4489 instr = std::move(new_instr);
4490 ctx.info[instr->definitions[0].tempId()].label = 0;
4491 return;
4492 }
4493 }
4494 } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->gfx_level >= GFX9) {
4495 if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
4496 1 | 2)) {
4497 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
4498 "012", 1 | 2)) {
4499 } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4500 } else if (combine_v_andor_not(ctx, instr)) {
4501 }
4502 } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->gfx_level >= GFX10) {
4503 if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
4504 1 | 2)) {
4505 } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32,
4506 "012", 1 | 2)) {
4507 } else if (combine_xor_not(ctx, instr)) {
4508 }
4509 } else if (instr->opcode == aco_opcode::v_not_b32 && ctx.program->gfx_level >= GFX10) {
4510 combine_not_xor(ctx, instr);
4511 } else if (instr->opcode == aco_opcode::v_add_u16) {
4512 combine_three_valu_op(
4513 ctx, instr, aco_opcode::v_mul_lo_u16,
4514 ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_u16 : aco_opcode::v_mad_u16,
4515 "120", 1 | 2);
4516 } else if (instr->opcode == aco_opcode::v_add_u16_e64) {
4517 combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16_e64, aco_opcode::v_mad_u16, "120",
4518 1 | 2);
4519 } else if (instr->opcode == aco_opcode::v_add_u32) {
4520 if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4521 } else if (combine_add_bcnt(ctx, instr)) {
4522 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4523 aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4524 } else if (ctx.program->gfx_level >= GFX9 && !instr->usesModifiers()) {
4525 if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120",
4526 1 | 2)) {
4527 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32,
4528 "120", 1 | 2)) {
4529 } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32,
4530 "012", 1 | 2)) {
4531 } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32,
4532 "012", 1 | 2)) {
4533 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32,
4534 "012", 1 | 2)) {
4535 } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4536 }
4537 }
4538 } else if (instr->opcode == aco_opcode::v_add_co_u32 ||
4539 instr->opcode == aco_opcode::v_add_co_u32_e64) {
4540 bool carry_out = ctx.uses[instr->definitions[1].tempId()] > 0;
4541 if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4542 } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
4543 } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4544 aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4545 } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
4546 }
4547 } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == aco_opcode::v_sub_co_u32 ||
4548 instr->opcode == aco_opcode::v_sub_co_u32_e64) {
4549 bool carry_out =
4550 instr->opcode != aco_opcode::v_sub_u32 && ctx.uses[instr->definitions[1].tempId()] > 0;
4551 if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
4552 } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
4553 }
4554 } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
4555 instr->opcode == aco_opcode::v_subrev_co_u32 ||
4556 instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
4557 combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
4558 } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->gfx_level >= GFX9) {
4559 combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120",
4560 2);
4561 } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) &&
4562 ctx.program->gfx_level >= GFX9) {
4563 combine_salu_lshl_add(ctx, instr);
4564 } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) {
4565 if (!combine_salu_not_bitwise(ctx, instr))
4566 combine_inverse_comparison(ctx, instr);
4567 } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
4568 instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
4569 if (combine_ordering_test(ctx, instr)) {
4570 } else if (combine_comparison_ordering(ctx, instr)) {
4571 } else if (combine_constant_comparison_ordering(ctx, instr)) {
4572 } else if (combine_salu_n2(ctx, instr)) {
4573 }
4574 } else if (instr->opcode == aco_opcode::s_abs_i32) {
4575 combine_sabsdiff(ctx, instr);
4576 } else if (instr->opcode == aco_opcode::v_and_b32) {
4577 if (combine_and_subbrev(ctx, instr)) {
4578 } else if (combine_v_andor_not(ctx, instr)) {
4579 }
4580 } else if (instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) {
4581 /* set existing v_fma_f32 with label_mad so we can create v_fmamk_f32/v_fmaak_f32.
4582 * since ctx.uses[mad_info::mul_temp_id] is always 0, we don't have to worry about
4583 * select_instruction() using mad_info::add_instr.
4584 */
4585 ctx.mad_infos.emplace_back(nullptr, 0);
4586 ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4587 } else if (instr->opcode == aco_opcode::v_med3_f32 || instr->opcode == aco_opcode::v_med3_f16) {
4588 /* Optimize v_med3 to v_add so that it can be dual issued on GFX11. We start with v_med3 in
4589 * case omod can be applied.
4590 */
4591 unsigned idx;
4592 if (detect_clamp(instr.get(), &idx)) {
4593 instr->format = asVOP3(Format::VOP2);
4594 instr->operands[0] = instr->operands[idx];
4595 instr->operands[1] = Operand::zero();
4596 instr->opcode =
4597 instr->opcode == aco_opcode::v_med3_f32 ? aco_opcode::v_add_f32 : aco_opcode::v_add_f16;
4598 instr->valu().clamp = true;
4599 instr->valu().abs = (uint8_t)instr->valu().abs[idx];
4600 instr->valu().neg = (uint8_t)instr->valu().neg[idx];
4601 instr->operands.pop_back();
4602 }
4603 } else {
4604 aco_opcode min, max, min3, max3, med3, minmax;
4605 bool some_gfx9_only;
4606 if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &minmax,
4607 &some_gfx9_only) &&
4608 (!some_gfx9_only || ctx.program->gfx_level >= GFX9)) {
4609 if (combine_minmax(ctx, instr, instr->opcode == min ? max : min,
4610 instr->opcode == min ? min3 : max3, minmax)) {
4611 } else {
4612 combine_clamp(ctx, instr, min, max, med3);
4613 }
4614 }
4615 }
4616 }
4617
4618 bool
to_uniform_bool_instr(opt_ctx & ctx,aco_ptr<Instruction> & instr)4619 to_uniform_bool_instr(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4620 {
4621 /* Check every operand to make sure they are suitable. */
4622 for (Operand& op : instr->operands) {
4623 if (!op.isTemp())
4624 return false;
4625 if (!ctx.info[op.tempId()].is_uniform_bool() && !ctx.info[op.tempId()].is_uniform_bitwise())
4626 return false;
4627 }
4628
4629 switch (instr->opcode) {
4630 case aco_opcode::s_and_b32:
4631 case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_and_b32; break;
4632 case aco_opcode::s_or_b32:
4633 case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_or_b32; break;
4634 case aco_opcode::s_xor_b32:
4635 case aco_opcode::s_xor_b64: instr->opcode = aco_opcode::s_absdiff_i32; break;
4636 default:
4637 /* Don't transform other instructions. They are very unlikely to appear here. */
4638 return false;
4639 }
4640
4641 for (Operand& op : instr->operands) {
4642 ctx.uses[op.tempId()]--;
4643
4644 if (ctx.info[op.tempId()].is_uniform_bool()) {
4645 /* Just use the uniform boolean temp. */
4646 op.setTemp(ctx.info[op.tempId()].temp);
4647 } else if (ctx.info[op.tempId()].is_uniform_bitwise()) {
4648 /* Use the SCC definition of the predecessor instruction.
4649 * This allows the predecessor to get picked up by the same optimization (if it has no
4650 * divergent users), and it also makes sure that the current instruction will keep working
4651 * even if the predecessor won't be transformed.
4652 */
4653 Instruction* pred_instr = ctx.info[op.tempId()].instr;
4654 assert(pred_instr->definitions.size() >= 2);
4655 assert(pred_instr->definitions[1].isFixed() &&
4656 pred_instr->definitions[1].physReg() == scc);
4657 op.setTemp(pred_instr->definitions[1].getTemp());
4658 } else {
4659 unreachable("Invalid operand on uniform bitwise instruction.");
4660 }
4661
4662 ctx.uses[op.tempId()]++;
4663 }
4664
4665 instr->definitions[0].setTemp(Temp(instr->definitions[0].tempId(), s1));
4666 assert(instr->operands[0].regClass() == s1);
4667 assert(instr->operands[1].regClass() == s1);
4668 return true;
4669 }
4670
4671 void
select_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4672 select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4673 {
4674 const uint32_t threshold = 4;
4675
4676 if (is_dead(ctx.uses, instr.get())) {
4677 instr.reset();
4678 return;
4679 }
4680
4681 /* convert split_vector into a copy or extract_vector if only one definition is ever used */
4682 if (instr->opcode == aco_opcode::p_split_vector) {
4683 unsigned num_used = 0;
4684 unsigned idx = 0;
4685 unsigned split_offset = 0;
4686 for (unsigned i = 0, offset = 0; i < instr->definitions.size();
4687 offset += instr->definitions[i++].bytes()) {
4688 if (ctx.uses[instr->definitions[i].tempId()]) {
4689 num_used++;
4690 idx = i;
4691 split_offset = offset;
4692 }
4693 }
4694 bool done = false;
4695 if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
4696 ctx.uses[instr->operands[0].tempId()] == 1) {
4697 Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
4698
4699 unsigned off = 0;
4700 Operand op;
4701 for (Operand& vec_op : vec->operands) {
4702 if (off == split_offset) {
4703 op = vec_op;
4704 break;
4705 }
4706 off += vec_op.bytes();
4707 }
4708 if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
4709 ctx.uses[instr->operands[0].tempId()]--;
4710 for (Operand& vec_op : vec->operands) {
4711 if (vec_op.isTemp())
4712 ctx.uses[vec_op.tempId()]--;
4713 }
4714 if (op.isTemp())
4715 ctx.uses[op.tempId()]++;
4716
4717 aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4718 aco_opcode::p_create_vector, Format::PSEUDO, 1, 1)};
4719 extract->operands[0] = op;
4720 extract->definitions[0] = instr->definitions[idx];
4721 instr = std::move(extract);
4722
4723 done = true;
4724 }
4725 }
4726
4727 if (!done && num_used == 1 &&
4728 instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
4729 split_offset % instr->definitions[idx].bytes() == 0) {
4730 aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4731 aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
4732 extract->operands[0] = instr->operands[0];
4733 extract->operands[1] =
4734 Operand::c32((uint32_t)split_offset / instr->definitions[idx].bytes());
4735 extract->definitions[0] = instr->definitions[idx];
4736 instr = std::move(extract);
4737 }
4738 }
4739
4740 mad_info* mad_info = NULL;
4741 if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4742 mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
4743 /* re-check mad instructions */
4744 if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) {
4745 ctx.uses[mad_info->mul_temp_id]++;
4746 if (instr->operands[0].isTemp())
4747 ctx.uses[instr->operands[0].tempId()]--;
4748 if (instr->operands[1].isTemp())
4749 ctx.uses[instr->operands[1].tempId()]--;
4750 instr.swap(mad_info->add_instr);
4751 mad_info = NULL;
4752 }
4753 /* check literals */
4754 else if (!instr->isDPP() && !instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_f64 &&
4755 instr->opcode != aco_opcode::v_mad_legacy_f32 &&
4756 instr->opcode != aco_opcode::v_fma_legacy_f32) {
4757 /* FMA can only take literals on GFX10+ */
4758 if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
4759 ctx.program->gfx_level < GFX10)
4760 return;
4761 /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take
4762 * literals (GFX10+), these instructions don't exist.
4763 */
4764 if (instr->opcode == aco_opcode::v_fma_legacy_f16)
4765 return;
4766
4767 uint32_t literal_mask = 0;
4768 uint32_t fp16_mask = 0;
4769 uint32_t sgpr_mask = 0;
4770 uint32_t vgpr_mask = 0;
4771 uint32_t literal_uses = UINT32_MAX;
4772 uint32_t literal_value = 0;
4773
4774 /* Iterate in reverse to prefer v_madak/v_fmaak. */
4775 for (int i = 2; i >= 0; i--) {
4776 Operand& op = instr->operands[i];
4777 if (!op.isTemp())
4778 continue;
4779 if (ctx.info[op.tempId()].is_literal(get_operand_size(instr, i))) {
4780 uint32_t new_literal = ctx.info[op.tempId()].val;
4781 float value = uif(new_literal);
4782 uint16_t fp16_val = _mesa_float_to_half(value);
4783 bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
4784 if (_mesa_half_to_float(fp16_val) == value &&
4785 (!is_denorm || (ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)))
4786 fp16_mask |= 1 << i;
4787
4788 if (!literal_mask || literal_value == new_literal) {
4789 literal_value = new_literal;
4790 literal_uses = MIN2(literal_uses, ctx.uses[op.tempId()]);
4791 literal_mask |= 1 << i;
4792 continue;
4793 }
4794 }
4795 sgpr_mask |= op.isOfType(RegType::sgpr) << i;
4796 vgpr_mask |= op.isOfType(RegType::vgpr) << i;
4797 }
4798
4799 /* The constant bus limitations before GFX10 disallows SGPRs. */
4800 if (sgpr_mask && ctx.program->gfx_level < GFX10)
4801 literal_mask = 0;
4802
4803 /* Encoding needs a vgpr. */
4804 if (!vgpr_mask)
4805 literal_mask = 0;
4806
4807 /* v_madmk/v_fmamk needs a vgpr in the third source. */
4808 if (!(literal_mask & 0b100) && !(vgpr_mask & 0b100))
4809 literal_mask = 0;
4810
4811 /* opsel with GFX11+ is the only modifier supported by fmamk/fmaak*/
4812 if (instr->valu().abs || instr->valu().neg || instr->valu().omod || instr->valu().clamp ||
4813 (instr->valu().opsel && ctx.program->gfx_level < GFX11))
4814 literal_mask = 0;
4815
4816 if (instr->valu().opsel & ~vgpr_mask)
4817 literal_mask = 0;
4818
4819 /* We can't use three unique fp16 literals */
4820 if (fp16_mask == 0b111)
4821 fp16_mask = 0b11;
4822
4823 if ((instr->opcode == aco_opcode::v_fma_f32 ||
4824 (instr->opcode == aco_opcode::v_mad_f32 && !instr->definitions[0].isPrecise())) &&
4825 !instr->valu().omod && ctx.program->gfx_level >= GFX10 &&
4826 util_bitcount(fp16_mask) > std::max<uint32_t>(util_bitcount(literal_mask), 1)) {
4827 assert(ctx.program->dev.fused_mad_mix);
4828 u_foreach_bit (i, fp16_mask)
4829 ctx.uses[instr->operands[i].tempId()]--;
4830 mad_info->fp16_mask = fp16_mask;
4831 return;
4832 }
4833
4834 /* Limit the number of literals to apply to not increase the code
4835 * size too much, but always apply literals for v_mad->v_madak
4836 * because both instructions are 64-bit and this doesn't increase
4837 * code size.
4838 * TODO: try to apply the literals earlier to lower the number of
4839 * uses below threshold
4840 */
4841 if (literal_mask && (literal_uses < threshold || (literal_mask & 0b100))) {
4842 u_foreach_bit (i, literal_mask)
4843 ctx.uses[instr->operands[i].tempId()]--;
4844 mad_info->literal_mask = literal_mask;
4845 return;
4846 }
4847 }
4848 }
4849
4850 /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions
4851 * when it isn't beneficial */
4852 if (instr->isBranch() && instr->operands.size() && instr->operands[0].isTemp() &&
4853 instr->operands[0].isFixed() && instr->operands[0].physReg() == scc) {
4854 ctx.info[instr->operands[0].tempId()].set_scc_needed();
4855 return;
4856 } else if ((instr->opcode == aco_opcode::s_cselect_b64 ||
4857 instr->opcode == aco_opcode::s_cselect_b32) &&
4858 instr->operands[2].isTemp()) {
4859 ctx.info[instr->operands[2].tempId()].set_scc_needed();
4860 }
4861
4862 /* check for literals */
4863 if (!instr->isSALU() && !instr->isVALU())
4864 return;
4865
4866 /* Transform uniform bitwise boolean operations to 32-bit when there are no divergent uses. */
4867 if (instr->definitions.size() && ctx.uses[instr->definitions[0].tempId()] == 0 &&
4868 ctx.info[instr->definitions[0].tempId()].is_uniform_bitwise()) {
4869 bool transform_done = to_uniform_bool_instr(ctx, instr);
4870
4871 if (transform_done && !ctx.info[instr->definitions[1].tempId()].is_scc_needed()) {
4872 /* Swap the two definition IDs in order to avoid overusing the SCC.
4873 * This reduces extra moves generated by RA. */
4874 uint32_t def0_id = instr->definitions[0].getTemp().id();
4875 uint32_t def1_id = instr->definitions[1].getTemp().id();
4876 instr->definitions[0].setTemp(Temp(def1_id, s1));
4877 instr->definitions[1].setTemp(Temp(def0_id, s1));
4878 }
4879
4880 return;
4881 }
4882
4883 /* This optimization is done late in order to be able to apply otherwise
4884 * unsafe optimizations such as the inverse comparison optimization.
4885 */
4886 if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
4887 if (instr->operands[0].isTemp() && fixed_to_exec(instr->operands[1]) &&
4888 ctx.uses[instr->operands[0].tempId()] == 1 &&
4889 ctx.uses[instr->definitions[1].tempId()] == 0 &&
4890 can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), instr->pass_flags)) {
4891 ctx.uses[instr->operands[0].tempId()]--;
4892 ctx.info[instr->operands[0].tempId()].instr->definitions[0].setTemp(
4893 instr->definitions[0].getTemp());
4894 instr.reset();
4895 return;
4896 }
4897 }
4898
4899 /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
4900 if (instr->isVALU() && !instr->isDPP()) {
4901 for (unsigned i = 0; i < instr->operands.size(); i++) {
4902 if (!instr->operands[i].isTemp())
4903 continue;
4904 ssa_info info = ctx.info[instr->operands[i].tempId()];
4905
4906 if (!info.is_dpp() || info.instr->pass_flags != instr->pass_flags)
4907 continue;
4908
4909 /* We won't eliminate the DPP mov if the operand is used twice */
4910 bool op_used_twice = false;
4911 for (unsigned j = 0; j < instr->operands.size(); j++)
4912 op_used_twice |= i != j && instr->operands[i] == instr->operands[j];
4913 if (op_used_twice)
4914 continue;
4915
4916 if (i != 0) {
4917 if (!can_swap_operands(instr, &instr->opcode, 0, i))
4918 continue;
4919 instr->valu().swapOperands(0, i);
4920 }
4921
4922 if (!can_use_DPP(ctx.program->gfx_level, instr, info.is_dpp8()))
4923 continue;
4924
4925 bool dpp8 = info.is_dpp8();
4926 bool input_mods = can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, 0) &&
4927 get_operand_size(instr, 0) == 32;
4928 bool mov_uses_mods = info.instr->valu().neg[0] || info.instr->valu().abs[0];
4929 if (((dpp8 && ctx.program->gfx_level < GFX11) || !input_mods) && mov_uses_mods)
4930 continue;
4931
4932 convert_to_DPP(ctx.program->gfx_level, instr, dpp8);
4933
4934 if (dpp8) {
4935 DPP8_instruction* dpp = &instr->dpp8();
4936 dpp->lane_sel = info.instr->dpp8().lane_sel;
4937 dpp->fetch_inactive = info.instr->dpp8().fetch_inactive;
4938 if (mov_uses_mods)
4939 instr->format = asVOP3(instr->format);
4940 } else {
4941 DPP16_instruction* dpp = &instr->dpp16();
4942 dpp->dpp_ctrl = info.instr->dpp16().dpp_ctrl;
4943 dpp->bound_ctrl = info.instr->dpp16().bound_ctrl;
4944 dpp->fetch_inactive = info.instr->dpp16().fetch_inactive;
4945 }
4946
4947 instr->valu().neg[0] ^= info.instr->valu().neg[0] && !instr->valu().abs[0];
4948 instr->valu().abs[0] |= info.instr->valu().abs[0];
4949
4950 if (--ctx.uses[info.instr->definitions[0].tempId()])
4951 ctx.uses[info.instr->operands[0].tempId()]++;
4952 instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4953 break;
4954 }
4955 }
4956
4957 /* Use v_fma_mix for f2f32/f2f16 if it has higher throughput.
4958 * Do this late to not disturb other optimizations.
4959 */
4960 if ((instr->opcode == aco_opcode::v_cvt_f32_f16 || instr->opcode == aco_opcode::v_cvt_f16_f32) &&
4961 ctx.program->gfx_level >= GFX11 && ctx.program->wave_size == 64 && !instr->valu().omod &&
4962 !instr->isDPP()) {
4963 bool is_f2f16 = instr->opcode == aco_opcode::v_cvt_f16_f32;
4964 Instruction* fma = create_instruction<VALU_instruction>(
4965 is_f2f16 ? aco_opcode::v_fma_mixlo_f16 : aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1);
4966 fma->definitions[0] = instr->definitions[0];
4967 fma->operands[0] = instr->operands[0];
4968 fma->valu().opsel_hi[0] = !is_f2f16;
4969 fma->valu().opsel_lo[0] = instr->valu().opsel[0];
4970 fma->valu().clamp = instr->valu().clamp;
4971 fma->valu().abs[0] = instr->valu().abs[0];
4972 fma->valu().neg[0] = instr->valu().neg[0];
4973 fma->operands[1] = Operand::c32(fui(1.0f));
4974 fma->operands[2] = Operand::zero();
4975 /* fma_mix is only dual issued if dst and acc type match */
4976 fma->valu().opsel_hi[2] = is_f2f16;
4977 fma->valu().neg[2] = true;
4978 instr.reset(fma);
4979 ctx.info[instr->definitions[0].tempId()].label = 0;
4980 }
4981
4982 if (instr->isSDWA() || (instr->isVOP3() && ctx.program->gfx_level < GFX10) ||
4983 (instr->isVOP3P() && ctx.program->gfx_level < GFX10))
4984 return; /* some encodings can't ever take literals */
4985
4986 /* we do not apply the literals yet as we don't know if it is profitable */
4987 Operand current_literal(s1);
4988
4989 unsigned literal_id = 0;
4990 unsigned literal_uses = UINT32_MAX;
4991 Operand literal(s1);
4992 unsigned num_operands = 1;
4993 if (instr->isSALU() || (ctx.program->gfx_level >= GFX10 &&
4994 (can_use_VOP3(ctx, instr) || instr->isVOP3P()) && !instr->isDPP()))
4995 num_operands = instr->operands.size();
4996 /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
4997 else if (instr->isVALU() && instr->operands.size() >= 3)
4998 return;
4999
5000 unsigned sgpr_ids[2] = {0, 0};
5001 bool is_literal_sgpr = false;
5002 uint32_t mask = 0;
5003
5004 /* choose a literal to apply */
5005 for (unsigned i = 0; i < num_operands; i++) {
5006 Operand op = instr->operands[i];
5007 unsigned bits = get_operand_size(instr, i);
5008
5009 if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
5010 op.tempId() != sgpr_ids[0])
5011 sgpr_ids[!!sgpr_ids[0]] = op.tempId();
5012
5013 if (op.isLiteral()) {
5014 current_literal = op;
5015 continue;
5016 } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
5017 continue;
5018 }
5019
5020 if (!alu_can_accept_constant(instr, i))
5021 continue;
5022
5023 if (ctx.uses[op.tempId()] < literal_uses) {
5024 is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
5025 mask = 0;
5026 literal = Operand::c32(ctx.info[op.tempId()].val);
5027 literal_uses = ctx.uses[op.tempId()];
5028 literal_id = op.tempId();
5029 }
5030
5031 mask |= (op.tempId() == literal_id) << i;
5032 }
5033
5034 /* don't go over the constant bus limit */
5035 bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
5036 instr->opcode == aco_opcode::v_lshrrev_b64 ||
5037 instr->opcode == aco_opcode::v_ashrrev_i64;
5038 unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
5039 if (ctx.program->gfx_level >= GFX10 && !is_shift64)
5040 const_bus_limit = 2;
5041
5042 unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
5043 if (num_sgprs == const_bus_limit && !is_literal_sgpr)
5044 return;
5045
5046 if (literal_id && literal_uses < threshold &&
5047 (current_literal.isUndefined() ||
5048 (current_literal.size() == literal.size() &&
5049 current_literal.constantValue() == literal.constantValue()))) {
5050 /* mark the literal to be applied */
5051 while (mask) {
5052 unsigned i = u_bit_scan(&mask);
5053 if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
5054 ctx.uses[instr->operands[i].tempId()]--;
5055 }
5056 }
5057 }
5058
5059 static aco_opcode
sopk_opcode_for_sopc(aco_opcode opcode)5060 sopk_opcode_for_sopc(aco_opcode opcode)
5061 {
5062 #define CTOK(op) \
5063 case aco_opcode::s_cmp_##op##_i32: return aco_opcode::s_cmpk_##op##_i32; \
5064 case aco_opcode::s_cmp_##op##_u32: return aco_opcode::s_cmpk_##op##_u32;
5065 switch (opcode) {
5066 CTOK(eq)
5067 CTOK(lg)
5068 CTOK(gt)
5069 CTOK(ge)
5070 CTOK(lt)
5071 CTOK(le)
5072 default: return aco_opcode::num_opcodes;
5073 }
5074 #undef CTOK
5075 }
5076
5077 static bool
sopc_is_signed(aco_opcode opcode)5078 sopc_is_signed(aco_opcode opcode)
5079 {
5080 #define SOPC(op) \
5081 case aco_opcode::s_cmp_##op##_i32: return true; \
5082 case aco_opcode::s_cmp_##op##_u32: return false;
5083 switch (opcode) {
5084 SOPC(eq)
5085 SOPC(lg)
5086 SOPC(gt)
5087 SOPC(ge)
5088 SOPC(lt)
5089 SOPC(le)
5090 default: unreachable("Not a valid SOPC instruction.");
5091 }
5092 #undef SOPC
5093 }
5094
5095 static aco_opcode
sopc_32_swapped(aco_opcode opcode)5096 sopc_32_swapped(aco_opcode opcode)
5097 {
5098 #define SOPC(op1, op2) \
5099 case aco_opcode::s_cmp_##op1##_i32: return aco_opcode::s_cmp_##op2##_i32; \
5100 case aco_opcode::s_cmp_##op1##_u32: return aco_opcode::s_cmp_##op2##_u32;
5101 switch (opcode) {
5102 SOPC(eq, eq)
5103 SOPC(lg, lg)
5104 SOPC(gt, lt)
5105 SOPC(ge, le)
5106 SOPC(lt, gt)
5107 SOPC(le, ge)
5108 default: return aco_opcode::num_opcodes;
5109 }
5110 #undef SOPC
5111 }
5112
5113 static void
try_convert_sopc_to_sopk(aco_ptr<Instruction> & instr)5114 try_convert_sopc_to_sopk(aco_ptr<Instruction>& instr)
5115 {
5116 if (sopk_opcode_for_sopc(instr->opcode) == aco_opcode::num_opcodes)
5117 return;
5118
5119 if (instr->operands[0].isLiteral()) {
5120 std::swap(instr->operands[0], instr->operands[1]);
5121 instr->opcode = sopc_32_swapped(instr->opcode);
5122 }
5123
5124 if (!instr->operands[1].isLiteral())
5125 return;
5126
5127 if (instr->operands[0].isFixed() && instr->operands[0].physReg() >= 128)
5128 return;
5129
5130 uint32_t value = instr->operands[1].constantValue();
5131
5132 const uint32_t i16_mask = 0xffff8000u;
5133
5134 bool value_is_i16 = (value & i16_mask) == 0 || (value & i16_mask) == i16_mask;
5135 bool value_is_u16 = !(value & 0xffff0000u);
5136
5137 if (!value_is_i16 && !value_is_u16)
5138 return;
5139
5140 if (!value_is_i16 && sopc_is_signed(instr->opcode)) {
5141 if (instr->opcode == aco_opcode::s_cmp_lg_i32)
5142 instr->opcode = aco_opcode::s_cmp_lg_u32;
5143 else if (instr->opcode == aco_opcode::s_cmp_eq_i32)
5144 instr->opcode = aco_opcode::s_cmp_eq_u32;
5145 else
5146 return;
5147 } else if (!value_is_u16 && !sopc_is_signed(instr->opcode)) {
5148 if (instr->opcode == aco_opcode::s_cmp_lg_u32)
5149 instr->opcode = aco_opcode::s_cmp_lg_i32;
5150 else if (instr->opcode == aco_opcode::s_cmp_eq_u32)
5151 instr->opcode = aco_opcode::s_cmp_eq_i32;
5152 else
5153 return;
5154 }
5155
5156 static_assert(sizeof(SOPK_instruction) <= sizeof(SOPC_instruction),
5157 "Invalid direct instruction cast.");
5158 instr->format = Format::SOPK;
5159 SOPK_instruction* instr_sopk = &instr->sopk();
5160
5161 instr_sopk->imm = instr_sopk->operands[1].constantValue() & 0xffff;
5162 instr_sopk->opcode = sopk_opcode_for_sopc(instr_sopk->opcode);
5163 instr_sopk->operands.pop_back();
5164 }
5165
5166 static void
unswizzle_vop3p_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)5167 unswizzle_vop3p_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
5168 {
5169 /* This opt is only beneficial for v_pk_fma_f16 because we can use v_pk_fmac_f16 if the
5170 * instruction doesn't use swizzles. */
5171 if (instr->opcode != aco_opcode::v_pk_fma_f16)
5172 return;
5173
5174 VALU_instruction& vop3p = instr->valu();
5175
5176 unsigned literal_swizzle = ~0u;
5177 for (unsigned i = 0; i < instr->operands.size(); i++) {
5178 if (!instr->operands[i].isLiteral())
5179 continue;
5180 unsigned new_swizzle = vop3p.opsel_lo[i] | (vop3p.opsel_hi[i] << 1);
5181 if (literal_swizzle != ~0u && new_swizzle != literal_swizzle)
5182 return; /* Literal swizzles conflict. */
5183 literal_swizzle = new_swizzle;
5184 }
5185
5186 if (literal_swizzle == 0b10 || literal_swizzle == ~0u)
5187 return; /* already unswizzled */
5188
5189 for (unsigned i = 0; i < instr->operands.size(); i++) {
5190 if (!instr->operands[i].isLiteral())
5191 continue;
5192 uint32_t literal = instr->operands[i].constantValue();
5193 literal = (literal >> (16 * (literal_swizzle & 0x1)) & 0xffff) |
5194 (literal >> (8 * (literal_swizzle & 0x2)) << 16);
5195 instr->operands[i] = Operand::literal32(literal);
5196 vop3p.opsel_lo[i] = false;
5197 vop3p.opsel_hi[i] = true;
5198 }
5199 }
5200
5201 void
apply_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)5202 apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
5203 {
5204 /* Cleanup Dead Instructions */
5205 if (!instr)
5206 return;
5207
5208 /* apply literals on MAD */
5209 if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
5210 mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
5211 const bool madak = (info->literal_mask & 0b100);
5212 bool has_dead_literal = false;
5213 u_foreach_bit (i, info->literal_mask | info->fp16_mask)
5214 has_dead_literal |= ctx.uses[instr->operands[i].tempId()] == 0;
5215
5216 if (has_dead_literal && info->fp16_mask) {
5217 instr->format = Format::VOP3P;
5218 instr->opcode = aco_opcode::v_fma_mix_f32;
5219
5220 uint32_t literal = 0;
5221 bool second = false;
5222 u_foreach_bit (i, info->fp16_mask) {
5223 float value = uif(ctx.info[instr->operands[i].tempId()].val);
5224 literal |= _mesa_float_to_half(value) << (second * 16);
5225 instr->valu().opsel_lo[i] = second;
5226 instr->valu().opsel_hi[i] = true;
5227 second = true;
5228 }
5229
5230 for (unsigned i = 0; i < 3; i++) {
5231 if (info->fp16_mask & (1 << i))
5232 instr->operands[i] = Operand::literal32(literal);
5233 }
5234
5235 ctx.instructions.emplace_back(std::move(instr));
5236 return;
5237 }
5238
5239 if (has_dead_literal || madak) {
5240 aco_opcode new_op = madak ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
5241 if (instr->opcode == aco_opcode::v_fma_f32)
5242 new_op = madak ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
5243 else if (instr->opcode == aco_opcode::v_mad_f16 ||
5244 instr->opcode == aco_opcode::v_mad_legacy_f16)
5245 new_op = madak ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
5246 else if (instr->opcode == aco_opcode::v_fma_f16)
5247 new_op = madak ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
5248
5249 uint32_t literal = ctx.info[instr->operands[ffs(info->literal_mask) - 1].tempId()].val;
5250 instr->format = Format::VOP2;
5251 instr->opcode = new_op;
5252 for (unsigned i = 0; i < 3; i++) {
5253 if (info->literal_mask & (1 << i))
5254 instr->operands[i] = Operand::literal32(literal);
5255 }
5256 if (madak) { /* add literal -> madak */
5257 if (!instr->operands[1].isOfType(RegType::vgpr))
5258 instr->valu().swapOperands(0, 1);
5259 } else { /* mul literal -> madmk */
5260 if (!(info->literal_mask & 0b10))
5261 instr->valu().swapOperands(0, 1);
5262 instr->valu().swapOperands(1, 2);
5263 }
5264 ctx.instructions.emplace_back(std::move(instr));
5265 return;
5266 }
5267 }
5268
5269 /* apply literals on other SALU/VALU */
5270 if (instr->isSALU() || instr->isVALU()) {
5271 for (unsigned i = 0; i < instr->operands.size(); i++) {
5272 Operand op = instr->operands[i];
5273 unsigned bits = get_operand_size(instr, i);
5274 if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
5275 Operand literal = Operand::literal32(ctx.info[op.tempId()].val);
5276 instr->format = withoutDPP(instr->format);
5277 if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P)
5278 instr->format = asVOP3(instr->format);
5279 instr->operands[i] = literal;
5280 }
5281 }
5282 }
5283
5284 if (instr->isSOPC())
5285 try_convert_sopc_to_sopk(instr);
5286
5287 /* allow more s_addk_i32 optimizations if carry isn't used */
5288 if (instr->opcode == aco_opcode::s_add_u32 && ctx.uses[instr->definitions[1].tempId()] == 0 &&
5289 (instr->operands[0].isLiteral() || instr->operands[1].isLiteral()))
5290 instr->opcode = aco_opcode::s_add_i32;
5291
5292 if (instr->isVOP3P())
5293 unswizzle_vop3p_literals(ctx, instr);
5294
5295 ctx.instructions.emplace_back(std::move(instr));
5296 }
5297
5298 void
optimize(Program * program)5299 optimize(Program* program)
5300 {
5301 opt_ctx ctx;
5302 ctx.program = program;
5303 std::vector<ssa_info> info(program->peekAllocationId());
5304 ctx.info = info.data();
5305
5306 /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
5307 for (Block& block : program->blocks) {
5308 ctx.fp_mode = block.fp_mode;
5309 for (aco_ptr<Instruction>& instr : block.instructions)
5310 label_instruction(ctx, instr);
5311 }
5312
5313 ctx.uses = dead_code_analysis(program);
5314
5315 /* 2. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
5316 for (Block& block : program->blocks) {
5317 ctx.fp_mode = block.fp_mode;
5318 for (aco_ptr<Instruction>& instr : block.instructions)
5319 combine_instruction(ctx, instr);
5320 }
5321
5322 /* 3. Top-Down DAG pass (backward) to select instructions (includes DCE) */
5323 for (auto block_rit = program->blocks.rbegin(); block_rit != program->blocks.rend();
5324 ++block_rit) {
5325 Block* block = &(*block_rit);
5326 ctx.fp_mode = block->fp_mode;
5327 for (auto instr_rit = block->instructions.rbegin(); instr_rit != block->instructions.rend();
5328 ++instr_rit)
5329 select_instruction(ctx, *instr_rit);
5330 }
5331
5332 /* 4. Add literals to instructions */
5333 for (Block& block : program->blocks) {
5334 ctx.instructions.reserve(block.instructions.size());
5335 ctx.fp_mode = block.fp_mode;
5336 for (aco_ptr<Instruction>& instr : block.instructions)
5337 apply_literals(ctx, instr);
5338 block.instructions = std::move(ctx.instructions);
5339 }
5340 }
5341
5342 } // namespace aco
5343