1 /*
2 * Copyright © 2021 Google
3 * Copyright © 2023 Valve Corporation
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "nir.h"
8 #include "nir_builder.h"
9 #include "nir_deref.h"
10 #include "radv_constants.h"
11 #include "radv_nir.h"
12
13 struct lower_hit_attrib_deref_args {
14 nir_variable_mode mode;
15 uint32_t base_offset;
16 };
17
18 static bool
lower_hit_attrib_deref(nir_builder * b,nir_instr * instr,void * data)19 lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
20 {
21 if (instr->type != nir_instr_type_intrinsic)
22 return false;
23
24 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
25 if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != nir_intrinsic_store_deref)
26 return false;
27
28 struct lower_hit_attrib_deref_args *args = data;
29 nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
30 if (!nir_deref_mode_is(deref, args->mode))
31 return false;
32
33 b->cursor = nir_after_instr(instr);
34
35 nir_variable *var = nir_deref_instr_get_variable(deref);
36 uint32_t location = args->base_offset + var->data.driver_location +
37 nir_deref_instr_get_const_offset(deref, glsl_get_natural_size_align_bytes);
38
39 if (intrin->intrinsic == nir_intrinsic_load_deref) {
40 uint32_t num_components = intrin->def.num_components;
41 uint32_t bit_size = intrin->def.bit_size;
42
43 nir_def *components[NIR_MAX_VEC_COMPONENTS];
44
45 for (uint32_t comp = 0; comp < num_components; comp++) {
46 uint32_t offset = location + comp * DIV_ROUND_UP(bit_size, 8);
47 uint32_t base = offset / 4;
48 uint32_t comp_offset = offset % 4;
49
50 if (bit_size == 64) {
51 components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base),
52 nir_load_hit_attrib_amd(b, .base = base + 1));
53 } else if (bit_size == 32) {
54 components[comp] = nir_load_hit_attrib_amd(b, .base = base);
55 } else if (bit_size == 16) {
56 components[comp] =
57 nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2);
58 } else if (bit_size == 8) {
59 components[comp] =
60 nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset);
61 } else {
62 assert(bit_size == 1);
63 components[comp] = nir_i2b(b, nir_load_hit_attrib_amd(b, .base = base));
64 }
65 }
66
67 nir_def_rewrite_uses(&intrin->def, nir_vec(b, components, num_components));
68 } else {
69 nir_def *value = intrin->src[1].ssa;
70 uint32_t num_components = value->num_components;
71 uint32_t bit_size = value->bit_size;
72
73 for (uint32_t comp = 0; comp < num_components; comp++) {
74 uint32_t offset = location + comp * DIV_ROUND_UP(bit_size, 8);
75 uint32_t base = offset / 4;
76 uint32_t comp_offset = offset % 4;
77
78 nir_def *component = nir_channel(b, value, comp);
79
80 if (bit_size == 64) {
81 nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base);
82 nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1);
83 } else if (bit_size == 32) {
84 nir_store_hit_attrib_amd(b, component, .base = base);
85 } else if (bit_size == 16) {
86 nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base));
87 nir_def *components[2];
88 for (uint32_t word = 0; word < 2; word++)
89 components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word);
90 nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), .base = base);
91 } else if (bit_size == 8) {
92 nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8);
93 nir_def *components[4];
94 for (uint32_t byte = 0; byte < 4; byte++)
95 components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte);
96 nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), .base = base);
97 } else {
98 assert(bit_size == 1);
99 nir_store_hit_attrib_amd(b, nir_b2i32(b, component), .base = base);
100 }
101 }
102 }
103
104 nir_instr_remove(instr);
105 return true;
106 }
107
108 static bool
radv_lower_payload_arg_to_offset(nir_builder * b,nir_intrinsic_instr * instr,void * data)109 radv_lower_payload_arg_to_offset(nir_builder *b, nir_intrinsic_instr *instr, void *data)
110 {
111 if (instr->intrinsic != nir_intrinsic_trace_ray)
112 return false;
113
114 nir_deref_instr *payload = nir_src_as_deref(instr->src[10]);
115 assert(payload->deref_type == nir_deref_type_var);
116
117 b->cursor = nir_before_instr(&instr->instr);
118 nir_def *offset = nir_imm_int(b, payload->var->data.driver_location);
119
120 nir_src_rewrite(&instr->src[10], offset);
121
122 return true;
123 }
124
125 static bool
radv_nir_lower_rt_vars(nir_shader * shader,nir_variable_mode mode,uint32_t base_offset)126 radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, uint32_t base_offset)
127 {
128 bool progress = false;
129
130 progress |= nir_lower_indirect_derefs(shader, mode, UINT32_MAX);
131
132 progress |= nir_lower_vars_to_explicit_types(shader, mode, glsl_get_natural_size_align_bytes);
133
134 if (shader->info.stage == MESA_SHADER_RAYGEN && mode == nir_var_function_temp)
135 progress |= nir_shader_intrinsics_pass(shader, radv_lower_payload_arg_to_offset, nir_metadata_control_flow, NULL);
136
137 struct lower_hit_attrib_deref_args args = {
138 .mode = mode,
139 .base_offset = base_offset,
140 };
141
142 progress |= nir_shader_instructions_pass(shader, lower_hit_attrib_deref, nir_metadata_control_flow, &args);
143
144 if (progress) {
145 nir_remove_dead_derefs(shader);
146 nir_remove_dead_variables(shader, mode, NULL);
147 }
148
149 return progress;
150 }
151
152 bool
radv_nir_lower_hit_attrib_derefs(nir_shader * shader)153 radv_nir_lower_hit_attrib_derefs(nir_shader *shader)
154 {
155 return radv_nir_lower_rt_vars(shader, nir_var_ray_hit_attrib, 0);
156 }
157
158 bool
radv_nir_lower_ray_payload_derefs(nir_shader * shader,uint32_t offset)159 radv_nir_lower_ray_payload_derefs(nir_shader *shader, uint32_t offset)
160 {
161 bool progress = radv_nir_lower_rt_vars(shader, nir_var_function_temp, RADV_MAX_HIT_ATTRIB_SIZE + offset);
162 progress |= radv_nir_lower_rt_vars(shader, nir_var_shader_call_data, RADV_MAX_HIT_ATTRIB_SIZE + offset);
163 return progress;
164 }
165