• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2023 Alyssa Rosenzweig
3  * Copyright 2020 Intel Corporation
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "asahi/compiler/agx_compile.h"
8 #include "compiler/clc/clc.h"
9 #include "compiler/glsl_types.h"
10 #include "compiler/spirv/nir_spirv.h"
11 #include "util/build_id.h"
12 #include "util/disk_cache.h"
13 #include "util/macros.h"
14 #include "util/mesa-sha1.h"
15 #include "util/u_dynarray.h"
16 #include "nir.h"
17 #include "nir_builder.h"
18 #include "nir_serialize.h"
19 
20 #include <fcntl.h>
21 #include <getopt.h>
22 #include <inttypes.h>
23 #include <stdio.h>
24 #include <string.h>
25 #include <unistd.h>
26 #include <sys/mman.h>
27 
28 struct spirv_to_nir_options spirv_options = {
29    .environment = NIR_SPIRV_OPENCL,
30    .caps =
31       {
32          .address = true,
33          .float16 = true,
34          .float64 = true,
35          .groups = true,
36          .image_write_without_format = true,
37          .int8 = true,
38          .int16 = true,
39          .int64 = true,
40          .int64_atomics = false,
41          .kernel = true,
42          .linkage = true,
43          .float_controls = true,
44          .generic_pointers = true,
45          .storage_8bit = true,
46          .storage_16bit = true,
47          .subgroup_arithmetic = true,
48          .subgroup_basic = true,
49          .subgroup_ballot = true,
50          .subgroup_dispatch = true,
51          .subgroup_quad = true,
52          .subgroup_shuffle = true,
53          .subgroup_vote = true,
54 
55          .intel_subgroup_shuffle = true,
56          .intel_subgroup_buffer_block_io = true,
57       },
58    .shared_addr_format = nir_address_format_62bit_generic,
59    .global_addr_format = nir_address_format_62bit_generic,
60    .temp_addr_format = nir_address_format_62bit_generic,
61    .constant_addr_format = nir_address_format_64bit_global,
62    .create_library = true,
63 };
64 
65 static bool
lower_builtins(nir_builder * b,nir_instr * instr,void * data)66 lower_builtins(nir_builder *b, nir_instr *instr, void *data)
67 {
68    if (instr->type != nir_instr_type_call)
69       return false;
70 
71    nir_call_instr *call = nir_instr_as_call(instr);
72    nir_function *func = call->callee;
73 
74    if (strcmp(func->name, "nir_interleave_agx") == 0) {
75       b->cursor = nir_instr_remove(&call->instr);
76       nir_store_deref(
77          b, nir_src_as_deref(call->params[0]),
78          nir_interleave_agx(b, call->params[1].ssa, call->params[2].ssa), 1);
79 
80       return true;
81    } else if (strcmp(func->name, "nir_doorbell_agx") == 0) {
82       b->cursor = nir_instr_remove(&call->instr);
83       nir_doorbell_agx(b, call->params[0].ssa);
84       return true;
85    } else if (strcmp(func->name, "nir_stack_map_agx") == 0) {
86       b->cursor = nir_instr_remove(&call->instr);
87       nir_stack_map_agx(b, call->params[0].ssa, call->params[1].ssa);
88       return true;
89    } else if (strcmp(func->name, "nir_stack_unmap_agx") == 0) {
90       b->cursor = nir_instr_remove(&call->instr);
91       nir_store_deref(b, nir_src_as_deref(call->params[0]),
92                       nir_stack_unmap_agx(b, call->params[1].ssa), 1);
93       return true;
94    } else if (strcmp(func->name, "nir_load_core_id_agx") == 0) {
95       b->cursor = nir_instr_remove(&call->instr);
96       nir_store_deref(b, nir_src_as_deref(call->params[0]),
97                       nir_load_core_id_agx(b), 1);
98       return true;
99    } else if (strcmp(func->name, "nir_load_helper_op_id_agx") == 0) {
100       b->cursor = nir_instr_remove(&call->instr);
101       nir_store_deref(b, nir_src_as_deref(call->params[0]),
102                       nir_load_helper_op_id_agx(b, 1, 32), 1);
103       return true;
104    } else if (strcmp(func->name, "nir_load_helper_arg_lo_agx") == 0) {
105       b->cursor = nir_instr_remove(&call->instr);
106       nir_store_deref(b, nir_src_as_deref(call->params[0]),
107                       nir_load_helper_arg_lo_agx(b, 1, 32), 1);
108       return true;
109    } else if (strcmp(func->name, "nir_load_helper_arg_hi_agx") == 0) {
110       b->cursor = nir_instr_remove(&call->instr);
111       nir_store_deref(b, nir_src_as_deref(call->params[0]),
112                       nir_load_helper_arg_hi_agx(b, 1, 32), 1);
113       return true;
114    } else if (strcmp(func->name, "nir_fence_helper_exit_agx") == 0) {
115       b->cursor = nir_instr_remove(&call->instr);
116       nir_fence_helper_exit_agx(b);
117       return true;
118    }
119 
120    return false;
121 }
122 
123 /* Standard optimization loop */
124 static void
optimize(nir_shader * nir)125 optimize(nir_shader *nir)
126 {
127    bool progress;
128    do {
129       progress = false;
130 
131       NIR_PASS(progress, nir, nir_lower_var_copies);
132       NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
133 
134       NIR_PASS(progress, nir, nir_copy_prop);
135       NIR_PASS(progress, nir, nir_opt_remove_phis);
136       NIR_PASS(progress, nir, nir_lower_phis_to_scalar, true);
137       NIR_PASS(progress, nir, nir_opt_dce);
138       NIR_PASS(progress, nir, nir_opt_dead_cf);
139       NIR_PASS(progress, nir, nir_opt_cse);
140       NIR_PASS(progress, nir, nir_opt_peephole_select, 64, false, true);
141       NIR_PASS(progress, nir, nir_opt_phi_precision);
142       NIR_PASS(progress, nir, nir_opt_algebraic);
143       NIR_PASS(progress, nir, nir_opt_constant_folding);
144 
145       NIR_PASS(progress, nir, nir_opt_deref);
146       NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
147       NIR_PASS(progress, nir, nir_opt_undef);
148       NIR_PASS(progress, nir, nir_lower_undef_to_zero);
149 
150       NIR_PASS(progress, nir, nir_opt_shrink_vectors);
151       NIR_PASS(progress, nir, nir_opt_loop_unroll);
152 
153       NIR_PASS(progress, nir, nir_split_var_copies);
154       NIR_PASS(progress, nir, nir_split_struct_vars, nir_var_function_temp);
155    } while (progress);
156 }
157 
158 static nir_shader *
compile(void * memctx,const uint32_t * spirv,size_t spirv_size)159 compile(void *memctx, const uint32_t *spirv, size_t spirv_size)
160 {
161    const nir_shader_compiler_options *nir_options = &agx_nir_options;
162 
163    assert(spirv_size % 4 == 0);
164    nir_shader *nir =
165       spirv_to_nir(spirv, spirv_size / 4, NULL, 0, MESA_SHADER_KERNEL,
166                    "library", &spirv_options, nir_options);
167    nir_validate_shader(nir, "after spirv_to_nir");
168    nir_validate_ssa_dominance(nir, "after spirv_to_nir");
169    ralloc_steal(memctx, nir);
170 
171    NIR_PASS(_, nir, nir_lower_system_values);
172    nir_shader_instructions_pass(nir, lower_builtins, nir_metadata_none, NULL);
173 
174    /* We have to lower away local constant initializers right before we
175     * inline functions.  That way they get properly initialized at the top
176     * of the function and not at the top of its caller.
177     */
178    NIR_PASS(_, nir, nir_lower_variable_initializers, nir_var_function_temp);
179    NIR_PASS(_, nir, nir_lower_returns);
180    NIR_PASS(_, nir, nir_inline_functions);
181    nir_remove_non_exported(nir);
182    NIR_PASS(_, nir, nir_copy_prop);
183    NIR_PASS(_, nir, nir_opt_deref);
184 
185    /* We can go ahead and lower the rest of the constant initializers.  We do
186     * this here so that nir_remove_dead_variables and split_per_member_structs
187     * below see the corresponding stores.
188     */
189    NIR_PASS(_, nir, nir_lower_variable_initializers, ~0);
190 
191    /* LLVM loves take advantage of the fact that vec3s in OpenCL are 16B
192     * aligned and so it can just read/write them as vec4s.  This results in a
193     * LOT of vec4->vec3 casts on loads and stores.  One solution to this
194     * problem is to get rid of all vec3 variables.
195     */
196    NIR_PASS(_, nir, nir_lower_vec3_to_vec4,
197             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
198                nir_var_mem_global | nir_var_mem_constant);
199 
200    /* We assign explicit types early so that the optimizer can take advantage
201     * of that information and hopefully get rid of some of our memcpys.
202     */
203    NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
204             nir_var_uniform | nir_var_shader_temp | nir_var_function_temp |
205                nir_var_mem_shared | nir_var_mem_global,
206             glsl_get_cl_type_size_align);
207 
208    optimize(nir);
209 
210    NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_all, NULL);
211 
212    /* Lower again, this time after dead-variables to get more compact variable
213     * layouts.
214     */
215    NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
216             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
217                nir_var_mem_global | nir_var_mem_constant,
218             glsl_get_cl_type_size_align);
219    if (nir->constant_data_size > 0) {
220       assert(nir->constant_data == NULL);
221       nir->constant_data = rzalloc_size(nir, nir->constant_data_size);
222       nir_gather_explicit_io_initializers(nir, nir->constant_data,
223                                           nir->constant_data_size,
224                                           nir_var_mem_constant);
225    }
226 
227    NIR_PASS(_, nir, nir_lower_memcpy);
228 
229    NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_constant,
230             nir_address_format_64bit_global);
231 
232    NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_uniform,
233             nir_address_format_32bit_offset_as_64bit);
234 
235    /* Note: we cannot lower explicit I/O here, because we need derefs in tact
236     * for function calls into the library to work.
237     */
238 
239    NIR_PASS(_, nir, nir_lower_convert_alu_types, NULL);
240    NIR_PASS(_, nir, nir_opt_if, 0);
241    NIR_PASS(_, nir, nir_opt_idiv_const, 16);
242 
243    optimize(nir);
244 
245    return nir;
246 }
247 
248 /* Shader functions */
249 #define SPIR_V_MAGIC_NUMBER 0x07230203
250 
251 static void
msg_callback(void * priv,const char * msg)252 msg_callback(void *priv, const char *msg)
253 {
254    (void)priv;
255    fprintf(stderr, "%s", msg);
256 }
257 
258 static void
print_u32_data(FILE * fp,const char * prefix,const char * arr_name,const uint32_t * data,size_t len)259 print_u32_data(FILE *fp, const char *prefix, const char *arr_name,
260                const uint32_t *data, size_t len)
261 {
262    assert(len % 4 == 0);
263    fprintf(fp, "static const uint32_t %s_%s[] = {", prefix, arr_name);
264    for (unsigned i = 0; i < (len / 4); i++) {
265       if (i % 4 == 0)
266          fprintf(fp, "\n   ");
267 
268       fprintf(fp, " 0x%08" PRIx32 ",", data[i]);
269    }
270    fprintf(fp, "\n};\n");
271 }
272 
273 static void
print_usage(char * exec_name,FILE * f)274 print_usage(char *exec_name, FILE *f)
275 {
276    fprintf(
277       f,
278       "Usage: %s [options] -- [clang args]\n"
279       "Options:\n"
280       "  -h  --help              Print this help.\n"
281       "      --prefix <prefix>   Prefix for variable names in generated C code.\n"
282       "  -o, --out <filename>    Specify the output filename.\n"
283       "  -i, --in <filename>     Specify one input filename. Accepted multiple times.\n"
284       "  -s, --spv <filename>    Specify the output filename for spirv.\n"
285       "  -v, --verbose           Print more information during compilation.\n",
286       exec_name);
287 }
288 
289 #define OPT_PREFIX 1000
290 
291 static uint32_t
get_module_spirv_version(const uint32_t * spirv,size_t size)292 get_module_spirv_version(const uint32_t *spirv, size_t size)
293 {
294    assert(size >= 8);
295    assert(spirv[0] == SPIR_V_MAGIC_NUMBER);
296    return spirv[1];
297 }
298 
299 static void
set_module_spirv_version(uint32_t * spirv,size_t size,uint32_t version)300 set_module_spirv_version(uint32_t *spirv, size_t size, uint32_t version)
301 {
302    assert(size >= 8);
303    assert(spirv[0] == SPIR_V_MAGIC_NUMBER);
304    spirv[1] = version;
305 }
306 
307 int
main(int argc,char ** argv)308 main(int argc, char **argv)
309 {
310    static struct option long_options[] = {
311       {"help", no_argument, 0, 'h'},
312       {"prefix", required_argument, 0, OPT_PREFIX},
313       {"in", required_argument, 0, 'i'},
314       {"out", required_argument, 0, 'o'},
315       {"spv", required_argument, 0, 's'},
316       {"verbose", no_argument, 0, 'v'},
317       {0, 0, 0, 0},
318    };
319 
320    char *outfile = NULL, *spv_outfile = NULL, *prefix = NULL;
321    struct util_dynarray clang_args;
322    struct util_dynarray input_files;
323    struct util_dynarray spirv_objs;
324    struct util_dynarray spirv_ptr_objs;
325 
326    void *mem_ctx = ralloc_context(NULL);
327 
328    util_dynarray_init(&clang_args, mem_ctx);
329    util_dynarray_init(&input_files, mem_ctx);
330    util_dynarray_init(&spirv_objs, mem_ctx);
331    util_dynarray_init(&spirv_ptr_objs, mem_ctx);
332 
333    int ch;
334    while ((ch = getopt_long(argc, argv, "he:p:s:i:o:v", long_options, NULL)) !=
335           -1) {
336       switch (ch) {
337       case 'h':
338          print_usage(argv[0], stdout);
339          return 0;
340       case 'o':
341          outfile = optarg;
342          break;
343       case 'i':
344          util_dynarray_append(&input_files, char *, optarg);
345          break;
346       case 's':
347          spv_outfile = optarg;
348          break;
349       case OPT_PREFIX:
350          prefix = optarg;
351          break;
352       default:
353          fprintf(stderr, "Unrecognized option \"%s\".\n", optarg);
354          print_usage(argv[0], stderr);
355          return 1;
356       }
357    }
358 
359    for (int i = optind; i < argc; i++) {
360       util_dynarray_append(&clang_args, char *, argv[i]);
361    }
362 
363    if (util_dynarray_num_elements(&input_files, char *) == 0) {
364       fprintf(stderr, "No input file(s).\n");
365       print_usage(argv[0], stderr);
366       return -1;
367    }
368 
369    if (prefix == NULL) {
370       fprintf(stderr, "No prefix specified.\n");
371       print_usage(argv[0], stderr);
372       return -1;
373    }
374 
375    struct clc_logger logger = {
376       .error = msg_callback,
377       .warning = msg_callback,
378    };
379 
380    util_dynarray_foreach(&input_files, char *, infile) {
381       int fd = open(*infile, O_RDONLY);
382       if (fd < 0) {
383          fprintf(stderr, "Failed to open %s\n", *infile);
384          ralloc_free(mem_ctx);
385          return 1;
386       }
387 
388       off_t len = lseek(fd, 0, SEEK_END);
389       const void *map = mmap(NULL, len, PROT_READ, MAP_PRIVATE, fd, 0);
390       close(fd);
391       if (map == MAP_FAILED) {
392          fprintf(stderr, "Failed to mmap the file: errno=%d, %s\n", errno,
393                  strerror(errno));
394          ralloc_free(mem_ctx);
395          return 1;
396       }
397 
398       const char *allowed_spirv_extensions[] = {
399          "SPV_EXT_shader_atomic_float_add",
400          "SPV_EXT_shader_atomic_float_min_max",
401          "SPV_KHR_float_controls",
402          "SPV_INTEL_subgroups",
403          NULL,
404       };
405 
406       struct clc_compile_args clc_args = {
407          .source =
408             {
409                .name = *infile,
410                .value = map,
411             },
412          .features =
413             {
414                .fp16 = true,
415                .intel_subgroups = true,
416                .subgroups = true,
417                .subgroups_ifp = true,
418             },
419          .args = util_dynarray_begin(&clang_args),
420          .num_args = util_dynarray_num_elements(&clang_args, char *),
421          .allowed_spirv_extensions = allowed_spirv_extensions,
422       };
423 
424       struct clc_binary *spirv_out =
425          util_dynarray_grow(&spirv_objs, struct clc_binary, 1);
426 
427       if (!clc_compile_c_to_spirv(&clc_args, &logger, spirv_out)) {
428          ralloc_free(mem_ctx);
429          return 1;
430       }
431    }
432 
433    util_dynarray_foreach(&spirv_objs, struct clc_binary, p) {
434       util_dynarray_append(&spirv_ptr_objs, struct clc_binary *, p);
435    }
436 
437    /* The SPIRV-Tools linker started checking that all modules have the same
438     * version. But SPIRV-LLVM-Translator picks the lower required version for
439     * each module it compiles. So we have to iterate over all of them and set
440     * the max found to make SPIRV-Tools link our modules.
441     *
442     * TODO: This is not the correct thing to do. We need SPIRV-LLVM-Translator
443     *       to pick a given SPIRV version given to it and have all the modules
444     *       at that version. We should remove this hack when this issue is
445     *       fixed :
446     *       https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1445
447     */
448    uint32_t max_spirv_version = 0;
449    util_dynarray_foreach(&spirv_ptr_objs, struct clc_binary *, module) {
450       max_spirv_version =
451          MAX2(max_spirv_version,
452               get_module_spirv_version((*module)->data, (*module)->size));
453    }
454 
455    assert(max_spirv_version > 0);
456    util_dynarray_foreach(&spirv_ptr_objs, struct clc_binary *, module) {
457       set_module_spirv_version((*module)->data, (*module)->size,
458                                max_spirv_version);
459    }
460 
461    struct clc_linker_args link_args = {
462       .in_objs = util_dynarray_begin(&spirv_ptr_objs),
463       .num_in_objs =
464          util_dynarray_num_elements(&spirv_ptr_objs, struct clc_binary *),
465       .create_library = true,
466    };
467    struct clc_binary final_spirv;
468    if (!clc_link_spirv(&link_args, &logger, &final_spirv)) {
469       ralloc_free(mem_ctx);
470       return 1;
471    }
472 
473    if (spv_outfile) {
474       FILE *fp = fopen(spv_outfile, "w");
475       fwrite(final_spirv.data, final_spirv.size, 1, fp);
476       fclose(fp);
477    }
478 
479    FILE *fp = stdout;
480    if (outfile != NULL)
481       fp = fopen(outfile, "w");
482 
483    glsl_type_singleton_init_or_ref();
484 
485    fprintf(fp, "/*\n");
486    fprintf(fp, " * Copyright The Asahi Linux Contributors\n");
487    fprintf(fp, " * SPDX-License-Identifier: MIT\n");
488    fprintf(fp, " *\n");
489    fprintf(fp, " * Autogenerated file, do not edit\n");
490    fprintf(fp, " */\n");
491    fprintf(fp, " #include <stdint.h>\n");
492 
493    /* Compile SPIR-V to NIR */
494    nir_shader *nir = compile(NULL, final_spirv.data, final_spirv.size);
495 
496    {
497       struct util_dynarray binary;
498       util_dynarray_init(&binary, NULL);
499 
500       nir_builder b = nir_builder_init_simple_shader(
501          MESA_SHADER_COMPUTE, &agx_nir_options, "Helper shader");
502 
503       nir_function *func =
504          nir_shader_get_function_for_name(nir, "libagx_helper");
505 
506       nir_call(&b, nir_function_clone(b.shader, func));
507 
508       UNUSED struct agx_uncompiled_shader_info info;
509       UNUSED struct agx_shader_info compiled_info;
510       struct agx_shader_key key = {
511          .libagx = nir,
512          .is_helper = true,
513       };
514 
515       agx_preprocess_nir(b.shader, nir, false, &info);
516       agx_compile_shader_nir(b.shader, &key, NULL, &binary, &compiled_info);
517 
518       /* Pad out */
519       uint8_t zero = 0;
520       while (binary.size % 4) {
521          util_dynarray_append(&binary, uint8_t, zero);
522       }
523 
524       print_u32_data(fp, "libagx_g13", "helper", binary.data, binary.size);
525       util_dynarray_fini(&binary);
526       ralloc_free(b.shader);
527 
528       /* Remove the NIR function, it's compiled, we don't need it at runtime */
529       exec_node_remove(&func->node);
530    }
531 
532    spirv_library_to_nir_builder(fp, final_spirv.data, final_spirv.size / 4,
533                                 &spirv_options);
534 
535    /* Serialize NIR for embedding */
536    struct blob blob;
537    blob_init(&blob);
538    nir_serialize(&blob, nir, false /* strip */);
539    print_u32_data(fp, prefix, "nir", (const uint32_t *)blob.data, blob.size);
540    blob_finish(&blob);
541 
542    glsl_type_singleton_decref();
543 
544    if (fp != stdout)
545       fclose(fp);
546 
547    ralloc_free(mem_ctx);
548 
549    return 0;
550 }
551