1 /*
2 * Copyright © 2017 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 /**
28 * \file nir_opt_intrinsics.c
29 */
30
31 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
32 static nir_ssa_def *
uint_to_ballot_type(nir_builder * b,nir_ssa_def * value,unsigned num_components,unsigned bit_size)33 uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
34 unsigned num_components, unsigned bit_size)
35 {
36 assert(value->num_components == 1);
37 assert(value->bit_size == 32 || value->bit_size == 64);
38
39 nir_ssa_def *zero = nir_imm_int(b, 0);
40 if (num_components > 1) {
41 /* SPIR-V uses a uvec4 for ballot values */
42 assert(num_components == 4);
43 assert(bit_size == 32);
44
45 if (value->bit_size == 32) {
46 return nir_vec4(b, value, zero, zero, zero);
47 } else {
48 assert(value->bit_size == 64);
49 return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
50 nir_unpack_64_2x32_split_y(b, value),
51 zero, zero);
52 }
53 } else {
54 /* GLSL uses a uint64_t for ballot values */
55 assert(num_components == 1);
56 assert(bit_size == 64);
57
58 if (value->bit_size == 32) {
59 return nir_pack_64_2x32_split(b, value, zero);
60 } else {
61 assert(value->bit_size == 64);
62 return value;
63 }
64 }
65 }
66
67 static nir_ssa_def *
lower_read_invocation_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin)68 lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
69 {
70 /* This is safe to call on scalar things but it would be silly */
71 assert(intrin->dest.ssa.num_components > 1);
72
73 nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0],
74 intrin->num_components);
75 nir_ssa_def *reads[4];
76
77 for (unsigned i = 0; i < intrin->num_components; i++) {
78 nir_intrinsic_instr *chan_intrin =
79 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
80 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
81 1, intrin->dest.ssa.bit_size, NULL);
82 chan_intrin->num_components = 1;
83
84 /* value */
85 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
86 /* invocation */
87 if (intrin->intrinsic == nir_intrinsic_read_invocation)
88 nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
89
90 nir_builder_instr_insert(b, &chan_intrin->instr);
91
92 reads[i] = &chan_intrin->dest.ssa;
93 }
94
95 return nir_vec(b, reads, intrin->num_components);
96 }
97
98 static nir_ssa_def *
lower_subgroups_intrin(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)99 lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
100 const nir_lower_subgroups_options *options)
101 {
102 switch (intrin->intrinsic) {
103 case nir_intrinsic_vote_any:
104 case nir_intrinsic_vote_all:
105 if (options->lower_vote_trivial)
106 return nir_ssa_for_src(b, intrin->src[0], 1);
107 break;
108
109 case nir_intrinsic_vote_eq:
110 if (options->lower_vote_trivial)
111 return nir_imm_int(b, NIR_TRUE);
112 break;
113
114 case nir_intrinsic_load_subgroup_size:
115 if (options->subgroup_size)
116 return nir_imm_int(b, options->subgroup_size);
117 break;
118
119 case nir_intrinsic_read_invocation:
120 case nir_intrinsic_read_first_invocation:
121 if (options->lower_to_scalar && intrin->num_components > 1)
122 return lower_read_invocation_to_scalar(b, intrin);
123 break;
124
125 case nir_intrinsic_load_subgroup_eq_mask:
126 case nir_intrinsic_load_subgroup_ge_mask:
127 case nir_intrinsic_load_subgroup_gt_mask:
128 case nir_intrinsic_load_subgroup_le_mask:
129 case nir_intrinsic_load_subgroup_lt_mask: {
130 if (!options->lower_subgroup_masks)
131 return NULL;
132
133 /* If either the result or the requested bit size is 64-bits then we
134 * know that we have 64-bit types and using them will probably be more
135 * efficient than messing around with 32-bit shifts and packing.
136 */
137 const unsigned bit_size = MAX2(options->ballot_bit_size,
138 intrin->dest.ssa.bit_size);
139
140 assert(options->subgroup_size <= 64);
141 uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
142
143 nir_ssa_def *count = nir_load_subgroup_invocation(b);
144 nir_ssa_def *val;
145 switch (intrin->intrinsic) {
146 case nir_intrinsic_load_subgroup_eq_mask:
147 val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
148 break;
149 case nir_intrinsic_load_subgroup_ge_mask:
150 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
151 nir_imm_intN_t(b, group_mask, bit_size));
152 break;
153 case nir_intrinsic_load_subgroup_gt_mask:
154 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
155 nir_imm_intN_t(b, group_mask, bit_size));
156 break;
157 case nir_intrinsic_load_subgroup_le_mask:
158 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
159 break;
160 case nir_intrinsic_load_subgroup_lt_mask:
161 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
162 break;
163 default:
164 unreachable("you seriously can't tell this is unreachable?");
165 }
166
167 return uint_to_ballot_type(b, val,
168 intrin->dest.ssa.num_components,
169 intrin->dest.ssa.bit_size);
170 }
171
172 case nir_intrinsic_ballot: {
173 if (intrin->dest.ssa.num_components == 1 &&
174 intrin->dest.ssa.bit_size == options->ballot_bit_size)
175 return NULL;
176
177 nir_intrinsic_instr *ballot =
178 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
179 ballot->num_components = 1;
180 nir_ssa_dest_init(&ballot->instr, &ballot->dest,
181 1, options->ballot_bit_size, NULL);
182 nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
183 nir_builder_instr_insert(b, &ballot->instr);
184
185 return uint_to_ballot_type(b, &ballot->dest.ssa,
186 intrin->dest.ssa.num_components,
187 intrin->dest.ssa.bit_size);
188 }
189
190 default:
191 break;
192 }
193
194 return NULL;
195 }
196
197 static bool
lower_subgroups_impl(nir_function_impl * impl,const nir_lower_subgroups_options * options)198 lower_subgroups_impl(nir_function_impl *impl,
199 const nir_lower_subgroups_options *options)
200 {
201 nir_builder b;
202 nir_builder_init(&b, impl);
203 bool progress = false;
204
205 nir_foreach_block(block, impl) {
206 nir_foreach_instr_safe(instr, block) {
207 if (instr->type != nir_instr_type_intrinsic)
208 continue;
209
210 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
211 b.cursor = nir_before_instr(instr);
212
213 nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options);
214 if (!lower)
215 continue;
216
217 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower));
218 nir_instr_remove(instr);
219 progress = true;
220 }
221 }
222
223 return progress;
224 }
225
226 bool
nir_lower_subgroups(nir_shader * shader,const nir_lower_subgroups_options * options)227 nir_lower_subgroups(nir_shader *shader,
228 const nir_lower_subgroups_options *options)
229 {
230 bool progress = false;
231
232 nir_foreach_function(function, shader) {
233 if (!function->impl)
234 continue;
235
236 if (lower_subgroups_impl(function->impl, options)) {
237 progress = true;
238 nir_metadata_preserve(function->impl, nir_metadata_block_index |
239 nir_metadata_dominance);
240 }
241 }
242
243 return progress;
244 }
245