• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_serialize.h"
26 #include "glsl_types.h"
27 #include "nir_types.h"
28 #include "clc_compiler.h"
29 #include "clc_helpers.h"
30 #include "clc_nir.h"
31 #include "../compiler/dxil_nir.h"
32 #include "../compiler/dxil_nir_lower_int_samplers.h"
33 #include "../compiler/nir_to_dxil.h"
34 
35 #include "util/u_debug.h"
36 #include <util/u_math.h>
37 #include "spirv/nir_spirv.h"
38 #include "nir_builder.h"
39 #include "nir_builtin_builder.h"
40 
41 #include "git_sha1.h"
42 
43 struct clc_image_lower_context
44 {
45    struct clc_dxil_metadata *metadata;
46    unsigned *num_srvs;
47    unsigned *num_uavs;
48    nir_deref_instr *deref;
49    unsigned num_buf_ids;
50    int metadata_index;
51 };
52 
53 static int
lower_image_deref_impl(nir_builder * b,struct clc_image_lower_context * context,const struct glsl_type * new_var_type,unsigned * num_bindings)54 lower_image_deref_impl(nir_builder *b, struct clc_image_lower_context *context,
55                        const struct glsl_type *new_var_type,
56                        unsigned *num_bindings)
57 {
58    nir_variable *in_var = nir_deref_instr_get_variable(context->deref);
59    nir_variable *uniform = nir_variable_create(b->shader, nir_var_uniform, new_var_type, NULL);
60    uniform->data.access = in_var->data.access;
61    uniform->data.binding = in_var->data.binding;
62    if (context->num_buf_ids > 0) {
63       // Need to assign a new binding
64       context->metadata->args[context->metadata_index].
65          image.buf_ids[context->num_buf_ids] = uniform->data.binding = (*num_bindings)++;
66    }
67    context->num_buf_ids++;
68    return uniform->data.binding;
69 }
70 
71 static int
lower_read_only_image_deref(nir_builder * b,struct clc_image_lower_context * context,nir_alu_type image_type)72 lower_read_only_image_deref(nir_builder *b, struct clc_image_lower_context *context,
73                             nir_alu_type image_type)
74 {
75    nir_variable *in_var = nir_deref_instr_get_variable(context->deref);
76 
77    // Non-writeable images should be converted to samplers,
78    // since they may have texture operations done on them
79    const struct glsl_type *new_var_type =
80       glsl_sampler_type(glsl_get_sampler_dim(in_var->type),
81             false, glsl_sampler_type_is_array(in_var->type),
82             nir_get_glsl_base_type_for_nir_type(image_type | 32));
83    return lower_image_deref_impl(b, context, new_var_type, context->num_srvs);
84 }
85 
86 static int
lower_read_write_image_deref(nir_builder * b,struct clc_image_lower_context * context,nir_alu_type image_type)87 lower_read_write_image_deref(nir_builder *b, struct clc_image_lower_context *context,
88                              nir_alu_type image_type)
89 {
90    nir_variable *in_var = nir_deref_instr_get_variable(context->deref);
91    const struct glsl_type *new_var_type =
92       glsl_image_type(glsl_get_sampler_dim(in_var->type),
93          glsl_sampler_type_is_array(in_var->type),
94          nir_get_glsl_base_type_for_nir_type(image_type | 32));
95    return lower_image_deref_impl(b, context, new_var_type, context->num_uavs);
96 }
97 
98 static void
clc_lower_input_image_deref(nir_builder * b,struct clc_image_lower_context * context)99 clc_lower_input_image_deref(nir_builder *b, struct clc_image_lower_context *context)
100 {
101    // The input variable here isn't actually an image, it's just the
102    // image format data.
103    //
104    // For every use of an image in a different way, we'll add an
105    // appropriate uniform to match it. That can result in up to
106    // 3 uniforms (float4, int4, uint4) for each image. Only one of these
107    // formats will actually produce correct data, but a single kernel
108    // could use runtime conditionals to potentially access any of them.
109    //
110    // If the image is used in a query that doesn't have a corresponding
111    // DXIL intrinsic (CL image channel order or channel format), then
112    // we'll add a kernel input for that data that'll be lowered by the
113    // explicit IO pass later on.
114    //
115    // After all that, we can remove the image input variable and deref.
116 
117    enum image_uniform_type {
118       FLOAT4,
119       INT4,
120       UINT4,
121       IMAGE_UNIFORM_TYPE_COUNT
122    };
123 
124    int image_bindings[IMAGE_UNIFORM_TYPE_COUNT] = {-1, -1, -1};
125    nir_ssa_def *format_deref_dest = NULL, *order_deref_dest = NULL;
126 
127    nir_variable *in_var = nir_deref_instr_get_variable(context->deref);
128    enum gl_access_qualifier access = in_var->data.access;
129 
130    context->metadata_index = 0;
131    while (context->metadata->args[context->metadata_index].image.buf_ids[0] != in_var->data.binding)
132       context->metadata_index++;
133 
134    context->num_buf_ids = 0;
135 
136    /* Do this in 2 passes:
137     * 1. When encountering a strongly-typed access (load/store), replace the deref
138     *    with one that references an appropriately typed variable. When encountering
139     *    an untyped access (size query), if we have a strongly-typed variable already,
140     *    replace the deref to point to it.
141     * 2. If there's any references left, they should all be untyped. If we found
142     *    a strongly-typed access later in the 1st pass, then just replace the reference.
143     *    If we didn't, e.g. the resource is only used for a size query, then pick an
144     *    arbitrary type for it.
145     */
146    for (int pass = 0; pass < 2; ++pass) {
147       nir_foreach_use_safe(src, &context->deref->dest.ssa) {
148          enum image_uniform_type type;
149 
150          if (src->parent_instr->type == nir_instr_type_intrinsic) {
151             nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(src->parent_instr);
152             enum nir_alu_type dest_type;
153 
154             b->cursor = nir_before_instr(&intrinsic->instr);
155 
156             switch (intrinsic->intrinsic) {
157             case nir_intrinsic_image_deref_load:
158             case nir_intrinsic_image_deref_store: {
159                dest_type = intrinsic->intrinsic == nir_intrinsic_image_deref_load ?
160                   nir_intrinsic_dest_type(intrinsic) : nir_intrinsic_src_type(intrinsic);
161 
162                switch (nir_alu_type_get_base_type(dest_type)) {
163                case nir_type_float: type = FLOAT4; break;
164                case nir_type_int: type = INT4; break;
165                case nir_type_uint: type = UINT4; break;
166                default: unreachable("Unsupported image type for load.");
167                }
168 
169                int image_binding = image_bindings[type];
170                if (image_binding < 0) {
171                   image_binding = image_bindings[type] =
172                      lower_read_write_image_deref(b, context, dest_type);
173                }
174 
175                assert((in_var->data.access & ACCESS_NON_WRITEABLE) == 0);
176                nir_rewrite_image_intrinsic(intrinsic, nir_imm_int(b, image_binding), false);
177                break;
178             }
179 
180             case nir_intrinsic_image_deref_size: {
181                int image_binding = -1;
182                for (unsigned i = 0; i < IMAGE_UNIFORM_TYPE_COUNT; ++i) {
183                   if (image_bindings[i] >= 0) {
184                      image_binding = image_bindings[i];
185                      break;
186                   }
187                }
188                if (image_binding < 0) {
189                   // Skip for now and come back to it
190                   if (pass == 0)
191                      break;
192 
193                   type = FLOAT4;
194                   image_binding = image_bindings[type] =
195                      lower_read_write_image_deref(b, context, nir_type_float32);
196                }
197 
198                assert((in_var->data.access & ACCESS_NON_WRITEABLE) == 0);
199                nir_rewrite_image_intrinsic(intrinsic, nir_imm_int(b, image_binding), false);
200                break;
201             }
202 
203             case nir_intrinsic_image_deref_format:
204             case nir_intrinsic_image_deref_order: {
205                nir_ssa_def **cached_deref = intrinsic->intrinsic == nir_intrinsic_image_deref_format ?
206                   &format_deref_dest : &order_deref_dest;
207                if (!*cached_deref) {
208                   nir_variable *new_input = nir_variable_create(b->shader, nir_var_uniform, glsl_uint_type(), NULL);
209                   new_input->data.driver_location = in_var->data.driver_location;
210                   if (intrinsic->intrinsic == nir_intrinsic_image_deref_format) {
211                      /* Match cl_image_format { image_channel_order, image_channel_data_type }; */
212                      new_input->data.driver_location += glsl_get_cl_size(new_input->type);
213                   }
214 
215                   b->cursor = nir_after_instr(&context->deref->instr);
216                   *cached_deref = nir_load_var(b, new_input);
217                }
218 
219                /* No actual intrinsic needed here, just reference the loaded variable */
220                nir_ssa_def_rewrite_uses(&intrinsic->dest.ssa, *cached_deref);
221                nir_instr_remove(&intrinsic->instr);
222                break;
223             }
224 
225             default:
226                unreachable("Unsupported image intrinsic");
227             }
228          } else if (src->parent_instr->type == nir_instr_type_tex) {
229             assert(in_var->data.access & ACCESS_NON_WRITEABLE);
230             nir_tex_instr *tex = nir_instr_as_tex(src->parent_instr);
231 
232             switch (nir_alu_type_get_base_type(tex->dest_type)) {
233             case nir_type_float: type = FLOAT4; break;
234             case nir_type_int: type = INT4; break;
235             case nir_type_uint: type = UINT4; break;
236             default: unreachable("Unsupported image format for sample.");
237             }
238 
239             int image_binding = image_bindings[type];
240             if (image_binding < 0) {
241                image_binding = image_bindings[type] =
242                   lower_read_only_image_deref(b, context, tex->dest_type);
243             }
244 
245             nir_tex_instr_remove_src(tex, nir_tex_instr_src_index(tex, nir_tex_src_texture_deref));
246             tex->texture_index = image_binding;
247          }
248       }
249    }
250 
251    context->metadata->args[context->metadata_index].image.num_buf_ids = context->num_buf_ids;
252 
253    nir_instr_remove(&context->deref->instr);
254    exec_node_remove(&in_var->node);
255 }
256 
257 static void
clc_lower_images(nir_shader * nir,struct clc_image_lower_context * context)258 clc_lower_images(nir_shader *nir, struct clc_image_lower_context *context)
259 {
260    nir_foreach_function(func, nir) {
261       if (!func->is_entrypoint)
262          continue;
263       assert(func->impl);
264 
265       nir_builder b;
266       nir_builder_init(&b, func->impl);
267 
268       nir_foreach_block(block, func->impl) {
269          nir_foreach_instr_safe(instr, block) {
270             if (instr->type == nir_instr_type_deref) {
271                context->deref = nir_instr_as_deref(instr);
272 
273                if (glsl_type_is_image(context->deref->type)) {
274                   assert(context->deref->deref_type == nir_deref_type_var);
275                   clc_lower_input_image_deref(&b, context);
276                }
277             }
278          }
279       }
280    }
281 }
282 
283 static void
clc_lower_64bit_semantics(nir_shader * nir)284 clc_lower_64bit_semantics(nir_shader *nir)
285 {
286    nir_foreach_function(func, nir) {
287       nir_builder b;
288       nir_builder_init(&b, func->impl);
289 
290       nir_foreach_block(block, func->impl) {
291          nir_foreach_instr_safe(instr, block) {
292             if (instr->type == nir_instr_type_intrinsic) {
293                nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
294                switch (intrinsic->intrinsic) {
295                case nir_intrinsic_load_global_invocation_id:
296                case nir_intrinsic_load_global_invocation_id_zero_base:
297                case nir_intrinsic_load_base_global_invocation_id:
298                case nir_intrinsic_load_local_invocation_id:
299                case nir_intrinsic_load_workgroup_id:
300                case nir_intrinsic_load_workgroup_id_zero_base:
301                case nir_intrinsic_load_base_workgroup_id:
302                case nir_intrinsic_load_num_workgroups:
303                   break;
304                default:
305                   continue;
306                }
307 
308                if (nir_instr_ssa_def(instr)->bit_size != 64)
309                   continue;
310 
311                intrinsic->dest.ssa.bit_size = 32;
312                b.cursor = nir_after_instr(instr);
313 
314                nir_ssa_def *i64 = nir_u2u64(&b, &intrinsic->dest.ssa);
315                nir_ssa_def_rewrite_uses_after(
316                   &intrinsic->dest.ssa,
317                   i64,
318                   i64->parent_instr);
319             }
320          }
321       }
322    }
323 }
324 
325 static void
clc_lower_nonnormalized_samplers(nir_shader * nir,const dxil_wrap_sampler_state * states)326 clc_lower_nonnormalized_samplers(nir_shader *nir,
327                                  const dxil_wrap_sampler_state *states)
328 {
329    nir_foreach_function(func, nir) {
330       if (!func->is_entrypoint)
331          continue;
332       assert(func->impl);
333 
334       nir_builder b;
335       nir_builder_init(&b, func->impl);
336 
337       nir_foreach_block(block, func->impl) {
338          nir_foreach_instr_safe(instr, block) {
339             if (instr->type != nir_instr_type_tex)
340                continue;
341             nir_tex_instr *tex = nir_instr_as_tex(instr);
342 
343             int sampler_src_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
344             if (sampler_src_idx == -1)
345                continue;
346 
347             nir_src *sampler_src = &tex->src[sampler_src_idx].src;
348             assert(sampler_src->is_ssa && sampler_src->ssa->parent_instr->type == nir_instr_type_deref);
349             nir_variable *sampler = nir_deref_instr_get_variable(
350                nir_instr_as_deref(sampler_src->ssa->parent_instr));
351 
352             // If the sampler returns ints, we'll handle this in the int lowering pass
353             if (nir_alu_type_get_base_type(tex->dest_type) != nir_type_float)
354                continue;
355 
356             // If sampler uses normalized coords, nothing to do
357             if (!states[sampler->data.binding].is_nonnormalized_coords)
358                continue;
359 
360             b.cursor = nir_before_instr(&tex->instr);
361 
362             int coords_idx = nir_tex_instr_src_index(tex, nir_tex_src_coord);
363             assert(coords_idx != -1);
364             nir_ssa_def *coords =
365                nir_ssa_for_src(&b, tex->src[coords_idx].src, tex->coord_components);
366 
367             nir_ssa_def *txs = nir_i2f32(&b, nir_get_texture_size(&b, tex));
368 
369             // Normalize coords for tex
370             nir_ssa_def *scale = nir_frcp(&b, txs);
371             nir_ssa_def *comps[4];
372             for (unsigned i = 0; i < coords->num_components; ++i) {
373                comps[i] = nir_channel(&b, coords, i);
374                if (tex->is_array && i == coords->num_components - 1) {
375                   // Don't scale the array index, but do clamp it
376                   comps[i] = nir_fround_even(&b, comps[i]);
377                   comps[i] = nir_fmax(&b, comps[i], nir_imm_float(&b, 0.0f));
378                   comps[i] = nir_fmin(&b, comps[i], nir_fsub(&b, nir_channel(&b, txs, i), nir_imm_float(&b, 1.0f)));
379                   break;
380                }
381 
382                // The CTS is pretty clear that this value has to be floored for nearest sampling
383                // but must not be for linear sampling.
384                if (!states[sampler->data.binding].is_linear_filtering)
385                   comps[i] = nir_fadd_imm(&b, nir_ffloor(&b, comps[i]), 0.5f);
386                comps[i] = nir_fmul(&b, comps[i], nir_channel(&b, scale, i));
387             }
388             nir_ssa_def *normalized_coords = nir_vec(&b, comps, coords->num_components);
389             nir_instr_rewrite_src(&tex->instr,
390                                   &tex->src[coords_idx].src,
391                                   nir_src_for_ssa(normalized_coords));
392          }
393       }
394    }
395 }
396 
397 static nir_variable *
add_kernel_inputs_var(struct clc_dxil_object * dxil,nir_shader * nir,unsigned * cbv_id)398 add_kernel_inputs_var(struct clc_dxil_object *dxil, nir_shader *nir,
399                       unsigned *cbv_id)
400 {
401    if (!dxil->kernel->num_args)
402       return NULL;
403 
404    struct clc_dxil_metadata *metadata = &dxil->metadata;
405    unsigned size = 0;
406 
407    nir_foreach_variable_with_modes(var, nir, nir_var_uniform)
408       size = MAX2(size,
409                   var->data.driver_location +
410                   glsl_get_cl_size(var->type));
411 
412    size = align(size, 4);
413 
414    const struct glsl_type *array_type = glsl_array_type(glsl_uint_type(), size / 4, 4);
415    const struct glsl_struct_field field = { array_type, "arr" };
416    nir_variable *var =
417       nir_variable_create(nir, nir_var_mem_ubo,
418          glsl_struct_type(&field, 1, "kernel_inputs", false),
419          "kernel_inputs");
420    var->data.binding = (*cbv_id)++;
421    var->data.how_declared = nir_var_hidden;
422    return var;
423 }
424 
425 static nir_variable *
add_work_properties_var(struct clc_dxil_object * dxil,struct nir_shader * nir,unsigned * cbv_id)426 add_work_properties_var(struct clc_dxil_object *dxil,
427                            struct nir_shader *nir, unsigned *cbv_id)
428 {
429    struct clc_dxil_metadata *metadata = &dxil->metadata;
430    const struct glsl_type *array_type =
431       glsl_array_type(glsl_uint_type(),
432          sizeof(struct clc_work_properties_data) / sizeof(unsigned),
433          sizeof(unsigned));
434    const struct glsl_struct_field field = { array_type, "arr" };
435    nir_variable *var =
436       nir_variable_create(nir, nir_var_mem_ubo,
437          glsl_struct_type(&field, 1, "kernel_work_properties", false),
438          "kernel_work_properies");
439    var->data.binding = (*cbv_id)++;
440    var->data.how_declared = nir_var_hidden;
441    return var;
442 }
443 
444 static void
clc_lower_constant_to_ssbo(nir_shader * nir,const struct clc_kernel_info * kerninfo,unsigned * uav_id)445 clc_lower_constant_to_ssbo(nir_shader *nir,
446                       const struct clc_kernel_info *kerninfo, unsigned *uav_id)
447 {
448    /* Update UBO vars and assign them a binding. */
449    nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant) {
450       var->data.mode = nir_var_mem_ssbo;
451       var->data.binding = (*uav_id)++;
452    }
453 
454    /* And finally patch all the derefs referincing the constant
455     * variables/pointers.
456     */
457    nir_foreach_function(func, nir) {
458       if (!func->is_entrypoint)
459          continue;
460 
461       assert(func->impl);
462 
463       nir_builder b;
464       nir_builder_init(&b, func->impl);
465 
466       nir_foreach_block(block, func->impl) {
467          nir_foreach_instr(instr, block) {
468             if (instr->type != nir_instr_type_deref)
469                continue;
470 
471             nir_deref_instr *deref = nir_instr_as_deref(instr);
472 
473             if (deref->modes != nir_var_mem_constant)
474                continue;
475 
476             deref->modes = nir_var_mem_ssbo;
477          }
478       }
479    }
480 }
481 
482 static void
clc_lower_global_to_ssbo(nir_shader * nir)483 clc_lower_global_to_ssbo(nir_shader *nir)
484 {
485    nir_foreach_function(func, nir) {
486       if (!func->is_entrypoint)
487          continue;
488 
489       assert(func->impl);
490 
491       nir_foreach_block(block, func->impl) {
492          nir_foreach_instr(instr, block) {
493             if (instr->type != nir_instr_type_deref)
494                continue;
495 
496             nir_deref_instr *deref = nir_instr_as_deref(instr);
497 
498             if (deref->modes != nir_var_mem_global)
499                continue;
500 
501             deref->modes = nir_var_mem_ssbo;
502          }
503       }
504    }
505 }
506 
507 static void
copy_const_initializer(const nir_constant * constant,const struct glsl_type * type,uint8_t * data)508 copy_const_initializer(const nir_constant *constant, const struct glsl_type *type,
509                        uint8_t *data)
510 {
511    unsigned size = glsl_get_cl_size(type);
512 
513    if (glsl_type_is_array(type)) {
514       const struct glsl_type *elm_type = glsl_get_array_element(type);
515       unsigned step_size = glsl_get_explicit_stride(type);
516 
517       for (unsigned i = 0; i < constant->num_elements; i++) {
518          copy_const_initializer(constant->elements[i], elm_type,
519                                 data + (i * step_size));
520       }
521    } else if (glsl_type_is_struct(type)) {
522       for (unsigned i = 0; i < constant->num_elements; i++) {
523          const struct glsl_type *elm_type = glsl_get_struct_field(type, i);
524          int offset = glsl_get_struct_field_offset(type, i);
525          copy_const_initializer(constant->elements[i], elm_type, data + offset);
526       }
527    } else {
528       assert(glsl_type_is_vector_or_scalar(type));
529 
530       for (unsigned i = 0; i < glsl_get_components(type); i++) {
531          switch (glsl_get_bit_size(type)) {
532          case 64:
533             *((uint64_t *)data) = constant->values[i].u64;
534             break;
535          case 32:
536             *((uint32_t *)data) = constant->values[i].u32;
537             break;
538          case 16:
539             *((uint16_t *)data) = constant->values[i].u16;
540             break;
541          case 8:
542             *((uint8_t *)data) = constant->values[i].u8;
543             break;
544          default:
545             unreachable("Invalid base type");
546          }
547 
548          data += glsl_get_bit_size(type) / 8;
549       }
550    }
551 }
552 
553 static const struct glsl_type *
get_cast_type(unsigned bit_size)554 get_cast_type(unsigned bit_size)
555 {
556    switch (bit_size) {
557    case 64:
558       return glsl_int64_t_type();
559    case 32:
560       return glsl_int_type();
561    case 16:
562       return glsl_int16_t_type();
563    case 8:
564       return glsl_int8_t_type();
565    }
566    unreachable("Invalid bit_size");
567 }
568 
569 static void
split_unaligned_load(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)570 split_unaligned_load(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
571 {
572    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
573    nir_ssa_def *srcs[NIR_MAX_VEC_COMPONENTS * NIR_MAX_VEC_COMPONENTS * sizeof(int64_t) / 8];
574    unsigned comp_size = intrin->dest.ssa.bit_size / 8;
575    unsigned num_comps = intrin->dest.ssa.num_components;
576 
577    b->cursor = nir_before_instr(&intrin->instr);
578 
579    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
580 
581    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
582    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->dest.ssa, ptr->modes, cast_type, alignment);
583 
584    unsigned num_loads = DIV_ROUND_UP(comp_size * num_comps, alignment);
585    for (unsigned i = 0; i < num_loads; ++i) {
586       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->dest.ssa.bit_size));
587       srcs[i] = nir_load_deref_with_access(b, elem, access);
588    }
589 
590    nir_ssa_def *new_dest = nir_extract_bits(b, srcs, num_loads, 0, num_comps, intrin->dest.ssa.bit_size);
591    nir_ssa_def_rewrite_uses(&intrin->dest.ssa, new_dest);
592    nir_instr_remove(&intrin->instr);
593 }
594 
595 static void
split_unaligned_store(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)596 split_unaligned_store(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
597 {
598    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
599 
600    assert(intrin->src[1].is_ssa);
601    nir_ssa_def *value = intrin->src[1].ssa;
602    unsigned comp_size = value->bit_size / 8;
603    unsigned num_comps = value->num_components;
604 
605    b->cursor = nir_before_instr(&intrin->instr);
606 
607    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
608 
609    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
610    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->dest.ssa, ptr->modes, cast_type, alignment);
611 
612    unsigned num_stores = DIV_ROUND_UP(comp_size * num_comps, alignment);
613    for (unsigned i = 0; i < num_stores; ++i) {
614       nir_ssa_def *substore_val = nir_extract_bits(b, &value, 1, i * alignment * 8, 1, alignment * 8);
615       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->dest.ssa.bit_size));
616       nir_store_deref_with_access(b, elem, substore_val, ~0, access);
617    }
618 
619    nir_instr_remove(&intrin->instr);
620 }
621 
622 static bool
split_unaligned_loads_stores(nir_shader * shader)623 split_unaligned_loads_stores(nir_shader *shader)
624 {
625    bool progress = false;
626 
627    nir_foreach_function(function, shader) {
628       if (!function->impl)
629          continue;
630 
631       nir_builder b;
632       nir_builder_init(&b, function->impl);
633 
634       nir_foreach_block(block, function->impl) {
635          nir_foreach_instr_safe(instr, block) {
636             if (instr->type != nir_instr_type_intrinsic)
637                continue;
638             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
639             if (intrin->intrinsic != nir_intrinsic_load_deref &&
640                 intrin->intrinsic != nir_intrinsic_store_deref)
641                continue;
642             nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
643 
644             unsigned align_mul = 0, align_offset = 0;
645             nir_get_explicit_deref_align(deref, true, &align_mul, &align_offset);
646 
647             unsigned alignment = align_offset ? 1 << (ffs(align_offset) - 1) : align_mul;
648 
649             /* We can load anything at 4-byte alignment, except for
650              * UBOs (AKA CBs where the granularity is 16 bytes).
651              */
652             if (alignment >= (deref->modes == nir_var_mem_ubo ? 16 : 4))
653                continue;
654 
655             nir_ssa_def *val;
656             if (intrin->intrinsic == nir_intrinsic_load_deref) {
657                assert(intrin->dest.is_ssa);
658                val = &intrin->dest.ssa;
659             } else {
660                assert(intrin->src[1].is_ssa);
661                val = intrin->src[1].ssa;
662             }
663 
664             unsigned natural_alignment =
665                val->bit_size / 8 *
666                (val->num_components == 3 ? 4 : val->num_components);
667 
668             if (alignment >= natural_alignment)
669                continue;
670 
671             if (intrin->intrinsic == nir_intrinsic_load_deref)
672                split_unaligned_load(&b, intrin, alignment);
673             else
674                split_unaligned_store(&b, intrin, alignment);
675             progress = true;
676          }
677       }
678    }
679 
680    return progress;
681 }
682 
683 static enum pipe_tex_wrap
wrap_from_cl_addressing(unsigned addressing_mode)684 wrap_from_cl_addressing(unsigned addressing_mode)
685 {
686    switch (addressing_mode)
687    {
688    default:
689    case SAMPLER_ADDRESSING_MODE_NONE:
690    case SAMPLER_ADDRESSING_MODE_CLAMP:
691       // Since OpenCL's only border color is 0's and D3D specs out-of-bounds loads to return 0, don't apply any wrap mode
692       return (enum pipe_tex_wrap)-1;
693    case SAMPLER_ADDRESSING_MODE_CLAMP_TO_EDGE: return PIPE_TEX_WRAP_CLAMP_TO_EDGE;
694    case SAMPLER_ADDRESSING_MODE_REPEAT: return PIPE_TEX_WRAP_REPEAT;
695    case SAMPLER_ADDRESSING_MODE_REPEAT_MIRRORED: return PIPE_TEX_WRAP_MIRROR_REPEAT;
696    }
697 }
698 
shader_has_double(nir_shader * nir)699 static bool shader_has_double(nir_shader *nir)
700 {
701    bool progress = false;
702 
703    foreach_list_typed(nir_function, func, node, &nir->functions) {
704       if (!func->is_entrypoint)
705          continue;
706 
707       assert(func->impl);
708 
709       nir_foreach_block(block, func->impl) {
710          nir_foreach_instr_safe(instr, block) {
711             if (instr->type != nir_instr_type_alu)
712                continue;
713 
714              nir_alu_instr *alu = nir_instr_as_alu(instr);
715              const nir_op_info *info = &nir_op_infos[alu->op];
716 
717              if (info->output_type & nir_type_float &&
718                  nir_dest_bit_size(alu->dest.dest) == 64)
719                  return true;
720          }
721       }
722    }
723 
724    return false;
725 }
726 
727 static bool
scale_fdiv(nir_shader * nir)728 scale_fdiv(nir_shader *nir)
729 {
730    bool progress = false;
731    nir_foreach_function(func, nir) {
732       if (!func->impl)
733          continue;
734       nir_builder b;
735       nir_builder_init(&b, func->impl);
736       nir_foreach_block(block, func->impl) {
737          nir_foreach_instr(instr, block) {
738             if (instr->type != nir_instr_type_alu)
739                continue;
740             nir_alu_instr *alu = nir_instr_as_alu(instr);
741             if (alu->op != nir_op_fdiv || alu->src[0].src.ssa->bit_size != 32)
742                continue;
743 
744             b.cursor = nir_before_instr(instr);
745             nir_ssa_def *fabs = nir_fabs(&b, alu->src[1].src.ssa);
746             nir_ssa_def *big = nir_flt(&b, nir_imm_int(&b, 0x7e800000), fabs);
747             nir_ssa_def *small = nir_flt(&b, fabs, nir_imm_int(&b, 0x00800000));
748 
749             nir_ssa_def *scaled_down_a = nir_fmul_imm(&b, alu->src[0].src.ssa, 0.25);
750             nir_ssa_def *scaled_down_b = nir_fmul_imm(&b, alu->src[1].src.ssa, 0.25);
751             nir_ssa_def *scaled_up_a = nir_fmul_imm(&b, alu->src[0].src.ssa, 16777216.0);
752             nir_ssa_def *scaled_up_b = nir_fmul_imm(&b, alu->src[1].src.ssa, 16777216.0);
753 
754             nir_ssa_def *final_a =
755                nir_bcsel(&b, big, scaled_down_a,
756               (nir_bcsel(&b, small, scaled_up_a, alu->src[0].src.ssa)));
757             nir_ssa_def *final_b =
758                nir_bcsel(&b, big, scaled_down_b,
759               (nir_bcsel(&b, small, scaled_up_b, alu->src[1].src.ssa)));
760 
761             nir_instr_rewrite_src(instr, &alu->src[0].src, nir_src_for_ssa(final_a));
762             nir_instr_rewrite_src(instr, &alu->src[1].src, nir_src_for_ssa(final_b));
763             progress = true;
764          }
765       }
766    }
767    return progress;
768 }
769 
770 struct clc_libclc *
clc_libclc_new_dxil(const struct clc_logger * logger,const struct clc_libclc_dxil_options * options)771 clc_libclc_new_dxil(const struct clc_logger *logger,
772                     const struct clc_libclc_dxil_options *options)
773 {
774    struct clc_libclc_options clc_options = {
775       .optimize = options->optimize,
776       .nir_options = dxil_get_nir_compiler_options(),
777    };
778 
779    return clc_libclc_new(logger, &clc_options);
780 }
781 
782 bool
clc_spirv_to_dxil(struct clc_libclc * lib,const struct clc_binary * linked_spirv,const struct clc_parsed_spirv * parsed_data,const char * entrypoint,const struct clc_runtime_kernel_conf * conf,const struct clc_spirv_specialization_consts * consts,const struct clc_logger * logger,struct clc_dxil_object * out_dxil)783 clc_spirv_to_dxil(struct clc_libclc *lib,
784                   const struct clc_binary *linked_spirv,
785                   const struct clc_parsed_spirv *parsed_data,
786                   const char *entrypoint,
787                   const struct clc_runtime_kernel_conf *conf,
788                   const struct clc_spirv_specialization_consts *consts,
789                   const struct clc_logger *logger,
790                   struct clc_dxil_object *out_dxil)
791 {
792    struct nir_shader *nir;
793 
794    for (unsigned i = 0; i < parsed_data->num_kernels; i++) {
795       if (!strcmp(parsed_data->kernels[i].name, entrypoint)) {
796          out_dxil->kernel = &parsed_data->kernels[i];
797          break;
798       }
799    }
800 
801    if (!out_dxil->kernel) {
802       clc_error(logger, "no '%s' kernel found", entrypoint);
803       return false;
804    }
805 
806    const struct spirv_to_nir_options spirv_options = {
807       .environment = NIR_SPIRV_OPENCL,
808       .clc_shader = clc_libclc_get_clc_shader(lib),
809       .constant_addr_format = nir_address_format_32bit_index_offset_pack64,
810       .global_addr_format = nir_address_format_32bit_index_offset_pack64,
811       .shared_addr_format = nir_address_format_32bit_offset_as_64bit,
812       .temp_addr_format = nir_address_format_32bit_offset_as_64bit,
813       .float_controls_execution_mode = FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP32,
814       .caps = {
815          .address = true,
816          .float64 = true,
817          .int8 = true,
818          .int16 = true,
819          .int64 = true,
820          .kernel = true,
821          .kernel_image = true,
822          .kernel_image_read_write = true,
823          .literal_sampler = true,
824          .printf = true,
825       },
826    };
827    nir_shader_compiler_options nir_options =
828       *dxil_get_nir_compiler_options();
829 
830    if (conf && conf->lower_bit_size & 64) {
831       nir_options.lower_pack_64_2x32_split = false;
832       nir_options.lower_unpack_64_2x32_split = false;
833       nir_options.lower_int64_options = ~0;
834    }
835 
836    if (conf && conf->lower_bit_size & 16)
837       nir_options.support_16bit_alu = true;
838 
839    glsl_type_singleton_init_or_ref();
840 
841    nir = spirv_to_nir(linked_spirv->data, linked_spirv->size / 4,
842                       consts ? (struct nir_spirv_specialization *)consts->specializations : NULL,
843                       consts ? consts->num_specializations : 0,
844                       MESA_SHADER_KERNEL, entrypoint,
845                       &spirv_options,
846                       &nir_options);
847    if (!nir) {
848       clc_error(logger, "spirv_to_nir() failed");
849       goto err_free_dxil;
850    }
851    nir->info.workgroup_size_variable = true;
852 
853    NIR_PASS_V(nir, nir_lower_goto_ifs);
854    NIR_PASS_V(nir, nir_opt_dead_cf);
855 
856    struct clc_dxil_metadata *metadata = &out_dxil->metadata;
857 
858    metadata->args = calloc(out_dxil->kernel->num_args,
859                            sizeof(*metadata->args));
860    if (!metadata->args) {
861       clc_error(logger, "failed to allocate arg positions");
862       goto err_free_dxil;
863    }
864 
865    {
866       bool progress;
867       do
868       {
869          progress = false;
870          NIR_PASS(progress, nir, nir_copy_prop);
871          NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
872          NIR_PASS(progress, nir, nir_opt_deref);
873          NIR_PASS(progress, nir, nir_opt_dce);
874          NIR_PASS(progress, nir, nir_opt_undef);
875          NIR_PASS(progress, nir, nir_opt_constant_folding);
876          NIR_PASS(progress, nir, nir_opt_cse);
877          NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
878          NIR_PASS(progress, nir, nir_opt_algebraic);
879       } while (progress);
880    }
881 
882    // Inline all functions first.
883    // according to the comment on nir_inline_functions
884    NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
885    NIR_PASS_V(nir, nir_lower_returns);
886    NIR_PASS_V(nir, nir_lower_libclc, clc_libclc_get_clc_shader(lib));
887    NIR_PASS_V(nir, nir_inline_functions);
888 
889    // Pick off the single entrypoint that we want.
890    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
891       if (!func->is_entrypoint)
892          exec_node_remove(&func->node);
893    }
894    assert(exec_list_length(&nir->functions) == 1);
895 
896    {
897       bool progress;
898       do
899       {
900          progress = false;
901          NIR_PASS(progress, nir, nir_copy_prop);
902          NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
903          NIR_PASS(progress, nir, nir_opt_deref);
904          NIR_PASS(progress, nir, nir_opt_dce);
905          NIR_PASS(progress, nir, nir_opt_undef);
906          NIR_PASS(progress, nir, nir_opt_constant_folding);
907          NIR_PASS(progress, nir, nir_opt_cse);
908          NIR_PASS(progress, nir, nir_split_var_copies);
909          NIR_PASS(progress, nir, nir_lower_var_copies);
910          NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
911          NIR_PASS(progress, nir, nir_opt_algebraic);
912          NIR_PASS(progress, nir, nir_opt_if, true);
913          NIR_PASS(progress, nir, nir_opt_dead_cf);
914          NIR_PASS(progress, nir, nir_opt_remove_phis);
915          NIR_PASS(progress, nir, nir_opt_peephole_select, 8, true, true);
916          NIR_PASS(progress, nir, nir_lower_vec3_to_vec4, nir_var_mem_generic | nir_var_uniform);
917       } while (progress);
918    }
919 
920    NIR_PASS_V(nir, scale_fdiv);
921 
922    dxil_wrap_sampler_state int_sampler_states[PIPE_MAX_SHADER_SAMPLER_VIEWS] = { {{0}} };
923    unsigned sampler_id = 0;
924 
925    struct exec_list inline_samplers_list;
926    exec_list_make_empty(&inline_samplers_list);
927 
928    // Move inline samplers to the end of the uniforms list
929    nir_foreach_variable_with_modes_safe(var, nir, nir_var_uniform) {
930       if (glsl_type_is_sampler(var->type) && var->data.sampler.is_inline_sampler) {
931          exec_node_remove(&var->node);
932          exec_list_push_tail(&inline_samplers_list, &var->node);
933       }
934    }
935    exec_node_insert_list_after(exec_list_get_tail(&nir->variables), &inline_samplers_list);
936 
937    NIR_PASS_V(nir, nir_lower_variable_initializers, ~(nir_var_function_temp | nir_var_shader_temp));
938 
939    // Lower memcpy
940    NIR_PASS_V(nir, dxil_nir_lower_memcpy_deref);
941 
942    // Ensure the printf struct has explicit types, but we'll throw away the scratch size, because we haven't
943    // necessarily removed all temp variables (e.g. the printf struct itself) at this point, so we'll rerun this later
944    assert(nir->scratch_size == 0);
945    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_function_temp, glsl_get_cl_type_size_align);
946 
947    nir_lower_printf_options printf_options = {
948       .treat_doubles_as_floats = true,
949       .max_buffer_size = 1024 * 1024
950    };
951    NIR_PASS_V(nir, nir_lower_printf, &printf_options);
952 
953    metadata->printf.info_count = nir->printf_info_count;
954    metadata->printf.infos = calloc(nir->printf_info_count, sizeof(struct clc_printf_info));
955    for (unsigned i = 0; i < nir->printf_info_count; i++) {
956       metadata->printf.infos[i].str = malloc(nir->printf_info[i].string_size);
957       memcpy(metadata->printf.infos[i].str, nir->printf_info[i].strings, nir->printf_info[i].string_size);
958       metadata->printf.infos[i].num_args = nir->printf_info[i].num_args;
959       metadata->printf.infos[i].arg_sizes = malloc(nir->printf_info[i].num_args * sizeof(unsigned));
960       memcpy(metadata->printf.infos[i].arg_sizes, nir->printf_info[i].arg_sizes, nir->printf_info[i].num_args * sizeof(unsigned));
961    }
962 
963    // copy propagate to prepare for lower_explicit_io
964    NIR_PASS_V(nir, nir_split_var_copies);
965    NIR_PASS_V(nir, nir_opt_copy_prop_vars);
966    NIR_PASS_V(nir, nir_lower_var_copies);
967    NIR_PASS_V(nir, nir_lower_vars_to_ssa);
968    NIR_PASS_V(nir, nir_lower_alu);
969    NIR_PASS_V(nir, nir_opt_dce);
970    NIR_PASS_V(nir, nir_opt_deref);
971 
972    // For uniforms (kernel inputs), run this before adjusting variable list via image/sampler lowering
973    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_uniform, glsl_get_cl_type_size_align);
974 
975    // Calculate input offsets/metadata.
976    unsigned uav_id = 0;
977    nir_foreach_variable_with_modes(var, nir, nir_var_uniform) {
978       int i = var->data.location;
979       if (i < 0)
980          continue;
981 
982       unsigned size = glsl_get_cl_size(var->type);
983 
984       metadata->args[i].offset = var->data.driver_location;
985       metadata->args[i].size = size;
986       metadata->kernel_inputs_buf_size = MAX2(metadata->kernel_inputs_buf_size,
987          var->data.driver_location + size);
988       if ((out_dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL ||
989          out_dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT) &&
990          // Ignore images during this pass - global memory buffers need to have contiguous bindings
991          !glsl_type_is_image(var->type)) {
992          metadata->args[i].globconstptr.buf_id = uav_id++;
993       } else if (glsl_type_is_sampler(var->type)) {
994          unsigned address_mode = conf ? conf->args[i].sampler.addressing_mode : 0u;
995          int_sampler_states[sampler_id].wrap[0] =
996             int_sampler_states[sampler_id].wrap[1] =
997             int_sampler_states[sampler_id].wrap[2] = wrap_from_cl_addressing(address_mode);
998          int_sampler_states[sampler_id].is_nonnormalized_coords =
999             conf ? !conf->args[i].sampler.normalized_coords : 0;
1000          int_sampler_states[sampler_id].is_linear_filtering =
1001             conf ? conf->args[i].sampler.linear_filtering : 0;
1002          metadata->args[i].sampler.sampler_id = var->data.binding = sampler_id++;
1003       }
1004    }
1005 
1006    unsigned num_global_inputs = uav_id;
1007 
1008    // Second pass over inputs to calculate image bindings
1009    unsigned srv_id = 0;
1010    nir_foreach_variable_with_modes(var, nir, nir_var_uniform) {
1011       int i = var->data.location;
1012       if (i < 0)
1013          continue;
1014 
1015       if (glsl_type_is_image(var->type)) {
1016          if (var->data.access == ACCESS_NON_WRITEABLE) {
1017             metadata->args[i].image.buf_ids[0] = srv_id++;
1018          } else {
1019             // Write or read-write are UAVs
1020             metadata->args[i].image.buf_ids[0] = uav_id++;
1021          }
1022 
1023          metadata->args[i].image.num_buf_ids = 1;
1024          var->data.binding = metadata->args[i].image.buf_ids[0];
1025       }
1026    }
1027 
1028    // Before removing dead uniforms, dedupe constant samplers to make more dead uniforms
1029    NIR_PASS_V(nir, clc_nir_dedupe_const_samplers);
1030    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_uniform | nir_var_mem_ubo | nir_var_mem_constant | nir_var_function_temp, NULL);
1031 
1032    // Fill out inline sampler metadata, now that they've been deduped and dead ones removed
1033    nir_foreach_variable_with_modes(var, nir, nir_var_uniform) {
1034       if (glsl_type_is_sampler(var->type) && var->data.sampler.is_inline_sampler) {
1035          int_sampler_states[sampler_id].wrap[0] =
1036             int_sampler_states[sampler_id].wrap[1] =
1037             int_sampler_states[sampler_id].wrap[2] =
1038             wrap_from_cl_addressing(var->data.sampler.addressing_mode);
1039          int_sampler_states[sampler_id].is_nonnormalized_coords =
1040             !var->data.sampler.normalized_coordinates;
1041          int_sampler_states[sampler_id].is_linear_filtering =
1042             var->data.sampler.filter_mode == SAMPLER_FILTER_MODE_LINEAR;
1043          var->data.binding = sampler_id++;
1044 
1045          assert(metadata->num_const_samplers < CLC_MAX_SAMPLERS);
1046          metadata->const_samplers[metadata->num_const_samplers].sampler_id = var->data.binding;
1047          metadata->const_samplers[metadata->num_const_samplers].addressing_mode = var->data.sampler.addressing_mode;
1048          metadata->const_samplers[metadata->num_const_samplers].normalized_coords = var->data.sampler.normalized_coordinates;
1049          metadata->const_samplers[metadata->num_const_samplers].filter_mode = var->data.sampler.filter_mode;
1050          metadata->num_const_samplers++;
1051       }
1052    }
1053 
1054    // Needs to come before lower_explicit_io
1055    NIR_PASS_V(nir, nir_lower_readonly_images_to_tex, false);
1056    struct clc_image_lower_context image_lower_context = { metadata, &srv_id, &uav_id };
1057    NIR_PASS_V(nir, clc_lower_images, &image_lower_context);
1058    NIR_PASS_V(nir, clc_lower_nonnormalized_samplers, int_sampler_states);
1059    NIR_PASS_V(nir, nir_lower_samplers);
1060    NIR_PASS_V(nir, dxil_lower_sample_to_txf_for_integer_tex,
1061               int_sampler_states, NULL, 14.0f);
1062 
1063    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_mem_shared | nir_var_function_temp, NULL);
1064 
1065    nir->scratch_size = 0;
1066    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
1067               nir_var_mem_shared | nir_var_function_temp | nir_var_mem_global | nir_var_mem_constant,
1068               glsl_get_cl_type_size_align);
1069 
1070    NIR_PASS_V(nir, dxil_nir_lower_ubo_to_temp);
1071    NIR_PASS_V(nir, clc_lower_constant_to_ssbo, out_dxil->kernel, &uav_id);
1072    NIR_PASS_V(nir, clc_lower_global_to_ssbo);
1073 
1074    bool has_printf = false;
1075    NIR_PASS(has_printf, nir, clc_lower_printf_base, uav_id);
1076    metadata->printf.uav_id = has_printf ? uav_id++ : -1;
1077 
1078    NIR_PASS_V(nir, dxil_nir_lower_deref_ssbo);
1079 
1080    NIR_PASS_V(nir, split_unaligned_loads_stores);
1081 
1082    assert(nir->info.cs.ptr_size == 64);
1083    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ssbo,
1084               nir_address_format_32bit_index_offset_pack64);
1085    NIR_PASS_V(nir, nir_lower_explicit_io,
1086               nir_var_mem_shared | nir_var_function_temp | nir_var_uniform,
1087               nir_address_format_32bit_offset_as_64bit);
1088 
1089    NIR_PASS_V(nir, nir_lower_system_values);
1090 
1091    nir_lower_compute_system_values_options compute_options = {
1092       .has_base_global_invocation_id = (conf && conf->support_global_work_id_offsets),
1093       .has_base_workgroup_id = (conf && conf->support_workgroup_id_offsets),
1094    };
1095    NIR_PASS_V(nir, nir_lower_compute_system_values, &compute_options);
1096 
1097    NIR_PASS_V(nir, clc_lower_64bit_semantics);
1098 
1099    NIR_PASS_V(nir, nir_opt_deref);
1100    NIR_PASS_V(nir, nir_lower_vars_to_ssa);
1101 
1102    unsigned cbv_id = 0;
1103 
1104    nir_variable *inputs_var =
1105       add_kernel_inputs_var(out_dxil, nir, &cbv_id);
1106    nir_variable *work_properties_var =
1107       add_work_properties_var(out_dxil, nir, &cbv_id);
1108 
1109    memcpy(metadata->local_size, nir->info.workgroup_size,
1110           sizeof(metadata->local_size));
1111    memcpy(metadata->local_size_hint, nir->info.cs.workgroup_size_hint,
1112           sizeof(metadata->local_size));
1113 
1114    // Patch the localsize before calling clc_nir_lower_system_values().
1115    if (conf) {
1116       for (unsigned i = 0; i < ARRAY_SIZE(nir->info.workgroup_size); i++) {
1117          if (!conf->local_size[i] ||
1118              conf->local_size[i] == nir->info.workgroup_size[i])
1119             continue;
1120 
1121          if (nir->info.workgroup_size[i] &&
1122              nir->info.workgroup_size[i] != conf->local_size[i]) {
1123             debug_printf("D3D12: runtime local size does not match reqd_work_group_size() values\n");
1124             goto err_free_dxil;
1125          }
1126 
1127          nir->info.workgroup_size[i] = conf->local_size[i];
1128       }
1129       memcpy(metadata->local_size, nir->info.workgroup_size,
1130             sizeof(metadata->local_size));
1131    } else {
1132       /* Make sure there's at least one thread that's set to run */
1133       for (unsigned i = 0; i < ARRAY_SIZE(nir->info.workgroup_size); i++) {
1134          if (nir->info.workgroup_size[i] == 0)
1135             nir->info.workgroup_size[i] = 1;
1136       }
1137    }
1138 
1139    NIR_PASS_V(nir, clc_nir_lower_kernel_input_loads, inputs_var);
1140    NIR_PASS_V(nir, split_unaligned_loads_stores);
1141    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ubo,
1142               nir_address_format_32bit_index_offset);
1143    NIR_PASS_V(nir, clc_nir_lower_system_values, work_properties_var);
1144    NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil);
1145    NIR_PASS_V(nir, dxil_nir_opt_alu_deref_srcs);
1146    NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
1147    NIR_PASS_V(nir, nir_lower_fp16_casts);
1148    NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
1149 
1150    // Convert pack to pack_split
1151    NIR_PASS_V(nir, nir_lower_pack);
1152    // Lower pack_split to bit math
1153    NIR_PASS_V(nir, nir_opt_algebraic);
1154 
1155    NIR_PASS_V(nir, nir_opt_dce);
1156 
1157    nir_validate_shader(nir, "Validate before feeding NIR to the DXIL compiler");
1158    struct nir_to_dxil_options opts = {
1159       .interpolate_at_vertex = false,
1160       .lower_int16 = (conf && (conf->lower_bit_size & 16) != 0),
1161       .ubo_binding_offset = 0,
1162       .disable_math_refactoring = true,
1163       .num_kernel_globals = num_global_inputs,
1164    };
1165 
1166    for (unsigned i = 0; i < out_dxil->kernel->num_args; i++) {
1167       if (out_dxil->kernel->args[i].address_qualifier != CLC_KERNEL_ARG_ADDRESS_LOCAL)
1168          continue;
1169 
1170       /* If we don't have the runtime conf yet, we just create a dummy variable.
1171        * This will be adjusted when clc_spirv_to_dxil() is called with a conf
1172        * argument.
1173        */
1174       unsigned size = 4;
1175       if (conf && conf->args)
1176          size = conf->args[i].localptr.size;
1177 
1178       /* The alignment required for the pointee type is not easy to get from
1179        * here, so let's base our logic on the size itself. Anything bigger than
1180        * the maximum alignment constraint (which is 128 bytes, since ulong16 or
1181        * doubl16 size are the biggest base types) should be aligned on this
1182        * maximum alignment constraint. For smaller types, we use the size
1183        * itself to calculate the alignment.
1184        */
1185       unsigned alignment = size < 128 ? (1 << (ffs(size) - 1)) : 128;
1186 
1187       nir->info.shared_size = align(nir->info.shared_size, alignment);
1188       metadata->args[i].localptr.sharedmem_offset = nir->info.shared_size;
1189       nir->info.shared_size += size;
1190    }
1191 
1192    metadata->local_mem_size = nir->info.shared_size;
1193    metadata->priv_mem_size = nir->scratch_size;
1194 
1195    /* DXIL double math is too limited compared to what NIR expects. Let's refuse
1196     * to compile a shader when it contains double operations until we have
1197     * double lowering hooked up.
1198     */
1199    if (shader_has_double(nir)) {
1200       clc_error(logger, "NIR shader contains doubles, which we don't support yet");
1201       goto err_free_dxil;
1202    }
1203 
1204    struct blob tmp;
1205    if (!nir_to_dxil(nir, &opts, &tmp)) {
1206       debug_printf("D3D12: nir_to_dxil failed\n");
1207       goto err_free_dxil;
1208    }
1209 
1210    nir_foreach_variable_with_modes(var, nir, nir_var_mem_ssbo) {
1211       if (var->constant_initializer) {
1212          if (glsl_type_is_array(var->type)) {
1213             int size = align(glsl_get_cl_size(var->type), 4);
1214             uint8_t *data = malloc(size);
1215             if (!data)
1216                goto err_free_dxil;
1217 
1218             copy_const_initializer(var->constant_initializer, var->type, data);
1219             metadata->consts[metadata->num_consts].data = data;
1220             metadata->consts[metadata->num_consts].size = size;
1221             metadata->consts[metadata->num_consts].uav_id = var->data.binding;
1222             metadata->num_consts++;
1223          } else
1224             unreachable("unexpected constant initializer");
1225       }
1226    }
1227 
1228    metadata->kernel_inputs_cbv_id = inputs_var ? inputs_var->data.binding : 0;
1229    metadata->work_properties_cbv_id = work_properties_var->data.binding;
1230    metadata->num_uavs = uav_id;
1231    metadata->num_srvs = srv_id;
1232    metadata->num_samplers = sampler_id;
1233 
1234    ralloc_free(nir);
1235    glsl_type_singleton_decref();
1236 
1237    blob_finish_get_buffer(&tmp, &out_dxil->binary.data,
1238                           &out_dxil->binary.size);
1239    return true;
1240 
1241 err_free_dxil:
1242    clc_free_dxil_object(out_dxil);
1243    return false;
1244 }
1245 
clc_free_dxil_object(struct clc_dxil_object * dxil)1246 void clc_free_dxil_object(struct clc_dxil_object *dxil)
1247 {
1248    for (unsigned i = 0; i < dxil->metadata.num_consts; i++)
1249       free(dxil->metadata.consts[i].data);
1250 
1251    for (unsigned i = 0; i < dxil->metadata.printf.info_count; i++) {
1252       free(dxil->metadata.printf.infos[i].arg_sizes);
1253       free(dxil->metadata.printf.infos[i].str);
1254    }
1255    free(dxil->metadata.printf.infos);
1256 
1257    free(dxil->binary.data);
1258 }
1259 
clc_compiler_get_version()1260 uint64_t clc_compiler_get_version()
1261 {
1262    const char sha1[] = MESA_GIT_SHA1;
1263    const char* dash = strchr(sha1, '-');
1264    if (dash) {
1265       return strtoull(dash + 1, NULL, 16);
1266    }
1267    return 0;
1268 }
1269