• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::api::icd::*;
2 use crate::core::device::*;
3 use crate::core::event::*;
4 use crate::core::memory::*;
5 use crate::core::platform::*;
6 use crate::core::program::*;
7 use crate::core::queue::*;
8 use crate::impl_cl_type_trait;
9 
10 use mesa_rust::compiler::clc::*;
11 use mesa_rust::compiler::nir::*;
12 use mesa_rust::nir_pass;
13 use mesa_rust::pipe::context::RWFlags;
14 use mesa_rust::pipe::resource::*;
15 use mesa_rust::pipe::screen::ResourceType;
16 use mesa_rust_gen::*;
17 use mesa_rust_util::math::*;
18 use mesa_rust_util::serialize::*;
19 use rusticl_opencl_gen::*;
20 use spirv::SpirvKernelInfo;
21 
22 use std::cmp;
23 use std::collections::HashMap;
24 use std::convert::TryInto;
25 use std::ffi::CStr;
26 use std::fmt::Debug;
27 use std::fmt::Display;
28 use std::ops::Index;
29 use std::ops::Not;
30 use std::os::raw::c_void;
31 use std::ptr;
32 use std::slice;
33 use std::sync::Arc;
34 use std::sync::Mutex;
35 use std::sync::MutexGuard;
36 use std::sync::Weak;
37 
38 // According to the CL spec we are not allowed to let any cl_kernel object hold any references on
39 // its arguments as this might make it unfeasible for applications to free the backing memory of
40 // memory objects allocated with `CL_USE_HOST_PTR`.
41 //
42 // However those arguments might temporarily get referenced by event objects, so we'll use Weak in
43 // order to upgrade the reference when needed. It's also safer to use Weak over raw pointers,
44 // because it makes it impossible to run into use-after-free issues.
45 //
46 // Technically we also need to do it for samplers, but there it's kinda pointless to take a weak
47 // reference as samplers don't have the same host_ptr or any similar problems as cl_mem objects.
48 #[derive(Clone)]
49 pub enum KernelArgValue {
50     None,
51     Buffer(Weak<Buffer>),
52     Constant(Vec<u8>),
53     Image(Weak<Image>),
54     LocalMem(usize),
55     Sampler(Arc<Sampler>),
56 }
57 
58 #[repr(u8)]
59 #[derive(Hash, PartialEq, Eq, Clone, Copy)]
60 pub enum KernelArgType {
61     Constant(/* size */ u16), // for anything passed by value
62     Image,
63     RWImage,
64     Sampler,
65     Texture,
66     MemGlobal,
67     MemConstant,
68     MemLocal,
69 }
70 
71 impl KernelArgType {
deserialize(blob: &mut blob_reader) -> Option<Self>72     fn deserialize(blob: &mut blob_reader) -> Option<Self> {
73         // SAFETY: we get 0 on an overrun, but we verify that later and act accordingly.
74         let res = match unsafe { blob_read_uint8(blob) } {
75             0 => {
76                 // SAFETY: same here
77                 let size = unsafe { blob_read_uint16(blob) };
78                 KernelArgType::Constant(size)
79             }
80             1 => KernelArgType::Image,
81             2 => KernelArgType::RWImage,
82             3 => KernelArgType::Sampler,
83             4 => KernelArgType::Texture,
84             5 => KernelArgType::MemGlobal,
85             6 => KernelArgType::MemConstant,
86             7 => KernelArgType::MemLocal,
87             _ => return None,
88         };
89 
90         blob.overrun.not().then_some(res)
91     }
92 
serialize(&self, blob: &mut blob)93     fn serialize(&self, blob: &mut blob) {
94         unsafe {
95             match self {
96                 KernelArgType::Constant(size) => {
97                     blob_write_uint8(blob, 0);
98                     blob_write_uint16(blob, *size)
99                 }
100                 KernelArgType::Image => blob_write_uint8(blob, 1),
101                 KernelArgType::RWImage => blob_write_uint8(blob, 2),
102                 KernelArgType::Sampler => blob_write_uint8(blob, 3),
103                 KernelArgType::Texture => blob_write_uint8(blob, 4),
104                 KernelArgType::MemGlobal => blob_write_uint8(blob, 5),
105                 KernelArgType::MemConstant => blob_write_uint8(blob, 6),
106                 KernelArgType::MemLocal => blob_write_uint8(blob, 7),
107             };
108         }
109     }
110 
is_opaque(&self) -> bool111     fn is_opaque(&self) -> bool {
112         matches!(
113             self,
114             KernelArgType::Image
115                 | KernelArgType::RWImage
116                 | KernelArgType::Texture
117                 | KernelArgType::Sampler
118         )
119     }
120 }
121 
122 #[derive(Hash, PartialEq, Eq, Clone)]
123 enum CompiledKernelArgType {
124     APIArg(u32),
125     ConstantBuffer,
126     GlobalWorkOffsets,
127     GlobalWorkSize,
128     PrintfBuffer,
129     InlineSampler((cl_addressing_mode, cl_filter_mode, bool)),
130     FormatArray,
131     OrderArray,
132     WorkDim,
133     WorkGroupOffsets,
134     NumWorkgroups,
135 }
136 
137 #[derive(Hash, PartialEq, Eq, Clone)]
138 pub struct KernelArg {
139     spirv: spirv::SPIRVKernelArg,
140     pub kind: KernelArgType,
141     pub dead: bool,
142 }
143 
144 impl KernelArg {
from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self>145     fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
146         let nir_arg_map: HashMap<_, _> = nir
147             .variables_with_mode(
148                 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
149             )
150             .map(|v| (v.data.location, v))
151             .collect();
152         let mut res = Vec::new();
153 
154         for (i, s) in spirv.iter().enumerate() {
155             let nir = nir_arg_map.get(&(i as i32)).unwrap();
156             let kind = match s.address_qualifier {
157                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
158                     if unsafe { glsl_type_is_sampler(nir.type_) } {
159                         KernelArgType::Sampler
160                     } else {
161                         let size = unsafe { glsl_get_cl_size(nir.type_) } as u16;
162                         // nir types of non opaque types are never sized 0
163                         KernelArgType::Constant(size)
164                     }
165                 }
166                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
167                     KernelArgType::MemConstant
168                 }
169                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
170                     KernelArgType::MemLocal
171                 }
172                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
173                     if unsafe { glsl_type_is_image(nir.type_) } {
174                         let access = nir.data.access();
175                         if access == gl_access_qualifier::ACCESS_NON_WRITEABLE.0 {
176                             KernelArgType::Texture
177                         } else if access == gl_access_qualifier::ACCESS_NON_READABLE.0 {
178                             KernelArgType::Image
179                         } else {
180                             KernelArgType::RWImage
181                         }
182                     } else {
183                         KernelArgType::MemGlobal
184                     }
185                 }
186             };
187 
188             res.push(Self {
189                 spirv: s.clone(),
190                 // we'll update it later in the 2nd pass
191                 kind: kind,
192                 dead: true,
193             });
194         }
195         res
196     }
197 
serialize(args: &[Self], blob: &mut blob)198     fn serialize(args: &[Self], blob: &mut blob) {
199         unsafe {
200             blob_write_uint16(blob, args.len() as u16);
201 
202             for arg in args {
203                 arg.spirv.serialize(blob);
204                 blob_write_uint8(blob, arg.dead.into());
205                 arg.kind.serialize(blob);
206             }
207         }
208     }
209 
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>210     fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
211         // SAFETY: we check the overrun status, blob_read returns 0 in such a case.
212         let len = unsafe { blob_read_uint16(blob) } as usize;
213         let mut res = Vec::with_capacity(len);
214 
215         for _ in 0..len {
216             let spirv = spirv::SPIRVKernelArg::deserialize(blob)?;
217             // SAFETY: we check the overrun status
218             let dead = unsafe { blob_read_uint8(blob) } != 0;
219             let kind = KernelArgType::deserialize(blob)?;
220 
221             res.push(Self {
222                 spirv: spirv,
223                 kind: kind,
224                 dead: dead,
225             });
226         }
227 
228         blob.overrun.not().then_some(res)
229     }
230 }
231 
232 #[derive(Hash, PartialEq, Eq, Clone)]
233 struct CompiledKernelArg {
234     kind: CompiledKernelArgType,
235     /// The binding for image/sampler args, the offset into the input buffer
236     /// for anything else.
237     offset: u32,
238     dead: bool,
239 }
240 
241 impl CompiledKernelArg {
assign_locations(compiled_args: &mut [Self], nir: &mut NirShader)242     fn assign_locations(compiled_args: &mut [Self], nir: &mut NirShader) {
243         for var in nir.variables_with_mode(
244             nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
245         ) {
246             let arg = &mut compiled_args[var.data.location as usize];
247             let t = var.type_;
248 
249             arg.dead = false;
250             arg.offset = if unsafe {
251                 glsl_type_is_image(t) || glsl_type_is_texture(t) || glsl_type_is_sampler(t)
252             } {
253                 var.data.binding
254             } else {
255                 var.data.driver_location
256             };
257         }
258     }
259 
serialize(args: &[Self], blob: &mut blob)260     fn serialize(args: &[Self], blob: &mut blob) {
261         unsafe {
262             blob_write_uint16(blob, args.len() as u16);
263             for arg in args {
264                 blob_write_uint32(blob, arg.offset);
265                 blob_write_uint8(blob, arg.dead.into());
266                 match arg.kind {
267                     CompiledKernelArgType::ConstantBuffer => blob_write_uint8(blob, 0),
268                     CompiledKernelArgType::GlobalWorkOffsets => blob_write_uint8(blob, 1),
269                     CompiledKernelArgType::PrintfBuffer => blob_write_uint8(blob, 2),
270                     CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm)) => {
271                         blob_write_uint8(blob, 3);
272                         blob_write_uint8(blob, norm.into());
273                         blob_write_uint32(blob, addr_mode);
274                         blob_write_uint32(blob, filter_mode)
275                     }
276                     CompiledKernelArgType::FormatArray => blob_write_uint8(blob, 4),
277                     CompiledKernelArgType::OrderArray => blob_write_uint8(blob, 5),
278                     CompiledKernelArgType::WorkDim => blob_write_uint8(blob, 6),
279                     CompiledKernelArgType::WorkGroupOffsets => blob_write_uint8(blob, 7),
280                     CompiledKernelArgType::NumWorkgroups => blob_write_uint8(blob, 8),
281                     CompiledKernelArgType::GlobalWorkSize => blob_write_uint8(blob, 9),
282                     CompiledKernelArgType::APIArg(idx) => {
283                         blob_write_uint8(blob, 10);
284                         blob_write_uint32(blob, idx)
285                     }
286                 };
287             }
288         }
289     }
290 
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>291     fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
292         unsafe {
293             let len = blob_read_uint16(blob) as usize;
294             let mut res = Vec::with_capacity(len);
295 
296             for _ in 0..len {
297                 let offset = blob_read_uint32(blob);
298                 let dead = blob_read_uint8(blob) != 0;
299 
300                 let kind = match blob_read_uint8(blob) {
301                     0 => CompiledKernelArgType::ConstantBuffer,
302                     1 => CompiledKernelArgType::GlobalWorkOffsets,
303                     2 => CompiledKernelArgType::PrintfBuffer,
304                     3 => {
305                         let norm = blob_read_uint8(blob) != 0;
306                         let addr_mode = blob_read_uint32(blob);
307                         let filter_mode = blob_read_uint32(blob);
308                         CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm))
309                     }
310                     4 => CompiledKernelArgType::FormatArray,
311                     5 => CompiledKernelArgType::OrderArray,
312                     6 => CompiledKernelArgType::WorkDim,
313                     7 => CompiledKernelArgType::WorkGroupOffsets,
314                     8 => CompiledKernelArgType::NumWorkgroups,
315                     9 => CompiledKernelArgType::GlobalWorkSize,
316                     10 => {
317                         let idx = blob_read_uint32(blob);
318                         CompiledKernelArgType::APIArg(idx)
319                     }
320                     _ => return None,
321                 };
322 
323                 res.push(Self {
324                     kind: kind,
325                     offset: offset,
326                     dead: dead,
327                 });
328             }
329 
330             Some(res)
331         }
332     }
333 }
334 
335 #[derive(Clone, PartialEq, Eq, Hash)]
336 pub struct KernelInfo {
337     pub args: Vec<KernelArg>,
338     pub attributes_string: String,
339     work_group_size: [usize; 3],
340     work_group_size_hint: [u32; 3],
341     subgroup_size: usize,
342     num_subgroups: usize,
343 }
344 
345 struct CSOWrapper {
346     cso_ptr: *mut c_void,
347     dev: &'static Device,
348 }
349 
350 impl CSOWrapper {
new(dev: &'static Device, nir: &NirShader) -> Self351     fn new(dev: &'static Device, nir: &NirShader) -> Self {
352         let cso_ptr = dev
353             .helper_ctx()
354             .create_compute_state(nir, nir.shared_size());
355 
356         Self {
357             cso_ptr: cso_ptr,
358             dev: dev,
359         }
360     }
361 
get_cso_info(&self) -> pipe_compute_state_object_info362     fn get_cso_info(&self) -> pipe_compute_state_object_info {
363         self.dev.helper_ctx().compute_state_info(self.cso_ptr)
364     }
365 }
366 
367 impl Drop for CSOWrapper {
drop(&mut self)368     fn drop(&mut self) {
369         self.dev.helper_ctx().delete_compute_state(self.cso_ptr);
370     }
371 }
372 
373 enum KernelDevStateVariant {
374     Cso(CSOWrapper),
375     Nir(NirShader),
376 }
377 
378 #[derive(Debug, PartialEq)]
379 enum NirKernelVariant {
380     /// Can be used under any circumstance.
381     Default,
382 
383     /// Optimized variant making the following assumptions:
384     ///  - global_id_offsets are 0
385     ///  - workgroup_offsets are 0
386     ///  - local_size is info.local_size_hint
387     Optimized,
388 }
389 
390 impl Display for NirKernelVariant {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result391     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392         // this simply prints the enum name, so that's fine
393         Debug::fmt(self, f)
394     }
395 }
396 
397 pub struct NirKernelBuilds {
398     default_build: NirKernelBuild,
399     optimized: Option<NirKernelBuild>,
400     /// merged info with worst case values
401     info: pipe_compute_state_object_info,
402 }
403 
404 impl Index<NirKernelVariant> for NirKernelBuilds {
405     type Output = NirKernelBuild;
406 
index(&self, index: NirKernelVariant) -> &Self::Output407     fn index(&self, index: NirKernelVariant) -> &Self::Output {
408         match index {
409             NirKernelVariant::Default => &self.default_build,
410             NirKernelVariant::Optimized => self.optimized.as_ref().unwrap_or(&self.default_build),
411         }
412     }
413 }
414 
415 impl NirKernelBuilds {
new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self416     fn new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self {
417         let mut info = default_build.info;
418         if let Some(build) = &optimized {
419             info.max_threads = cmp::min(info.max_threads, build.info.max_threads);
420             info.simd_sizes &= build.info.simd_sizes;
421             info.private_memory = cmp::max(info.private_memory, build.info.private_memory);
422             info.preferred_simd_size =
423                 cmp::max(info.preferred_simd_size, build.info.preferred_simd_size);
424         }
425 
426         Self {
427             default_build: default_build,
428             optimized: optimized,
429             info: info,
430         }
431     }
432 }
433 
434 struct NirKernelBuild {
435     nir_or_cso: KernelDevStateVariant,
436     constant_buffer: Option<Arc<PipeResource>>,
437     info: pipe_compute_state_object_info,
438     shared_size: u64,
439     printf_info: Option<NirPrintfInfo>,
440     compiled_args: Vec<CompiledKernelArg>,
441 }
442 
443 // SAFETY: `CSOWrapper` is only safe to use if the device supports `pipe_caps.shareable_shaders` and
444 //         we make sure to set `nir_or_cso` to `KernelDevStateVariant::Cso` only if that's the case.
445 unsafe impl Send for NirKernelBuild {}
446 unsafe impl Sync for NirKernelBuild {}
447 
448 impl NirKernelBuild {
new(dev: &'static Device, mut out: CompilationResult) -> Self449     fn new(dev: &'static Device, mut out: CompilationResult) -> Self {
450         let cso = CSOWrapper::new(dev, &out.nir);
451         let info = cso.get_cso_info();
452         let cb = Self::create_nir_constant_buffer(dev, &out.nir);
453         let shared_size = out.nir.shared_size() as u64;
454         let printf_info = out.nir.take_printf_info();
455 
456         let nir_or_cso = if !dev.shareable_shaders() {
457             KernelDevStateVariant::Nir(out.nir)
458         } else {
459             KernelDevStateVariant::Cso(cso)
460         };
461 
462         NirKernelBuild {
463             nir_or_cso: nir_or_cso,
464             constant_buffer: cb,
465             info: info,
466             shared_size: shared_size,
467             printf_info: printf_info,
468             compiled_args: out.compiled_args,
469         }
470     }
471 
create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>>472     fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>> {
473         let buf = nir.get_constant_buffer();
474         let len = buf.len() as u32;
475 
476         if len > 0 {
477             // TODO bind as constant buffer
478             let res = dev
479                 .screen()
480                 .resource_create_buffer(len, ResourceType::Normal, PIPE_BIND_GLOBAL)
481                 .unwrap();
482 
483             dev.helper_ctx()
484                 .exec(|ctx| ctx.buffer_subdata(&res, 0, buf.as_ptr().cast(), len))
485                 .wait();
486 
487             Some(Arc::new(res))
488         } else {
489             None
490         }
491     }
492 }
493 
494 pub struct Kernel {
495     pub base: CLObjectBase<CL_INVALID_KERNEL>,
496     pub prog: Arc<Program>,
497     pub name: String,
498     values: Mutex<Vec<Option<KernelArgValue>>>,
499     builds: HashMap<&'static Device, Arc<NirKernelBuilds>>,
500     pub kernel_info: Arc<KernelInfo>,
501 }
502 
503 impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
504 
create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]> where T: std::convert::TryFrom<usize> + Copy, <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,505 fn create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]>
506 where
507     T: std::convert::TryFrom<usize> + Copy,
508     <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,
509 {
510     let mut res = [val; 3];
511     for (i, v) in vals.iter().enumerate() {
512         res[i] = (*v).try_into().ok().ok_or(CL_OUT_OF_RESOURCES)?;
513     }
514 
515     Ok(res)
516 }
517 
518 #[derive(Clone)]
519 struct CompilationResult {
520     nir: NirShader,
521     compiled_args: Vec<CompiledKernelArg>,
522 }
523 
524 impl CompilationResult {
deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self>525     fn deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self> {
526         let nir = NirShader::deserialize(
527             reader,
528             d.screen()
529                 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
530         )?;
531         let compiled_args = CompiledKernelArg::deserialize(reader)?;
532 
533         Some(Self {
534             nir: nir,
535             compiled_args,
536         })
537     }
538 
serialize(&self, blob: &mut blob)539     fn serialize(&self, blob: &mut blob) {
540         self.nir.serialize(blob);
541         CompiledKernelArg::serialize(&self.compiled_args, blob);
542     }
543 }
544 
opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool)545 fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
546     let nir_options = unsafe {
547         &*dev
548             .screen
549             .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
550     };
551 
552     while {
553         let mut progress = false;
554 
555         progress |= nir_pass!(nir, nir_copy_prop);
556         progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
557         progress |= nir_pass!(nir, nir_opt_dead_write_vars);
558 
559         if nir_options.lower_to_scalar {
560             nir_pass!(
561                 nir,
562                 nir_lower_alu_to_scalar,
563                 nir_options.lower_to_scalar_filter,
564                 ptr::null(),
565             );
566             nir_pass!(nir, nir_lower_phis_to_scalar, false);
567         }
568 
569         progress |= nir_pass!(nir, nir_opt_deref);
570         if has_explicit_types {
571             progress |= nir_pass!(nir, nir_opt_memcpy);
572         }
573         progress |= nir_pass!(nir, nir_opt_dce);
574         progress |= nir_pass!(nir, nir_opt_undef);
575         progress |= nir_pass!(nir, nir_opt_constant_folding);
576         progress |= nir_pass!(nir, nir_opt_cse);
577         nir_pass!(nir, nir_split_var_copies);
578         progress |= nir_pass!(nir, nir_lower_var_copies);
579         progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
580         nir_pass!(nir, nir_lower_alu);
581         progress |= nir_pass!(nir, nir_opt_phi_precision);
582         progress |= nir_pass!(nir, nir_opt_algebraic);
583         progress |= nir_pass!(
584             nir,
585             nir_opt_if,
586             nir_opt_if_options::nir_opt_if_optimize_phi_true_false,
587         );
588         progress |= nir_pass!(nir, nir_opt_dead_cf);
589         progress |= nir_pass!(nir, nir_opt_remove_phis);
590         // we don't want to be too aggressive here, but it kills a bit of CFG
591         progress |= nir_pass!(nir, nir_opt_peephole_select, 8, true, true);
592         progress |= nir_pass!(
593             nir,
594             nir_lower_vec3_to_vec4,
595             nir_variable_mode::nir_var_mem_generic | nir_variable_mode::nir_var_uniform,
596         );
597 
598         if nir_options.max_unroll_iterations != 0 {
599             progress |= nir_pass!(nir, nir_opt_loop_unroll);
600         }
601         nir.sweep_mem();
602         progress
603     } {}
604 }
605 
606 /// # Safety
607 ///
608 /// Only safe to call when `var` is a valid pointer to a valid [`nir_variable`]
can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool609 unsafe extern "C" fn can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool {
610     // SAFETY: It is the caller's responsibility to provide a valid and aligned pointer
611     let var_type = unsafe { (*var).type_ };
612     // SAFETY: `nir_variable`'s type invariant guarantees that the `type_` field is valid and
613     // properly aligned.
614     unsafe {
615         !glsl_type_is_image(var_type)
616             && !glsl_type_is_texture(var_type)
617             && !glsl_type_is_sampler(var_type)
618     }
619 }
620 
621 const DV_OPTS: nir_remove_dead_variables_options = nir_remove_dead_variables_options {
622     can_remove_var: Some(can_remove_var),
623     can_remove_var_data: ptr::null_mut(),
624 };
625 
compile_nir_to_args( dev: &Device, mut nir: NirShader, args: &[spirv::SPIRVKernelArg], lib_clc: &NirShader, ) -> (Vec<KernelArg>, NirShader)626 fn compile_nir_to_args(
627     dev: &Device,
628     mut nir: NirShader,
629     args: &[spirv::SPIRVKernelArg],
630     lib_clc: &NirShader,
631 ) -> (Vec<KernelArg>, NirShader) {
632     // this is a hack until we support fp16 properly and check for denorms inside vstore/vload_half
633     nir.preserve_fp16_denorms();
634 
635     // Set to rtne for now until drivers are able to report their preferred rounding mode, that also
636     // matches what we report via the API.
637     nir.set_fp_rounding_mode_rtne();
638 
639     nir_pass!(nir, nir_scale_fdiv);
640     nir.set_workgroup_size_variable_if_zero();
641     nir.structurize();
642     while {
643         let mut progress = false;
644         nir_pass!(nir, nir_split_var_copies);
645         progress |= nir_pass!(nir, nir_copy_prop);
646         progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
647         progress |= nir_pass!(nir, nir_opt_dead_write_vars);
648         progress |= nir_pass!(nir, nir_opt_deref);
649         progress |= nir_pass!(nir, nir_opt_dce);
650         progress |= nir_pass!(nir, nir_opt_undef);
651         progress |= nir_pass!(nir, nir_opt_constant_folding);
652         progress |= nir_pass!(nir, nir_opt_cse);
653         progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
654         progress |= nir_pass!(nir, nir_opt_algebraic);
655         progress
656     } {}
657     nir.inline(lib_clc);
658     nir.cleanup_functions();
659     // that should free up tons of memory
660     nir.sweep_mem();
661 
662     nir_pass!(nir, nir_dedup_inline_samplers);
663 
664     let printf_opts = nir_lower_printf_options {
665         ptr_bit_size: 0,
666         use_printf_base_identifier: false,
667         hash_format_strings: false,
668         max_buffer_size: dev.printf_buffer_size() as u32,
669     };
670     nir_pass!(nir, nir_lower_printf, &printf_opts);
671 
672     opt_nir(&mut nir, dev, false);
673 
674     (KernelArg::from_spirv_nir(args, &mut nir), nir)
675 }
676 
compile_nir_prepare_for_variants( dev: &Device, nir: &mut NirShader, compiled_args: &mut Vec<CompiledKernelArg>, )677 fn compile_nir_prepare_for_variants(
678     dev: &Device,
679     nir: &mut NirShader,
680     compiled_args: &mut Vec<CompiledKernelArg>,
681 ) {
682     // assign locations for inline samplers.
683     // IMPORTANT: this needs to happen before nir_remove_dead_variables.
684     let mut last_loc = -1;
685     for v in nir
686         .variables_with_mode(nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image)
687     {
688         if unsafe { !glsl_type_is_sampler(v.type_) } {
689             last_loc = v.data.location;
690             continue;
691         }
692         let s = unsafe { v.data.anon_1.sampler };
693         if s.is_inline_sampler() != 0 {
694             last_loc += 1;
695             v.data.location = last_loc;
696 
697             compiled_args.push(CompiledKernelArg {
698                 kind: CompiledKernelArgType::InlineSampler(Sampler::nir_to_cl(
699                     s.addressing_mode(),
700                     s.filter_mode(),
701                     s.normalized_coordinates(),
702                 )),
703                 offset: 0,
704                 dead: true,
705             });
706         } else {
707             last_loc = v.data.location;
708         }
709     }
710 
711     nir_pass!(
712         nir,
713         nir_remove_dead_variables,
714         nir_variable_mode::nir_var_uniform
715             | nir_variable_mode::nir_var_image
716             | nir_variable_mode::nir_var_mem_constant
717             | nir_variable_mode::nir_var_mem_shared
718             | nir_variable_mode::nir_var_function_temp,
719         &DV_OPTS,
720     );
721 
722     nir_pass!(nir, nir_lower_readonly_images_to_tex, true);
723     nir_pass!(
724         nir,
725         nir_lower_cl_images,
726         !dev.images_as_deref(),
727         !dev.samplers_as_deref(),
728     );
729 
730     nir_pass!(
731         nir,
732         nir_lower_vars_to_explicit_types,
733         nir_variable_mode::nir_var_mem_constant,
734         Some(glsl_get_cl_type_size_align),
735     );
736 
737     // has to run before adding internal kernel arguments
738     nir.extract_constant_initializers();
739 
740     // needed to convert variables to load intrinsics
741     nir_pass!(nir, nir_lower_system_values);
742 
743     // Run here so we can decide if it makes sense to compile a variant, e.g. read system values.
744     nir.gather_info();
745 }
746 
compile_nir_variant( res: &mut CompilationResult, dev: &Device, variant: NirKernelVariant, args: &[KernelArg], name: &str, )747 fn compile_nir_variant(
748     res: &mut CompilationResult,
749     dev: &Device,
750     variant: NirKernelVariant,
751     args: &[KernelArg],
752     name: &str,
753 ) {
754     let mut lower_state = rusticl_lower_state::default();
755     let compiled_args = &mut res.compiled_args;
756     let nir = &mut res.nir;
757 
758     let address_bits_ptr_type;
759     let address_bits_base_type;
760     let global_address_format;
761     let shared_address_format;
762 
763     if dev.address_bits() == 64 {
764         address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
765         address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
766         global_address_format = nir_address_format::nir_address_format_64bit_global;
767         shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
768     } else {
769         address_bits_ptr_type = unsafe { glsl_uint_type() };
770         address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
771         global_address_format = nir_address_format::nir_address_format_32bit_global;
772         shared_address_format = nir_address_format::nir_address_format_32bit_offset;
773     }
774 
775     let nir_options = unsafe {
776         &*dev
777             .screen
778             .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
779     };
780 
781     if variant == NirKernelVariant::Optimized {
782         let wgsh = nir.workgroup_size_hint();
783         if wgsh != [0; 3] {
784             nir.set_workgroup_size(wgsh);
785         }
786     }
787 
788     let mut compute_options = nir_lower_compute_system_values_options::default();
789     compute_options.set_has_global_size(true);
790     if variant != NirKernelVariant::Optimized {
791         compute_options.set_has_base_global_invocation_id(true);
792         compute_options.set_has_base_workgroup_id(true);
793     }
794     nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
795     nir.gather_info();
796 
797     let mut add_var = |nir: &mut NirShader,
798                        var_loc: &mut usize,
799                        kind: CompiledKernelArgType,
800                        glsl_type: *const glsl_type,
801                        name| {
802         *var_loc = compiled_args.len();
803         compiled_args.push(CompiledKernelArg {
804             kind: kind,
805             offset: 0,
806             dead: true,
807         });
808         nir.add_var(
809             nir_variable_mode::nir_var_uniform,
810             glsl_type,
811             *var_loc,
812             name,
813         );
814     };
815 
816     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
817         debug_assert_ne!(variant, NirKernelVariant::Optimized);
818         add_var(
819             nir,
820             &mut lower_state.base_global_invoc_id_loc,
821             CompiledKernelArgType::GlobalWorkOffsets,
822             unsafe { glsl_vector_type(address_bits_base_type, 3) },
823             c"base_global_invocation_id",
824         )
825     }
826 
827     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_GROUP_SIZE) {
828         add_var(
829             nir,
830             &mut lower_state.global_size_loc,
831             CompiledKernelArgType::GlobalWorkSize,
832             unsafe { glsl_vector_type(address_bits_base_type, 3) },
833             c"global_size",
834         )
835     }
836 
837     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_WORKGROUP_ID) {
838         debug_assert_ne!(variant, NirKernelVariant::Optimized);
839         add_var(
840             nir,
841             &mut lower_state.base_workgroup_id_loc,
842             CompiledKernelArgType::WorkGroupOffsets,
843             unsafe { glsl_vector_type(address_bits_base_type, 3) },
844             c"base_workgroup_id",
845         );
846     }
847 
848     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_NUM_WORKGROUPS) {
849         add_var(
850             nir,
851             &mut lower_state.num_workgroups_loc,
852             CompiledKernelArgType::NumWorkgroups,
853             unsafe { glsl_vector_type(glsl_base_type::GLSL_TYPE_UINT, 3) },
854             c"num_workgroups",
855         );
856     }
857 
858     if nir.has_constant() {
859         add_var(
860             nir,
861             &mut lower_state.const_buf_loc,
862             CompiledKernelArgType::ConstantBuffer,
863             address_bits_ptr_type,
864             c"constant_buffer_addr",
865         );
866     }
867     if nir.has_printf() {
868         add_var(
869             nir,
870             &mut lower_state.printf_buf_loc,
871             CompiledKernelArgType::PrintfBuffer,
872             address_bits_ptr_type,
873             c"printf_buffer_addr",
874         );
875     }
876 
877     if nir.num_images() > 0 || nir.num_textures() > 0 {
878         let count = nir.num_images() + nir.num_textures();
879 
880         add_var(
881             nir,
882             &mut lower_state.format_arr_loc,
883             CompiledKernelArgType::FormatArray,
884             unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
885             c"image_formats",
886         );
887 
888         add_var(
889             nir,
890             &mut lower_state.order_arr_loc,
891             CompiledKernelArgType::OrderArray,
892             unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
893             c"image_orders",
894         );
895     }
896 
897     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
898         add_var(
899             nir,
900             &mut lower_state.work_dim_loc,
901             CompiledKernelArgType::WorkDim,
902             unsafe { glsl_uint8_t_type() },
903             c"work_dim",
904         );
905     }
906 
907     // need to run after first opt loop and remove_dead_variables to get rid of uneccessary scratch
908     // memory
909     nir_pass!(
910         nir,
911         nir_lower_vars_to_explicit_types,
912         nir_variable_mode::nir_var_mem_shared
913             | nir_variable_mode::nir_var_function_temp
914             | nir_variable_mode::nir_var_shader_temp
915             | nir_variable_mode::nir_var_uniform
916             | nir_variable_mode::nir_var_mem_global
917             | nir_variable_mode::nir_var_mem_generic,
918         Some(glsl_get_cl_type_size_align),
919     );
920 
921     opt_nir(nir, dev, true);
922     nir_pass!(nir, nir_lower_memcpy);
923 
924     // we might have got rid of more function_temp or shared memory
925     nir.reset_scratch_size();
926     nir.reset_shared_size();
927     nir_pass!(
928         nir,
929         nir_remove_dead_variables,
930         nir_variable_mode::nir_var_function_temp | nir_variable_mode::nir_var_mem_shared,
931         &DV_OPTS,
932     );
933     nir_pass!(
934         nir,
935         nir_lower_vars_to_explicit_types,
936         nir_variable_mode::nir_var_function_temp
937             | nir_variable_mode::nir_var_mem_shared
938             | nir_variable_mode::nir_var_mem_generic,
939         Some(glsl_get_cl_type_size_align),
940     );
941 
942     nir_pass!(
943         nir,
944         nir_lower_explicit_io,
945         nir_variable_mode::nir_var_mem_global | nir_variable_mode::nir_var_mem_constant,
946         global_address_format,
947     );
948 
949     nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
950     nir_pass!(
951         nir,
952         nir_lower_explicit_io,
953         nir_variable_mode::nir_var_mem_shared
954             | nir_variable_mode::nir_var_function_temp
955             | nir_variable_mode::nir_var_uniform,
956         shared_address_format,
957     );
958 
959     if nir_options.lower_int64_options.0 != 0 && !nir_options.late_lower_int64 {
960         nir_pass!(nir, nir_lower_int64);
961     }
962 
963     if nir_options.lower_uniforms_to_ubo {
964         nir_pass!(nir, rusticl_lower_inputs);
965     }
966 
967     nir_pass!(nir, nir_lower_convert_alu_types, None);
968 
969     opt_nir(nir, dev, true);
970 
971     /* before passing it into drivers, assign locations as drivers might remove nir_variables or
972      * other things we depend on
973      */
974     CompiledKernelArg::assign_locations(compiled_args, nir);
975 
976     /* update the has_variable_shared_mem info as we might have DCEed all of them */
977     nir.set_has_variable_shared_mem(compiled_args.iter().any(|arg| {
978         if let CompiledKernelArgType::APIArg(idx) = arg.kind {
979             args[idx as usize].kind == KernelArgType::MemLocal && !arg.dead
980         } else {
981             false
982         }
983     }));
984 
985     if Platform::dbg().nir {
986         eprintln!("=== Printing nir variant '{variant}' for '{name}' before driver finalization");
987         nir.print();
988     }
989 
990     if dev.screen.finalize_nir(nir) {
991         if Platform::dbg().nir {
992             eprintln!(
993                 "=== Printing nir variant '{variant}' for '{name}' after driver finalization"
994             );
995             nir.print();
996         }
997     }
998 
999     nir_pass!(nir, nir_opt_dce);
1000     nir.sweep_mem();
1001 }
1002 
compile_nir_remaining( dev: &Device, mut nir: NirShader, args: &[KernelArg], name: &str, ) -> (CompilationResult, Option<CompilationResult>)1003 fn compile_nir_remaining(
1004     dev: &Device,
1005     mut nir: NirShader,
1006     args: &[KernelArg],
1007     name: &str,
1008 ) -> (CompilationResult, Option<CompilationResult>) {
1009     // add all API kernel args
1010     let mut compiled_args: Vec<_> = (0..args.len())
1011         .map(|idx| CompiledKernelArg {
1012             kind: CompiledKernelArgType::APIArg(idx as u32),
1013             offset: 0,
1014             dead: true,
1015         })
1016         .collect();
1017 
1018     compile_nir_prepare_for_variants(dev, &mut nir, &mut compiled_args);
1019     if Platform::dbg().nir {
1020         eprintln!("=== Printing nir for '{name}' before specialization");
1021         nir.print();
1022     }
1023 
1024     let mut default_build = CompilationResult {
1025         nir: nir,
1026         compiled_args: compiled_args,
1027     };
1028 
1029     // check if we even want to compile a variant before cloning the compilation state
1030     let has_wgs_hint = default_build.nir.workgroup_size_variable()
1031         && default_build.nir.workgroup_size_hint() != [0; 3];
1032     let has_offsets = default_build
1033         .nir
1034         .reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_INVOCATION_ID);
1035 
1036     let mut optimized = (!Platform::dbg().no_variants && (has_offsets || has_wgs_hint))
1037         .then(|| default_build.clone());
1038 
1039     compile_nir_variant(
1040         &mut default_build,
1041         dev,
1042         NirKernelVariant::Default,
1043         args,
1044         name,
1045     );
1046     if let Some(optimized) = &mut optimized {
1047         compile_nir_variant(optimized, dev, NirKernelVariant::Optimized, args, name);
1048     }
1049 
1050     (default_build, optimized)
1051 }
1052 
1053 pub struct SPIRVToNirResult {
1054     pub kernel_info: KernelInfo,
1055     pub nir_kernel_builds: NirKernelBuilds,
1056 }
1057 
1058 impl SPIRVToNirResult {
new( dev: &'static Device, kernel_info: &clc_kernel_info, args: Vec<KernelArg>, default_build: CompilationResult, optimized: Option<CompilationResult>, ) -> Self1059     fn new(
1060         dev: &'static Device,
1061         kernel_info: &clc_kernel_info,
1062         args: Vec<KernelArg>,
1063         default_build: CompilationResult,
1064         optimized: Option<CompilationResult>,
1065     ) -> Self {
1066         // TODO: we _should_ be able to parse them out of the SPIR-V, but clc doesn't handle
1067         //       indirections yet.
1068         let nir = &default_build.nir;
1069         let wgs = nir.workgroup_size();
1070         let subgroup_size = nir.subgroup_size();
1071         let num_subgroups = nir.num_subgroups();
1072 
1073         let default_build = NirKernelBuild::new(dev, default_build);
1074         let optimized = optimized.map(|opt| NirKernelBuild::new(dev, opt));
1075 
1076         let kernel_info = KernelInfo {
1077             args: args,
1078             attributes_string: kernel_info.attribute_str(),
1079             work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
1080             work_group_size_hint: kernel_info.local_size_hint,
1081             subgroup_size: subgroup_size as usize,
1082             num_subgroups: num_subgroups as usize,
1083         };
1084 
1085         Self {
1086             kernel_info: kernel_info,
1087             nir_kernel_builds: NirKernelBuilds::new(default_build, optimized),
1088         }
1089     }
1090 
deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self>1091     fn deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self> {
1092         let mut reader = blob_reader::default();
1093         unsafe {
1094             blob_reader_init(&mut reader, bin.as_ptr().cast(), bin.len());
1095         }
1096 
1097         let args = KernelArg::deserialize(&mut reader)?;
1098         let default_build = CompilationResult::deserialize(&mut reader, d)?;
1099 
1100         // SAFETY: on overrun this returns 0
1101         let optimized = match unsafe { blob_read_uint8(&mut reader) } {
1102             0 => None,
1103             _ => Some(CompilationResult::deserialize(&mut reader, d)?),
1104         };
1105 
1106         reader
1107             .overrun
1108             .not()
1109             .then(|| SPIRVToNirResult::new(d, kernel_info, args, default_build, optimized))
1110     }
1111 
1112     // we can't use Self here as the nir shader might be compiled to a cso already and we can't
1113     // cache that.
serialize( blob: &mut blob, args: &[KernelArg], default_build: &CompilationResult, optimized: &Option<CompilationResult>, )1114     fn serialize(
1115         blob: &mut blob,
1116         args: &[KernelArg],
1117         default_build: &CompilationResult,
1118         optimized: &Option<CompilationResult>,
1119     ) {
1120         KernelArg::serialize(args, blob);
1121         default_build.serialize(blob);
1122         match optimized {
1123             Some(variant) => {
1124                 unsafe { blob_write_uint8(blob, 1) };
1125                 variant.serialize(blob);
1126             }
1127             None => unsafe {
1128                 blob_write_uint8(blob, 0);
1129             },
1130         }
1131     }
1132 }
1133 
convert_spirv_to_nir( build: &ProgramBuild, name: &str, args: &[spirv::SPIRVKernelArg], dev: &'static Device, ) -> SPIRVToNirResult1134 pub(super) fn convert_spirv_to_nir(
1135     build: &ProgramBuild,
1136     name: &str,
1137     args: &[spirv::SPIRVKernelArg],
1138     dev: &'static Device,
1139 ) -> SPIRVToNirResult {
1140     let cache = dev.screen().shader_cache();
1141     let key = build.hash_key(dev, name);
1142     let spirv_info = build.spirv_info(name, dev).unwrap();
1143 
1144     cache
1145         .as_ref()
1146         .and_then(|cache| cache.get(&mut key?))
1147         .and_then(|entry| SPIRVToNirResult::deserialize(&entry, dev, spirv_info))
1148         .unwrap_or_else(|| {
1149             let nir = build.to_nir(name, dev);
1150 
1151             if Platform::dbg().nir {
1152                 eprintln!("=== Printing nir for '{name}' after spirv_to_nir");
1153                 nir.print();
1154             }
1155 
1156             let (mut args, nir) = compile_nir_to_args(dev, nir, args, &dev.lib_clc);
1157             let (default_build, optimized) = compile_nir_remaining(dev, nir, &args, name);
1158 
1159             for build in [Some(&default_build), optimized.as_ref()].into_iter() {
1160                 let Some(build) = build else {
1161                     continue;
1162                 };
1163 
1164                 for arg in &build.compiled_args {
1165                     if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1166                         args[idx as usize].dead &= arg.dead;
1167                     }
1168                 }
1169             }
1170 
1171             if let Some(cache) = cache {
1172                 let mut blob = blob::default();
1173                 unsafe {
1174                     blob_init(&mut blob);
1175                     SPIRVToNirResult::serialize(&mut blob, &args, &default_build, &optimized);
1176                     let bin = slice::from_raw_parts(blob.data, blob.size);
1177                     cache.put(bin, &mut key.unwrap());
1178                     blob_finish(&mut blob);
1179                 }
1180             }
1181 
1182             SPIRVToNirResult::new(dev, spirv_info, args, default_build, optimized)
1183         })
1184 }
1185 
extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S]1186 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
1187     let val;
1188     (val, *buf) = (*buf).split_at(S);
1189     // we split of 4 bytes and convert to [u8; 4], so this should be safe
1190     // use split_array_ref once it's stable
1191     val.try_into().unwrap()
1192 }
1193 
1194 impl Kernel {
new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel>1195     pub fn new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel> {
1196         let kernel_info = Arc::clone(prog_build.kernel_info.get(&name).unwrap());
1197         let builds = prog_build
1198             .builds
1199             .iter()
1200             .filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, k.clone())))
1201             .collect();
1202 
1203         let values = vec![None; kernel_info.args.len()];
1204         Arc::new(Self {
1205             base: CLObjectBase::new(RusticlTypes::Kernel),
1206             prog: prog,
1207             name: name,
1208             values: Mutex::new(values),
1209             builds: builds,
1210             kernel_info: kernel_info,
1211         })
1212     }
1213 
suggest_local_size( &self, d: &Device, work_dim: usize, grid: &mut [usize], block: &mut [usize], )1214     pub fn suggest_local_size(
1215         &self,
1216         d: &Device,
1217         work_dim: usize,
1218         grid: &mut [usize],
1219         block: &mut [usize],
1220     ) {
1221         let mut threads = self.max_threads_per_block(d);
1222         let dim_threads = d.max_block_sizes();
1223         let subgroups = self.preferred_simd_size(d);
1224 
1225         for i in 0..work_dim {
1226             let t = cmp::min(threads, dim_threads[i]);
1227             let gcd = gcd(t, grid[i]);
1228 
1229             block[i] = gcd;
1230             grid[i] /= gcd;
1231 
1232             // update limits
1233             threads /= block[i];
1234         }
1235 
1236         // if we didn't fill the subgroup we can do a bit better if we have threads remaining
1237         let total_threads = block.iter().take(work_dim).product::<usize>();
1238         if threads != 1 && total_threads < subgroups {
1239             for i in 0..work_dim {
1240                 if grid[i] * total_threads < threads && grid[i] * block[i] <= dim_threads[i] {
1241                     block[i] *= grid[i];
1242                     grid[i] = 1;
1243                     // can only do it once as nothing is cleanly divisible
1244                     break;
1245                 }
1246             }
1247         }
1248     }
1249 
optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3])1250     fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) {
1251         if !block.contains(&0) {
1252             for i in 0..3 {
1253                 // we already made sure everything is fine
1254                 grid[i] /= block[i] as usize;
1255             }
1256             return;
1257         }
1258 
1259         let mut usize_block = [0usize; 3];
1260         for i in 0..3 {
1261             usize_block[i] = block[i] as usize;
1262         }
1263 
1264         self.suggest_local_size(d, 3, grid, &mut usize_block);
1265 
1266         for i in 0..3 {
1267             block[i] = usize_block[i] as u32;
1268         }
1269     }
1270 
1271     // the painful part is, that host threads are allowed to modify the kernel object once it was
1272     // enqueued, so return a closure with all req data included.
launch( self: &Arc<Self>, q: &Arc<Queue>, work_dim: u32, block: &[usize], grid: &[usize], offsets: &[usize], ) -> CLResult<EventSig>1273     pub fn launch(
1274         self: &Arc<Self>,
1275         q: &Arc<Queue>,
1276         work_dim: u32,
1277         block: &[usize],
1278         grid: &[usize],
1279         offsets: &[usize],
1280     ) -> CLResult<EventSig> {
1281         // Clone all the data we need to execute this kernel
1282         let kernel_info = Arc::clone(&self.kernel_info);
1283         let arg_values = self.arg_values().clone();
1284         let nir_kernel_builds = Arc::clone(&self.builds[q.device]);
1285 
1286         let mut buffer_arcs = HashMap::new();
1287         let mut image_arcs = HashMap::new();
1288 
1289         // need to preprocess buffer and image arguments so we hold a strong reference until the
1290         // event was processed.
1291         for arg in arg_values.iter() {
1292             match arg {
1293                 Some(KernelArgValue::Buffer(buffer)) => {
1294                     buffer_arcs.insert(
1295                         // we use the ptr as the key, and also cast it to usize so we don't need to
1296                         // deal with Send + Sync here.
1297                         buffer.as_ptr() as usize,
1298                         buffer.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
1299                     );
1300                 }
1301                 Some(KernelArgValue::Image(image)) => {
1302                     image_arcs.insert(
1303                         image.as_ptr() as usize,
1304                         image.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
1305                     );
1306                 }
1307                 _ => {}
1308             }
1309         }
1310 
1311         // operations we want to report errors to the clients
1312         let mut block = create_kernel_arr::<u32>(block, 1)?;
1313         let mut grid = create_kernel_arr::<usize>(grid, 1)?;
1314         let offsets = create_kernel_arr::<usize>(offsets, 0)?;
1315 
1316         let api_grid = grid;
1317 
1318         self.optimize_local_size(q.device, &mut grid, &mut block);
1319 
1320         Ok(Box::new(move |q, ctx| {
1321             let hw_max_grid: Vec<usize> = q
1322                 .device
1323                 .max_grid_size()
1324                 .into_iter()
1325                 .map(|val| val.try_into().unwrap_or(usize::MAX))
1326                 // clamped as pipe_launch_grid::grid is only u32
1327                 .map(|val| cmp::min(val, u32::MAX as usize))
1328                 .collect();
1329 
1330             let variant = if offsets == [0; 3]
1331                 && grid[0] <= hw_max_grid[0]
1332                 && grid[1] <= hw_max_grid[1]
1333                 && grid[2] <= hw_max_grid[2]
1334                 && (kernel_info.work_group_size_hint == [0; 3]
1335                     || block == kernel_info.work_group_size_hint)
1336             {
1337                 NirKernelVariant::Optimized
1338             } else {
1339                 NirKernelVariant::Default
1340             };
1341 
1342             let nir_kernel_build = &nir_kernel_builds[variant];
1343             let mut workgroup_id_offset_loc = None;
1344             let mut input = Vec::new();
1345             // Set it once so we get the alignment padding right
1346             let static_local_size: u64 = nir_kernel_build.shared_size;
1347             let mut variable_local_size: u64 = static_local_size;
1348             let printf_size = q.device.printf_buffer_size() as u32;
1349             let mut samplers = Vec::new();
1350             let mut iviews = Vec::new();
1351             let mut sviews = Vec::new();
1352             let mut tex_formats: Vec<u16> = Vec::new();
1353             let mut tex_orders: Vec<u16> = Vec::new();
1354             let mut img_formats: Vec<u16> = Vec::new();
1355             let mut img_orders: Vec<u16> = Vec::new();
1356 
1357             let null_ptr;
1358             let null_ptr_v3;
1359             if q.device.address_bits() == 64 {
1360                 null_ptr = [0u8; 8].as_slice();
1361                 null_ptr_v3 = [0u8; 24].as_slice();
1362             } else {
1363                 null_ptr = [0u8; 4].as_slice();
1364                 null_ptr_v3 = [0u8; 12].as_slice();
1365             };
1366 
1367             let mut resource_info = Vec::new();
1368             fn add_global<'a>(
1369                 q: &Queue,
1370                 input: &mut Vec<u8>,
1371                 resource_info: &mut Vec<(&'a PipeResource, usize)>,
1372                 res: &'a PipeResource,
1373                 offset: usize,
1374             ) {
1375                 resource_info.push((res, input.len()));
1376                 if q.device.address_bits() == 64 {
1377                     let offset: u64 = offset as u64;
1378                     input.extend_from_slice(&offset.to_ne_bytes());
1379                 } else {
1380                     let offset: u32 = offset as u32;
1381                     input.extend_from_slice(&offset.to_ne_bytes());
1382                 }
1383             }
1384 
1385             fn add_sysval(q: &Queue, input: &mut Vec<u8>, vals: &[usize; 3]) {
1386                 if q.device.address_bits() == 64 {
1387                     input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u64)) });
1388                 } else {
1389                     input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u32)) });
1390                 }
1391             }
1392 
1393             let mut printf_buf = None;
1394             if nir_kernel_build.printf_info.is_some() {
1395                 let buf = q
1396                     .device
1397                     .screen
1398                     .resource_create_buffer(printf_size, ResourceType::Staging, PIPE_BIND_GLOBAL)
1399                     .unwrap();
1400 
1401                 let init_data: [u8; 1] = [4];
1402                 ctx.buffer_subdata(&buf, 0, init_data.as_ptr().cast(), init_data.len() as u32);
1403 
1404                 printf_buf = Some(buf);
1405             }
1406 
1407             for arg in &nir_kernel_build.compiled_args {
1408                 let is_opaque = if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1409                     kernel_info.args[idx as usize].kind.is_opaque()
1410                 } else {
1411                     false
1412                 };
1413 
1414                 if !is_opaque && arg.offset as usize > input.len() {
1415                     input.resize(arg.offset as usize, 0);
1416                 }
1417 
1418                 match arg.kind {
1419                     CompiledKernelArgType::APIArg(idx) => {
1420                         let api_arg = &kernel_info.args[idx as usize];
1421                         if api_arg.dead {
1422                             continue;
1423                         }
1424 
1425                         let Some(value) = &arg_values[idx as usize] else {
1426                             continue;
1427                         };
1428 
1429                         match value {
1430                             KernelArgValue::Constant(c) => input.extend_from_slice(c),
1431                             KernelArgValue::Buffer(buffer) => {
1432                                 let buffer = &buffer_arcs[&(buffer.as_ptr() as usize)];
1433                                 let rw = if api_arg.spirv.address_qualifier
1434                                     == clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT
1435                                 {
1436                                     RWFlags::RD
1437                                 } else {
1438                                     RWFlags::RW
1439                                 };
1440 
1441                                 let res = buffer.get_res_for_access(ctx, rw)?;
1442                                 add_global(q, &mut input, &mut resource_info, res, buffer.offset());
1443                             }
1444                             KernelArgValue::Image(image) => {
1445                                 let image = &image_arcs[&(image.as_ptr() as usize)];
1446                                 let (formats, orders) = if api_arg.kind == KernelArgType::Image {
1447                                     iviews.push(image.image_view(ctx, false)?);
1448                                     (&mut img_formats, &mut img_orders)
1449                                 } else if api_arg.kind == KernelArgType::RWImage {
1450                                     iviews.push(image.image_view(ctx, true)?);
1451                                     (&mut img_formats, &mut img_orders)
1452                                 } else {
1453                                     sviews.push(image.sampler_view(ctx)?);
1454                                     (&mut tex_formats, &mut tex_orders)
1455                                 };
1456 
1457                                 let binding = arg.offset as usize;
1458                                 assert!(binding >= formats.len());
1459 
1460                                 formats.resize(binding, 0);
1461                                 orders.resize(binding, 0);
1462 
1463                                 formats.push(image.image_format.image_channel_data_type as u16);
1464                                 orders.push(image.image_format.image_channel_order as u16);
1465                             }
1466                             KernelArgValue::LocalMem(size) => {
1467                                 // TODO 32 bit
1468                                 let pot = cmp::min(*size, 0x80);
1469                                 variable_local_size = variable_local_size
1470                                     .next_multiple_of(pot.next_power_of_two() as u64);
1471                                 if q.device.address_bits() == 64 {
1472                                     let variable_local_size: [u8; 8] =
1473                                         variable_local_size.to_ne_bytes();
1474                                     input.extend_from_slice(&variable_local_size);
1475                                 } else {
1476                                     let variable_local_size: [u8; 4] =
1477                                         (variable_local_size as u32).to_ne_bytes();
1478                                     input.extend_from_slice(&variable_local_size);
1479                                 }
1480                                 variable_local_size += *size as u64;
1481                             }
1482                             KernelArgValue::Sampler(sampler) => {
1483                                 samplers.push(sampler.pipe());
1484                             }
1485                             KernelArgValue::None => {
1486                                 assert!(
1487                                     api_arg.kind == KernelArgType::MemGlobal
1488                                         || api_arg.kind == KernelArgType::MemConstant
1489                                 );
1490                                 input.extend_from_slice(null_ptr);
1491                             }
1492                         }
1493                     }
1494                     CompiledKernelArgType::ConstantBuffer => {
1495                         assert!(nir_kernel_build.constant_buffer.is_some());
1496                         let res = nir_kernel_build.constant_buffer.as_ref().unwrap();
1497                         add_global(q, &mut input, &mut resource_info, res, 0);
1498                     }
1499                     CompiledKernelArgType::GlobalWorkOffsets => {
1500                         add_sysval(q, &mut input, &offsets);
1501                     }
1502                     CompiledKernelArgType::WorkGroupOffsets => {
1503                         workgroup_id_offset_loc = Some(input.len());
1504                         input.extend_from_slice(null_ptr_v3);
1505                     }
1506                     CompiledKernelArgType::GlobalWorkSize => {
1507                         add_sysval(q, &mut input, &api_grid);
1508                     }
1509                     CompiledKernelArgType::PrintfBuffer => {
1510                         let res = printf_buf.as_ref().unwrap();
1511                         add_global(q, &mut input, &mut resource_info, res, 0);
1512                     }
1513                     CompiledKernelArgType::InlineSampler(cl) => {
1514                         samplers.push(Sampler::cl_to_pipe(cl));
1515                     }
1516                     CompiledKernelArgType::FormatArray => {
1517                         input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) });
1518                         input.extend_from_slice(unsafe { as_byte_slice(&img_formats) });
1519                     }
1520                     CompiledKernelArgType::OrderArray => {
1521                         input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) });
1522                         input.extend_from_slice(unsafe { as_byte_slice(&img_orders) });
1523                     }
1524                     CompiledKernelArgType::WorkDim => {
1525                         input.extend_from_slice(&[work_dim as u8; 1]);
1526                     }
1527                     CompiledKernelArgType::NumWorkgroups => {
1528                         input.extend_from_slice(unsafe {
1529                             as_byte_slice(&[grid[0] as u32, grid[1] as u32, grid[2] as u32])
1530                         });
1531                     }
1532                 }
1533             }
1534 
1535             // subtract the shader local_size as we only request something on top of that.
1536             variable_local_size -= static_local_size;
1537 
1538             let samplers: Vec<_> = samplers
1539                 .iter()
1540                 .map(|s| ctx.create_sampler_state(s))
1541                 .collect();
1542 
1543             let mut resources = Vec::with_capacity(resource_info.len());
1544             let mut globals: Vec<*mut u32> = Vec::with_capacity(resource_info.len());
1545             for (res, offset) in resource_info {
1546                 resources.push(res);
1547                 globals.push(unsafe { input.as_mut_ptr().byte_add(offset) }.cast());
1548             }
1549 
1550             let temp_cso;
1551             let cso = match &nir_kernel_build.nir_or_cso {
1552                 KernelDevStateVariant::Cso(cso) => cso,
1553                 KernelDevStateVariant::Nir(nir) => {
1554                     temp_cso = CSOWrapper::new(q.device, nir);
1555                     &temp_cso
1556                 }
1557             };
1558 
1559             let sviews_len = sviews.len();
1560             ctx.bind_compute_state(cso.cso_ptr);
1561             ctx.bind_sampler_states(&samplers);
1562             ctx.set_sampler_views(sviews);
1563             ctx.set_shader_images(&iviews);
1564             ctx.set_global_binding(resources.as_slice(), &mut globals);
1565 
1566             for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
1567                 for y in 0..grid[1].div_ceil(hw_max_grid[1]) {
1568                     for x in 0..grid[0].div_ceil(hw_max_grid[0]) {
1569                         if let Some(workgroup_id_offset_loc) = workgroup_id_offset_loc {
1570                             let this_offsets =
1571                                 [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]];
1572 
1573                             if q.device.address_bits() == 64 {
1574                                 let val = this_offsets.map(|v| v as u64);
1575                                 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24]
1576                                     .copy_from_slice(unsafe { as_byte_slice(&val) });
1577                             } else {
1578                                 let val = this_offsets.map(|v| v as u32);
1579                                 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12]
1580                                     .copy_from_slice(unsafe { as_byte_slice(&val) });
1581                             }
1582                         }
1583 
1584                         let this_grid = [
1585                             cmp::min(hw_max_grid[0], grid[0] - hw_max_grid[0] * x) as u32,
1586                             cmp::min(hw_max_grid[1], grid[1] - hw_max_grid[1] * y) as u32,
1587                             cmp::min(hw_max_grid[2], grid[2] - hw_max_grid[2] * z) as u32,
1588                         ];
1589 
1590                         ctx.update_cb0(&input)?;
1591                         ctx.launch_grid(work_dim, block, this_grid, variable_local_size as u32);
1592 
1593                         if Platform::dbg().sync_every_event {
1594                             ctx.flush().wait();
1595                         }
1596                     }
1597                 }
1598             }
1599 
1600             ctx.clear_global_binding(globals.len() as u32);
1601             ctx.clear_shader_images(iviews.len() as u32);
1602             ctx.clear_sampler_views(sviews_len as u32);
1603             ctx.clear_sampler_states(samplers.len() as u32);
1604 
1605             ctx.bind_compute_state(ptr::null_mut());
1606 
1607             ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
1608 
1609             samplers.iter().for_each(|s| ctx.delete_sampler_state(*s));
1610 
1611             if let Some(printf_buf) = &printf_buf {
1612                 let tx = ctx
1613                     .buffer_map(printf_buf, 0, printf_size as i32, RWFlags::RD)
1614                     .ok_or(CL_OUT_OF_RESOURCES)?;
1615                 let mut buf: &[u8] =
1616                     unsafe { slice::from_raw_parts(tx.ptr().cast(), printf_size as usize) };
1617                 let length = u32::from_ne_bytes(*extract(&mut buf));
1618 
1619                 // update our slice to make sure we don't go out of bounds
1620                 buf = &buf[0..(length - 4) as usize];
1621                 if let Some(pf) = &nir_kernel_build.printf_info {
1622                     pf.u_printf(buf)
1623                 }
1624             }
1625 
1626             Ok(())
1627         }))
1628     }
1629 
arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>>1630     pub fn arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>> {
1631         self.values.lock().unwrap()
1632     }
1633 
set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()>1634     pub fn set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()> {
1635         self.values
1636             .lock()
1637             .unwrap()
1638             .get_mut(idx)
1639             .ok_or(CL_INVALID_ARG_INDEX)?
1640             .replace(arg);
1641         Ok(())
1642     }
1643 
access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier1644     pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
1645         let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier;
1646 
1647         if aq
1648             == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
1649                 | clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE
1650         {
1651             CL_KERNEL_ARG_ACCESS_READ_WRITE
1652         } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ {
1653             CL_KERNEL_ARG_ACCESS_READ_ONLY
1654         } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE {
1655             CL_KERNEL_ARG_ACCESS_WRITE_ONLY
1656         } else {
1657             CL_KERNEL_ARG_ACCESS_NONE
1658         }
1659     }
1660 
address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier1661     pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
1662         match self.kernel_info.args[idx as usize].spirv.address_qualifier {
1663             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
1664                 CL_KERNEL_ARG_ADDRESS_PRIVATE
1665             }
1666             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
1667                 CL_KERNEL_ARG_ADDRESS_CONSTANT
1668             }
1669             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
1670                 CL_KERNEL_ARG_ADDRESS_LOCAL
1671             }
1672             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
1673                 CL_KERNEL_ARG_ADDRESS_GLOBAL
1674             }
1675         }
1676     }
1677 
type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier1678     pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
1679         let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier;
1680         let zero = clc_kernel_arg_type_qualifier(0);
1681         let mut res = CL_KERNEL_ARG_TYPE_NONE;
1682 
1683         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_CONST != zero {
1684             res |= CL_KERNEL_ARG_TYPE_CONST;
1685         }
1686 
1687         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_RESTRICT != zero {
1688             res |= CL_KERNEL_ARG_TYPE_RESTRICT;
1689         }
1690 
1691         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_VOLATILE != zero {
1692             res |= CL_KERNEL_ARG_TYPE_VOLATILE;
1693         }
1694 
1695         res.into()
1696     }
1697 
work_group_size(&self) -> [usize; 3]1698     pub fn work_group_size(&self) -> [usize; 3] {
1699         self.kernel_info.work_group_size
1700     }
1701 
num_subgroups(&self) -> usize1702     pub fn num_subgroups(&self) -> usize {
1703         self.kernel_info.num_subgroups
1704     }
1705 
subgroup_size(&self) -> usize1706     pub fn subgroup_size(&self) -> usize {
1707         self.kernel_info.subgroup_size
1708     }
1709 
arg_name(&self, idx: cl_uint) -> Option<&CStr>1710     pub fn arg_name(&self, idx: cl_uint) -> Option<&CStr> {
1711         let name = &self.kernel_info.args[idx as usize].spirv.name;
1712         name.is_empty().not().then_some(name)
1713     }
1714 
arg_type_name(&self, idx: cl_uint) -> Option<&CStr>1715     pub fn arg_type_name(&self, idx: cl_uint) -> Option<&CStr> {
1716         let type_name = &self.kernel_info.args[idx as usize].spirv.type_name;
1717         type_name.is_empty().not().then_some(type_name)
1718     }
1719 
priv_mem_size(&self, dev: &Device) -> cl_ulong1720     pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
1721         self.builds.get(dev).unwrap().info.private_memory as cl_ulong
1722     }
1723 
max_threads_per_block(&self, dev: &Device) -> usize1724     pub fn max_threads_per_block(&self, dev: &Device) -> usize {
1725         self.builds.get(dev).unwrap().info.max_threads as usize
1726     }
1727 
preferred_simd_size(&self, dev: &Device) -> usize1728     pub fn preferred_simd_size(&self, dev: &Device) -> usize {
1729         self.builds.get(dev).unwrap().info.preferred_simd_size as usize
1730     }
1731 
local_mem_size(&self, dev: &Device) -> cl_ulong1732     pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
1733         // TODO: take alignment into account?
1734         // this is purely informational so it shouldn't even matter
1735         let local =
1736             self.builds.get(dev).unwrap()[NirKernelVariant::Default].shared_size as cl_ulong;
1737         let args: cl_ulong = self
1738             .arg_values()
1739             .iter()
1740             .map(|arg| match arg {
1741                 Some(KernelArgValue::LocalMem(val)) => *val as cl_ulong,
1742                 // If the local memory size, for any pointer argument to the kernel declared with
1743                 // the __local address qualifier, is not specified, its size is assumed to be 0.
1744                 _ => 0,
1745             })
1746             .sum();
1747 
1748         local + args
1749     }
1750 
has_svm_devs(&self) -> bool1751     pub fn has_svm_devs(&self) -> bool {
1752         self.prog.devs.iter().any(|dev| dev.svm_supported())
1753     }
1754 
subgroup_sizes(&self, dev: &Device) -> Vec<usize>1755     pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
1756         SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
1757             .map(|bit| 1 << bit)
1758             .collect()
1759     }
1760 
subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize1761     pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1762         let subgroup_size = self.subgroup_size_for_block(dev, block);
1763         if subgroup_size == 0 {
1764             return 0;
1765         }
1766 
1767         let threads: usize = block.iter().product();
1768         threads.div_ceil(subgroup_size)
1769     }
1770 
subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize1771     pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1772         let subgroup_sizes = self.subgroup_sizes(dev);
1773         if subgroup_sizes.is_empty() {
1774             return 0;
1775         }
1776 
1777         if subgroup_sizes.len() == 1 {
1778             return subgroup_sizes[0];
1779         }
1780 
1781         let block = [
1782             *block.first().unwrap_or(&1) as u32,
1783             *block.get(1).unwrap_or(&1) as u32,
1784             *block.get(2).unwrap_or(&1) as u32,
1785         ];
1786 
1787         // TODO: this _might_ bite us somewhere, but I think it probably doesn't matter
1788         match &self.builds.get(dev).unwrap()[NirKernelVariant::Default].nir_or_cso {
1789             KernelDevStateVariant::Cso(cso) => {
1790                 dev.helper_ctx()
1791                     .compute_state_subgroup_size(cso.cso_ptr, &block) as usize
1792             }
1793             _ => {
1794                 panic!()
1795             }
1796         }
1797     }
1798 }
1799 
1800 impl Clone for Kernel {
clone(&self) -> Self1801     fn clone(&self) -> Self {
1802         Self {
1803             base: CLObjectBase::new(RusticlTypes::Kernel),
1804             prog: self.prog.clone(),
1805             name: self.name.clone(),
1806             values: Mutex::new(self.arg_values().clone()),
1807             builds: self.builds.clone(),
1808             kernel_info: self.kernel_info.clone(),
1809         }
1810     }
1811 }
1812