1 use crate::api::icd::*;
2 use crate::core::device::*;
3 use crate::core::event::*;
4 use crate::core::memory::*;
5 use crate::core::platform::*;
6 use crate::core::program::*;
7 use crate::core::queue::*;
8 use crate::impl_cl_type_trait;
9
10 use mesa_rust::compiler::clc::*;
11 use mesa_rust::compiler::nir::*;
12 use mesa_rust::nir_pass;
13 use mesa_rust::pipe::context::RWFlags;
14 use mesa_rust::pipe::resource::*;
15 use mesa_rust::pipe::screen::ResourceType;
16 use mesa_rust_gen::*;
17 use mesa_rust_util::math::*;
18 use mesa_rust_util::serialize::*;
19 use rusticl_opencl_gen::*;
20 use spirv::SpirvKernelInfo;
21
22 use std::cmp;
23 use std::collections::HashMap;
24 use std::convert::TryInto;
25 use std::ffi::CStr;
26 use std::fmt::Debug;
27 use std::fmt::Display;
28 use std::ops::Index;
29 use std::ops::Not;
30 use std::os::raw::c_void;
31 use std::ptr;
32 use std::slice;
33 use std::sync::Arc;
34 use std::sync::Mutex;
35 use std::sync::MutexGuard;
36 use std::sync::Weak;
37
38 // According to the CL spec we are not allowed to let any cl_kernel object hold any references on
39 // its arguments as this might make it unfeasible for applications to free the backing memory of
40 // memory objects allocated with `CL_USE_HOST_PTR`.
41 //
42 // However those arguments might temporarily get referenced by event objects, so we'll use Weak in
43 // order to upgrade the reference when needed. It's also safer to use Weak over raw pointers,
44 // because it makes it impossible to run into use-after-free issues.
45 //
46 // Technically we also need to do it for samplers, but there it's kinda pointless to take a weak
47 // reference as samplers don't have the same host_ptr or any similar problems as cl_mem objects.
48 #[derive(Clone)]
49 pub enum KernelArgValue {
50 None,
51 Buffer(Weak<Buffer>),
52 Constant(Vec<u8>),
53 Image(Weak<Image>),
54 LocalMem(usize),
55 Sampler(Arc<Sampler>),
56 }
57
58 #[repr(u8)]
59 #[derive(Hash, PartialEq, Eq, Clone, Copy)]
60 pub enum KernelArgType {
61 Constant(/* size */ u16), // for anything passed by value
62 Image,
63 RWImage,
64 Sampler,
65 Texture,
66 MemGlobal,
67 MemConstant,
68 MemLocal,
69 }
70
71 impl KernelArgType {
deserialize(blob: &mut blob_reader) -> Option<Self>72 fn deserialize(blob: &mut blob_reader) -> Option<Self> {
73 // SAFETY: we get 0 on an overrun, but we verify that later and act accordingly.
74 let res = match unsafe { blob_read_uint8(blob) } {
75 0 => {
76 // SAFETY: same here
77 let size = unsafe { blob_read_uint16(blob) };
78 KernelArgType::Constant(size)
79 }
80 1 => KernelArgType::Image,
81 2 => KernelArgType::RWImage,
82 3 => KernelArgType::Sampler,
83 4 => KernelArgType::Texture,
84 5 => KernelArgType::MemGlobal,
85 6 => KernelArgType::MemConstant,
86 7 => KernelArgType::MemLocal,
87 _ => return None,
88 };
89
90 blob.overrun.not().then_some(res)
91 }
92
serialize(&self, blob: &mut blob)93 fn serialize(&self, blob: &mut blob) {
94 unsafe {
95 match self {
96 KernelArgType::Constant(size) => {
97 blob_write_uint8(blob, 0);
98 blob_write_uint16(blob, *size)
99 }
100 KernelArgType::Image => blob_write_uint8(blob, 1),
101 KernelArgType::RWImage => blob_write_uint8(blob, 2),
102 KernelArgType::Sampler => blob_write_uint8(blob, 3),
103 KernelArgType::Texture => blob_write_uint8(blob, 4),
104 KernelArgType::MemGlobal => blob_write_uint8(blob, 5),
105 KernelArgType::MemConstant => blob_write_uint8(blob, 6),
106 KernelArgType::MemLocal => blob_write_uint8(blob, 7),
107 };
108 }
109 }
110
is_opaque(&self) -> bool111 fn is_opaque(&self) -> bool {
112 matches!(
113 self,
114 KernelArgType::Image
115 | KernelArgType::RWImage
116 | KernelArgType::Texture
117 | KernelArgType::Sampler
118 )
119 }
120 }
121
122 #[derive(Hash, PartialEq, Eq, Clone)]
123 enum CompiledKernelArgType {
124 APIArg(u32),
125 ConstantBuffer,
126 GlobalWorkOffsets,
127 GlobalWorkSize,
128 PrintfBuffer,
129 InlineSampler((cl_addressing_mode, cl_filter_mode, bool)),
130 FormatArray,
131 OrderArray,
132 WorkDim,
133 WorkGroupOffsets,
134 NumWorkgroups,
135 }
136
137 #[derive(Hash, PartialEq, Eq, Clone)]
138 pub struct KernelArg {
139 spirv: spirv::SPIRVKernelArg,
140 pub kind: KernelArgType,
141 pub dead: bool,
142 }
143
144 impl KernelArg {
from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self>145 fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
146 let nir_arg_map: HashMap<_, _> = nir
147 .variables_with_mode(
148 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
149 )
150 .map(|v| (v.data.location, v))
151 .collect();
152 let mut res = Vec::new();
153
154 for (i, s) in spirv.iter().enumerate() {
155 let nir = nir_arg_map.get(&(i as i32)).unwrap();
156 let kind = match s.address_qualifier {
157 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
158 if unsafe { glsl_type_is_sampler(nir.type_) } {
159 KernelArgType::Sampler
160 } else {
161 let size = unsafe { glsl_get_cl_size(nir.type_) } as u16;
162 // nir types of non opaque types are never sized 0
163 KernelArgType::Constant(size)
164 }
165 }
166 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
167 KernelArgType::MemConstant
168 }
169 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
170 KernelArgType::MemLocal
171 }
172 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
173 if unsafe { glsl_type_is_image(nir.type_) } {
174 let access = nir.data.access();
175 if access == gl_access_qualifier::ACCESS_NON_WRITEABLE.0 {
176 KernelArgType::Texture
177 } else if access == gl_access_qualifier::ACCESS_NON_READABLE.0 {
178 KernelArgType::Image
179 } else {
180 KernelArgType::RWImage
181 }
182 } else {
183 KernelArgType::MemGlobal
184 }
185 }
186 };
187
188 res.push(Self {
189 spirv: s.clone(),
190 // we'll update it later in the 2nd pass
191 kind: kind,
192 dead: true,
193 });
194 }
195 res
196 }
197
serialize(args: &[Self], blob: &mut blob)198 fn serialize(args: &[Self], blob: &mut blob) {
199 unsafe {
200 blob_write_uint16(blob, args.len() as u16);
201
202 for arg in args {
203 arg.spirv.serialize(blob);
204 blob_write_uint8(blob, arg.dead.into());
205 arg.kind.serialize(blob);
206 }
207 }
208 }
209
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>210 fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
211 // SAFETY: we check the overrun status, blob_read returns 0 in such a case.
212 let len = unsafe { blob_read_uint16(blob) } as usize;
213 let mut res = Vec::with_capacity(len);
214
215 for _ in 0..len {
216 let spirv = spirv::SPIRVKernelArg::deserialize(blob)?;
217 // SAFETY: we check the overrun status
218 let dead = unsafe { blob_read_uint8(blob) } != 0;
219 let kind = KernelArgType::deserialize(blob)?;
220
221 res.push(Self {
222 spirv: spirv,
223 kind: kind,
224 dead: dead,
225 });
226 }
227
228 blob.overrun.not().then_some(res)
229 }
230 }
231
232 #[derive(Hash, PartialEq, Eq, Clone)]
233 struct CompiledKernelArg {
234 kind: CompiledKernelArgType,
235 /// The binding for image/sampler args, the offset into the input buffer
236 /// for anything else.
237 offset: u32,
238 dead: bool,
239 }
240
241 impl CompiledKernelArg {
assign_locations(compiled_args: &mut [Self], nir: &mut NirShader)242 fn assign_locations(compiled_args: &mut [Self], nir: &mut NirShader) {
243 for var in nir.variables_with_mode(
244 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
245 ) {
246 let arg = &mut compiled_args[var.data.location as usize];
247 let t = var.type_;
248
249 arg.dead = false;
250 arg.offset = if unsafe {
251 glsl_type_is_image(t) || glsl_type_is_texture(t) || glsl_type_is_sampler(t)
252 } {
253 var.data.binding
254 } else {
255 var.data.driver_location
256 };
257 }
258 }
259
serialize(args: &[Self], blob: &mut blob)260 fn serialize(args: &[Self], blob: &mut blob) {
261 unsafe {
262 blob_write_uint16(blob, args.len() as u16);
263 for arg in args {
264 blob_write_uint32(blob, arg.offset);
265 blob_write_uint8(blob, arg.dead.into());
266 match arg.kind {
267 CompiledKernelArgType::ConstantBuffer => blob_write_uint8(blob, 0),
268 CompiledKernelArgType::GlobalWorkOffsets => blob_write_uint8(blob, 1),
269 CompiledKernelArgType::PrintfBuffer => blob_write_uint8(blob, 2),
270 CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm)) => {
271 blob_write_uint8(blob, 3);
272 blob_write_uint8(blob, norm.into());
273 blob_write_uint32(blob, addr_mode);
274 blob_write_uint32(blob, filter_mode)
275 }
276 CompiledKernelArgType::FormatArray => blob_write_uint8(blob, 4),
277 CompiledKernelArgType::OrderArray => blob_write_uint8(blob, 5),
278 CompiledKernelArgType::WorkDim => blob_write_uint8(blob, 6),
279 CompiledKernelArgType::WorkGroupOffsets => blob_write_uint8(blob, 7),
280 CompiledKernelArgType::NumWorkgroups => blob_write_uint8(blob, 8),
281 CompiledKernelArgType::GlobalWorkSize => blob_write_uint8(blob, 9),
282 CompiledKernelArgType::APIArg(idx) => {
283 blob_write_uint8(blob, 10);
284 blob_write_uint32(blob, idx)
285 }
286 };
287 }
288 }
289 }
290
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>291 fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
292 unsafe {
293 let len = blob_read_uint16(blob) as usize;
294 let mut res = Vec::with_capacity(len);
295
296 for _ in 0..len {
297 let offset = blob_read_uint32(blob);
298 let dead = blob_read_uint8(blob) != 0;
299
300 let kind = match blob_read_uint8(blob) {
301 0 => CompiledKernelArgType::ConstantBuffer,
302 1 => CompiledKernelArgType::GlobalWorkOffsets,
303 2 => CompiledKernelArgType::PrintfBuffer,
304 3 => {
305 let norm = blob_read_uint8(blob) != 0;
306 let addr_mode = blob_read_uint32(blob);
307 let filter_mode = blob_read_uint32(blob);
308 CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm))
309 }
310 4 => CompiledKernelArgType::FormatArray,
311 5 => CompiledKernelArgType::OrderArray,
312 6 => CompiledKernelArgType::WorkDim,
313 7 => CompiledKernelArgType::WorkGroupOffsets,
314 8 => CompiledKernelArgType::NumWorkgroups,
315 9 => CompiledKernelArgType::GlobalWorkSize,
316 10 => {
317 let idx = blob_read_uint32(blob);
318 CompiledKernelArgType::APIArg(idx)
319 }
320 _ => return None,
321 };
322
323 res.push(Self {
324 kind: kind,
325 offset: offset,
326 dead: dead,
327 });
328 }
329
330 Some(res)
331 }
332 }
333 }
334
335 #[derive(Clone, PartialEq, Eq, Hash)]
336 pub struct KernelInfo {
337 pub args: Vec<KernelArg>,
338 pub attributes_string: String,
339 work_group_size: [usize; 3],
340 work_group_size_hint: [u32; 3],
341 subgroup_size: usize,
342 num_subgroups: usize,
343 }
344
345 struct CSOWrapper {
346 cso_ptr: *mut c_void,
347 dev: &'static Device,
348 }
349
350 impl CSOWrapper {
new(dev: &'static Device, nir: &NirShader) -> Self351 fn new(dev: &'static Device, nir: &NirShader) -> Self {
352 let cso_ptr = dev
353 .helper_ctx()
354 .create_compute_state(nir, nir.shared_size());
355
356 Self {
357 cso_ptr: cso_ptr,
358 dev: dev,
359 }
360 }
361
get_cso_info(&self) -> pipe_compute_state_object_info362 fn get_cso_info(&self) -> pipe_compute_state_object_info {
363 self.dev.helper_ctx().compute_state_info(self.cso_ptr)
364 }
365 }
366
367 impl Drop for CSOWrapper {
drop(&mut self)368 fn drop(&mut self) {
369 self.dev.helper_ctx().delete_compute_state(self.cso_ptr);
370 }
371 }
372
373 enum KernelDevStateVariant {
374 Cso(CSOWrapper),
375 Nir(NirShader),
376 }
377
378 #[derive(Debug, PartialEq)]
379 enum NirKernelVariant {
380 /// Can be used under any circumstance.
381 Default,
382
383 /// Optimized variant making the following assumptions:
384 /// - global_id_offsets are 0
385 /// - workgroup_offsets are 0
386 /// - local_size is info.local_size_hint
387 Optimized,
388 }
389
390 impl Display for NirKernelVariant {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 // this simply prints the enum name, so that's fine
393 Debug::fmt(self, f)
394 }
395 }
396
397 pub struct NirKernelBuilds {
398 default_build: NirKernelBuild,
399 optimized: Option<NirKernelBuild>,
400 /// merged info with worst case values
401 info: pipe_compute_state_object_info,
402 }
403
404 impl Index<NirKernelVariant> for NirKernelBuilds {
405 type Output = NirKernelBuild;
406
index(&self, index: NirKernelVariant) -> &Self::Output407 fn index(&self, index: NirKernelVariant) -> &Self::Output {
408 match index {
409 NirKernelVariant::Default => &self.default_build,
410 NirKernelVariant::Optimized => self.optimized.as_ref().unwrap_or(&self.default_build),
411 }
412 }
413 }
414
415 impl NirKernelBuilds {
new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self416 fn new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self {
417 let mut info = default_build.info;
418 if let Some(build) = &optimized {
419 info.max_threads = cmp::min(info.max_threads, build.info.max_threads);
420 info.simd_sizes &= build.info.simd_sizes;
421 info.private_memory = cmp::max(info.private_memory, build.info.private_memory);
422 info.preferred_simd_size =
423 cmp::max(info.preferred_simd_size, build.info.preferred_simd_size);
424 }
425
426 Self {
427 default_build: default_build,
428 optimized: optimized,
429 info: info,
430 }
431 }
432 }
433
434 struct NirKernelBuild {
435 nir_or_cso: KernelDevStateVariant,
436 constant_buffer: Option<Arc<PipeResource>>,
437 info: pipe_compute_state_object_info,
438 shared_size: u64,
439 printf_info: Option<NirPrintfInfo>,
440 compiled_args: Vec<CompiledKernelArg>,
441 }
442
443 // SAFETY: `CSOWrapper` is only safe to use if the device supports `pipe_caps.shareable_shaders` and
444 // we make sure to set `nir_or_cso` to `KernelDevStateVariant::Cso` only if that's the case.
445 unsafe impl Send for NirKernelBuild {}
446 unsafe impl Sync for NirKernelBuild {}
447
448 impl NirKernelBuild {
new(dev: &'static Device, mut out: CompilationResult) -> Self449 fn new(dev: &'static Device, mut out: CompilationResult) -> Self {
450 let cso = CSOWrapper::new(dev, &out.nir);
451 let info = cso.get_cso_info();
452 let cb = Self::create_nir_constant_buffer(dev, &out.nir);
453 let shared_size = out.nir.shared_size() as u64;
454 let printf_info = out.nir.take_printf_info();
455
456 let nir_or_cso = if !dev.shareable_shaders() {
457 KernelDevStateVariant::Nir(out.nir)
458 } else {
459 KernelDevStateVariant::Cso(cso)
460 };
461
462 NirKernelBuild {
463 nir_or_cso: nir_or_cso,
464 constant_buffer: cb,
465 info: info,
466 shared_size: shared_size,
467 printf_info: printf_info,
468 compiled_args: out.compiled_args,
469 }
470 }
471
create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>>472 fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>> {
473 let buf = nir.get_constant_buffer();
474 let len = buf.len() as u32;
475
476 if len > 0 {
477 // TODO bind as constant buffer
478 let res = dev
479 .screen()
480 .resource_create_buffer(len, ResourceType::Normal, PIPE_BIND_GLOBAL)
481 .unwrap();
482
483 dev.helper_ctx()
484 .exec(|ctx| ctx.buffer_subdata(&res, 0, buf.as_ptr().cast(), len))
485 .wait();
486
487 Some(Arc::new(res))
488 } else {
489 None
490 }
491 }
492 }
493
494 pub struct Kernel {
495 pub base: CLObjectBase<CL_INVALID_KERNEL>,
496 pub prog: Arc<Program>,
497 pub name: String,
498 values: Mutex<Vec<Option<KernelArgValue>>>,
499 builds: HashMap<&'static Device, Arc<NirKernelBuilds>>,
500 pub kernel_info: Arc<KernelInfo>,
501 }
502
503 impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
504
create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]> where T: std::convert::TryFrom<usize> + Copy, <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,505 fn create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]>
506 where
507 T: std::convert::TryFrom<usize> + Copy,
508 <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,
509 {
510 let mut res = [val; 3];
511 for (i, v) in vals.iter().enumerate() {
512 res[i] = (*v).try_into().ok().ok_or(CL_OUT_OF_RESOURCES)?;
513 }
514
515 Ok(res)
516 }
517
518 #[derive(Clone)]
519 struct CompilationResult {
520 nir: NirShader,
521 compiled_args: Vec<CompiledKernelArg>,
522 }
523
524 impl CompilationResult {
deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self>525 fn deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self> {
526 let nir = NirShader::deserialize(
527 reader,
528 d.screen()
529 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
530 )?;
531 let compiled_args = CompiledKernelArg::deserialize(reader)?;
532
533 Some(Self {
534 nir: nir,
535 compiled_args,
536 })
537 }
538
serialize(&self, blob: &mut blob)539 fn serialize(&self, blob: &mut blob) {
540 self.nir.serialize(blob);
541 CompiledKernelArg::serialize(&self.compiled_args, blob);
542 }
543 }
544
opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool)545 fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
546 let nir_options = unsafe {
547 &*dev
548 .screen
549 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
550 };
551
552 while {
553 let mut progress = false;
554
555 progress |= nir_pass!(nir, nir_copy_prop);
556 progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
557 progress |= nir_pass!(nir, nir_opt_dead_write_vars);
558
559 if nir_options.lower_to_scalar {
560 nir_pass!(
561 nir,
562 nir_lower_alu_to_scalar,
563 nir_options.lower_to_scalar_filter,
564 ptr::null(),
565 );
566 nir_pass!(nir, nir_lower_phis_to_scalar, false);
567 }
568
569 progress |= nir_pass!(nir, nir_opt_deref);
570 if has_explicit_types {
571 progress |= nir_pass!(nir, nir_opt_memcpy);
572 }
573 progress |= nir_pass!(nir, nir_opt_dce);
574 progress |= nir_pass!(nir, nir_opt_undef);
575 progress |= nir_pass!(nir, nir_opt_constant_folding);
576 progress |= nir_pass!(nir, nir_opt_cse);
577 nir_pass!(nir, nir_split_var_copies);
578 progress |= nir_pass!(nir, nir_lower_var_copies);
579 progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
580 nir_pass!(nir, nir_lower_alu);
581 progress |= nir_pass!(nir, nir_opt_phi_precision);
582 progress |= nir_pass!(nir, nir_opt_algebraic);
583 progress |= nir_pass!(
584 nir,
585 nir_opt_if,
586 nir_opt_if_options::nir_opt_if_optimize_phi_true_false,
587 );
588 progress |= nir_pass!(nir, nir_opt_dead_cf);
589 progress |= nir_pass!(nir, nir_opt_remove_phis);
590 // we don't want to be too aggressive here, but it kills a bit of CFG
591 progress |= nir_pass!(nir, nir_opt_peephole_select, 8, true, true);
592 progress |= nir_pass!(
593 nir,
594 nir_lower_vec3_to_vec4,
595 nir_variable_mode::nir_var_mem_generic | nir_variable_mode::nir_var_uniform,
596 );
597
598 if nir_options.max_unroll_iterations != 0 {
599 progress |= nir_pass!(nir, nir_opt_loop_unroll);
600 }
601 nir.sweep_mem();
602 progress
603 } {}
604 }
605
606 /// # Safety
607 ///
608 /// Only safe to call when `var` is a valid pointer to a valid [`nir_variable`]
can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool609 unsafe extern "C" fn can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool {
610 // SAFETY: It is the caller's responsibility to provide a valid and aligned pointer
611 let var_type = unsafe { (*var).type_ };
612 // SAFETY: `nir_variable`'s type invariant guarantees that the `type_` field is valid and
613 // properly aligned.
614 unsafe {
615 !glsl_type_is_image(var_type)
616 && !glsl_type_is_texture(var_type)
617 && !glsl_type_is_sampler(var_type)
618 }
619 }
620
621 const DV_OPTS: nir_remove_dead_variables_options = nir_remove_dead_variables_options {
622 can_remove_var: Some(can_remove_var),
623 can_remove_var_data: ptr::null_mut(),
624 };
625
compile_nir_to_args( dev: &Device, mut nir: NirShader, args: &[spirv::SPIRVKernelArg], lib_clc: &NirShader, ) -> (Vec<KernelArg>, NirShader)626 fn compile_nir_to_args(
627 dev: &Device,
628 mut nir: NirShader,
629 args: &[spirv::SPIRVKernelArg],
630 lib_clc: &NirShader,
631 ) -> (Vec<KernelArg>, NirShader) {
632 // this is a hack until we support fp16 properly and check for denorms inside vstore/vload_half
633 nir.preserve_fp16_denorms();
634
635 // Set to rtne for now until drivers are able to report their preferred rounding mode, that also
636 // matches what we report via the API.
637 nir.set_fp_rounding_mode_rtne();
638
639 nir_pass!(nir, nir_scale_fdiv);
640 nir.set_workgroup_size_variable_if_zero();
641 nir.structurize();
642 nir_pass!(
643 nir,
644 nir_lower_variable_initializers,
645 nir_variable_mode::nir_var_function_temp
646 );
647
648 while {
649 let mut progress = false;
650 nir_pass!(nir, nir_split_var_copies);
651 progress |= nir_pass!(nir, nir_copy_prop);
652 progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
653 progress |= nir_pass!(nir, nir_opt_dead_write_vars);
654 progress |= nir_pass!(nir, nir_opt_deref);
655 progress |= nir_pass!(nir, nir_opt_dce);
656 progress |= nir_pass!(nir, nir_opt_undef);
657 progress |= nir_pass!(nir, nir_opt_constant_folding);
658 progress |= nir_pass!(nir, nir_opt_cse);
659 progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
660 progress |= nir_pass!(nir, nir_opt_algebraic);
661 progress
662 } {}
663 nir.inline(lib_clc);
664 nir.cleanup_functions();
665 // that should free up tons of memory
666 nir.sweep_mem();
667
668 nir_pass!(nir, nir_dedup_inline_samplers);
669
670 let printf_opts = nir_lower_printf_options {
671 ptr_bit_size: 0,
672 use_printf_base_identifier: false,
673 hash_format_strings: false,
674 max_buffer_size: dev.printf_buffer_size() as u32,
675 };
676 nir_pass!(nir, nir_lower_printf, &printf_opts);
677
678 opt_nir(&mut nir, dev, false);
679
680 (KernelArg::from_spirv_nir(args, &mut nir), nir)
681 }
682
compile_nir_prepare_for_variants( dev: &Device, nir: &mut NirShader, compiled_args: &mut Vec<CompiledKernelArg>, )683 fn compile_nir_prepare_for_variants(
684 dev: &Device,
685 nir: &mut NirShader,
686 compiled_args: &mut Vec<CompiledKernelArg>,
687 ) {
688 // assign locations for inline samplers.
689 // IMPORTANT: this needs to happen before nir_remove_dead_variables.
690 let mut last_loc = -1;
691 for v in nir
692 .variables_with_mode(nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image)
693 {
694 if unsafe { !glsl_type_is_sampler(v.type_) } {
695 last_loc = v.data.location;
696 continue;
697 }
698 let s = unsafe { v.data.anon_1.sampler };
699 if s.is_inline_sampler() != 0 {
700 last_loc += 1;
701 v.data.location = last_loc;
702
703 compiled_args.push(CompiledKernelArg {
704 kind: CompiledKernelArgType::InlineSampler(Sampler::nir_to_cl(
705 s.addressing_mode(),
706 s.filter_mode(),
707 s.normalized_coordinates(),
708 )),
709 offset: 0,
710 dead: true,
711 });
712 } else {
713 last_loc = v.data.location;
714 }
715 }
716
717 nir_pass!(
718 nir,
719 nir_remove_dead_variables,
720 nir_variable_mode::nir_var_uniform
721 | nir_variable_mode::nir_var_image
722 | nir_variable_mode::nir_var_mem_constant
723 | nir_variable_mode::nir_var_mem_shared
724 | nir_variable_mode::nir_var_function_temp,
725 &DV_OPTS,
726 );
727
728 nir_pass!(nir, nir_lower_readonly_images_to_tex, true);
729 nir_pass!(
730 nir,
731 nir_lower_cl_images,
732 !dev.images_as_deref(),
733 !dev.samplers_as_deref(),
734 );
735
736 nir_pass!(
737 nir,
738 nir_lower_vars_to_explicit_types,
739 nir_variable_mode::nir_var_mem_constant,
740 Some(glsl_get_cl_type_size_align),
741 );
742
743 // has to run before adding internal kernel arguments
744 nir.extract_constant_initializers();
745
746 // needed to convert variables to load intrinsics
747 nir_pass!(nir, nir_lower_system_values);
748
749 // Run here so we can decide if it makes sense to compile a variant, e.g. read system values.
750 nir.gather_info();
751 }
752
compile_nir_variant( res: &mut CompilationResult, dev: &Device, variant: NirKernelVariant, args: &[KernelArg], name: &str, )753 fn compile_nir_variant(
754 res: &mut CompilationResult,
755 dev: &Device,
756 variant: NirKernelVariant,
757 args: &[KernelArg],
758 name: &str,
759 ) {
760 let mut lower_state = rusticl_lower_state::default();
761 let compiled_args = &mut res.compiled_args;
762 let nir = &mut res.nir;
763
764 let address_bits_ptr_type;
765 let address_bits_base_type;
766 let global_address_format;
767 let shared_address_format;
768
769 if dev.address_bits() == 64 {
770 address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
771 address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
772 global_address_format = nir_address_format::nir_address_format_64bit_global;
773 shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
774 } else {
775 address_bits_ptr_type = unsafe { glsl_uint_type() };
776 address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
777 global_address_format = nir_address_format::nir_address_format_32bit_global;
778 shared_address_format = nir_address_format::nir_address_format_32bit_offset;
779 }
780
781 let nir_options = unsafe {
782 &*dev
783 .screen
784 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
785 };
786
787 if variant == NirKernelVariant::Optimized {
788 let wgsh = nir.workgroup_size_hint();
789 if wgsh != [0; 3] {
790 nir.set_workgroup_size(wgsh);
791 }
792 }
793
794 let mut compute_options = nir_lower_compute_system_values_options::default();
795 compute_options.set_has_global_size(true);
796 if variant != NirKernelVariant::Optimized {
797 compute_options.set_has_base_global_invocation_id(true);
798 compute_options.set_has_base_workgroup_id(true);
799 }
800 nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
801 nir.gather_info();
802
803 let mut add_var = |nir: &mut NirShader,
804 var_loc: &mut usize,
805 kind: CompiledKernelArgType,
806 glsl_type: *const glsl_type,
807 name| {
808 *var_loc = compiled_args.len();
809 compiled_args.push(CompiledKernelArg {
810 kind: kind,
811 offset: 0,
812 dead: true,
813 });
814 nir.add_var(
815 nir_variable_mode::nir_var_uniform,
816 glsl_type,
817 *var_loc,
818 name,
819 );
820 };
821
822 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
823 debug_assert_ne!(variant, NirKernelVariant::Optimized);
824 add_var(
825 nir,
826 &mut lower_state.base_global_invoc_id_loc,
827 CompiledKernelArgType::GlobalWorkOffsets,
828 unsafe { glsl_vector_type(address_bits_base_type, 3) },
829 c"base_global_invocation_id",
830 )
831 }
832
833 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_GROUP_SIZE) {
834 add_var(
835 nir,
836 &mut lower_state.global_size_loc,
837 CompiledKernelArgType::GlobalWorkSize,
838 unsafe { glsl_vector_type(address_bits_base_type, 3) },
839 c"global_size",
840 )
841 }
842
843 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_WORKGROUP_ID) {
844 debug_assert_ne!(variant, NirKernelVariant::Optimized);
845 add_var(
846 nir,
847 &mut lower_state.base_workgroup_id_loc,
848 CompiledKernelArgType::WorkGroupOffsets,
849 unsafe { glsl_vector_type(address_bits_base_type, 3) },
850 c"base_workgroup_id",
851 );
852 }
853
854 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_NUM_WORKGROUPS) {
855 add_var(
856 nir,
857 &mut lower_state.num_workgroups_loc,
858 CompiledKernelArgType::NumWorkgroups,
859 unsafe { glsl_vector_type(glsl_base_type::GLSL_TYPE_UINT, 3) },
860 c"num_workgroups",
861 );
862 }
863
864 if nir.has_constant() {
865 add_var(
866 nir,
867 &mut lower_state.const_buf_loc,
868 CompiledKernelArgType::ConstantBuffer,
869 address_bits_ptr_type,
870 c"constant_buffer_addr",
871 );
872 }
873 if nir.has_printf() {
874 add_var(
875 nir,
876 &mut lower_state.printf_buf_loc,
877 CompiledKernelArgType::PrintfBuffer,
878 address_bits_ptr_type,
879 c"printf_buffer_addr",
880 );
881 }
882
883 if nir.num_images() > 0 || nir.num_textures() > 0 {
884 let count = nir.num_images() + nir.num_textures();
885
886 add_var(
887 nir,
888 &mut lower_state.format_arr_loc,
889 CompiledKernelArgType::FormatArray,
890 unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
891 c"image_formats",
892 );
893
894 add_var(
895 nir,
896 &mut lower_state.order_arr_loc,
897 CompiledKernelArgType::OrderArray,
898 unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
899 c"image_orders",
900 );
901 }
902
903 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
904 add_var(
905 nir,
906 &mut lower_state.work_dim_loc,
907 CompiledKernelArgType::WorkDim,
908 unsafe { glsl_uint8_t_type() },
909 c"work_dim",
910 );
911 }
912
913 // need to run after first opt loop and remove_dead_variables to get rid of uneccessary scratch
914 // memory
915 nir_pass!(
916 nir,
917 nir_lower_vars_to_explicit_types,
918 nir_variable_mode::nir_var_mem_shared
919 | nir_variable_mode::nir_var_function_temp
920 | nir_variable_mode::nir_var_shader_temp
921 | nir_variable_mode::nir_var_uniform
922 | nir_variable_mode::nir_var_mem_global
923 | nir_variable_mode::nir_var_mem_generic,
924 Some(glsl_get_cl_type_size_align),
925 );
926
927 opt_nir(nir, dev, true);
928 nir_pass!(nir, nir_lower_memcpy);
929
930 // we might have got rid of more function_temp or shared memory
931 nir.reset_scratch_size();
932 nir.reset_shared_size();
933 nir_pass!(
934 nir,
935 nir_remove_dead_variables,
936 nir_variable_mode::nir_var_function_temp | nir_variable_mode::nir_var_mem_shared,
937 &DV_OPTS,
938 );
939 nir_pass!(
940 nir,
941 nir_lower_vars_to_explicit_types,
942 nir_variable_mode::nir_var_function_temp
943 | nir_variable_mode::nir_var_mem_shared
944 | nir_variable_mode::nir_var_mem_generic,
945 Some(glsl_get_cl_type_size_align),
946 );
947
948 nir_pass!(
949 nir,
950 nir_lower_explicit_io,
951 nir_variable_mode::nir_var_mem_global | nir_variable_mode::nir_var_mem_constant,
952 global_address_format,
953 );
954
955 nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
956 nir_pass!(
957 nir,
958 nir_lower_explicit_io,
959 nir_variable_mode::nir_var_mem_shared
960 | nir_variable_mode::nir_var_function_temp
961 | nir_variable_mode::nir_var_uniform,
962 shared_address_format,
963 );
964
965 if nir_options.lower_int64_options.0 != 0 && !nir_options.late_lower_int64 {
966 nir_pass!(nir, nir_lower_int64);
967 }
968
969 if nir_options.lower_uniforms_to_ubo {
970 nir_pass!(nir, rusticl_lower_inputs);
971 }
972
973 nir_pass!(nir, nir_lower_convert_alu_types, None);
974
975 opt_nir(nir, dev, true);
976
977 /* before passing it into drivers, assign locations as drivers might remove nir_variables or
978 * other things we depend on
979 */
980 CompiledKernelArg::assign_locations(compiled_args, nir);
981
982 /* update the has_variable_shared_mem info as we might have DCEed all of them */
983 nir.set_has_variable_shared_mem(compiled_args.iter().any(|arg| {
984 if let CompiledKernelArgType::APIArg(idx) = arg.kind {
985 args[idx as usize].kind == KernelArgType::MemLocal && !arg.dead
986 } else {
987 false
988 }
989 }));
990
991 if Platform::dbg().nir {
992 eprintln!("=== Printing nir variant '{variant}' for '{name}' before driver finalization");
993 nir.print();
994 }
995
996 if dev.screen.finalize_nir(nir) {
997 if Platform::dbg().nir {
998 eprintln!(
999 "=== Printing nir variant '{variant}' for '{name}' after driver finalization"
1000 );
1001 nir.print();
1002 }
1003 }
1004
1005 nir_pass!(nir, nir_opt_dce);
1006 nir.sweep_mem();
1007 }
1008
compile_nir_remaining( dev: &Device, mut nir: NirShader, args: &[KernelArg], name: &str, ) -> (CompilationResult, Option<CompilationResult>)1009 fn compile_nir_remaining(
1010 dev: &Device,
1011 mut nir: NirShader,
1012 args: &[KernelArg],
1013 name: &str,
1014 ) -> (CompilationResult, Option<CompilationResult>) {
1015 // add all API kernel args
1016 let mut compiled_args: Vec<_> = (0..args.len())
1017 .map(|idx| CompiledKernelArg {
1018 kind: CompiledKernelArgType::APIArg(idx as u32),
1019 offset: 0,
1020 dead: true,
1021 })
1022 .collect();
1023
1024 compile_nir_prepare_for_variants(dev, &mut nir, &mut compiled_args);
1025 if Platform::dbg().nir {
1026 eprintln!("=== Printing nir for '{name}' before specialization");
1027 nir.print();
1028 }
1029
1030 let mut default_build = CompilationResult {
1031 nir: nir,
1032 compiled_args: compiled_args,
1033 };
1034
1035 // check if we even want to compile a variant before cloning the compilation state
1036 let has_wgs_hint = default_build.nir.workgroup_size_variable()
1037 && default_build.nir.workgroup_size_hint() != [0; 3];
1038 let has_offsets = default_build
1039 .nir
1040 .reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_INVOCATION_ID);
1041
1042 let mut optimized = (!Platform::dbg().no_variants && (has_offsets || has_wgs_hint))
1043 .then(|| default_build.clone());
1044
1045 compile_nir_variant(
1046 &mut default_build,
1047 dev,
1048 NirKernelVariant::Default,
1049 args,
1050 name,
1051 );
1052 if let Some(optimized) = &mut optimized {
1053 compile_nir_variant(optimized, dev, NirKernelVariant::Optimized, args, name);
1054 }
1055
1056 (default_build, optimized)
1057 }
1058
1059 pub struct SPIRVToNirResult {
1060 pub kernel_info: KernelInfo,
1061 pub nir_kernel_builds: NirKernelBuilds,
1062 }
1063
1064 impl SPIRVToNirResult {
new( dev: &'static Device, kernel_info: &clc_kernel_info, args: Vec<KernelArg>, default_build: CompilationResult, optimized: Option<CompilationResult>, ) -> Self1065 fn new(
1066 dev: &'static Device,
1067 kernel_info: &clc_kernel_info,
1068 args: Vec<KernelArg>,
1069 default_build: CompilationResult,
1070 optimized: Option<CompilationResult>,
1071 ) -> Self {
1072 // TODO: we _should_ be able to parse them out of the SPIR-V, but clc doesn't handle
1073 // indirections yet.
1074 let nir = &default_build.nir;
1075 let wgs = nir.workgroup_size();
1076 let subgroup_size = nir.subgroup_size();
1077 let num_subgroups = nir.num_subgroups();
1078
1079 let default_build = NirKernelBuild::new(dev, default_build);
1080 let optimized = optimized.map(|opt| NirKernelBuild::new(dev, opt));
1081
1082 let kernel_info = KernelInfo {
1083 args: args,
1084 attributes_string: kernel_info.attribute_str(),
1085 work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
1086 work_group_size_hint: kernel_info.local_size_hint,
1087 subgroup_size: subgroup_size as usize,
1088 num_subgroups: num_subgroups as usize,
1089 };
1090
1091 Self {
1092 kernel_info: kernel_info,
1093 nir_kernel_builds: NirKernelBuilds::new(default_build, optimized),
1094 }
1095 }
1096
deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self>1097 fn deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self> {
1098 let mut reader = blob_reader::default();
1099 unsafe {
1100 blob_reader_init(&mut reader, bin.as_ptr().cast(), bin.len());
1101 }
1102
1103 let args = KernelArg::deserialize(&mut reader)?;
1104 let default_build = CompilationResult::deserialize(&mut reader, d)?;
1105
1106 // SAFETY: on overrun this returns 0
1107 let optimized = match unsafe { blob_read_uint8(&mut reader) } {
1108 0 => None,
1109 _ => Some(CompilationResult::deserialize(&mut reader, d)?),
1110 };
1111
1112 reader
1113 .overrun
1114 .not()
1115 .then(|| SPIRVToNirResult::new(d, kernel_info, args, default_build, optimized))
1116 }
1117
1118 // we can't use Self here as the nir shader might be compiled to a cso already and we can't
1119 // cache that.
serialize( blob: &mut blob, args: &[KernelArg], default_build: &CompilationResult, optimized: &Option<CompilationResult>, )1120 fn serialize(
1121 blob: &mut blob,
1122 args: &[KernelArg],
1123 default_build: &CompilationResult,
1124 optimized: &Option<CompilationResult>,
1125 ) {
1126 KernelArg::serialize(args, blob);
1127 default_build.serialize(blob);
1128 match optimized {
1129 Some(variant) => {
1130 unsafe { blob_write_uint8(blob, 1) };
1131 variant.serialize(blob);
1132 }
1133 None => unsafe {
1134 blob_write_uint8(blob, 0);
1135 },
1136 }
1137 }
1138 }
1139
convert_spirv_to_nir( build: &ProgramBuild, name: &str, args: &[spirv::SPIRVKernelArg], dev: &'static Device, ) -> SPIRVToNirResult1140 pub(super) fn convert_spirv_to_nir(
1141 build: &ProgramBuild,
1142 name: &str,
1143 args: &[spirv::SPIRVKernelArg],
1144 dev: &'static Device,
1145 ) -> SPIRVToNirResult {
1146 let cache = dev.screen().shader_cache();
1147 let key = build.hash_key(dev, name);
1148 let spirv_info = build.spirv_info(name, dev).unwrap();
1149
1150 cache
1151 .as_ref()
1152 .and_then(|cache| cache.get(&mut key?))
1153 .and_then(|entry| SPIRVToNirResult::deserialize(&entry, dev, spirv_info))
1154 .unwrap_or_else(|| {
1155 let nir = build.to_nir(name, dev);
1156
1157 if Platform::dbg().nir {
1158 eprintln!("=== Printing nir for '{name}' after spirv_to_nir");
1159 nir.print();
1160 }
1161
1162 let (mut args, nir) = compile_nir_to_args(dev, nir, args, &dev.lib_clc);
1163 let (default_build, optimized) = compile_nir_remaining(dev, nir, &args, name);
1164
1165 for build in [Some(&default_build), optimized.as_ref()].into_iter() {
1166 let Some(build) = build else {
1167 continue;
1168 };
1169
1170 for arg in &build.compiled_args {
1171 if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1172 args[idx as usize].dead &= arg.dead;
1173 }
1174 }
1175 }
1176
1177 if let Some(cache) = cache {
1178 let mut blob = blob::default();
1179 unsafe {
1180 blob_init(&mut blob);
1181 SPIRVToNirResult::serialize(&mut blob, &args, &default_build, &optimized);
1182 let bin = slice::from_raw_parts(blob.data, blob.size);
1183 cache.put(bin, &mut key.unwrap());
1184 blob_finish(&mut blob);
1185 }
1186 }
1187
1188 SPIRVToNirResult::new(dev, spirv_info, args, default_build, optimized)
1189 })
1190 }
1191
extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S]1192 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
1193 let val;
1194 (val, *buf) = (*buf).split_at(S);
1195 // we split of 4 bytes and convert to [u8; 4], so this should be safe
1196 // use split_array_ref once it's stable
1197 val.try_into().unwrap()
1198 }
1199
1200 impl Kernel {
new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel>1201 pub fn new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel> {
1202 let kernel_info = Arc::clone(prog_build.kernel_info.get(&name).unwrap());
1203 let builds = prog_build
1204 .builds
1205 .iter()
1206 .filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, Arc::clone(k))))
1207 .collect();
1208
1209 let values = vec![None; kernel_info.args.len()];
1210 Arc::new(Self {
1211 base: CLObjectBase::new(RusticlTypes::Kernel),
1212 prog: prog,
1213 name: name,
1214 values: Mutex::new(values),
1215 builds: builds,
1216 kernel_info: kernel_info,
1217 })
1218 }
1219
suggest_local_size( &self, d: &Device, work_dim: usize, grid: &mut [usize], block: &mut [usize], )1220 pub fn suggest_local_size(
1221 &self,
1222 d: &Device,
1223 work_dim: usize,
1224 grid: &mut [usize],
1225 block: &mut [usize],
1226 ) {
1227 let mut threads = self.max_threads_per_block(d);
1228 let dim_threads = d.max_block_sizes();
1229 let subgroups = self.preferred_simd_size(d);
1230
1231 for i in 0..work_dim {
1232 let t = cmp::min(threads, dim_threads[i]);
1233 let gcd = gcd(t, grid[i]);
1234
1235 block[i] = gcd;
1236 grid[i] /= gcd;
1237
1238 // update limits
1239 threads /= block[i];
1240 }
1241
1242 // if we didn't fill the subgroup we can do a bit better if we have threads remaining
1243 let total_threads = block.iter().take(work_dim).product::<usize>();
1244 if threads != 1 && total_threads < subgroups {
1245 for i in 0..work_dim {
1246 if grid[i] * total_threads < threads && grid[i] * block[i] <= dim_threads[i] {
1247 block[i] *= grid[i];
1248 grid[i] = 1;
1249 // can only do it once as nothing is cleanly divisible
1250 break;
1251 }
1252 }
1253 }
1254 }
1255
optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3])1256 fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) {
1257 if !block.contains(&0) {
1258 for i in 0..3 {
1259 // we already made sure everything is fine
1260 grid[i] /= block[i] as usize;
1261 }
1262 return;
1263 }
1264
1265 let mut usize_block = [0usize; 3];
1266 for i in 0..3 {
1267 usize_block[i] = block[i] as usize;
1268 }
1269
1270 self.suggest_local_size(d, 3, grid, &mut usize_block);
1271
1272 for i in 0..3 {
1273 block[i] = usize_block[i] as u32;
1274 }
1275 }
1276
1277 // the painful part is, that host threads are allowed to modify the kernel object once it was
1278 // enqueued, so return a closure with all req data included.
launch( self: &Arc<Self>, q: &Arc<Queue>, work_dim: u32, block: &[usize], grid: &[usize], offsets: &[usize], ) -> CLResult<EventSig>1279 pub fn launch(
1280 self: &Arc<Self>,
1281 q: &Arc<Queue>,
1282 work_dim: u32,
1283 block: &[usize],
1284 grid: &[usize],
1285 offsets: &[usize],
1286 ) -> CLResult<EventSig> {
1287 // Clone all the data we need to execute this kernel
1288 let kernel_info = Arc::clone(&self.kernel_info);
1289 let arg_values = self.arg_values().clone();
1290 let nir_kernel_builds = Arc::clone(&self.builds[q.device]);
1291
1292 let mut buffer_arcs = HashMap::new();
1293 let mut image_arcs = HashMap::new();
1294
1295 // need to preprocess buffer and image arguments so we hold a strong reference until the
1296 // event was processed.
1297 for arg in arg_values.iter() {
1298 match arg {
1299 Some(KernelArgValue::Buffer(buffer)) => {
1300 buffer_arcs.insert(
1301 // we use the ptr as the key, and also cast it to usize so we don't need to
1302 // deal with Send + Sync here.
1303 buffer.as_ptr() as usize,
1304 buffer.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
1305 );
1306 }
1307 Some(KernelArgValue::Image(image)) => {
1308 image_arcs.insert(
1309 image.as_ptr() as usize,
1310 image.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
1311 );
1312 }
1313 _ => {}
1314 }
1315 }
1316
1317 // operations we want to report errors to the clients
1318 let mut block = create_kernel_arr::<u32>(block, 1)?;
1319 let mut grid = create_kernel_arr::<usize>(grid, 1)?;
1320 let offsets = create_kernel_arr::<usize>(offsets, 0)?;
1321
1322 let api_grid = grid;
1323
1324 self.optimize_local_size(q.device, &mut grid, &mut block);
1325
1326 Ok(Box::new(move |q, ctx| {
1327 let hw_max_grid: Vec<usize> = q
1328 .device
1329 .max_grid_size()
1330 .into_iter()
1331 .map(|val| val.try_into().unwrap_or(usize::MAX))
1332 // clamped as pipe_launch_grid::grid is only u32
1333 .map(|val| cmp::min(val, u32::MAX as usize))
1334 .collect();
1335
1336 let variant = if offsets == [0; 3]
1337 && grid[0] <= hw_max_grid[0]
1338 && grid[1] <= hw_max_grid[1]
1339 && grid[2] <= hw_max_grid[2]
1340 && (kernel_info.work_group_size_hint == [0; 3]
1341 || block == kernel_info.work_group_size_hint)
1342 {
1343 NirKernelVariant::Optimized
1344 } else {
1345 NirKernelVariant::Default
1346 };
1347
1348 let nir_kernel_build = &nir_kernel_builds[variant];
1349 let mut workgroup_id_offset_loc = None;
1350 let mut input = Vec::new();
1351 // Set it once so we get the alignment padding right
1352 let static_local_size: u64 = nir_kernel_build.shared_size;
1353 let mut variable_local_size: u64 = static_local_size;
1354 let printf_size = q.device.printf_buffer_size() as u32;
1355 let mut samplers = Vec::new();
1356 let mut iviews = Vec::new();
1357 let mut sviews = Vec::new();
1358 let mut tex_formats: Vec<u16> = Vec::new();
1359 let mut tex_orders: Vec<u16> = Vec::new();
1360 let mut img_formats: Vec<u16> = Vec::new();
1361 let mut img_orders: Vec<u16> = Vec::new();
1362
1363 let null_ptr;
1364 let null_ptr_v3;
1365 if q.device.address_bits() == 64 {
1366 null_ptr = [0u8; 8].as_slice();
1367 null_ptr_v3 = [0u8; 24].as_slice();
1368 } else {
1369 null_ptr = [0u8; 4].as_slice();
1370 null_ptr_v3 = [0u8; 12].as_slice();
1371 };
1372
1373 let mut resource_info = Vec::new();
1374 fn add_global<'a>(
1375 q: &Queue,
1376 input: &mut Vec<u8>,
1377 resource_info: &mut Vec<(&'a PipeResource, usize)>,
1378 res: &'a PipeResource,
1379 offset: usize,
1380 ) {
1381 resource_info.push((res, input.len()));
1382 if q.device.address_bits() == 64 {
1383 let offset: u64 = offset as u64;
1384 input.extend_from_slice(&offset.to_ne_bytes());
1385 } else {
1386 let offset: u32 = offset as u32;
1387 input.extend_from_slice(&offset.to_ne_bytes());
1388 }
1389 }
1390
1391 fn add_sysval(q: &Queue, input: &mut Vec<u8>, vals: &[usize; 3]) {
1392 if q.device.address_bits() == 64 {
1393 input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u64)) });
1394 } else {
1395 input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u32)) });
1396 }
1397 }
1398
1399 let mut printf_buf = None;
1400 if nir_kernel_build.printf_info.is_some() {
1401 let buf = q
1402 .device
1403 .screen
1404 .resource_create_buffer(printf_size, ResourceType::Staging, PIPE_BIND_GLOBAL)
1405 .unwrap();
1406
1407 let init_data: [u8; 1] = [4];
1408 ctx.buffer_subdata(&buf, 0, init_data.as_ptr().cast(), init_data.len() as u32);
1409
1410 printf_buf = Some(buf);
1411 }
1412
1413 for arg in &nir_kernel_build.compiled_args {
1414 let is_opaque = if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1415 kernel_info.args[idx as usize].kind.is_opaque()
1416 } else {
1417 false
1418 };
1419
1420 if !is_opaque && arg.offset as usize > input.len() {
1421 input.resize(arg.offset as usize, 0);
1422 }
1423
1424 match arg.kind {
1425 CompiledKernelArgType::APIArg(idx) => {
1426 let api_arg = &kernel_info.args[idx as usize];
1427 if api_arg.dead {
1428 continue;
1429 }
1430
1431 let Some(value) = &arg_values[idx as usize] else {
1432 continue;
1433 };
1434
1435 match value {
1436 KernelArgValue::Constant(c) => input.extend_from_slice(c),
1437 KernelArgValue::Buffer(buffer) => {
1438 let buffer = &buffer_arcs[&(buffer.as_ptr() as usize)];
1439 let rw = if api_arg.spirv.address_qualifier
1440 == clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT
1441 {
1442 RWFlags::RD
1443 } else {
1444 RWFlags::RW
1445 };
1446
1447 let res = buffer.get_res_for_access(ctx, rw)?;
1448 add_global(q, &mut input, &mut resource_info, res, buffer.offset());
1449 }
1450 KernelArgValue::Image(image) => {
1451 let image = &image_arcs[&(image.as_ptr() as usize)];
1452 let (formats, orders) = if api_arg.kind == KernelArgType::Image {
1453 iviews.push(image.image_view(ctx, false)?);
1454 (&mut img_formats, &mut img_orders)
1455 } else if api_arg.kind == KernelArgType::RWImage {
1456 iviews.push(image.image_view(ctx, true)?);
1457 (&mut img_formats, &mut img_orders)
1458 } else {
1459 sviews.push(image.sampler_view(ctx)?);
1460 (&mut tex_formats, &mut tex_orders)
1461 };
1462
1463 let binding = arg.offset as usize;
1464 assert!(binding >= formats.len());
1465
1466 formats.resize(binding, 0);
1467 orders.resize(binding, 0);
1468
1469 formats.push(image.image_format.image_channel_data_type as u16);
1470 orders.push(image.image_format.image_channel_order as u16);
1471 }
1472 KernelArgValue::LocalMem(size) => {
1473 // TODO 32 bit
1474 let pot = cmp::min(*size, 0x80);
1475 variable_local_size = variable_local_size
1476 .next_multiple_of(pot.next_power_of_two() as u64);
1477 if q.device.address_bits() == 64 {
1478 let variable_local_size: [u8; 8] =
1479 variable_local_size.to_ne_bytes();
1480 input.extend_from_slice(&variable_local_size);
1481 } else {
1482 let variable_local_size: [u8; 4] =
1483 (variable_local_size as u32).to_ne_bytes();
1484 input.extend_from_slice(&variable_local_size);
1485 }
1486 variable_local_size += *size as u64;
1487 }
1488 KernelArgValue::Sampler(sampler) => {
1489 samplers.push(sampler.pipe());
1490 }
1491 KernelArgValue::None => {
1492 assert!(
1493 api_arg.kind == KernelArgType::MemGlobal
1494 || api_arg.kind == KernelArgType::MemConstant
1495 );
1496 input.extend_from_slice(null_ptr);
1497 }
1498 }
1499 }
1500 CompiledKernelArgType::ConstantBuffer => {
1501 assert!(nir_kernel_build.constant_buffer.is_some());
1502 let res = nir_kernel_build.constant_buffer.as_ref().unwrap();
1503 add_global(q, &mut input, &mut resource_info, res, 0);
1504 }
1505 CompiledKernelArgType::GlobalWorkOffsets => {
1506 add_sysval(q, &mut input, &offsets);
1507 }
1508 CompiledKernelArgType::WorkGroupOffsets => {
1509 workgroup_id_offset_loc = Some(input.len());
1510 input.extend_from_slice(null_ptr_v3);
1511 }
1512 CompiledKernelArgType::GlobalWorkSize => {
1513 add_sysval(q, &mut input, &api_grid);
1514 }
1515 CompiledKernelArgType::PrintfBuffer => {
1516 let res = printf_buf.as_ref().unwrap();
1517 add_global(q, &mut input, &mut resource_info, res, 0);
1518 }
1519 CompiledKernelArgType::InlineSampler(cl) => {
1520 samplers.push(Sampler::cl_to_pipe(cl));
1521 }
1522 CompiledKernelArgType::FormatArray => {
1523 input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) });
1524 input.extend_from_slice(unsafe { as_byte_slice(&img_formats) });
1525 }
1526 CompiledKernelArgType::OrderArray => {
1527 input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) });
1528 input.extend_from_slice(unsafe { as_byte_slice(&img_orders) });
1529 }
1530 CompiledKernelArgType::WorkDim => {
1531 input.extend_from_slice(&[work_dim as u8; 1]);
1532 }
1533 CompiledKernelArgType::NumWorkgroups => {
1534 input.extend_from_slice(unsafe {
1535 as_byte_slice(&[grid[0] as u32, grid[1] as u32, grid[2] as u32])
1536 });
1537 }
1538 }
1539 }
1540
1541 // subtract the shader local_size as we only request something on top of that.
1542 variable_local_size -= static_local_size;
1543
1544 let samplers: Vec<_> = samplers
1545 .iter()
1546 .map(|s| ctx.create_sampler_state(s))
1547 .collect();
1548
1549 let mut resources = Vec::with_capacity(resource_info.len());
1550 let mut globals: Vec<*mut u32> = Vec::with_capacity(resource_info.len());
1551 for (res, offset) in resource_info {
1552 resources.push(res);
1553 globals.push(unsafe { input.as_mut_ptr().byte_add(offset) }.cast());
1554 }
1555
1556 let temp_cso;
1557 let cso = match &nir_kernel_build.nir_or_cso {
1558 KernelDevStateVariant::Cso(cso) => cso,
1559 KernelDevStateVariant::Nir(nir) => {
1560 temp_cso = CSOWrapper::new(q.device, nir);
1561 &temp_cso
1562 }
1563 };
1564
1565 let sviews_len = sviews.len();
1566 ctx.bind_compute_state(cso.cso_ptr);
1567 ctx.bind_sampler_states(&samplers);
1568 ctx.set_sampler_views(sviews);
1569 ctx.set_shader_images(&iviews);
1570 ctx.set_global_binding(resources.as_slice(), &mut globals);
1571
1572 for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
1573 for y in 0..grid[1].div_ceil(hw_max_grid[1]) {
1574 for x in 0..grid[0].div_ceil(hw_max_grid[0]) {
1575 if let Some(workgroup_id_offset_loc) = workgroup_id_offset_loc {
1576 let this_offsets =
1577 [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]];
1578
1579 if q.device.address_bits() == 64 {
1580 let val = this_offsets.map(|v| v as u64);
1581 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24]
1582 .copy_from_slice(unsafe { as_byte_slice(&val) });
1583 } else {
1584 let val = this_offsets.map(|v| v as u32);
1585 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12]
1586 .copy_from_slice(unsafe { as_byte_slice(&val) });
1587 }
1588 }
1589
1590 let this_grid = [
1591 cmp::min(hw_max_grid[0], grid[0] - hw_max_grid[0] * x) as u32,
1592 cmp::min(hw_max_grid[1], grid[1] - hw_max_grid[1] * y) as u32,
1593 cmp::min(hw_max_grid[2], grid[2] - hw_max_grid[2] * z) as u32,
1594 ];
1595
1596 ctx.update_cb0(&input)?;
1597 ctx.launch_grid(work_dim, block, this_grid, variable_local_size as u32);
1598
1599 if Platform::dbg().sync_every_event {
1600 ctx.flush().wait();
1601 }
1602 }
1603 }
1604 }
1605
1606 ctx.clear_global_binding(globals.len() as u32);
1607 ctx.clear_shader_images(iviews.len() as u32);
1608 ctx.clear_sampler_views(sviews_len as u32);
1609 ctx.clear_sampler_states(samplers.len() as u32);
1610
1611 ctx.bind_compute_state(ptr::null_mut());
1612
1613 ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
1614
1615 samplers.iter().for_each(|s| ctx.delete_sampler_state(*s));
1616
1617 if let Some(printf_buf) = &printf_buf {
1618 let tx = ctx
1619 .buffer_map(printf_buf, 0, printf_size as i32, RWFlags::RD)
1620 .ok_or(CL_OUT_OF_RESOURCES)?;
1621 let mut buf: &[u8] =
1622 unsafe { slice::from_raw_parts(tx.ptr().cast(), printf_size as usize) };
1623 let length = u32::from_ne_bytes(*extract(&mut buf));
1624
1625 // update our slice to make sure we don't go out of bounds
1626 buf = &buf[0..(length - 4) as usize];
1627 if let Some(pf) = &nir_kernel_build.printf_info {
1628 pf.u_printf(buf)
1629 }
1630 }
1631
1632 Ok(())
1633 }))
1634 }
1635
arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>>1636 pub fn arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>> {
1637 self.values.lock().unwrap()
1638 }
1639
set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()>1640 pub fn set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()> {
1641 self.values
1642 .lock()
1643 .unwrap()
1644 .get_mut(idx)
1645 .ok_or(CL_INVALID_ARG_INDEX)?
1646 .replace(arg);
1647 Ok(())
1648 }
1649
access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier1650 pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
1651 let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier;
1652
1653 if aq
1654 == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
1655 | clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE
1656 {
1657 CL_KERNEL_ARG_ACCESS_READ_WRITE
1658 } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ {
1659 CL_KERNEL_ARG_ACCESS_READ_ONLY
1660 } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE {
1661 CL_KERNEL_ARG_ACCESS_WRITE_ONLY
1662 } else {
1663 CL_KERNEL_ARG_ACCESS_NONE
1664 }
1665 }
1666
address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier1667 pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
1668 match self.kernel_info.args[idx as usize].spirv.address_qualifier {
1669 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
1670 CL_KERNEL_ARG_ADDRESS_PRIVATE
1671 }
1672 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
1673 CL_KERNEL_ARG_ADDRESS_CONSTANT
1674 }
1675 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
1676 CL_KERNEL_ARG_ADDRESS_LOCAL
1677 }
1678 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
1679 CL_KERNEL_ARG_ADDRESS_GLOBAL
1680 }
1681 }
1682 }
1683
type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier1684 pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
1685 let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier;
1686 let zero = clc_kernel_arg_type_qualifier(0);
1687 let mut res = CL_KERNEL_ARG_TYPE_NONE;
1688
1689 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_CONST != zero {
1690 res |= CL_KERNEL_ARG_TYPE_CONST;
1691 }
1692
1693 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_RESTRICT != zero {
1694 res |= CL_KERNEL_ARG_TYPE_RESTRICT;
1695 }
1696
1697 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_VOLATILE != zero {
1698 res |= CL_KERNEL_ARG_TYPE_VOLATILE;
1699 }
1700
1701 res.into()
1702 }
1703
work_group_size(&self) -> [usize; 3]1704 pub fn work_group_size(&self) -> [usize; 3] {
1705 self.kernel_info.work_group_size
1706 }
1707
num_subgroups(&self) -> usize1708 pub fn num_subgroups(&self) -> usize {
1709 self.kernel_info.num_subgroups
1710 }
1711
subgroup_size(&self) -> usize1712 pub fn subgroup_size(&self) -> usize {
1713 self.kernel_info.subgroup_size
1714 }
1715
arg_name(&self, idx: cl_uint) -> Option<&CStr>1716 pub fn arg_name(&self, idx: cl_uint) -> Option<&CStr> {
1717 let name = &self.kernel_info.args[idx as usize].spirv.name;
1718 name.is_empty().not().then_some(name)
1719 }
1720
arg_type_name(&self, idx: cl_uint) -> Option<&CStr>1721 pub fn arg_type_name(&self, idx: cl_uint) -> Option<&CStr> {
1722 let type_name = &self.kernel_info.args[idx as usize].spirv.type_name;
1723 type_name.is_empty().not().then_some(type_name)
1724 }
1725
priv_mem_size(&self, dev: &Device) -> cl_ulong1726 pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
1727 self.builds.get(dev).unwrap().info.private_memory as cl_ulong
1728 }
1729
max_threads_per_block(&self, dev: &Device) -> usize1730 pub fn max_threads_per_block(&self, dev: &Device) -> usize {
1731 self.builds.get(dev).unwrap().info.max_threads as usize
1732 }
1733
preferred_simd_size(&self, dev: &Device) -> usize1734 pub fn preferred_simd_size(&self, dev: &Device) -> usize {
1735 self.builds.get(dev).unwrap().info.preferred_simd_size as usize
1736 }
1737
local_mem_size(&self, dev: &Device) -> cl_ulong1738 pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
1739 // TODO: take alignment into account?
1740 // this is purely informational so it shouldn't even matter
1741 let local =
1742 self.builds.get(dev).unwrap()[NirKernelVariant::Default].shared_size as cl_ulong;
1743 let args: cl_ulong = self
1744 .arg_values()
1745 .iter()
1746 .map(|arg| match arg {
1747 Some(KernelArgValue::LocalMem(val)) => *val as cl_ulong,
1748 // If the local memory size, for any pointer argument to the kernel declared with
1749 // the __local address qualifier, is not specified, its size is assumed to be 0.
1750 _ => 0,
1751 })
1752 .sum();
1753
1754 local + args
1755 }
1756
has_svm_devs(&self) -> bool1757 pub fn has_svm_devs(&self) -> bool {
1758 self.prog.devs.iter().any(|dev| dev.svm_supported())
1759 }
1760
subgroup_sizes(&self, dev: &Device) -> Vec<usize>1761 pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
1762 SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
1763 .map(|bit| 1 << bit)
1764 .collect()
1765 }
1766
subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize1767 pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1768 let subgroup_size = self.subgroup_size_for_block(dev, block);
1769 if subgroup_size == 0 {
1770 return 0;
1771 }
1772
1773 let threads: usize = block.iter().product();
1774 threads.div_ceil(subgroup_size)
1775 }
1776
subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize1777 pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1778 let subgroup_sizes = self.subgroup_sizes(dev);
1779 if subgroup_sizes.is_empty() {
1780 return 0;
1781 }
1782
1783 if subgroup_sizes.len() == 1 {
1784 return subgroup_sizes[0];
1785 }
1786
1787 let block = [
1788 *block.first().unwrap_or(&1) as u32,
1789 *block.get(1).unwrap_or(&1) as u32,
1790 *block.get(2).unwrap_or(&1) as u32,
1791 ];
1792
1793 // TODO: this _might_ bite us somewhere, but I think it probably doesn't matter
1794 match &self.builds.get(dev).unwrap()[NirKernelVariant::Default].nir_or_cso {
1795 KernelDevStateVariant::Cso(cso) => {
1796 dev.helper_ctx()
1797 .compute_state_subgroup_size(cso.cso_ptr, &block) as usize
1798 }
1799 _ => {
1800 panic!()
1801 }
1802 }
1803 }
1804 }
1805
1806 impl Clone for Kernel {
clone(&self) -> Self1807 fn clone(&self) -> Self {
1808 Self {
1809 base: CLObjectBase::new(RusticlTypes::Kernel),
1810 prog: Arc::clone(&self.prog),
1811 name: self.name.clone(),
1812 values: Mutex::new(self.arg_values().clone()),
1813 builds: self.builds.clone(),
1814 kernel_info: Arc::clone(&self.kernel_info),
1815 }
1816 }
1817 }
1818