1 use crate::api::icd::*;
2 use crate::core::device::*;
3 use crate::core::event::*;
4 use crate::core::memory::*;
5 use crate::core::platform::*;
6 use crate::core::program::*;
7 use crate::core::queue::*;
8 use crate::impl_cl_type_trait;
9
10 use mesa_rust::compiler::clc::*;
11 use mesa_rust::compiler::nir::*;
12 use mesa_rust::nir_pass;
13 use mesa_rust::pipe::context::RWFlags;
14 use mesa_rust::pipe::resource::*;
15 use mesa_rust::pipe::screen::ResourceType;
16 use mesa_rust_gen::*;
17 use mesa_rust_util::math::*;
18 use mesa_rust_util::serialize::*;
19 use rusticl_opencl_gen::*;
20 use spirv::SpirvKernelInfo;
21
22 use std::cmp;
23 use std::collections::HashMap;
24 use std::convert::TryInto;
25 use std::ffi::CStr;
26 use std::fmt::Debug;
27 use std::fmt::Display;
28 use std::ops::Index;
29 use std::ops::Not;
30 use std::os::raw::c_void;
31 use std::ptr;
32 use std::slice;
33 use std::sync::Arc;
34 use std::sync::Mutex;
35 use std::sync::MutexGuard;
36 use std::sync::Weak;
37
38 // According to the CL spec we are not allowed to let any cl_kernel object hold any references on
39 // its arguments as this might make it unfeasible for applications to free the backing memory of
40 // memory objects allocated with `CL_USE_HOST_PTR`.
41 //
42 // However those arguments might temporarily get referenced by event objects, so we'll use Weak in
43 // order to upgrade the reference when needed. It's also safer to use Weak over raw pointers,
44 // because it makes it impossible to run into use-after-free issues.
45 //
46 // Technically we also need to do it for samplers, but there it's kinda pointless to take a weak
47 // reference as samplers don't have the same host_ptr or any similar problems as cl_mem objects.
48 #[derive(Clone)]
49 pub enum KernelArgValue {
50 None,
51 Buffer(Weak<Buffer>),
52 Constant(Vec<u8>),
53 Image(Weak<Image>),
54 LocalMem(usize),
55 Sampler(Arc<Sampler>),
56 }
57
58 #[repr(u8)]
59 #[derive(Hash, PartialEq, Eq, Clone, Copy)]
60 pub enum KernelArgType {
61 Constant(/* size */ u16), // for anything passed by value
62 Image,
63 RWImage,
64 Sampler,
65 Texture,
66 MemGlobal,
67 MemConstant,
68 MemLocal,
69 }
70
71 impl KernelArgType {
deserialize(blob: &mut blob_reader) -> Option<Self>72 fn deserialize(blob: &mut blob_reader) -> Option<Self> {
73 // SAFETY: we get 0 on an overrun, but we verify that later and act accordingly.
74 let res = match unsafe { blob_read_uint8(blob) } {
75 0 => {
76 // SAFETY: same here
77 let size = unsafe { blob_read_uint16(blob) };
78 KernelArgType::Constant(size)
79 }
80 1 => KernelArgType::Image,
81 2 => KernelArgType::RWImage,
82 3 => KernelArgType::Sampler,
83 4 => KernelArgType::Texture,
84 5 => KernelArgType::MemGlobal,
85 6 => KernelArgType::MemConstant,
86 7 => KernelArgType::MemLocal,
87 _ => return None,
88 };
89
90 blob.overrun.not().then_some(res)
91 }
92
serialize(&self, blob: &mut blob)93 fn serialize(&self, blob: &mut blob) {
94 unsafe {
95 match self {
96 KernelArgType::Constant(size) => {
97 blob_write_uint8(blob, 0);
98 blob_write_uint16(blob, *size)
99 }
100 KernelArgType::Image => blob_write_uint8(blob, 1),
101 KernelArgType::RWImage => blob_write_uint8(blob, 2),
102 KernelArgType::Sampler => blob_write_uint8(blob, 3),
103 KernelArgType::Texture => blob_write_uint8(blob, 4),
104 KernelArgType::MemGlobal => blob_write_uint8(blob, 5),
105 KernelArgType::MemConstant => blob_write_uint8(blob, 6),
106 KernelArgType::MemLocal => blob_write_uint8(blob, 7),
107 };
108 }
109 }
110
is_opaque(&self) -> bool111 fn is_opaque(&self) -> bool {
112 matches!(
113 self,
114 KernelArgType::Image
115 | KernelArgType::RWImage
116 | KernelArgType::Texture
117 | KernelArgType::Sampler
118 )
119 }
120 }
121
122 #[derive(Hash, PartialEq, Eq, Clone)]
123 enum CompiledKernelArgType {
124 APIArg(u32),
125 ConstantBuffer,
126 GlobalWorkOffsets,
127 GlobalWorkSize,
128 PrintfBuffer,
129 InlineSampler((cl_addressing_mode, cl_filter_mode, bool)),
130 FormatArray,
131 OrderArray,
132 WorkDim,
133 WorkGroupOffsets,
134 NumWorkgroups,
135 }
136
137 #[derive(Hash, PartialEq, Eq, Clone)]
138 pub struct KernelArg {
139 spirv: spirv::SPIRVKernelArg,
140 pub kind: KernelArgType,
141 pub dead: bool,
142 }
143
144 impl KernelArg {
from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self>145 fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
146 let nir_arg_map: HashMap<_, _> = nir
147 .variables_with_mode(
148 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
149 )
150 .map(|v| (v.data.location, v))
151 .collect();
152 let mut res = Vec::new();
153
154 for (i, s) in spirv.iter().enumerate() {
155 let nir = nir_arg_map.get(&(i as i32)).unwrap();
156 let kind = match s.address_qualifier {
157 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
158 if unsafe { glsl_type_is_sampler(nir.type_) } {
159 KernelArgType::Sampler
160 } else {
161 let size = unsafe { glsl_get_cl_size(nir.type_) } as u16;
162 // nir types of non opaque types are never sized 0
163 KernelArgType::Constant(size)
164 }
165 }
166 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
167 KernelArgType::MemConstant
168 }
169 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
170 KernelArgType::MemLocal
171 }
172 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
173 if unsafe { glsl_type_is_image(nir.type_) } {
174 let access = nir.data.access();
175 if access == gl_access_qualifier::ACCESS_NON_WRITEABLE.0 {
176 KernelArgType::Texture
177 } else if access == gl_access_qualifier::ACCESS_NON_READABLE.0 {
178 KernelArgType::Image
179 } else {
180 KernelArgType::RWImage
181 }
182 } else {
183 KernelArgType::MemGlobal
184 }
185 }
186 };
187
188 res.push(Self {
189 spirv: s.clone(),
190 // we'll update it later in the 2nd pass
191 kind: kind,
192 dead: true,
193 });
194 }
195 res
196 }
197
serialize(args: &[Self], blob: &mut blob)198 fn serialize(args: &[Self], blob: &mut blob) {
199 unsafe {
200 blob_write_uint16(blob, args.len() as u16);
201
202 for arg in args {
203 arg.spirv.serialize(blob);
204 blob_write_uint8(blob, arg.dead.into());
205 arg.kind.serialize(blob);
206 }
207 }
208 }
209
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>210 fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
211 // SAFETY: we check the overrun status, blob_read returns 0 in such a case.
212 let len = unsafe { blob_read_uint16(blob) } as usize;
213 let mut res = Vec::with_capacity(len);
214
215 for _ in 0..len {
216 let spirv = spirv::SPIRVKernelArg::deserialize(blob)?;
217 // SAFETY: we check the overrun status
218 let dead = unsafe { blob_read_uint8(blob) } != 0;
219 let kind = KernelArgType::deserialize(blob)?;
220
221 res.push(Self {
222 spirv: spirv,
223 kind: kind,
224 dead: dead,
225 });
226 }
227
228 blob.overrun.not().then_some(res)
229 }
230 }
231
232 #[derive(Hash, PartialEq, Eq, Clone)]
233 struct CompiledKernelArg {
234 kind: CompiledKernelArgType,
235 /// The binding for image/sampler args, the offset into the input buffer
236 /// for anything else.
237 offset: u32,
238 dead: bool,
239 }
240
241 impl CompiledKernelArg {
assign_locations(compiled_args: &mut [Self], nir: &mut NirShader)242 fn assign_locations(compiled_args: &mut [Self], nir: &mut NirShader) {
243 for var in nir.variables_with_mode(
244 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
245 ) {
246 let arg = &mut compiled_args[var.data.location as usize];
247 let t = var.type_;
248
249 arg.dead = false;
250 arg.offset = if unsafe {
251 glsl_type_is_image(t) || glsl_type_is_texture(t) || glsl_type_is_sampler(t)
252 } {
253 var.data.binding
254 } else {
255 var.data.driver_location
256 };
257 }
258 }
259
serialize(args: &[Self], blob: &mut blob)260 fn serialize(args: &[Self], blob: &mut blob) {
261 unsafe {
262 blob_write_uint16(blob, args.len() as u16);
263 for arg in args {
264 blob_write_uint32(blob, arg.offset);
265 blob_write_uint8(blob, arg.dead.into());
266 match arg.kind {
267 CompiledKernelArgType::ConstantBuffer => blob_write_uint8(blob, 0),
268 CompiledKernelArgType::GlobalWorkOffsets => blob_write_uint8(blob, 1),
269 CompiledKernelArgType::PrintfBuffer => blob_write_uint8(blob, 2),
270 CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm)) => {
271 blob_write_uint8(blob, 3);
272 blob_write_uint8(blob, norm.into());
273 blob_write_uint32(blob, addr_mode);
274 blob_write_uint32(blob, filter_mode)
275 }
276 CompiledKernelArgType::FormatArray => blob_write_uint8(blob, 4),
277 CompiledKernelArgType::OrderArray => blob_write_uint8(blob, 5),
278 CompiledKernelArgType::WorkDim => blob_write_uint8(blob, 6),
279 CompiledKernelArgType::WorkGroupOffsets => blob_write_uint8(blob, 7),
280 CompiledKernelArgType::NumWorkgroups => blob_write_uint8(blob, 8),
281 CompiledKernelArgType::GlobalWorkSize => blob_write_uint8(blob, 9),
282 CompiledKernelArgType::APIArg(idx) => {
283 blob_write_uint8(blob, 10);
284 blob_write_uint32(blob, idx)
285 }
286 };
287 }
288 }
289 }
290
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>291 fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
292 unsafe {
293 let len = blob_read_uint16(blob) as usize;
294 let mut res = Vec::with_capacity(len);
295
296 for _ in 0..len {
297 let offset = blob_read_uint32(blob);
298 let dead = blob_read_uint8(blob) != 0;
299
300 let kind = match blob_read_uint8(blob) {
301 0 => CompiledKernelArgType::ConstantBuffer,
302 1 => CompiledKernelArgType::GlobalWorkOffsets,
303 2 => CompiledKernelArgType::PrintfBuffer,
304 3 => {
305 let norm = blob_read_uint8(blob) != 0;
306 let addr_mode = blob_read_uint32(blob);
307 let filter_mode = blob_read_uint32(blob);
308 CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm))
309 }
310 4 => CompiledKernelArgType::FormatArray,
311 5 => CompiledKernelArgType::OrderArray,
312 6 => CompiledKernelArgType::WorkDim,
313 7 => CompiledKernelArgType::WorkGroupOffsets,
314 8 => CompiledKernelArgType::NumWorkgroups,
315 9 => CompiledKernelArgType::GlobalWorkSize,
316 10 => {
317 let idx = blob_read_uint32(blob);
318 CompiledKernelArgType::APIArg(idx)
319 }
320 _ => return None,
321 };
322
323 res.push(Self {
324 kind: kind,
325 offset: offset,
326 dead: dead,
327 });
328 }
329
330 Some(res)
331 }
332 }
333 }
334
335 #[derive(Clone, PartialEq, Eq, Hash)]
336 pub struct KernelInfo {
337 pub args: Vec<KernelArg>,
338 pub attributes_string: String,
339 work_group_size: [usize; 3],
340 work_group_size_hint: [u32; 3],
341 subgroup_size: usize,
342 num_subgroups: usize,
343 }
344
345 struct CSOWrapper {
346 cso_ptr: *mut c_void,
347 dev: &'static Device,
348 }
349
350 impl CSOWrapper {
new(dev: &'static Device, nir: &NirShader) -> Self351 fn new(dev: &'static Device, nir: &NirShader) -> Self {
352 let cso_ptr = dev
353 .helper_ctx()
354 .create_compute_state(nir, nir.shared_size());
355
356 Self {
357 cso_ptr: cso_ptr,
358 dev: dev,
359 }
360 }
361
get_cso_info(&self) -> pipe_compute_state_object_info362 fn get_cso_info(&self) -> pipe_compute_state_object_info {
363 self.dev.helper_ctx().compute_state_info(self.cso_ptr)
364 }
365 }
366
367 impl Drop for CSOWrapper {
drop(&mut self)368 fn drop(&mut self) {
369 self.dev.helper_ctx().delete_compute_state(self.cso_ptr);
370 }
371 }
372
373 enum KernelDevStateVariant {
374 Cso(CSOWrapper),
375 Nir(NirShader),
376 }
377
378 #[derive(Debug, PartialEq)]
379 enum NirKernelVariant {
380 /// Can be used under any circumstance.
381 Default,
382
383 /// Optimized variant making the following assumptions:
384 /// - global_id_offsets are 0
385 /// - workgroup_offsets are 0
386 /// - local_size is info.local_size_hint
387 Optimized,
388 }
389
390 impl Display for NirKernelVariant {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 // this simply prints the enum name, so that's fine
393 Debug::fmt(self, f)
394 }
395 }
396
397 pub struct NirKernelBuilds {
398 default_build: NirKernelBuild,
399 optimized: Option<NirKernelBuild>,
400 /// merged info with worst case values
401 info: pipe_compute_state_object_info,
402 }
403
404 impl Index<NirKernelVariant> for NirKernelBuilds {
405 type Output = NirKernelBuild;
406
index(&self, index: NirKernelVariant) -> &Self::Output407 fn index(&self, index: NirKernelVariant) -> &Self::Output {
408 match index {
409 NirKernelVariant::Default => &self.default_build,
410 NirKernelVariant::Optimized => self.optimized.as_ref().unwrap_or(&self.default_build),
411 }
412 }
413 }
414
415 impl NirKernelBuilds {
new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self416 fn new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self {
417 let mut info = default_build.info;
418 if let Some(build) = &optimized {
419 info.max_threads = cmp::min(info.max_threads, build.info.max_threads);
420 info.simd_sizes &= build.info.simd_sizes;
421 info.private_memory = cmp::max(info.private_memory, build.info.private_memory);
422 info.preferred_simd_size =
423 cmp::max(info.preferred_simd_size, build.info.preferred_simd_size);
424 }
425
426 Self {
427 default_build: default_build,
428 optimized: optimized,
429 info: info,
430 }
431 }
432 }
433
434 struct NirKernelBuild {
435 nir_or_cso: KernelDevStateVariant,
436 constant_buffer: Option<Arc<PipeResource>>,
437 info: pipe_compute_state_object_info,
438 shared_size: u64,
439 printf_info: Option<NirPrintfInfo>,
440 compiled_args: Vec<CompiledKernelArg>,
441 }
442
443 // SAFETY: `CSOWrapper` is only safe to use if the device supports `pipe_caps.shareable_shaders` and
444 // we make sure to set `nir_or_cso` to `KernelDevStateVariant::Cso` only if that's the case.
445 unsafe impl Send for NirKernelBuild {}
446 unsafe impl Sync for NirKernelBuild {}
447
448 impl NirKernelBuild {
new(dev: &'static Device, mut out: CompilationResult) -> Self449 fn new(dev: &'static Device, mut out: CompilationResult) -> Self {
450 let cso = CSOWrapper::new(dev, &out.nir);
451 let info = cso.get_cso_info();
452 let cb = Self::create_nir_constant_buffer(dev, &out.nir);
453 let shared_size = out.nir.shared_size() as u64;
454 let printf_info = out.nir.take_printf_info();
455
456 let nir_or_cso = if !dev.shareable_shaders() {
457 KernelDevStateVariant::Nir(out.nir)
458 } else {
459 KernelDevStateVariant::Cso(cso)
460 };
461
462 NirKernelBuild {
463 nir_or_cso: nir_or_cso,
464 constant_buffer: cb,
465 info: info,
466 shared_size: shared_size,
467 printf_info: printf_info,
468 compiled_args: out.compiled_args,
469 }
470 }
471
create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>>472 fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>> {
473 let buf = nir.get_constant_buffer();
474 let len = buf.len() as u32;
475
476 if len > 0 {
477 // TODO bind as constant buffer
478 let res = dev
479 .screen()
480 .resource_create_buffer(len, ResourceType::Normal, PIPE_BIND_GLOBAL)
481 .unwrap();
482
483 dev.helper_ctx()
484 .exec(|ctx| ctx.buffer_subdata(&res, 0, buf.as_ptr().cast(), len))
485 .wait();
486
487 Some(Arc::new(res))
488 } else {
489 None
490 }
491 }
492 }
493
494 pub struct Kernel {
495 pub base: CLObjectBase<CL_INVALID_KERNEL>,
496 pub prog: Arc<Program>,
497 pub name: String,
498 values: Mutex<Vec<Option<KernelArgValue>>>,
499 builds: HashMap<&'static Device, Arc<NirKernelBuilds>>,
500 pub kernel_info: Arc<KernelInfo>,
501 }
502
503 impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
504
create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]> where T: std::convert::TryFrom<usize> + Copy, <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,505 fn create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]>
506 where
507 T: std::convert::TryFrom<usize> + Copy,
508 <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,
509 {
510 let mut res = [val; 3];
511 for (i, v) in vals.iter().enumerate() {
512 res[i] = (*v).try_into().ok().ok_or(CL_OUT_OF_RESOURCES)?;
513 }
514
515 Ok(res)
516 }
517
518 #[derive(Clone)]
519 struct CompilationResult {
520 nir: NirShader,
521 compiled_args: Vec<CompiledKernelArg>,
522 }
523
524 impl CompilationResult {
deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self>525 fn deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self> {
526 let nir = NirShader::deserialize(
527 reader,
528 d.screen()
529 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
530 )?;
531 let compiled_args = CompiledKernelArg::deserialize(reader)?;
532
533 Some(Self {
534 nir: nir,
535 compiled_args,
536 })
537 }
538
serialize(&self, blob: &mut blob)539 fn serialize(&self, blob: &mut blob) {
540 self.nir.serialize(blob);
541 CompiledKernelArg::serialize(&self.compiled_args, blob);
542 }
543 }
544
opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool)545 fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
546 let nir_options = unsafe {
547 &*dev
548 .screen
549 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
550 };
551
552 while {
553 let mut progress = false;
554
555 progress |= nir_pass!(nir, nir_copy_prop);
556 progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
557 progress |= nir_pass!(nir, nir_opt_dead_write_vars);
558
559 if nir_options.lower_to_scalar {
560 nir_pass!(
561 nir,
562 nir_lower_alu_to_scalar,
563 nir_options.lower_to_scalar_filter,
564 ptr::null(),
565 );
566 nir_pass!(nir, nir_lower_phis_to_scalar, false);
567 }
568
569 progress |= nir_pass!(nir, nir_opt_deref);
570 if has_explicit_types {
571 progress |= nir_pass!(nir, nir_opt_memcpy);
572 }
573 progress |= nir_pass!(nir, nir_opt_dce);
574 progress |= nir_pass!(nir, nir_opt_undef);
575 progress |= nir_pass!(nir, nir_opt_constant_folding);
576 progress |= nir_pass!(nir, nir_opt_cse);
577 nir_pass!(nir, nir_split_var_copies);
578 progress |= nir_pass!(nir, nir_lower_var_copies);
579 progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
580 nir_pass!(nir, nir_lower_alu);
581 progress |= nir_pass!(nir, nir_opt_phi_precision);
582 progress |= nir_pass!(nir, nir_opt_algebraic);
583 progress |= nir_pass!(
584 nir,
585 nir_opt_if,
586 nir_opt_if_options::nir_opt_if_optimize_phi_true_false,
587 );
588 progress |= nir_pass!(nir, nir_opt_dead_cf);
589 progress |= nir_pass!(nir, nir_opt_remove_phis);
590 // we don't want to be too aggressive here, but it kills a bit of CFG
591 progress |= nir_pass!(nir, nir_opt_peephole_select, 8, true, true);
592 progress |= nir_pass!(
593 nir,
594 nir_lower_vec3_to_vec4,
595 nir_variable_mode::nir_var_mem_generic | nir_variable_mode::nir_var_uniform,
596 );
597
598 if nir_options.max_unroll_iterations != 0 {
599 progress |= nir_pass!(nir, nir_opt_loop_unroll);
600 }
601 nir.sweep_mem();
602 progress
603 } {}
604 }
605
606 /// # Safety
607 ///
608 /// Only safe to call when `var` is a valid pointer to a valid [`nir_variable`]
can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool609 unsafe extern "C" fn can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool {
610 // SAFETY: It is the caller's responsibility to provide a valid and aligned pointer
611 let var_type = unsafe { (*var).type_ };
612 // SAFETY: `nir_variable`'s type invariant guarantees that the `type_` field is valid and
613 // properly aligned.
614 unsafe {
615 !glsl_type_is_image(var_type)
616 && !glsl_type_is_texture(var_type)
617 && !glsl_type_is_sampler(var_type)
618 }
619 }
620
621 const DV_OPTS: nir_remove_dead_variables_options = nir_remove_dead_variables_options {
622 can_remove_var: Some(can_remove_var),
623 can_remove_var_data: ptr::null_mut(),
624 };
625
compile_nir_to_args( dev: &Device, mut nir: NirShader, args: &[spirv::SPIRVKernelArg], lib_clc: &NirShader, ) -> (Vec<KernelArg>, NirShader)626 fn compile_nir_to_args(
627 dev: &Device,
628 mut nir: NirShader,
629 args: &[spirv::SPIRVKernelArg],
630 lib_clc: &NirShader,
631 ) -> (Vec<KernelArg>, NirShader) {
632 // this is a hack until we support fp16 properly and check for denorms inside vstore/vload_half
633 nir.preserve_fp16_denorms();
634
635 // Set to rtne for now until drivers are able to report their preferred rounding mode, that also
636 // matches what we report via the API.
637 nir.set_fp_rounding_mode_rtne();
638
639 nir_pass!(nir, nir_scale_fdiv);
640 nir.set_workgroup_size_variable_if_zero();
641 nir.structurize();
642 while {
643 let mut progress = false;
644 nir_pass!(nir, nir_split_var_copies);
645 progress |= nir_pass!(nir, nir_copy_prop);
646 progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
647 progress |= nir_pass!(nir, nir_opt_dead_write_vars);
648 progress |= nir_pass!(nir, nir_opt_deref);
649 progress |= nir_pass!(nir, nir_opt_dce);
650 progress |= nir_pass!(nir, nir_opt_undef);
651 progress |= nir_pass!(nir, nir_opt_constant_folding);
652 progress |= nir_pass!(nir, nir_opt_cse);
653 progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
654 progress |= nir_pass!(nir, nir_opt_algebraic);
655 progress
656 } {}
657 nir.inline(lib_clc);
658 nir.cleanup_functions();
659 // that should free up tons of memory
660 nir.sweep_mem();
661
662 nir_pass!(nir, nir_dedup_inline_samplers);
663
664 let printf_opts = nir_lower_printf_options {
665 ptr_bit_size: 0,
666 use_printf_base_identifier: false,
667 hash_format_strings: false,
668 max_buffer_size: dev.printf_buffer_size() as u32,
669 };
670 nir_pass!(nir, nir_lower_printf, &printf_opts);
671
672 opt_nir(&mut nir, dev, false);
673
674 (KernelArg::from_spirv_nir(args, &mut nir), nir)
675 }
676
compile_nir_prepare_for_variants( dev: &Device, nir: &mut NirShader, compiled_args: &mut Vec<CompiledKernelArg>, )677 fn compile_nir_prepare_for_variants(
678 dev: &Device,
679 nir: &mut NirShader,
680 compiled_args: &mut Vec<CompiledKernelArg>,
681 ) {
682 // assign locations for inline samplers.
683 // IMPORTANT: this needs to happen before nir_remove_dead_variables.
684 let mut last_loc = -1;
685 for v in nir
686 .variables_with_mode(nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image)
687 {
688 if unsafe { !glsl_type_is_sampler(v.type_) } {
689 last_loc = v.data.location;
690 continue;
691 }
692 let s = unsafe { v.data.anon_1.sampler };
693 if s.is_inline_sampler() != 0 {
694 last_loc += 1;
695 v.data.location = last_loc;
696
697 compiled_args.push(CompiledKernelArg {
698 kind: CompiledKernelArgType::InlineSampler(Sampler::nir_to_cl(
699 s.addressing_mode(),
700 s.filter_mode(),
701 s.normalized_coordinates(),
702 )),
703 offset: 0,
704 dead: true,
705 });
706 } else {
707 last_loc = v.data.location;
708 }
709 }
710
711 nir_pass!(
712 nir,
713 nir_remove_dead_variables,
714 nir_variable_mode::nir_var_uniform
715 | nir_variable_mode::nir_var_image
716 | nir_variable_mode::nir_var_mem_constant
717 | nir_variable_mode::nir_var_mem_shared
718 | nir_variable_mode::nir_var_function_temp,
719 &DV_OPTS,
720 );
721
722 nir_pass!(nir, nir_lower_readonly_images_to_tex, true);
723 nir_pass!(
724 nir,
725 nir_lower_cl_images,
726 !dev.images_as_deref(),
727 !dev.samplers_as_deref(),
728 );
729
730 nir_pass!(
731 nir,
732 nir_lower_vars_to_explicit_types,
733 nir_variable_mode::nir_var_mem_constant,
734 Some(glsl_get_cl_type_size_align),
735 );
736
737 // has to run before adding internal kernel arguments
738 nir.extract_constant_initializers();
739
740 // needed to convert variables to load intrinsics
741 nir_pass!(nir, nir_lower_system_values);
742
743 // Run here so we can decide if it makes sense to compile a variant, e.g. read system values.
744 nir.gather_info();
745 }
746
compile_nir_variant( res: &mut CompilationResult, dev: &Device, variant: NirKernelVariant, args: &[KernelArg], name: &str, )747 fn compile_nir_variant(
748 res: &mut CompilationResult,
749 dev: &Device,
750 variant: NirKernelVariant,
751 args: &[KernelArg],
752 name: &str,
753 ) {
754 let mut lower_state = rusticl_lower_state::default();
755 let compiled_args = &mut res.compiled_args;
756 let nir = &mut res.nir;
757
758 let address_bits_ptr_type;
759 let address_bits_base_type;
760 let global_address_format;
761 let shared_address_format;
762
763 if dev.address_bits() == 64 {
764 address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
765 address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
766 global_address_format = nir_address_format::nir_address_format_64bit_global;
767 shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
768 } else {
769 address_bits_ptr_type = unsafe { glsl_uint_type() };
770 address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
771 global_address_format = nir_address_format::nir_address_format_32bit_global;
772 shared_address_format = nir_address_format::nir_address_format_32bit_offset;
773 }
774
775 let nir_options = unsafe {
776 &*dev
777 .screen
778 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
779 };
780
781 if variant == NirKernelVariant::Optimized {
782 let wgsh = nir.workgroup_size_hint();
783 if wgsh != [0; 3] {
784 nir.set_workgroup_size(wgsh);
785 }
786 }
787
788 let mut compute_options = nir_lower_compute_system_values_options::default();
789 compute_options.set_has_global_size(true);
790 if variant != NirKernelVariant::Optimized {
791 compute_options.set_has_base_global_invocation_id(true);
792 compute_options.set_has_base_workgroup_id(true);
793 }
794 nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
795 nir.gather_info();
796
797 let mut add_var = |nir: &mut NirShader,
798 var_loc: &mut usize,
799 kind: CompiledKernelArgType,
800 glsl_type: *const glsl_type,
801 name| {
802 *var_loc = compiled_args.len();
803 compiled_args.push(CompiledKernelArg {
804 kind: kind,
805 offset: 0,
806 dead: true,
807 });
808 nir.add_var(
809 nir_variable_mode::nir_var_uniform,
810 glsl_type,
811 *var_loc,
812 name,
813 );
814 };
815
816 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
817 debug_assert_ne!(variant, NirKernelVariant::Optimized);
818 add_var(
819 nir,
820 &mut lower_state.base_global_invoc_id_loc,
821 CompiledKernelArgType::GlobalWorkOffsets,
822 unsafe { glsl_vector_type(address_bits_base_type, 3) },
823 c"base_global_invocation_id",
824 )
825 }
826
827 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_GROUP_SIZE) {
828 add_var(
829 nir,
830 &mut lower_state.global_size_loc,
831 CompiledKernelArgType::GlobalWorkSize,
832 unsafe { glsl_vector_type(address_bits_base_type, 3) },
833 c"global_size",
834 )
835 }
836
837 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_WORKGROUP_ID) {
838 debug_assert_ne!(variant, NirKernelVariant::Optimized);
839 add_var(
840 nir,
841 &mut lower_state.base_workgroup_id_loc,
842 CompiledKernelArgType::WorkGroupOffsets,
843 unsafe { glsl_vector_type(address_bits_base_type, 3) },
844 c"base_workgroup_id",
845 );
846 }
847
848 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_NUM_WORKGROUPS) {
849 add_var(
850 nir,
851 &mut lower_state.num_workgroups_loc,
852 CompiledKernelArgType::NumWorkgroups,
853 unsafe { glsl_vector_type(glsl_base_type::GLSL_TYPE_UINT, 3) },
854 c"num_workgroups",
855 );
856 }
857
858 if nir.has_constant() {
859 add_var(
860 nir,
861 &mut lower_state.const_buf_loc,
862 CompiledKernelArgType::ConstantBuffer,
863 address_bits_ptr_type,
864 c"constant_buffer_addr",
865 );
866 }
867 if nir.has_printf() {
868 add_var(
869 nir,
870 &mut lower_state.printf_buf_loc,
871 CompiledKernelArgType::PrintfBuffer,
872 address_bits_ptr_type,
873 c"printf_buffer_addr",
874 );
875 }
876
877 if nir.num_images() > 0 || nir.num_textures() > 0 {
878 let count = nir.num_images() + nir.num_textures();
879
880 add_var(
881 nir,
882 &mut lower_state.format_arr_loc,
883 CompiledKernelArgType::FormatArray,
884 unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
885 c"image_formats",
886 );
887
888 add_var(
889 nir,
890 &mut lower_state.order_arr_loc,
891 CompiledKernelArgType::OrderArray,
892 unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
893 c"image_orders",
894 );
895 }
896
897 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
898 add_var(
899 nir,
900 &mut lower_state.work_dim_loc,
901 CompiledKernelArgType::WorkDim,
902 unsafe { glsl_uint8_t_type() },
903 c"work_dim",
904 );
905 }
906
907 // need to run after first opt loop and remove_dead_variables to get rid of uneccessary scratch
908 // memory
909 nir_pass!(
910 nir,
911 nir_lower_vars_to_explicit_types,
912 nir_variable_mode::nir_var_mem_shared
913 | nir_variable_mode::nir_var_function_temp
914 | nir_variable_mode::nir_var_shader_temp
915 | nir_variable_mode::nir_var_uniform
916 | nir_variable_mode::nir_var_mem_global
917 | nir_variable_mode::nir_var_mem_generic,
918 Some(glsl_get_cl_type_size_align),
919 );
920
921 opt_nir(nir, dev, true);
922 nir_pass!(nir, nir_lower_memcpy);
923
924 // we might have got rid of more function_temp or shared memory
925 nir.reset_scratch_size();
926 nir.reset_shared_size();
927 nir_pass!(
928 nir,
929 nir_remove_dead_variables,
930 nir_variable_mode::nir_var_function_temp | nir_variable_mode::nir_var_mem_shared,
931 &DV_OPTS,
932 );
933 nir_pass!(
934 nir,
935 nir_lower_vars_to_explicit_types,
936 nir_variable_mode::nir_var_function_temp
937 | nir_variable_mode::nir_var_mem_shared
938 | nir_variable_mode::nir_var_mem_generic,
939 Some(glsl_get_cl_type_size_align),
940 );
941
942 nir_pass!(
943 nir,
944 nir_lower_explicit_io,
945 nir_variable_mode::nir_var_mem_global | nir_variable_mode::nir_var_mem_constant,
946 global_address_format,
947 );
948
949 nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
950 nir_pass!(
951 nir,
952 nir_lower_explicit_io,
953 nir_variable_mode::nir_var_mem_shared
954 | nir_variable_mode::nir_var_function_temp
955 | nir_variable_mode::nir_var_uniform,
956 shared_address_format,
957 );
958
959 if nir_options.lower_int64_options.0 != 0 && !nir_options.late_lower_int64 {
960 nir_pass!(nir, nir_lower_int64);
961 }
962
963 if nir_options.lower_uniforms_to_ubo {
964 nir_pass!(nir, rusticl_lower_inputs);
965 }
966
967 nir_pass!(nir, nir_lower_convert_alu_types, None);
968
969 opt_nir(nir, dev, true);
970
971 /* before passing it into drivers, assign locations as drivers might remove nir_variables or
972 * other things we depend on
973 */
974 CompiledKernelArg::assign_locations(compiled_args, nir);
975
976 /* update the has_variable_shared_mem info as we might have DCEed all of them */
977 nir.set_has_variable_shared_mem(compiled_args.iter().any(|arg| {
978 if let CompiledKernelArgType::APIArg(idx) = arg.kind {
979 args[idx as usize].kind == KernelArgType::MemLocal && !arg.dead
980 } else {
981 false
982 }
983 }));
984
985 if Platform::dbg().nir {
986 eprintln!("=== Printing nir variant '{variant}' for '{name}' before driver finalization");
987 nir.print();
988 }
989
990 if dev.screen.finalize_nir(nir) {
991 if Platform::dbg().nir {
992 eprintln!(
993 "=== Printing nir variant '{variant}' for '{name}' after driver finalization"
994 );
995 nir.print();
996 }
997 }
998
999 nir_pass!(nir, nir_opt_dce);
1000 nir.sweep_mem();
1001 }
1002
compile_nir_remaining( dev: &Device, mut nir: NirShader, args: &[KernelArg], name: &str, ) -> (CompilationResult, Option<CompilationResult>)1003 fn compile_nir_remaining(
1004 dev: &Device,
1005 mut nir: NirShader,
1006 args: &[KernelArg],
1007 name: &str,
1008 ) -> (CompilationResult, Option<CompilationResult>) {
1009 // add all API kernel args
1010 let mut compiled_args: Vec<_> = (0..args.len())
1011 .map(|idx| CompiledKernelArg {
1012 kind: CompiledKernelArgType::APIArg(idx as u32),
1013 offset: 0,
1014 dead: true,
1015 })
1016 .collect();
1017
1018 compile_nir_prepare_for_variants(dev, &mut nir, &mut compiled_args);
1019 if Platform::dbg().nir {
1020 eprintln!("=== Printing nir for '{name}' before specialization");
1021 nir.print();
1022 }
1023
1024 let mut default_build = CompilationResult {
1025 nir: nir,
1026 compiled_args: compiled_args,
1027 };
1028
1029 // check if we even want to compile a variant before cloning the compilation state
1030 let has_wgs_hint = default_build.nir.workgroup_size_variable()
1031 && default_build.nir.workgroup_size_hint() != [0; 3];
1032 let has_offsets = default_build
1033 .nir
1034 .reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_INVOCATION_ID);
1035
1036 let mut optimized = (!Platform::dbg().no_variants && (has_offsets || has_wgs_hint))
1037 .then(|| default_build.clone());
1038
1039 compile_nir_variant(
1040 &mut default_build,
1041 dev,
1042 NirKernelVariant::Default,
1043 args,
1044 name,
1045 );
1046 if let Some(optimized) = &mut optimized {
1047 compile_nir_variant(optimized, dev, NirKernelVariant::Optimized, args, name);
1048 }
1049
1050 (default_build, optimized)
1051 }
1052
1053 pub struct SPIRVToNirResult {
1054 pub kernel_info: KernelInfo,
1055 pub nir_kernel_builds: NirKernelBuilds,
1056 }
1057
1058 impl SPIRVToNirResult {
new( dev: &'static Device, kernel_info: &clc_kernel_info, args: Vec<KernelArg>, default_build: CompilationResult, optimized: Option<CompilationResult>, ) -> Self1059 fn new(
1060 dev: &'static Device,
1061 kernel_info: &clc_kernel_info,
1062 args: Vec<KernelArg>,
1063 default_build: CompilationResult,
1064 optimized: Option<CompilationResult>,
1065 ) -> Self {
1066 // TODO: we _should_ be able to parse them out of the SPIR-V, but clc doesn't handle
1067 // indirections yet.
1068 let nir = &default_build.nir;
1069 let wgs = nir.workgroup_size();
1070 let subgroup_size = nir.subgroup_size();
1071 let num_subgroups = nir.num_subgroups();
1072
1073 let default_build = NirKernelBuild::new(dev, default_build);
1074 let optimized = optimized.map(|opt| NirKernelBuild::new(dev, opt));
1075
1076 let kernel_info = KernelInfo {
1077 args: args,
1078 attributes_string: kernel_info.attribute_str(),
1079 work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
1080 work_group_size_hint: kernel_info.local_size_hint,
1081 subgroup_size: subgroup_size as usize,
1082 num_subgroups: num_subgroups as usize,
1083 };
1084
1085 Self {
1086 kernel_info: kernel_info,
1087 nir_kernel_builds: NirKernelBuilds::new(default_build, optimized),
1088 }
1089 }
1090
deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self>1091 fn deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self> {
1092 let mut reader = blob_reader::default();
1093 unsafe {
1094 blob_reader_init(&mut reader, bin.as_ptr().cast(), bin.len());
1095 }
1096
1097 let args = KernelArg::deserialize(&mut reader)?;
1098 let default_build = CompilationResult::deserialize(&mut reader, d)?;
1099
1100 // SAFETY: on overrun this returns 0
1101 let optimized = match unsafe { blob_read_uint8(&mut reader) } {
1102 0 => None,
1103 _ => Some(CompilationResult::deserialize(&mut reader, d)?),
1104 };
1105
1106 reader
1107 .overrun
1108 .not()
1109 .then(|| SPIRVToNirResult::new(d, kernel_info, args, default_build, optimized))
1110 }
1111
1112 // we can't use Self here as the nir shader might be compiled to a cso already and we can't
1113 // cache that.
serialize( blob: &mut blob, args: &[KernelArg], default_build: &CompilationResult, optimized: &Option<CompilationResult>, )1114 fn serialize(
1115 blob: &mut blob,
1116 args: &[KernelArg],
1117 default_build: &CompilationResult,
1118 optimized: &Option<CompilationResult>,
1119 ) {
1120 KernelArg::serialize(args, blob);
1121 default_build.serialize(blob);
1122 match optimized {
1123 Some(variant) => {
1124 unsafe { blob_write_uint8(blob, 1) };
1125 variant.serialize(blob);
1126 }
1127 None => unsafe {
1128 blob_write_uint8(blob, 0);
1129 },
1130 }
1131 }
1132 }
1133
convert_spirv_to_nir( build: &ProgramBuild, name: &str, args: &[spirv::SPIRVKernelArg], dev: &'static Device, ) -> SPIRVToNirResult1134 pub(super) fn convert_spirv_to_nir(
1135 build: &ProgramBuild,
1136 name: &str,
1137 args: &[spirv::SPIRVKernelArg],
1138 dev: &'static Device,
1139 ) -> SPIRVToNirResult {
1140 let cache = dev.screen().shader_cache();
1141 let key = build.hash_key(dev, name);
1142 let spirv_info = build.spirv_info(name, dev).unwrap();
1143
1144 cache
1145 .as_ref()
1146 .and_then(|cache| cache.get(&mut key?))
1147 .and_then(|entry| SPIRVToNirResult::deserialize(&entry, dev, spirv_info))
1148 .unwrap_or_else(|| {
1149 let nir = build.to_nir(name, dev);
1150
1151 if Platform::dbg().nir {
1152 eprintln!("=== Printing nir for '{name}' after spirv_to_nir");
1153 nir.print();
1154 }
1155
1156 let (mut args, nir) = compile_nir_to_args(dev, nir, args, &dev.lib_clc);
1157 let (default_build, optimized) = compile_nir_remaining(dev, nir, &args, name);
1158
1159 for build in [Some(&default_build), optimized.as_ref()].into_iter() {
1160 let Some(build) = build else {
1161 continue;
1162 };
1163
1164 for arg in &build.compiled_args {
1165 if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1166 args[idx as usize].dead &= arg.dead;
1167 }
1168 }
1169 }
1170
1171 if let Some(cache) = cache {
1172 let mut blob = blob::default();
1173 unsafe {
1174 blob_init(&mut blob);
1175 SPIRVToNirResult::serialize(&mut blob, &args, &default_build, &optimized);
1176 let bin = slice::from_raw_parts(blob.data, blob.size);
1177 cache.put(bin, &mut key.unwrap());
1178 blob_finish(&mut blob);
1179 }
1180 }
1181
1182 SPIRVToNirResult::new(dev, spirv_info, args, default_build, optimized)
1183 })
1184 }
1185
extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S]1186 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
1187 let val;
1188 (val, *buf) = (*buf).split_at(S);
1189 // we split of 4 bytes and convert to [u8; 4], so this should be safe
1190 // use split_array_ref once it's stable
1191 val.try_into().unwrap()
1192 }
1193
1194 impl Kernel {
new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel>1195 pub fn new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel> {
1196 let kernel_info = Arc::clone(prog_build.kernel_info.get(&name).unwrap());
1197 let builds = prog_build
1198 .builds
1199 .iter()
1200 .filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, k.clone())))
1201 .collect();
1202
1203 let values = vec![None; kernel_info.args.len()];
1204 Arc::new(Self {
1205 base: CLObjectBase::new(RusticlTypes::Kernel),
1206 prog: prog,
1207 name: name,
1208 values: Mutex::new(values),
1209 builds: builds,
1210 kernel_info: kernel_info,
1211 })
1212 }
1213
suggest_local_size( &self, d: &Device, work_dim: usize, grid: &mut [usize], block: &mut [usize], )1214 pub fn suggest_local_size(
1215 &self,
1216 d: &Device,
1217 work_dim: usize,
1218 grid: &mut [usize],
1219 block: &mut [usize],
1220 ) {
1221 let mut threads = self.max_threads_per_block(d);
1222 let dim_threads = d.max_block_sizes();
1223 let subgroups = self.preferred_simd_size(d);
1224
1225 for i in 0..work_dim {
1226 let t = cmp::min(threads, dim_threads[i]);
1227 let gcd = gcd(t, grid[i]);
1228
1229 block[i] = gcd;
1230 grid[i] /= gcd;
1231
1232 // update limits
1233 threads /= block[i];
1234 }
1235
1236 // if we didn't fill the subgroup we can do a bit better if we have threads remaining
1237 let total_threads = block.iter().take(work_dim).product::<usize>();
1238 if threads != 1 && total_threads < subgroups {
1239 for i in 0..work_dim {
1240 if grid[i] * total_threads < threads && grid[i] * block[i] <= dim_threads[i] {
1241 block[i] *= grid[i];
1242 grid[i] = 1;
1243 // can only do it once as nothing is cleanly divisible
1244 break;
1245 }
1246 }
1247 }
1248 }
1249
optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3])1250 fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) {
1251 if !block.contains(&0) {
1252 for i in 0..3 {
1253 // we already made sure everything is fine
1254 grid[i] /= block[i] as usize;
1255 }
1256 return;
1257 }
1258
1259 let mut usize_block = [0usize; 3];
1260 for i in 0..3 {
1261 usize_block[i] = block[i] as usize;
1262 }
1263
1264 self.suggest_local_size(d, 3, grid, &mut usize_block);
1265
1266 for i in 0..3 {
1267 block[i] = usize_block[i] as u32;
1268 }
1269 }
1270
1271 // the painful part is, that host threads are allowed to modify the kernel object once it was
1272 // enqueued, so return a closure with all req data included.
launch( self: &Arc<Self>, q: &Arc<Queue>, work_dim: u32, block: &[usize], grid: &[usize], offsets: &[usize], ) -> CLResult<EventSig>1273 pub fn launch(
1274 self: &Arc<Self>,
1275 q: &Arc<Queue>,
1276 work_dim: u32,
1277 block: &[usize],
1278 grid: &[usize],
1279 offsets: &[usize],
1280 ) -> CLResult<EventSig> {
1281 // Clone all the data we need to execute this kernel
1282 let kernel_info = Arc::clone(&self.kernel_info);
1283 let arg_values = self.arg_values().clone();
1284 let nir_kernel_builds = Arc::clone(&self.builds[q.device]);
1285
1286 let mut buffer_arcs = HashMap::new();
1287 let mut image_arcs = HashMap::new();
1288
1289 // need to preprocess buffer and image arguments so we hold a strong reference until the
1290 // event was processed.
1291 for arg in arg_values.iter() {
1292 match arg {
1293 Some(KernelArgValue::Buffer(buffer)) => {
1294 buffer_arcs.insert(
1295 // we use the ptr as the key, and also cast it to usize so we don't need to
1296 // deal with Send + Sync here.
1297 buffer.as_ptr() as usize,
1298 buffer.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
1299 );
1300 }
1301 Some(KernelArgValue::Image(image)) => {
1302 image_arcs.insert(
1303 image.as_ptr() as usize,
1304 image.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
1305 );
1306 }
1307 _ => {}
1308 }
1309 }
1310
1311 // operations we want to report errors to the clients
1312 let mut block = create_kernel_arr::<u32>(block, 1)?;
1313 let mut grid = create_kernel_arr::<usize>(grid, 1)?;
1314 let offsets = create_kernel_arr::<usize>(offsets, 0)?;
1315
1316 let api_grid = grid;
1317
1318 self.optimize_local_size(q.device, &mut grid, &mut block);
1319
1320 Ok(Box::new(move |q, ctx| {
1321 let hw_max_grid: Vec<usize> = q
1322 .device
1323 .max_grid_size()
1324 .into_iter()
1325 .map(|val| val.try_into().unwrap_or(usize::MAX))
1326 // clamped as pipe_launch_grid::grid is only u32
1327 .map(|val| cmp::min(val, u32::MAX as usize))
1328 .collect();
1329
1330 let variant = if offsets == [0; 3]
1331 && grid[0] <= hw_max_grid[0]
1332 && grid[1] <= hw_max_grid[1]
1333 && grid[2] <= hw_max_grid[2]
1334 && (kernel_info.work_group_size_hint == [0; 3]
1335 || block == kernel_info.work_group_size_hint)
1336 {
1337 NirKernelVariant::Optimized
1338 } else {
1339 NirKernelVariant::Default
1340 };
1341
1342 let nir_kernel_build = &nir_kernel_builds[variant];
1343 let mut workgroup_id_offset_loc = None;
1344 let mut input = Vec::new();
1345 // Set it once so we get the alignment padding right
1346 let static_local_size: u64 = nir_kernel_build.shared_size;
1347 let mut variable_local_size: u64 = static_local_size;
1348 let printf_size = q.device.printf_buffer_size() as u32;
1349 let mut samplers = Vec::new();
1350 let mut iviews = Vec::new();
1351 let mut sviews = Vec::new();
1352 let mut tex_formats: Vec<u16> = Vec::new();
1353 let mut tex_orders: Vec<u16> = Vec::new();
1354 let mut img_formats: Vec<u16> = Vec::new();
1355 let mut img_orders: Vec<u16> = Vec::new();
1356
1357 let null_ptr;
1358 let null_ptr_v3;
1359 if q.device.address_bits() == 64 {
1360 null_ptr = [0u8; 8].as_slice();
1361 null_ptr_v3 = [0u8; 24].as_slice();
1362 } else {
1363 null_ptr = [0u8; 4].as_slice();
1364 null_ptr_v3 = [0u8; 12].as_slice();
1365 };
1366
1367 let mut resource_info = Vec::new();
1368 fn add_global<'a>(
1369 q: &Queue,
1370 input: &mut Vec<u8>,
1371 resource_info: &mut Vec<(&'a PipeResource, usize)>,
1372 res: &'a PipeResource,
1373 offset: usize,
1374 ) {
1375 resource_info.push((res, input.len()));
1376 if q.device.address_bits() == 64 {
1377 let offset: u64 = offset as u64;
1378 input.extend_from_slice(&offset.to_ne_bytes());
1379 } else {
1380 let offset: u32 = offset as u32;
1381 input.extend_from_slice(&offset.to_ne_bytes());
1382 }
1383 }
1384
1385 fn add_sysval(q: &Queue, input: &mut Vec<u8>, vals: &[usize; 3]) {
1386 if q.device.address_bits() == 64 {
1387 input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u64)) });
1388 } else {
1389 input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u32)) });
1390 }
1391 }
1392
1393 let mut printf_buf = None;
1394 if nir_kernel_build.printf_info.is_some() {
1395 let buf = q
1396 .device
1397 .screen
1398 .resource_create_buffer(printf_size, ResourceType::Staging, PIPE_BIND_GLOBAL)
1399 .unwrap();
1400
1401 let init_data: [u8; 1] = [4];
1402 ctx.buffer_subdata(&buf, 0, init_data.as_ptr().cast(), init_data.len() as u32);
1403
1404 printf_buf = Some(buf);
1405 }
1406
1407 for arg in &nir_kernel_build.compiled_args {
1408 let is_opaque = if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1409 kernel_info.args[idx as usize].kind.is_opaque()
1410 } else {
1411 false
1412 };
1413
1414 if !is_opaque && arg.offset as usize > input.len() {
1415 input.resize(arg.offset as usize, 0);
1416 }
1417
1418 match arg.kind {
1419 CompiledKernelArgType::APIArg(idx) => {
1420 let api_arg = &kernel_info.args[idx as usize];
1421 if api_arg.dead {
1422 continue;
1423 }
1424
1425 let Some(value) = &arg_values[idx as usize] else {
1426 continue;
1427 };
1428
1429 match value {
1430 KernelArgValue::Constant(c) => input.extend_from_slice(c),
1431 KernelArgValue::Buffer(buffer) => {
1432 let buffer = &buffer_arcs[&(buffer.as_ptr() as usize)];
1433 let rw = if api_arg.spirv.address_qualifier
1434 == clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT
1435 {
1436 RWFlags::RD
1437 } else {
1438 RWFlags::RW
1439 };
1440
1441 let res = buffer.get_res_for_access(ctx, rw)?;
1442 add_global(q, &mut input, &mut resource_info, res, buffer.offset());
1443 }
1444 KernelArgValue::Image(image) => {
1445 let image = &image_arcs[&(image.as_ptr() as usize)];
1446 let (formats, orders) = if api_arg.kind == KernelArgType::Image {
1447 iviews.push(image.image_view(ctx, false)?);
1448 (&mut img_formats, &mut img_orders)
1449 } else if api_arg.kind == KernelArgType::RWImage {
1450 iviews.push(image.image_view(ctx, true)?);
1451 (&mut img_formats, &mut img_orders)
1452 } else {
1453 sviews.push(image.sampler_view(ctx)?);
1454 (&mut tex_formats, &mut tex_orders)
1455 };
1456
1457 let binding = arg.offset as usize;
1458 assert!(binding >= formats.len());
1459
1460 formats.resize(binding, 0);
1461 orders.resize(binding, 0);
1462
1463 formats.push(image.image_format.image_channel_data_type as u16);
1464 orders.push(image.image_format.image_channel_order as u16);
1465 }
1466 KernelArgValue::LocalMem(size) => {
1467 // TODO 32 bit
1468 let pot = cmp::min(*size, 0x80);
1469 variable_local_size = variable_local_size
1470 .next_multiple_of(pot.next_power_of_two() as u64);
1471 if q.device.address_bits() == 64 {
1472 let variable_local_size: [u8; 8] =
1473 variable_local_size.to_ne_bytes();
1474 input.extend_from_slice(&variable_local_size);
1475 } else {
1476 let variable_local_size: [u8; 4] =
1477 (variable_local_size as u32).to_ne_bytes();
1478 input.extend_from_slice(&variable_local_size);
1479 }
1480 variable_local_size += *size as u64;
1481 }
1482 KernelArgValue::Sampler(sampler) => {
1483 samplers.push(sampler.pipe());
1484 }
1485 KernelArgValue::None => {
1486 assert!(
1487 api_arg.kind == KernelArgType::MemGlobal
1488 || api_arg.kind == KernelArgType::MemConstant
1489 );
1490 input.extend_from_slice(null_ptr);
1491 }
1492 }
1493 }
1494 CompiledKernelArgType::ConstantBuffer => {
1495 assert!(nir_kernel_build.constant_buffer.is_some());
1496 let res = nir_kernel_build.constant_buffer.as_ref().unwrap();
1497 add_global(q, &mut input, &mut resource_info, res, 0);
1498 }
1499 CompiledKernelArgType::GlobalWorkOffsets => {
1500 add_sysval(q, &mut input, &offsets);
1501 }
1502 CompiledKernelArgType::WorkGroupOffsets => {
1503 workgroup_id_offset_loc = Some(input.len());
1504 input.extend_from_slice(null_ptr_v3);
1505 }
1506 CompiledKernelArgType::GlobalWorkSize => {
1507 add_sysval(q, &mut input, &api_grid);
1508 }
1509 CompiledKernelArgType::PrintfBuffer => {
1510 let res = printf_buf.as_ref().unwrap();
1511 add_global(q, &mut input, &mut resource_info, res, 0);
1512 }
1513 CompiledKernelArgType::InlineSampler(cl) => {
1514 samplers.push(Sampler::cl_to_pipe(cl));
1515 }
1516 CompiledKernelArgType::FormatArray => {
1517 input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) });
1518 input.extend_from_slice(unsafe { as_byte_slice(&img_formats) });
1519 }
1520 CompiledKernelArgType::OrderArray => {
1521 input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) });
1522 input.extend_from_slice(unsafe { as_byte_slice(&img_orders) });
1523 }
1524 CompiledKernelArgType::WorkDim => {
1525 input.extend_from_slice(&[work_dim as u8; 1]);
1526 }
1527 CompiledKernelArgType::NumWorkgroups => {
1528 input.extend_from_slice(unsafe {
1529 as_byte_slice(&[grid[0] as u32, grid[1] as u32, grid[2] as u32])
1530 });
1531 }
1532 }
1533 }
1534
1535 // subtract the shader local_size as we only request something on top of that.
1536 variable_local_size -= static_local_size;
1537
1538 let samplers: Vec<_> = samplers
1539 .iter()
1540 .map(|s| ctx.create_sampler_state(s))
1541 .collect();
1542
1543 let mut resources = Vec::with_capacity(resource_info.len());
1544 let mut globals: Vec<*mut u32> = Vec::with_capacity(resource_info.len());
1545 for (res, offset) in resource_info {
1546 resources.push(res);
1547 globals.push(unsafe { input.as_mut_ptr().byte_add(offset) }.cast());
1548 }
1549
1550 let temp_cso;
1551 let cso = match &nir_kernel_build.nir_or_cso {
1552 KernelDevStateVariant::Cso(cso) => cso,
1553 KernelDevStateVariant::Nir(nir) => {
1554 temp_cso = CSOWrapper::new(q.device, nir);
1555 &temp_cso
1556 }
1557 };
1558
1559 let sviews_len = sviews.len();
1560 ctx.bind_compute_state(cso.cso_ptr);
1561 ctx.bind_sampler_states(&samplers);
1562 ctx.set_sampler_views(sviews);
1563 ctx.set_shader_images(&iviews);
1564 ctx.set_global_binding(resources.as_slice(), &mut globals);
1565
1566 for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
1567 for y in 0..grid[1].div_ceil(hw_max_grid[1]) {
1568 for x in 0..grid[0].div_ceil(hw_max_grid[0]) {
1569 if let Some(workgroup_id_offset_loc) = workgroup_id_offset_loc {
1570 let this_offsets =
1571 [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]];
1572
1573 if q.device.address_bits() == 64 {
1574 let val = this_offsets.map(|v| v as u64);
1575 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24]
1576 .copy_from_slice(unsafe { as_byte_slice(&val) });
1577 } else {
1578 let val = this_offsets.map(|v| v as u32);
1579 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12]
1580 .copy_from_slice(unsafe { as_byte_slice(&val) });
1581 }
1582 }
1583
1584 let this_grid = [
1585 cmp::min(hw_max_grid[0], grid[0] - hw_max_grid[0] * x) as u32,
1586 cmp::min(hw_max_grid[1], grid[1] - hw_max_grid[1] * y) as u32,
1587 cmp::min(hw_max_grid[2], grid[2] - hw_max_grid[2] * z) as u32,
1588 ];
1589
1590 ctx.update_cb0(&input)?;
1591 ctx.launch_grid(work_dim, block, this_grid, variable_local_size as u32);
1592
1593 if Platform::dbg().sync_every_event {
1594 ctx.flush().wait();
1595 }
1596 }
1597 }
1598 }
1599
1600 ctx.clear_global_binding(globals.len() as u32);
1601 ctx.clear_shader_images(iviews.len() as u32);
1602 ctx.clear_sampler_views(sviews_len as u32);
1603 ctx.clear_sampler_states(samplers.len() as u32);
1604
1605 ctx.bind_compute_state(ptr::null_mut());
1606
1607 ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
1608
1609 samplers.iter().for_each(|s| ctx.delete_sampler_state(*s));
1610
1611 if let Some(printf_buf) = &printf_buf {
1612 let tx = ctx
1613 .buffer_map(printf_buf, 0, printf_size as i32, RWFlags::RD)
1614 .ok_or(CL_OUT_OF_RESOURCES)?;
1615 let mut buf: &[u8] =
1616 unsafe { slice::from_raw_parts(tx.ptr().cast(), printf_size as usize) };
1617 let length = u32::from_ne_bytes(*extract(&mut buf));
1618
1619 // update our slice to make sure we don't go out of bounds
1620 buf = &buf[0..(length - 4) as usize];
1621 if let Some(pf) = &nir_kernel_build.printf_info {
1622 pf.u_printf(buf)
1623 }
1624 }
1625
1626 Ok(())
1627 }))
1628 }
1629
arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>>1630 pub fn arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>> {
1631 self.values.lock().unwrap()
1632 }
1633
set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()>1634 pub fn set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()> {
1635 self.values
1636 .lock()
1637 .unwrap()
1638 .get_mut(idx)
1639 .ok_or(CL_INVALID_ARG_INDEX)?
1640 .replace(arg);
1641 Ok(())
1642 }
1643
access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier1644 pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
1645 let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier;
1646
1647 if aq
1648 == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
1649 | clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE
1650 {
1651 CL_KERNEL_ARG_ACCESS_READ_WRITE
1652 } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ {
1653 CL_KERNEL_ARG_ACCESS_READ_ONLY
1654 } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE {
1655 CL_KERNEL_ARG_ACCESS_WRITE_ONLY
1656 } else {
1657 CL_KERNEL_ARG_ACCESS_NONE
1658 }
1659 }
1660
address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier1661 pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
1662 match self.kernel_info.args[idx as usize].spirv.address_qualifier {
1663 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
1664 CL_KERNEL_ARG_ADDRESS_PRIVATE
1665 }
1666 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
1667 CL_KERNEL_ARG_ADDRESS_CONSTANT
1668 }
1669 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
1670 CL_KERNEL_ARG_ADDRESS_LOCAL
1671 }
1672 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
1673 CL_KERNEL_ARG_ADDRESS_GLOBAL
1674 }
1675 }
1676 }
1677
type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier1678 pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
1679 let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier;
1680 let zero = clc_kernel_arg_type_qualifier(0);
1681 let mut res = CL_KERNEL_ARG_TYPE_NONE;
1682
1683 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_CONST != zero {
1684 res |= CL_KERNEL_ARG_TYPE_CONST;
1685 }
1686
1687 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_RESTRICT != zero {
1688 res |= CL_KERNEL_ARG_TYPE_RESTRICT;
1689 }
1690
1691 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_VOLATILE != zero {
1692 res |= CL_KERNEL_ARG_TYPE_VOLATILE;
1693 }
1694
1695 res.into()
1696 }
1697
work_group_size(&self) -> [usize; 3]1698 pub fn work_group_size(&self) -> [usize; 3] {
1699 self.kernel_info.work_group_size
1700 }
1701
num_subgroups(&self) -> usize1702 pub fn num_subgroups(&self) -> usize {
1703 self.kernel_info.num_subgroups
1704 }
1705
subgroup_size(&self) -> usize1706 pub fn subgroup_size(&self) -> usize {
1707 self.kernel_info.subgroup_size
1708 }
1709
arg_name(&self, idx: cl_uint) -> Option<&CStr>1710 pub fn arg_name(&self, idx: cl_uint) -> Option<&CStr> {
1711 let name = &self.kernel_info.args[idx as usize].spirv.name;
1712 name.is_empty().not().then_some(name)
1713 }
1714
arg_type_name(&self, idx: cl_uint) -> Option<&CStr>1715 pub fn arg_type_name(&self, idx: cl_uint) -> Option<&CStr> {
1716 let type_name = &self.kernel_info.args[idx as usize].spirv.type_name;
1717 type_name.is_empty().not().then_some(type_name)
1718 }
1719
priv_mem_size(&self, dev: &Device) -> cl_ulong1720 pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
1721 self.builds.get(dev).unwrap().info.private_memory as cl_ulong
1722 }
1723
max_threads_per_block(&self, dev: &Device) -> usize1724 pub fn max_threads_per_block(&self, dev: &Device) -> usize {
1725 self.builds.get(dev).unwrap().info.max_threads as usize
1726 }
1727
preferred_simd_size(&self, dev: &Device) -> usize1728 pub fn preferred_simd_size(&self, dev: &Device) -> usize {
1729 self.builds.get(dev).unwrap().info.preferred_simd_size as usize
1730 }
1731
local_mem_size(&self, dev: &Device) -> cl_ulong1732 pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
1733 // TODO: take alignment into account?
1734 // this is purely informational so it shouldn't even matter
1735 let local =
1736 self.builds.get(dev).unwrap()[NirKernelVariant::Default].shared_size as cl_ulong;
1737 let args: cl_ulong = self
1738 .arg_values()
1739 .iter()
1740 .map(|arg| match arg {
1741 Some(KernelArgValue::LocalMem(val)) => *val as cl_ulong,
1742 // If the local memory size, for any pointer argument to the kernel declared with
1743 // the __local address qualifier, is not specified, its size is assumed to be 0.
1744 _ => 0,
1745 })
1746 .sum();
1747
1748 local + args
1749 }
1750
has_svm_devs(&self) -> bool1751 pub fn has_svm_devs(&self) -> bool {
1752 self.prog.devs.iter().any(|dev| dev.svm_supported())
1753 }
1754
subgroup_sizes(&self, dev: &Device) -> Vec<usize>1755 pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
1756 SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
1757 .map(|bit| 1 << bit)
1758 .collect()
1759 }
1760
subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize1761 pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1762 let subgroup_size = self.subgroup_size_for_block(dev, block);
1763 if subgroup_size == 0 {
1764 return 0;
1765 }
1766
1767 let threads: usize = block.iter().product();
1768 threads.div_ceil(subgroup_size)
1769 }
1770
subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize1771 pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1772 let subgroup_sizes = self.subgroup_sizes(dev);
1773 if subgroup_sizes.is_empty() {
1774 return 0;
1775 }
1776
1777 if subgroup_sizes.len() == 1 {
1778 return subgroup_sizes[0];
1779 }
1780
1781 let block = [
1782 *block.first().unwrap_or(&1) as u32,
1783 *block.get(1).unwrap_or(&1) as u32,
1784 *block.get(2).unwrap_or(&1) as u32,
1785 ];
1786
1787 // TODO: this _might_ bite us somewhere, but I think it probably doesn't matter
1788 match &self.builds.get(dev).unwrap()[NirKernelVariant::Default].nir_or_cso {
1789 KernelDevStateVariant::Cso(cso) => {
1790 dev.helper_ctx()
1791 .compute_state_subgroup_size(cso.cso_ptr, &block) as usize
1792 }
1793 _ => {
1794 panic!()
1795 }
1796 }
1797 }
1798 }
1799
1800 impl Clone for Kernel {
clone(&self) -> Self1801 fn clone(&self) -> Self {
1802 Self {
1803 base: CLObjectBase::new(RusticlTypes::Kernel),
1804 prog: self.prog.clone(),
1805 name: self.name.clone(),
1806 values: Mutex::new(self.arg_values().clone()),
1807 builds: self.builds.clone(),
1808 kernel_info: self.kernel_info.clone(),
1809 }
1810 }
1811 }
1812