• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2023 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "compiler/nir/nir.h"
7 #include "compiler/nir/nir_builder.h"
8 #include "agx_nir.h"
9 #include "nir_builder_opcodes.h"
10 #include "nir_intrinsics.h"
11 
12 static bool
lower(nir_builder * b,nir_intrinsic_instr * intr,void * data)13 lower(nir_builder *b, nir_intrinsic_instr *intr, void *data)
14 {
15    b->cursor = nir_before_instr(&intr->instr);
16 
17    switch (intr->intrinsic) {
18    case nir_intrinsic_vote_any: {
19       /* We don't have vote instructions, but we have efficient ballots */
20       nir_def *ballot = nir_ballot(b, 1, 32, intr->src[0].ssa);
21       nir_def_rewrite_uses(&intr->def, nir_ine_imm(b, ballot, 0));
22       return true;
23    }
24 
25    case nir_intrinsic_vote_all: {
26       nir_def *ballot = nir_ballot(b, 1, 32, nir_inot(b, intr->src[0].ssa));
27       nir_def_rewrite_uses(&intr->def, nir_ieq_imm(b, ballot, 0));
28       return true;
29    }
30 
31    case nir_intrinsic_first_invocation: {
32       nir_def *active_id = nir_load_active_subgroup_invocation_agx(b);
33       nir_def *is_first = nir_ieq_imm(b, active_id, 0);
34       nir_def *first_bit = nir_ballot(b, 1, 32, is_first);
35       nir_def_rewrite_uses(&intr->def, nir_ufind_msb(b, first_bit));
36       return true;
37    }
38 
39    case nir_intrinsic_vote_ieq:
40    case nir_intrinsic_vote_feq: {
41       /* The common lowering does:
42        *
43        *    vote_all(x == read_first(x))
44        *
45        * This is not optimal for AGX, since we have ufind_msb but not ctz, so
46        * it's cheaper to read the last invocation than the first. So we do:
47        *
48        *    vote_all(x == read_last(x))
49        *
50        * implemented with lowered instructions as
51        *
52        *    ballot(x != broadcast(x, ffs(ballot(true)))) == 0
53        */
54       nir_def *active_mask = nir_ballot(b, 1, 32, nir_imm_true(b));
55       nir_def *active_bit = nir_ufind_msb(b, active_mask);
56       nir_def *other = nir_read_invocation(b, intr->src[0].ssa, active_bit);
57       nir_def *is_ne;
58 
59       if (intr->intrinsic == nir_intrinsic_vote_feq) {
60          is_ne = nir_fneu(b, other, intr->src[0].ssa);
61       } else {
62          is_ne = nir_ine(b, other, intr->src[0].ssa);
63       }
64 
65       nir_def *ballot = nir_ballot(b, 1, 32, is_ne);
66       nir_def_rewrite_uses(&intr->def, nir_ieq_imm(b, ballot, 0));
67       return true;
68    }
69 
70    default:
71       return false;
72    }
73 }
74 
75 bool
agx_nir_lower_subgroups(nir_shader * s)76 agx_nir_lower_subgroups(nir_shader *s)
77 {
78    /* First, do as much common lowering as we can */
79    nir_lower_subgroups_options opts = {
80       .lower_read_first_invocation = true,
81       .lower_to_scalar = true,
82       .lower_subgroup_masks = true,
83       .ballot_components = 1,
84       .ballot_bit_size = 32,
85       .subgroup_size = 32,
86    };
87 
88    bool progress = nir_lower_subgroups(s, &opts);
89 
90    /* Then do AGX-only lowerings on top */
91    progress |= nir_shader_intrinsics_pass(
92       s, lower, nir_metadata_block_index | nir_metadata_dominance, NULL);
93 
94    return progress;
95 }
96