• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2017 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 
27 /**
28  * \file nir_opt_intrinsics.c
29  */
30 
31 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
32 static nir_ssa_def *
uint_to_ballot_type(nir_builder * b,nir_ssa_def * value,unsigned num_components,unsigned bit_size)33 uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
34                     unsigned num_components, unsigned bit_size)
35 {
36    assert(value->num_components == 1);
37    assert(value->bit_size == 32 || value->bit_size == 64);
38 
39    nir_ssa_def *zero = nir_imm_int(b, 0);
40    if (num_components > 1) {
41       /* SPIR-V uses a uvec4 for ballot values */
42       assert(num_components == 4);
43       assert(bit_size == 32);
44 
45       if (value->bit_size == 32) {
46          return nir_vec4(b, value, zero, zero, zero);
47       } else {
48          assert(value->bit_size == 64);
49          return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
50                             nir_unpack_64_2x32_split_y(b, value),
51                             zero, zero);
52       }
53    } else {
54       /* GLSL uses a uint64_t for ballot values */
55       assert(num_components == 1);
56       assert(bit_size == 64);
57 
58       if (value->bit_size == 32) {
59          return nir_pack_64_2x32_split(b, value, zero);
60       } else {
61          assert(value->bit_size == 64);
62          return value;
63       }
64    }
65 }
66 
67 static nir_ssa_def *
lower_read_invocation_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin)68 lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
69 {
70    /* This is safe to call on scalar things but it would be silly */
71    assert(intrin->dest.ssa.num_components > 1);
72 
73    nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0],
74                                            intrin->num_components);
75    nir_ssa_def *reads[4];
76 
77    for (unsigned i = 0; i < intrin->num_components; i++) {
78       nir_intrinsic_instr *chan_intrin =
79          nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
80       nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
81                         1, intrin->dest.ssa.bit_size, NULL);
82       chan_intrin->num_components = 1;
83 
84       /* value */
85       chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
86       /* invocation */
87       if (intrin->intrinsic == nir_intrinsic_read_invocation)
88          nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
89 
90       nir_builder_instr_insert(b, &chan_intrin->instr);
91 
92       reads[i] = &chan_intrin->dest.ssa;
93    }
94 
95    return nir_vec(b, reads, intrin->num_components);
96 }
97 
98 static nir_ssa_def *
lower_subgroups_intrin(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)99 lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
100                        const nir_lower_subgroups_options *options)
101 {
102    switch (intrin->intrinsic) {
103    case nir_intrinsic_vote_any:
104    case nir_intrinsic_vote_all:
105       if (options->lower_vote_trivial)
106          return nir_ssa_for_src(b, intrin->src[0], 1);
107       break;
108 
109    case nir_intrinsic_vote_eq:
110       if (options->lower_vote_trivial)
111          return nir_imm_int(b, NIR_TRUE);
112       break;
113 
114    case nir_intrinsic_load_subgroup_size:
115       if (options->subgroup_size)
116          return nir_imm_int(b, options->subgroup_size);
117       break;
118 
119    case nir_intrinsic_read_invocation:
120    case nir_intrinsic_read_first_invocation:
121       if (options->lower_to_scalar && intrin->num_components > 1)
122          return lower_read_invocation_to_scalar(b, intrin);
123       break;
124 
125    case nir_intrinsic_load_subgroup_eq_mask:
126    case nir_intrinsic_load_subgroup_ge_mask:
127    case nir_intrinsic_load_subgroup_gt_mask:
128    case nir_intrinsic_load_subgroup_le_mask:
129    case nir_intrinsic_load_subgroup_lt_mask: {
130       if (!options->lower_subgroup_masks)
131          return NULL;
132 
133       /* If either the result or the requested bit size is 64-bits then we
134        * know that we have 64-bit types and using them will probably be more
135        * efficient than messing around with 32-bit shifts and packing.
136        */
137       const unsigned bit_size = MAX2(options->ballot_bit_size,
138                                      intrin->dest.ssa.bit_size);
139 
140       assert(options->subgroup_size <= 64);
141       uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
142 
143       nir_ssa_def *count = nir_load_subgroup_invocation(b);
144       nir_ssa_def *val;
145       switch (intrin->intrinsic) {
146       case nir_intrinsic_load_subgroup_eq_mask:
147          val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
148          break;
149       case nir_intrinsic_load_subgroup_ge_mask:
150          val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
151                            nir_imm_intN_t(b, group_mask, bit_size));
152          break;
153       case nir_intrinsic_load_subgroup_gt_mask:
154          val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
155                            nir_imm_intN_t(b, group_mask, bit_size));
156          break;
157       case nir_intrinsic_load_subgroup_le_mask:
158          val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
159          break;
160       case nir_intrinsic_load_subgroup_lt_mask:
161          val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
162          break;
163       default:
164          unreachable("you seriously can't tell this is unreachable?");
165       }
166 
167       return uint_to_ballot_type(b, val,
168                                  intrin->dest.ssa.num_components,
169                                  intrin->dest.ssa.bit_size);
170    }
171 
172    case nir_intrinsic_ballot: {
173       if (intrin->dest.ssa.num_components == 1 &&
174           intrin->dest.ssa.bit_size == options->ballot_bit_size)
175          return NULL;
176 
177       nir_intrinsic_instr *ballot =
178          nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
179       ballot->num_components = 1;
180       nir_ssa_dest_init(&ballot->instr, &ballot->dest,
181                         1, options->ballot_bit_size, NULL);
182       nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
183       nir_builder_instr_insert(b, &ballot->instr);
184 
185       return uint_to_ballot_type(b, &ballot->dest.ssa,
186                                  intrin->dest.ssa.num_components,
187                                  intrin->dest.ssa.bit_size);
188    }
189 
190    default:
191       break;
192    }
193 
194    return NULL;
195 }
196 
197 static bool
lower_subgroups_impl(nir_function_impl * impl,const nir_lower_subgroups_options * options)198 lower_subgroups_impl(nir_function_impl *impl,
199                      const nir_lower_subgroups_options *options)
200 {
201    nir_builder b;
202    nir_builder_init(&b, impl);
203    bool progress = false;
204 
205    nir_foreach_block(block, impl) {
206       nir_foreach_instr_safe(instr, block) {
207          if (instr->type != nir_instr_type_intrinsic)
208             continue;
209 
210          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
211          b.cursor = nir_before_instr(instr);
212 
213          nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options);
214          if (!lower)
215             continue;
216 
217          nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower));
218          nir_instr_remove(instr);
219          progress = true;
220       }
221    }
222 
223    return progress;
224 }
225 
226 bool
nir_lower_subgroups(nir_shader * shader,const nir_lower_subgroups_options * options)227 nir_lower_subgroups(nir_shader *shader,
228                     const nir_lower_subgroups_options *options)
229 {
230    bool progress = false;
231 
232    nir_foreach_function(function, shader) {
233       if (!function->impl)
234          continue;
235 
236       if (lower_subgroups_impl(function->impl, options)) {
237          progress = true;
238          nir_metadata_preserve(function->impl, nir_metadata_block_index |
239                                                nir_metadata_dominance);
240       }
241    }
242 
243    return progress;
244 }
245