1 /*
2 * Copyright © 2020 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "brw_nir_rt.h"
25 #include "brw_nir_rt_builder.h"
26 #include "nir_phi_builder.h"
27
28 /** Insert the appropriate return instruction at the end of the shader */
29 void
brw_nir_lower_shader_returns(nir_shader * shader)30 brw_nir_lower_shader_returns(nir_shader *shader)
31 {
32 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
33
34 /* Reserve scratch space at the start of the shader's per-thread scratch
35 * space for the return BINDLESS_SHADER_RECORD address and data payload.
36 * When a shader is called, the calling shader will write the return BSR
37 * address in this region of the callee's scratch space.
38 *
39 * We could also put it at the end of the caller's scratch space. However,
40 * doing this way means that a shader never accesses its caller's scratch
41 * space unless given an explicit pointer (such as for ray payloads). It
42 * also makes computing the address easier given that we want to apply an
43 * alignment to the scratch offset to ensure we can make alignment
44 * assumptions in the called shader.
45 *
46 * This isn't needed for ray-gen shaders because they end the thread and
47 * never return to the calling trampoline shader.
48 */
49 assert(shader->scratch_size == 0);
50 if (shader->info.stage != MESA_SHADER_RAYGEN)
51 shader->scratch_size = BRW_BTD_STACK_CALLEE_DATA_SIZE;
52
53 nir_builder b;
54 nir_builder_init(&b, impl);
55
56 set_foreach(impl->end_block->predecessors, block_entry) {
57 struct nir_block *block = (void *)block_entry->key;
58 b.cursor = nir_after_block_before_jump(block);
59
60 switch (shader->info.stage) {
61 case MESA_SHADER_RAYGEN:
62 /* A raygen shader is always the root of the shader call tree. When
63 * it ends, we retire the bindless stack ID and no further shaders
64 * will be executed.
65 */
66 brw_nir_btd_retire(&b);
67 break;
68
69 case MESA_SHADER_ANY_HIT:
70 /* The default action of an any-hit shader is to accept the ray
71 * intersection.
72 */
73 nir_accept_ray_intersection(&b);
74 break;
75
76 case MESA_SHADER_CALLABLE:
77 case MESA_SHADER_MISS:
78 case MESA_SHADER_CLOSEST_HIT:
79 /* Callable, miss, and closest-hit shaders don't take any special
80 * action at the end. They simply return back to the previous shader
81 * in the call stack.
82 */
83 brw_nir_btd_return(&b);
84 break;
85
86 case MESA_SHADER_INTERSECTION:
87 /* This will be handled by brw_nir_lower_intersection_shader */
88 break;
89
90 default:
91 unreachable("Invalid callable shader stage");
92 }
93
94 assert(impl->end_block->predecessors->entries == 1);
95 break;
96 }
97
98 nir_metadata_preserve(impl, nir_metadata_block_index |
99 nir_metadata_dominance);
100 }
101
102 static void
store_resume_addr(nir_builder * b,nir_intrinsic_instr * call)103 store_resume_addr(nir_builder *b, nir_intrinsic_instr *call)
104 {
105 uint32_t call_idx = nir_intrinsic_call_idx(call);
106 uint32_t offset = nir_intrinsic_stack_size(call);
107
108 /* First thing on the called shader's stack is the resume address
109 * followed by a pointer to the payload.
110 */
111 nir_ssa_def *resume_record_addr =
112 nir_iadd_imm(b, nir_load_btd_resume_sbt_addr_intel(b),
113 call_idx * BRW_BTD_RESUME_SBT_STRIDE);
114 /* By the time we get here, any remaining shader/function memory
115 * pointers have been lowered to SSA values.
116 */
117 assert(nir_get_shader_call_payload_src(call)->is_ssa);
118 nir_ssa_def *payload_addr =
119 nir_get_shader_call_payload_src(call)->ssa;
120 brw_nir_rt_store_scratch(b, offset, BRW_BTD_STACK_ALIGN,
121 nir_vec2(b, resume_record_addr, payload_addr),
122 0xf /* write_mask */);
123
124 nir_btd_stack_push_intel(b, offset);
125 }
126
127 static bool
lower_shader_calls_instr(struct nir_builder * b,nir_instr * instr,void * data)128 lower_shader_calls_instr(struct nir_builder *b, nir_instr *instr, void *data)
129 {
130 if (instr->type != nir_instr_type_intrinsic)
131 return false;
132
133 /* Leave nir_intrinsic_rt_resume to be lowered by
134 * brw_nir_lower_rt_intrinsics()
135 */
136 nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
137
138 switch (call->intrinsic) {
139 case nir_intrinsic_rt_trace_ray: {
140 store_resume_addr(b, call);
141
142 nir_ssa_def *as_addr = call->src[0].ssa;
143 nir_ssa_def *ray_flags = call->src[1].ssa;
144 /* From the SPIR-V spec:
145 *
146 * "Only the 8 least-significant bits of Cull Mask are used by this
147 * instruction - other bits are ignored.
148 *
149 * Only the 4 least-significant bits of SBT Offset and SBT Stride are
150 * used by this instruction - other bits are ignored.
151 *
152 * Only the 16 least-significant bits of Miss Index are used by this
153 * instruction - other bits are ignored."
154 */
155 nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
156 nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
157 nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
158 nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
159 nir_ssa_def *ray_orig = call->src[6].ssa;
160 nir_ssa_def *ray_t_min = call->src[7].ssa;
161 nir_ssa_def *ray_dir = call->src[8].ssa;
162 nir_ssa_def *ray_t_max = call->src[9].ssa;
163
164 /* The hardware packet takes the address to the root node in the
165 * acceleration structure, not the acceleration structure itself. To
166 * find that, we have to read the root node offset from the acceleration
167 * structure which is the first QWord.
168 */
169 nir_ssa_def *root_node_ptr =
170 nir_iadd(b, as_addr, nir_load_global(b, as_addr, 256, 1, 64));
171
172 /* The hardware packet requires an address to the first element of the
173 * hit SBT.
174 *
175 * In order to calculate this, we must multiply the "SBT Offset"
176 * provided to OpTraceRay by the SBT stride provided for the hit SBT in
177 * the call to vkCmdTraceRay() and add that to the base address of the
178 * hit SBT. This stride is not to be confused with the "SBT Stride"
179 * provided to OpTraceRay which is in units of this stride. It's a
180 * rather terrible overload of the word "stride". The hardware docs
181 * calls the SPIR-V stride value the "shader index multiplier" which is
182 * a much more sane name.
183 */
184 nir_ssa_def *hit_sbt_stride_B =
185 nir_load_ray_hit_sbt_stride_intel(b);
186 nir_ssa_def *hit_sbt_offset_B =
187 nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
188 nir_ssa_def *hit_sbt_addr =
189 nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
190 nir_u2u64(b, hit_sbt_offset_B));
191
192 /* The hardware packet takes an address to the miss BSR. */
193 nir_ssa_def *miss_sbt_stride_B =
194 nir_load_ray_miss_sbt_stride_intel(b);
195 nir_ssa_def *miss_sbt_offset_B =
196 nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
197 nir_ssa_def *miss_sbt_addr =
198 nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
199 nir_u2u64(b, miss_sbt_offset_B));
200
201 struct brw_nir_rt_mem_ray_defs ray_defs = {
202 .root_node_ptr = root_node_ptr,
203 .ray_flags = nir_u2u16(b, ray_flags),
204 .ray_mask = cull_mask,
205 .hit_group_sr_base_ptr = hit_sbt_addr,
206 .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
207 .miss_sr_ptr = miss_sbt_addr,
208 .orig = ray_orig,
209 .t_near = ray_t_min,
210 .dir = ray_dir,
211 .t_far = ray_t_max,
212 .shader_index_multiplier = sbt_stride,
213 };
214 brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
215 nir_trace_ray_initial_intel(b);
216 return true;
217 }
218
219 case nir_intrinsic_rt_execute_callable: {
220 store_resume_addr(b, call);
221
222 nir_ssa_def *sbt_offset32 =
223 nir_imul(b, call->src[0].ssa,
224 nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
225 nir_ssa_def *sbt_addr =
226 nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
227 nir_u2u64(b, sbt_offset32));
228 brw_nir_btd_spawn(b, sbt_addr);
229 return true;
230 }
231
232 default:
233 return false;
234 }
235 }
236
237 bool
brw_nir_lower_shader_calls(nir_shader * shader)238 brw_nir_lower_shader_calls(nir_shader *shader)
239 {
240 return nir_shader_instructions_pass(shader,
241 lower_shader_calls_instr,
242 nir_metadata_block_index |
243 nir_metadata_dominance,
244 NULL);
245 }
246
247 /** Creates a trivial return shader
248 *
249 * This is a callable shader that doesn't really do anything. It just loads
250 * the resume address from the stack and does a return.
251 */
252 nir_shader *
brw_nir_create_trivial_return_shader(const struct brw_compiler * compiler,void * mem_ctx)253 brw_nir_create_trivial_return_shader(const struct brw_compiler *compiler,
254 void *mem_ctx)
255 {
256 const nir_shader_compiler_options *nir_options =
257 compiler->glsl_compiler_options[MESA_SHADER_CALLABLE].NirOptions;
258
259 nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_CALLABLE,
260 nir_options,
261 "RT Trivial Return");
262 ralloc_steal(mem_ctx, b.shader);
263 nir_shader *nir = b.shader;
264
265 NIR_PASS_V(nir, brw_nir_lower_shader_returns);
266
267 return nir;
268 }
269