• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "u_math.h"
25 #include "nir.h"
26 #include "glsl_types.h"
27 #include "nir_types.h"
28 #include "nir_builder.h"
29 
30 #include "clc_nir.h"
31 #include "clc_compiler.h"
32 #include "../compiler/dxil_nir.h"
33 
34 static bool
lower_load_base_global_invocation_id(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)35 lower_load_base_global_invocation_id(nir_builder *b, nir_intrinsic_instr *intr,
36                                     nir_variable *var)
37 {
38    b->cursor = nir_after_instr(&intr->instr);
39 
40    nir_ssa_def *offset =
41       build_load_ubo_dxil(b, nir_imm_int(b, var->data.binding),
42                           nir_imm_int(b,
43                                       offsetof(struct clc_work_properties_data,
44                                                global_offset_x)),
45                           nir_dest_num_components(intr->dest),
46                           nir_dest_bit_size(intr->dest));
47    nir_ssa_def_rewrite_uses(&intr->dest.ssa, offset);
48    nir_instr_remove(&intr->instr);
49    return true;
50 }
51 
52 static bool
lower_load_work_dim(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)53 lower_load_work_dim(nir_builder *b, nir_intrinsic_instr *intr,
54                     nir_variable *var)
55 {
56    b->cursor = nir_after_instr(&intr->instr);
57 
58    nir_ssa_def *dim =
59       build_load_ubo_dxil(b, nir_imm_int(b, var->data.binding),
60                           nir_imm_int(b,
61                                       offsetof(struct clc_work_properties_data,
62                                                work_dim)),
63                           nir_dest_num_components(intr->dest),
64                           nir_dest_bit_size(intr->dest));
65    nir_ssa_def_rewrite_uses(&intr->dest.ssa, dim);
66    nir_instr_remove(&intr->instr);
67    return true;
68 }
69 
70 static bool
lower_load_num_workgroups(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)71 lower_load_num_workgroups(nir_builder *b, nir_intrinsic_instr *intr,
72                           nir_variable *var)
73 {
74    b->cursor = nir_after_instr(&intr->instr);
75 
76    nir_ssa_def *count =
77       build_load_ubo_dxil(b, nir_imm_int(b, var->data.binding),
78                          nir_imm_int(b,
79                                      offsetof(struct clc_work_properties_data,
80                                               group_count_total_x)),
81                          nir_dest_num_components(intr->dest),
82                          nir_dest_bit_size(intr->dest));
83    nir_ssa_def_rewrite_uses(&intr->dest.ssa, count);
84    nir_instr_remove(&intr->instr);
85    return true;
86 }
87 
88 static bool
lower_load_base_workgroup_id(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)89 lower_load_base_workgroup_id(nir_builder *b, nir_intrinsic_instr *intr,
90                              nir_variable *var)
91 {
92    b->cursor = nir_after_instr(&intr->instr);
93 
94    nir_ssa_def *offset =
95       build_load_ubo_dxil(b, nir_imm_int(b, var->data.binding),
96                          nir_imm_int(b,
97                                      offsetof(struct clc_work_properties_data,
98                                               group_id_offset_x)),
99                          nir_dest_num_components(intr->dest),
100                          nir_dest_bit_size(intr->dest));
101    nir_ssa_def_rewrite_uses(&intr->dest.ssa, offset);
102    nir_instr_remove(&intr->instr);
103    return true;
104 }
105 
106 bool
clc_nir_lower_system_values(nir_shader * nir,nir_variable * var)107 clc_nir_lower_system_values(nir_shader *nir, nir_variable *var)
108 {
109    bool progress = false;
110 
111    foreach_list_typed(nir_function, func, node, &nir->functions) {
112       if (!func->is_entrypoint)
113          continue;
114       assert(func->impl);
115 
116       nir_builder b;
117       nir_builder_init(&b, func->impl);
118 
119       nir_foreach_block(block, func->impl) {
120          nir_foreach_instr_safe(instr, block) {
121             if (instr->type != nir_instr_type_intrinsic)
122                continue;
123 
124             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
125 
126             switch (intr->intrinsic) {
127             case nir_intrinsic_load_base_global_invocation_id:
128                progress |= lower_load_base_global_invocation_id(&b, intr, var);
129                break;
130             case nir_intrinsic_load_work_dim:
131                progress |= lower_load_work_dim(&b, intr, var);
132                break;
133             case nir_intrinsic_load_num_workgroups:
134                lower_load_num_workgroups(&b, intr, var);
135                break;
136             case nir_intrinsic_load_base_workgroup_id:
137                lower_load_base_workgroup_id(&b, intr, var);
138                break;
139             default: break;
140             }
141          }
142       }
143    }
144 
145    return progress;
146 }
147 
148 static bool
lower_load_kernel_input(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)149 lower_load_kernel_input(nir_builder *b, nir_intrinsic_instr *intr,
150                         nir_variable *var)
151 {
152    b->cursor = nir_before_instr(&intr->instr);
153 
154    unsigned bit_size = nir_dest_bit_size(intr->dest);
155    enum glsl_base_type base_type;
156 
157    switch (bit_size) {
158    case 64:
159       base_type = GLSL_TYPE_UINT64;
160       break;
161    case 32:
162       base_type = GLSL_TYPE_UINT;
163       break;
164     case 16:
165       base_type = GLSL_TYPE_UINT16;
166       break;
167     case 8:
168       base_type = GLSL_TYPE_UINT8;
169       break;
170    }
171 
172    const struct glsl_type *type =
173       glsl_vector_type(base_type, nir_dest_num_components(intr->dest));
174    nir_ssa_def *ptr = nir_vec2(b, nir_imm_int(b, var->data.binding),
175                                   nir_u2u(b, intr->src[0].ssa, 32));
176    nir_deref_instr *deref = nir_build_deref_cast(b, ptr, nir_var_mem_ubo, type,
177                                                     bit_size / 8);
178    deref->cast.align_mul = nir_intrinsic_align_mul(intr);
179    deref->cast.align_offset = nir_intrinsic_align_offset(intr);
180 
181    nir_ssa_def *result =
182       nir_load_deref(b, deref);
183    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
184    nir_instr_remove(&intr->instr);
185    return true;
186 }
187 
188 bool
clc_nir_lower_kernel_input_loads(nir_shader * nir,nir_variable * var)189 clc_nir_lower_kernel_input_loads(nir_shader *nir, nir_variable *var)
190 {
191    bool progress = false;
192 
193    foreach_list_typed(nir_function, func, node, &nir->functions) {
194       if (!func->is_entrypoint)
195          continue;
196       assert(func->impl);
197 
198       nir_builder b;
199       nir_builder_init(&b, func->impl);
200 
201       nir_foreach_block(block, func->impl) {
202          nir_foreach_instr_safe(instr, block) {
203             if (instr->type != nir_instr_type_intrinsic)
204                continue;
205 
206             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
207 
208             if (intr->intrinsic == nir_intrinsic_load_kernel_input)
209                progress |= lower_load_kernel_input(&b, intr, var);
210          }
211       }
212    }
213 
214    return progress;
215 }
216 
217 
218 static nir_variable *
add_printf_var(struct nir_shader * nir,unsigned uav_id)219 add_printf_var(struct nir_shader *nir, unsigned uav_id)
220 {
221    /* This size is arbitrary. Minimum required per spec is 1MB */
222    const unsigned max_printf_size = 1 * 1024 * 1024;
223    const unsigned printf_array_size = max_printf_size / sizeof(unsigned);
224    nir_variable *var =
225       nir_variable_create(nir, nir_var_mem_ssbo,
226                           glsl_array_type(glsl_uint_type(), printf_array_size, sizeof(unsigned)),
227                           "printf");
228    var->data.binding = uav_id;
229    return var;
230 }
231 
232 bool
clc_lower_printf_base(nir_shader * nir,unsigned uav_id)233 clc_lower_printf_base(nir_shader *nir, unsigned uav_id)
234 {
235    nir_variable *printf_var = NULL;
236    nir_ssa_def *printf_deref = NULL;
237    nir_foreach_function(func, nir) {
238       nir_builder b;
239       nir_builder_init(&b, func->impl);
240       b.cursor = nir_before_instr(nir_block_first_instr(nir_start_block(func->impl)));
241       bool progress = false;
242 
243       nir_foreach_block(block, func->impl) {
244          nir_foreach_instr_safe(instr, block) {
245             if (instr->type != nir_instr_type_intrinsic)
246                continue;
247             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
248             if (intrin->intrinsic != nir_intrinsic_load_printf_buffer_address)
249                continue;
250 
251             if (!printf_var) {
252                printf_var = add_printf_var(nir, uav_id);
253                nir_deref_instr *deref = nir_build_deref_var(&b, printf_var);
254                printf_deref = &deref->dest.ssa;
255             }
256             nir_ssa_def_rewrite_uses(&intrin->dest.ssa, printf_deref);
257             progress = true;
258          }
259       }
260 
261       if (progress)
262          nir_metadata_preserve(func->impl, nir_metadata_loop_analysis |
263                                            nir_metadata_block_index |
264                                            nir_metadata_dominance);
265       else
266          nir_metadata_preserve(func->impl, nir_metadata_all);
267    }
268 
269    return printf_var != NULL;
270 }
271 
272 static nir_variable *
find_identical_const_sampler(nir_shader * nir,nir_variable * sampler)273 find_identical_const_sampler(nir_shader *nir, nir_variable *sampler)
274 {
275    nir_foreach_variable_with_modes(uniform, nir, nir_var_uniform) {
276       if (!glsl_type_is_sampler(uniform->type) || !uniform->data.sampler.is_inline_sampler)
277          continue;
278       if (uniform->data.sampler.addressing_mode == sampler->data.sampler.addressing_mode &&
279           uniform->data.sampler.normalized_coordinates == sampler->data.sampler.normalized_coordinates &&
280           uniform->data.sampler.filter_mode == sampler->data.sampler.filter_mode)
281          return uniform;
282    }
283    unreachable("Should have at least found the input sampler");
284 }
285 
286 static bool
clc_nir_dedupe_const_samplers_instr(nir_builder * b,nir_instr * instr,void * cb_data)287 clc_nir_dedupe_const_samplers_instr(nir_builder *b,
288                                     nir_instr *instr,
289                                     void *cb_data)
290 {
291    nir_shader *nir = cb_data;
292    if (instr->type != nir_instr_type_tex)
293       return false;
294 
295    nir_tex_instr *tex = nir_instr_as_tex(instr);
296    int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
297    if (sampler_idx == -1)
298       return false;
299 
300    nir_deref_instr *deref = nir_src_as_deref(tex->src[sampler_idx].src);
301    nir_variable *sampler = nir_deref_instr_get_variable(deref);
302    if (!sampler)
303       return false;
304 
305    assert(sampler->data.mode == nir_var_uniform);
306 
307    if (!sampler->data.sampler.is_inline_sampler)
308       return false;
309 
310    nir_variable *replacement = find_identical_const_sampler(nir, sampler);
311    if (replacement == sampler)
312       return false;
313 
314    b->cursor = nir_before_instr(&tex->instr);
315    nir_deref_instr *replacement_deref = nir_build_deref_var(b, replacement);
316    nir_instr_rewrite_src(&tex->instr, &tex->src[sampler_idx].src,
317                          nir_src_for_ssa(&replacement_deref->dest.ssa));
318    nir_deref_instr_remove_if_unused(deref);
319 
320    return true;
321 }
322 
323 bool
clc_nir_dedupe_const_samplers(nir_shader * nir)324 clc_nir_dedupe_const_samplers(nir_shader *nir)
325 {
326    return nir_shader_instructions_pass(nir,
327                                        clc_nir_dedupe_const_samplers_instr,
328                                        nir_metadata_block_index |
329                                        nir_metadata_dominance,
330                                        nir);
331 }
332