• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2020 Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "brw_kernel.h"
25 #include "brw_nir.h"
26 
27 #include "compiler/nir/nir_builder.h"
28 #include "compiler/spirv/nir_spirv.h"
29 #include "dev/intel_debug.h"
30 #include "util/u_atomic.h"
31 
32 static const nir_shader *
load_clc_shader(struct brw_compiler * compiler,struct disk_cache * disk_cache,const nir_shader_compiler_options * nir_options,const struct spirv_to_nir_options * spirv_options)33 load_clc_shader(struct brw_compiler *compiler, struct disk_cache *disk_cache,
34                 const nir_shader_compiler_options *nir_options,
35                 const struct spirv_to_nir_options *spirv_options)
36 {
37    if (compiler->clc_shader)
38       return compiler->clc_shader;
39 
40    nir_shader *nir =  nir_load_libclc_shader(64, disk_cache,
41                                              spirv_options, nir_options);
42    if (nir == NULL)
43       return NULL;
44 
45    const nir_shader *old_nir =
46       p_atomic_cmpxchg(&compiler->clc_shader, NULL, nir);
47    if (old_nir == NULL) {
48       /* We won the race */
49       return nir;
50    } else {
51       /* Someone else built the shader first */
52       ralloc_free(nir);
53       return old_nir;
54    }
55 }
56 
57 static void
builder_init_new_impl(nir_builder * b,nir_function * func)58 builder_init_new_impl(nir_builder *b, nir_function *func)
59 {
60    nir_function_impl *impl = nir_function_impl_create(func);
61    nir_builder_init(b, impl);
62    b->cursor = nir_before_cf_list(&impl->body);
63 }
64 
65 static void
implement_atomic_builtin(nir_function * func,nir_intrinsic_op op,enum glsl_base_type data_base_type,nir_variable_mode mode)66 implement_atomic_builtin(nir_function *func, nir_intrinsic_op op,
67                          enum glsl_base_type data_base_type,
68                          nir_variable_mode mode)
69 {
70    nir_builder b;
71    builder_init_new_impl(&b, func);
72 
73    const struct glsl_type *data_type = glsl_scalar_type(data_base_type);
74 
75    unsigned p = 0;
76 
77    nir_deref_instr *ret = NULL;
78    if (nir_intrinsic_infos[op].has_dest) {
79       ret = nir_build_deref_cast(&b, nir_load_param(&b, p++),
80                                  nir_var_function_temp, data_type, 0);
81    }
82 
83    nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b.shader, op);
84 
85    for (unsigned i = 0; i < nir_intrinsic_infos[op].num_srcs; i++) {
86       nir_ssa_def *src = nir_load_param(&b, p++);
87       if (i == 0) {
88          /* The first source is our deref */
89          assert(nir_intrinsic_infos[op].src_components[i] == -1);
90          src = &nir_build_deref_cast(&b, src, mode, data_type, 0)->dest.ssa;
91       }
92       atomic->src[i] = nir_src_for_ssa(src);
93    }
94 
95    if (nir_intrinsic_infos[op].has_dest) {
96       nir_ssa_dest_init_for_type(&atomic->instr, &atomic->dest,
97                                  data_type, NULL);
98    }
99 
100    nir_builder_instr_insert(&b, &atomic->instr);
101 
102    if (nir_intrinsic_infos[op].has_dest)
103       nir_store_deref(&b, ret, &atomic->dest.ssa, ~0);
104 }
105 
106 static void
implement_sub_group_ballot_builtin(nir_function * func)107 implement_sub_group_ballot_builtin(nir_function *func)
108 {
109    nir_builder b;
110    builder_init_new_impl(&b, func);
111 
112    nir_deref_instr *ret =
113       nir_build_deref_cast(&b, nir_load_param(&b, 0),
114                            nir_var_function_temp, glsl_uint_type(), 0);
115    nir_ssa_def *cond = nir_load_param(&b, 1);
116 
117    nir_intrinsic_instr *ballot =
118       nir_intrinsic_instr_create(b.shader, nir_intrinsic_ballot);
119    ballot->src[0] = nir_src_for_ssa(cond);
120    ballot->num_components = 1;
121    nir_ssa_dest_init(&ballot->instr, &ballot->dest, 1, 32, NULL);
122    nir_builder_instr_insert(&b, &ballot->instr);
123 
124    nir_store_deref(&b, ret, &ballot->dest.ssa, ~0);
125 }
126 
127 static bool
implement_intel_builtins(nir_shader * nir)128 implement_intel_builtins(nir_shader *nir)
129 {
130    bool progress = false;
131 
132    nir_foreach_function(func, nir) {
133       if (strcmp(func->name, "_Z10atomic_minPU3AS1Vff") == 0) {
134          /* float atom_min(__global float volatile *p, float val) */
135          implement_atomic_builtin(func, nir_intrinsic_deref_atomic_fmin,
136                                   GLSL_TYPE_FLOAT, nir_var_mem_global);
137          progress = true;
138       } else if (strcmp(func->name, "_Z10atomic_maxPU3AS1Vff") == 0) {
139          /* float atom_max(__global float volatile *p, float val) */
140          implement_atomic_builtin(func, nir_intrinsic_deref_atomic_fmax,
141                                   GLSL_TYPE_FLOAT, nir_var_mem_global);
142          progress = true;
143       } else if (strcmp(func->name, "_Z10atomic_minPU3AS3Vff") == 0) {
144          /* float atomic_min(__shared float volatile *, float) */
145          implement_atomic_builtin(func, nir_intrinsic_deref_atomic_fmin,
146                                   GLSL_TYPE_FLOAT, nir_var_mem_shared);
147          progress = true;
148       } else if (strcmp(func->name, "_Z10atomic_maxPU3AS3Vff") == 0) {
149          /* float atomic_max(__shared float volatile *, float) */
150          implement_atomic_builtin(func, nir_intrinsic_deref_atomic_fmax,
151                                   GLSL_TYPE_FLOAT, nir_var_mem_shared);
152          progress = true;
153       } else if (strcmp(func->name, "intel_sub_group_ballot") == 0) {
154          implement_sub_group_ballot_builtin(func);
155          progress = true;
156       }
157    }
158 
159    nir_shader_preserve_all_metadata(nir);
160 
161    return progress;
162 }
163 
164 static bool
lower_kernel_intrinsics(nir_shader * nir)165 lower_kernel_intrinsics(nir_shader *nir)
166 {
167    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
168 
169    bool progress = false;
170 
171    unsigned kernel_sysvals_start = 0;
172    unsigned kernel_arg_start = sizeof(struct brw_kernel_sysvals);
173    nir->num_uniforms += kernel_arg_start;
174 
175    nir_builder b;
176    nir_builder_init(&b, impl);
177 
178    nir_foreach_block(block, impl) {
179       nir_foreach_instr_safe(instr, block) {
180          if (instr->type != nir_instr_type_intrinsic)
181             continue;
182 
183          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
184          switch (intrin->intrinsic) {
185          case nir_intrinsic_load_kernel_input: {
186             b.cursor = nir_instr_remove(&intrin->instr);
187 
188             nir_intrinsic_instr *load =
189                nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
190             load->num_components = intrin->num_components;
191             load->src[0] = nir_src_for_ssa(nir_u2u32(&b, intrin->src[0].ssa));
192             nir_intrinsic_set_base(load, kernel_arg_start);
193             nir_intrinsic_set_range(load, nir->num_uniforms);
194             nir_ssa_dest_init(&load->instr, &load->dest,
195                               intrin->dest.ssa.num_components,
196                               intrin->dest.ssa.bit_size, NULL);
197             nir_builder_instr_insert(&b, &load->instr);
198 
199             nir_ssa_def_rewrite_uses(&intrin->dest.ssa, &load->dest.ssa);
200             progress = true;
201             break;
202          }
203 
204          case nir_intrinsic_load_constant_base_ptr: {
205             b.cursor = nir_instr_remove(&intrin->instr);
206             nir_ssa_def *const_data_base_addr = nir_pack_64_2x32_split(&b,
207                nir_load_reloc_const_intel(&b, BRW_SHADER_RELOC_CONST_DATA_ADDR_LOW),
208                nir_load_reloc_const_intel(&b, BRW_SHADER_RELOC_CONST_DATA_ADDR_HIGH));
209             nir_ssa_def_rewrite_uses(&intrin->dest.ssa, const_data_base_addr);
210             progress = true;
211             break;
212          }
213 
214          case nir_intrinsic_load_num_workgroups: {
215             b.cursor = nir_instr_remove(&intrin->instr);
216 
217             nir_intrinsic_instr *load =
218                nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
219             load->num_components = 3;
220             load->src[0] = nir_src_for_ssa(nir_imm_int(&b, 0));
221             nir_intrinsic_set_base(load, kernel_sysvals_start +
222                offsetof(struct brw_kernel_sysvals, num_work_groups));
223             nir_intrinsic_set_range(load, 3 * 4);
224             nir_ssa_dest_init(&load->instr, &load->dest, 3, 32, NULL);
225             nir_builder_instr_insert(&b, &load->instr);
226 
227             /* We may need to do a bit-size cast here */
228             nir_ssa_def *num_work_groups =
229                nir_u2u(&b, &load->dest.ssa, intrin->dest.ssa.bit_size);
230 
231             nir_ssa_def_rewrite_uses(&intrin->dest.ssa, num_work_groups);
232             progress = true;
233             break;
234          }
235 
236          default:
237             break;
238          }
239       }
240    }
241 
242    if (progress) {
243       nir_metadata_preserve(impl, nir_metadata_block_index |
244                                   nir_metadata_dominance);
245    } else {
246       nir_metadata_preserve(impl, nir_metadata_all);
247    }
248 
249    return progress;
250 }
251 
252 bool
brw_kernel_from_spirv(struct brw_compiler * compiler,struct disk_cache * disk_cache,struct brw_kernel * kernel,void * log_data,void * mem_ctx,const uint32_t * spirv,size_t spirv_size,const char * entrypoint_name,char ** error_str)253 brw_kernel_from_spirv(struct brw_compiler *compiler,
254                       struct disk_cache *disk_cache,
255                       struct brw_kernel *kernel,
256                       void *log_data, void *mem_ctx,
257                       const uint32_t *spirv, size_t spirv_size,
258                       const char *entrypoint_name,
259                       char **error_str)
260 {
261    const struct intel_device_info *devinfo = compiler->devinfo;
262    const nir_shader_compiler_options *nir_options =
263       compiler->nir_options[MESA_SHADER_KERNEL];
264 
265    struct spirv_to_nir_options spirv_options = {
266       .environment = NIR_SPIRV_OPENCL,
267       .caps = {
268          .address = true,
269          .float16 = devinfo->ver >= 8,
270          .float64 = devinfo->ver >= 8,
271          .groups = true,
272          .image_write_without_format = true,
273          .int8 = devinfo->ver >= 8,
274          .int16 = devinfo->ver >= 8,
275          .int64 = devinfo->ver >= 8,
276          .int64_atomics = devinfo->ver >= 9,
277          .kernel = true,
278          .linkage = true, /* We receive linked kernel from clc */
279          .float_controls = devinfo->ver >= 8,
280          .generic_pointers = true,
281          .storage_8bit = devinfo->ver >= 8,
282          .storage_16bit = devinfo->ver >= 8,
283          .subgroup_arithmetic = true,
284          .subgroup_basic = true,
285          .subgroup_ballot = true,
286          .subgroup_dispatch = true,
287          .subgroup_quad = true,
288          .subgroup_shuffle = true,
289          .subgroup_vote = true,
290 
291          .intel_subgroup_shuffle = true,
292          .intel_subgroup_buffer_block_io = true,
293       },
294       .shared_addr_format = nir_address_format_62bit_generic,
295       .global_addr_format = nir_address_format_62bit_generic,
296       .temp_addr_format = nir_address_format_62bit_generic,
297       .constant_addr_format = nir_address_format_64bit_global,
298    };
299 
300    spirv_options.clc_shader = load_clc_shader(compiler, disk_cache,
301                                               nir_options, &spirv_options);
302 
303    assert(spirv_size % 4 == 0);
304    nir_shader *nir =
305       spirv_to_nir(spirv, spirv_size / 4, NULL, 0, MESA_SHADER_KERNEL,
306                    entrypoint_name, &spirv_options, nir_options);
307    nir_validate_shader(nir, "after spirv_to_nir");
308    nir_validate_ssa_dominance(nir, "after spirv_to_nir");
309    ralloc_steal(mem_ctx, nir);
310    nir->info.name = ralloc_strdup(nir, entrypoint_name);
311 
312    if (INTEL_DEBUG(DEBUG_CS)) {
313       /* Re-index SSA defs so we print more sensible numbers. */
314       nir_foreach_function(function, nir) {
315          if (function->impl)
316             nir_index_ssa_defs(function->impl);
317       }
318 
319       fprintf(stderr, "NIR (from SPIR-V) for kernel\n");
320       nir_print_shader(nir, stderr);
321    }
322 
323    NIR_PASS_V(nir, implement_intel_builtins);
324    NIR_PASS_V(nir, nir_lower_libclc, spirv_options.clc_shader);
325 
326    /* We have to lower away local constant initializers right before we
327     * inline functions.  That way they get properly initialized at the top
328     * of the function and not at the top of its caller.
329     */
330    NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
331    NIR_PASS_V(nir, nir_lower_returns);
332    NIR_PASS_V(nir, nir_inline_functions);
333    NIR_PASS_V(nir, nir_copy_prop);
334    NIR_PASS_V(nir, nir_opt_deref);
335 
336    /* Pick off the single entrypoint that we want */
337    nir_remove_non_entrypoints(nir);
338 
339    /* Now that we've deleted all but the main function, we can go ahead and
340     * lower the rest of the constant initializers.  We do this here so that
341     * nir_remove_dead_variables and split_per_member_structs below see the
342     * corresponding stores.
343     */
344    NIR_PASS_V(nir, nir_lower_variable_initializers, ~0);
345 
346    /* LLVM loves take advantage of the fact that vec3s in OpenCL are 16B
347     * aligned and so it can just read/write them as vec4s.  This results in a
348     * LOT of vec4->vec3 casts on loads and stores.  One solution to this
349     * problem is to get rid of all vec3 variables.
350     */
351    NIR_PASS_V(nir, nir_lower_vec3_to_vec4,
352               nir_var_shader_temp | nir_var_function_temp |
353               nir_var_mem_shared | nir_var_mem_global|
354               nir_var_mem_constant);
355 
356    /* We assign explicit types early so that the optimizer can take advantage
357     * of that information and hopefully get rid of some of our memcpys.
358     */
359    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
360               nir_var_uniform |
361               nir_var_shader_temp | nir_var_function_temp |
362               nir_var_mem_shared | nir_var_mem_global,
363               glsl_get_cl_type_size_align);
364 
365    brw_preprocess_nir(compiler, nir, NULL);
366 
367    int max_arg_idx = -1;
368    nir_foreach_uniform_variable(var, nir) {
369       assert(var->data.location < 256);
370       max_arg_idx = MAX2(max_arg_idx, var->data.location);
371    }
372 
373    kernel->args_size = nir->num_uniforms;
374    kernel->arg_count = max_arg_idx + 1;
375 
376    /* No bindings */
377    struct brw_kernel_arg_desc *args =
378       rzalloc_array(mem_ctx, struct brw_kernel_arg_desc, kernel->arg_count);
379    kernel->args = args;
380 
381    nir_foreach_uniform_variable(var, nir) {
382       struct brw_kernel_arg_desc arg_desc = {
383          .offset = var->data.driver_location,
384          .size = glsl_get_explicit_size(var->type, false),
385       };
386       assert(arg_desc.offset + arg_desc.size <= nir->num_uniforms);
387 
388       assert(var->data.location >= 0);
389       args[var->data.location] = arg_desc;
390    }
391 
392    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_all, NULL);
393 
394    /* Lower again, this time after dead-variables to get more compact variable
395     * layouts.
396     */
397    nir->global_mem_size = 0;
398    nir->scratch_size = 0;
399    nir->info.shared_size = 0;
400    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
401               nir_var_shader_temp | nir_var_function_temp |
402               nir_var_mem_shared | nir_var_mem_global | nir_var_mem_constant,
403               glsl_get_cl_type_size_align);
404    if (nir->constant_data_size > 0) {
405       assert(nir->constant_data == NULL);
406       nir->constant_data = rzalloc_size(nir, nir->constant_data_size);
407       nir_gather_explicit_io_initializers(nir, nir->constant_data,
408                                           nir->constant_data_size,
409                                           nir_var_mem_constant);
410    }
411 
412    if (INTEL_DEBUG(DEBUG_CS)) {
413       /* Re-index SSA defs so we print more sensible numbers. */
414       nir_foreach_function(function, nir) {
415          if (function->impl)
416             nir_index_ssa_defs(function->impl);
417       }
418 
419       fprintf(stderr, "NIR (before I/O lowering) for kernel\n");
420       nir_print_shader(nir, stderr);
421    }
422 
423    NIR_PASS_V(nir, nir_lower_memcpy);
424 
425    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_constant,
426               nir_address_format_64bit_global);
427 
428    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_uniform,
429               nir_address_format_32bit_offset_as_64bit);
430 
431    NIR_PASS_V(nir, nir_lower_explicit_io,
432               nir_var_shader_temp | nir_var_function_temp |
433               nir_var_mem_shared | nir_var_mem_global,
434               nir_address_format_62bit_generic);
435 
436    NIR_PASS_V(nir, nir_lower_frexp);
437    NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
438 
439    NIR_PASS_V(nir, brw_nir_lower_cs_intrinsics);
440    NIR_PASS_V(nir, lower_kernel_intrinsics);
441 
442    struct brw_cs_prog_key key = { };
443 
444    memset(&kernel->prog_data, 0, sizeof(kernel->prog_data));
445    kernel->prog_data.base.nr_params = DIV_ROUND_UP(nir->num_uniforms, 4);
446 
447    struct brw_compile_cs_params params = {
448       .nir = nir,
449       .key = &key,
450       .prog_data = &kernel->prog_data,
451       .stats = kernel->stats,
452       .log_data = log_data,
453    };
454 
455    kernel->code = brw_compile_cs(compiler, mem_ctx, &params);
456 
457    if (error_str)
458       *error_str = params.error_str;
459 
460    return kernel->code != NULL;
461 }
462