1 /*
2 * Copyright 2023 Valve Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "compiler/nir/nir.h"
7 #include "compiler/nir/nir_builder.h"
8 #include "agx_nir.h"
9 #include "nir_builder_opcodes.h"
10 #include "nir_intrinsics.h"
11
12 static bool
lower(nir_builder * b,nir_intrinsic_instr * intr,void * data)13 lower(nir_builder *b, nir_intrinsic_instr *intr, void *data)
14 {
15 b->cursor = nir_before_instr(&intr->instr);
16
17 switch (intr->intrinsic) {
18 case nir_intrinsic_vote_any: {
19 /* We don't have vote instructions, but we have efficient ballots */
20 nir_def *ballot = nir_ballot(b, 1, 32, intr->src[0].ssa);
21 nir_def_rewrite_uses(&intr->def, nir_ine_imm(b, ballot, 0));
22 return true;
23 }
24
25 case nir_intrinsic_vote_all: {
26 nir_def *ballot = nir_ballot(b, 1, 32, nir_inot(b, intr->src[0].ssa));
27 nir_def_rewrite_uses(&intr->def, nir_ieq_imm(b, ballot, 0));
28 return true;
29 }
30
31 case nir_intrinsic_first_invocation: {
32 nir_def *active_id = nir_load_active_subgroup_invocation_agx(b);
33 nir_def *is_first = nir_ieq_imm(b, active_id, 0);
34 nir_def *first_bit = nir_ballot(b, 1, 32, is_first);
35 nir_def_rewrite_uses(&intr->def, nir_ufind_msb(b, first_bit));
36 return true;
37 }
38
39 case nir_intrinsic_vote_ieq:
40 case nir_intrinsic_vote_feq: {
41 /* The common lowering does:
42 *
43 * vote_all(x == read_first(x))
44 *
45 * This is not optimal for AGX, since we have ufind_msb but not ctz, so
46 * it's cheaper to read the last invocation than the first. So we do:
47 *
48 * vote_all(x == read_last(x))
49 *
50 * implemented with lowered instructions as
51 *
52 * ballot(x != broadcast(x, ffs(ballot(true)))) == 0
53 */
54 nir_def *active_mask = nir_ballot(b, 1, 32, nir_imm_true(b));
55 nir_def *active_bit = nir_ufind_msb(b, active_mask);
56 nir_def *other = nir_read_invocation(b, intr->src[0].ssa, active_bit);
57 nir_def *is_ne;
58
59 if (intr->intrinsic == nir_intrinsic_vote_feq) {
60 is_ne = nir_fneu(b, other, intr->src[0].ssa);
61 } else {
62 is_ne = nir_ine(b, other, intr->src[0].ssa);
63 }
64
65 nir_def *ballot = nir_ballot(b, 1, 32, is_ne);
66 nir_def_rewrite_uses(&intr->def, nir_ieq_imm(b, ballot, 0));
67 return true;
68 }
69
70 default:
71 return false;
72 }
73 }
74
75 bool
agx_nir_lower_subgroups(nir_shader * s)76 agx_nir_lower_subgroups(nir_shader *s)
77 {
78 /* First, do as much common lowering as we can */
79 nir_lower_subgroups_options opts = {
80 .lower_read_first_invocation = true,
81 .lower_to_scalar = true,
82 .lower_subgroup_masks = true,
83 .ballot_components = 1,
84 .ballot_bit_size = 32,
85 .subgroup_size = 32,
86 };
87
88 bool progress = nir_lower_subgroups(s, &opts);
89
90 /* Then do AGX-only lowerings on top */
91 progress |= nir_shader_intrinsics_pass(
92 s, lower, nir_metadata_block_index | nir_metadata_dominance, NULL);
93
94 return progress;
95 }
96