• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::api::icd::*;
2 use crate::core::device::*;
3 use crate::core::event::*;
4 use crate::core::memory::*;
5 use crate::core::platform::*;
6 use crate::core::program::*;
7 use crate::core::queue::*;
8 use crate::impl_cl_type_trait;
9 
10 use mesa_rust::compiler::clc::*;
11 use mesa_rust::compiler::nir::*;
12 use mesa_rust::nir_pass;
13 use mesa_rust::pipe::context::RWFlags;
14 use mesa_rust::pipe::resource::*;
15 use mesa_rust::pipe::screen::ResourceType;
16 use mesa_rust_gen::*;
17 use mesa_rust_util::math::*;
18 use mesa_rust_util::serialize::*;
19 use rusticl_opencl_gen::*;
20 use spirv::SpirvKernelInfo;
21 
22 use std::cmp;
23 use std::collections::HashMap;
24 use std::convert::TryInto;
25 use std::ffi::CStr;
26 use std::fmt::Debug;
27 use std::fmt::Display;
28 use std::ops::Index;
29 use std::ops::Not;
30 use std::os::raw::c_void;
31 use std::ptr;
32 use std::slice;
33 use std::sync::Arc;
34 use std::sync::Mutex;
35 use std::sync::MutexGuard;
36 use std::sync::Weak;
37 
38 // According to the CL spec we are not allowed to let any cl_kernel object hold any references on
39 // its arguments as this might make it unfeasible for applications to free the backing memory of
40 // memory objects allocated with `CL_USE_HOST_PTR`.
41 //
42 // However those arguments might temporarily get referenced by event objects, so we'll use Weak in
43 // order to upgrade the reference when needed. It's also safer to use Weak over raw pointers,
44 // because it makes it impossible to run into use-after-free issues.
45 //
46 // Technically we also need to do it for samplers, but there it's kinda pointless to take a weak
47 // reference as samplers don't have the same host_ptr or any similar problems as cl_mem objects.
48 #[derive(Clone)]
49 pub enum KernelArgValue {
50     None,
51     Buffer(Weak<Buffer>),
52     Constant(Vec<u8>),
53     Image(Weak<Image>),
54     LocalMem(usize),
55     Sampler(Arc<Sampler>),
56 }
57 
58 #[repr(u8)]
59 #[derive(Hash, PartialEq, Eq, Clone, Copy)]
60 pub enum KernelArgType {
61     Constant(/* size */ u16), // for anything passed by value
62     Image,
63     RWImage,
64     Sampler,
65     Texture,
66     MemGlobal,
67     MemConstant,
68     MemLocal,
69 }
70 
71 impl KernelArgType {
deserialize(blob: &mut blob_reader) -> Option<Self>72     fn deserialize(blob: &mut blob_reader) -> Option<Self> {
73         // SAFETY: we get 0 on an overrun, but we verify that later and act accordingly.
74         let res = match unsafe { blob_read_uint8(blob) } {
75             0 => {
76                 // SAFETY: same here
77                 let size = unsafe { blob_read_uint16(blob) };
78                 KernelArgType::Constant(size)
79             }
80             1 => KernelArgType::Image,
81             2 => KernelArgType::RWImage,
82             3 => KernelArgType::Sampler,
83             4 => KernelArgType::Texture,
84             5 => KernelArgType::MemGlobal,
85             6 => KernelArgType::MemConstant,
86             7 => KernelArgType::MemLocal,
87             _ => return None,
88         };
89 
90         blob.overrun.not().then_some(res)
91     }
92 
serialize(&self, blob: &mut blob)93     fn serialize(&self, blob: &mut blob) {
94         unsafe {
95             match self {
96                 KernelArgType::Constant(size) => {
97                     blob_write_uint8(blob, 0);
98                     blob_write_uint16(blob, *size)
99                 }
100                 KernelArgType::Image => blob_write_uint8(blob, 1),
101                 KernelArgType::RWImage => blob_write_uint8(blob, 2),
102                 KernelArgType::Sampler => blob_write_uint8(blob, 3),
103                 KernelArgType::Texture => blob_write_uint8(blob, 4),
104                 KernelArgType::MemGlobal => blob_write_uint8(blob, 5),
105                 KernelArgType::MemConstant => blob_write_uint8(blob, 6),
106                 KernelArgType::MemLocal => blob_write_uint8(blob, 7),
107             };
108         }
109     }
110 
is_opaque(&self) -> bool111     fn is_opaque(&self) -> bool {
112         matches!(
113             self,
114             KernelArgType::Image
115                 | KernelArgType::RWImage
116                 | KernelArgType::Texture
117                 | KernelArgType::Sampler
118         )
119     }
120 }
121 
122 #[derive(Hash, PartialEq, Eq, Clone)]
123 enum CompiledKernelArgType {
124     APIArg(u32),
125     ConstantBuffer,
126     GlobalWorkOffsets,
127     GlobalWorkSize,
128     PrintfBuffer,
129     InlineSampler((cl_addressing_mode, cl_filter_mode, bool)),
130     FormatArray,
131     OrderArray,
132     WorkDim,
133     WorkGroupOffsets,
134     NumWorkgroups,
135 }
136 
137 #[derive(Hash, PartialEq, Eq, Clone)]
138 pub struct KernelArg {
139     spirv: spirv::SPIRVKernelArg,
140     pub kind: KernelArgType,
141     pub dead: bool,
142 }
143 
144 impl KernelArg {
from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self>145     fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
146         let nir_arg_map: HashMap<_, _> = nir
147             .variables_with_mode(
148                 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
149             )
150             .map(|v| (v.data.location, v))
151             .collect();
152         let mut res = Vec::new();
153 
154         for (i, s) in spirv.iter().enumerate() {
155             let nir = nir_arg_map.get(&(i as i32)).unwrap();
156             let kind = match s.address_qualifier {
157                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
158                     if unsafe { glsl_type_is_sampler(nir.type_) } {
159                         KernelArgType::Sampler
160                     } else {
161                         let size = unsafe { glsl_get_cl_size(nir.type_) } as u16;
162                         // nir types of non opaque types are never sized 0
163                         KernelArgType::Constant(size)
164                     }
165                 }
166                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
167                     KernelArgType::MemConstant
168                 }
169                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
170                     KernelArgType::MemLocal
171                 }
172                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
173                     if unsafe { glsl_type_is_image(nir.type_) } {
174                         let access = nir.data.access();
175                         if access == gl_access_qualifier::ACCESS_NON_WRITEABLE.0 {
176                             KernelArgType::Texture
177                         } else if access == gl_access_qualifier::ACCESS_NON_READABLE.0 {
178                             KernelArgType::Image
179                         } else {
180                             KernelArgType::RWImage
181                         }
182                     } else {
183                         KernelArgType::MemGlobal
184                     }
185                 }
186             };
187 
188             res.push(Self {
189                 spirv: s.clone(),
190                 // we'll update it later in the 2nd pass
191                 kind: kind,
192                 dead: true,
193             });
194         }
195         res
196     }
197 
serialize(args: &[Self], blob: &mut blob)198     fn serialize(args: &[Self], blob: &mut blob) {
199         unsafe {
200             blob_write_uint16(blob, args.len() as u16);
201 
202             for arg in args {
203                 arg.spirv.serialize(blob);
204                 blob_write_uint8(blob, arg.dead.into());
205                 arg.kind.serialize(blob);
206             }
207         }
208     }
209 
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>210     fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
211         // SAFETY: we check the overrun status, blob_read returns 0 in such a case.
212         let len = unsafe { blob_read_uint16(blob) } as usize;
213         let mut res = Vec::with_capacity(len);
214 
215         for _ in 0..len {
216             let spirv = spirv::SPIRVKernelArg::deserialize(blob)?;
217             // SAFETY: we check the overrun status
218             let dead = unsafe { blob_read_uint8(blob) } != 0;
219             let kind = KernelArgType::deserialize(blob)?;
220 
221             res.push(Self {
222                 spirv: spirv,
223                 kind: kind,
224                 dead: dead,
225             });
226         }
227 
228         blob.overrun.not().then_some(res)
229     }
230 }
231 
232 #[derive(Hash, PartialEq, Eq, Clone)]
233 struct CompiledKernelArg {
234     kind: CompiledKernelArgType,
235     /// The binding for image/sampler args, the offset into the input buffer
236     /// for anything else.
237     offset: u32,
238     dead: bool,
239 }
240 
241 impl CompiledKernelArg {
assign_locations(compiled_args: &mut [Self], nir: &mut NirShader)242     fn assign_locations(compiled_args: &mut [Self], nir: &mut NirShader) {
243         for var in nir.variables_with_mode(
244             nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
245         ) {
246             let arg = &mut compiled_args[var.data.location as usize];
247             let t = var.type_;
248 
249             arg.dead = false;
250             arg.offset = if unsafe {
251                 glsl_type_is_image(t) || glsl_type_is_texture(t) || glsl_type_is_sampler(t)
252             } {
253                 var.data.binding
254             } else {
255                 var.data.driver_location
256             };
257         }
258     }
259 
serialize(args: &[Self], blob: &mut blob)260     fn serialize(args: &[Self], blob: &mut blob) {
261         unsafe {
262             blob_write_uint16(blob, args.len() as u16);
263             for arg in args {
264                 blob_write_uint32(blob, arg.offset);
265                 blob_write_uint8(blob, arg.dead.into());
266                 match arg.kind {
267                     CompiledKernelArgType::ConstantBuffer => blob_write_uint8(blob, 0),
268                     CompiledKernelArgType::GlobalWorkOffsets => blob_write_uint8(blob, 1),
269                     CompiledKernelArgType::PrintfBuffer => blob_write_uint8(blob, 2),
270                     CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm)) => {
271                         blob_write_uint8(blob, 3);
272                         blob_write_uint8(blob, norm.into());
273                         blob_write_uint32(blob, addr_mode);
274                         blob_write_uint32(blob, filter_mode)
275                     }
276                     CompiledKernelArgType::FormatArray => blob_write_uint8(blob, 4),
277                     CompiledKernelArgType::OrderArray => blob_write_uint8(blob, 5),
278                     CompiledKernelArgType::WorkDim => blob_write_uint8(blob, 6),
279                     CompiledKernelArgType::WorkGroupOffsets => blob_write_uint8(blob, 7),
280                     CompiledKernelArgType::NumWorkgroups => blob_write_uint8(blob, 8),
281                     CompiledKernelArgType::GlobalWorkSize => blob_write_uint8(blob, 9),
282                     CompiledKernelArgType::APIArg(idx) => {
283                         blob_write_uint8(blob, 10);
284                         blob_write_uint32(blob, idx)
285                     }
286                 };
287             }
288         }
289     }
290 
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>291     fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
292         unsafe {
293             let len = blob_read_uint16(blob) as usize;
294             let mut res = Vec::with_capacity(len);
295 
296             for _ in 0..len {
297                 let offset = blob_read_uint32(blob);
298                 let dead = blob_read_uint8(blob) != 0;
299 
300                 let kind = match blob_read_uint8(blob) {
301                     0 => CompiledKernelArgType::ConstantBuffer,
302                     1 => CompiledKernelArgType::GlobalWorkOffsets,
303                     2 => CompiledKernelArgType::PrintfBuffer,
304                     3 => {
305                         let norm = blob_read_uint8(blob) != 0;
306                         let addr_mode = blob_read_uint32(blob);
307                         let filter_mode = blob_read_uint32(blob);
308                         CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm))
309                     }
310                     4 => CompiledKernelArgType::FormatArray,
311                     5 => CompiledKernelArgType::OrderArray,
312                     6 => CompiledKernelArgType::WorkDim,
313                     7 => CompiledKernelArgType::WorkGroupOffsets,
314                     8 => CompiledKernelArgType::NumWorkgroups,
315                     9 => CompiledKernelArgType::GlobalWorkSize,
316                     10 => {
317                         let idx = blob_read_uint32(blob);
318                         CompiledKernelArgType::APIArg(idx)
319                     }
320                     _ => return None,
321                 };
322 
323                 res.push(Self {
324                     kind: kind,
325                     offset: offset,
326                     dead: dead,
327                 });
328             }
329 
330             Some(res)
331         }
332     }
333 }
334 
335 #[derive(Clone, PartialEq, Eq, Hash)]
336 pub struct KernelInfo {
337     pub args: Vec<KernelArg>,
338     pub attributes_string: String,
339     work_group_size: [usize; 3],
340     work_group_size_hint: [u32; 3],
341     subgroup_size: usize,
342     num_subgroups: usize,
343 }
344 
345 struct CSOWrapper {
346     cso_ptr: *mut c_void,
347     dev: &'static Device,
348 }
349 
350 impl CSOWrapper {
new(dev: &'static Device, nir: &NirShader) -> Self351     fn new(dev: &'static Device, nir: &NirShader) -> Self {
352         let cso_ptr = dev
353             .helper_ctx()
354             .create_compute_state(nir, nir.shared_size());
355 
356         Self {
357             cso_ptr: cso_ptr,
358             dev: dev,
359         }
360     }
361 
get_cso_info(&self) -> pipe_compute_state_object_info362     fn get_cso_info(&self) -> pipe_compute_state_object_info {
363         self.dev.helper_ctx().compute_state_info(self.cso_ptr)
364     }
365 }
366 
367 impl Drop for CSOWrapper {
drop(&mut self)368     fn drop(&mut self) {
369         self.dev.helper_ctx().delete_compute_state(self.cso_ptr);
370     }
371 }
372 
373 enum KernelDevStateVariant {
374     Cso(CSOWrapper),
375     Nir(NirShader),
376 }
377 
378 #[derive(Debug, PartialEq)]
379 enum NirKernelVariant {
380     /// Can be used under any circumstance.
381     Default,
382 
383     /// Optimized variant making the following assumptions:
384     ///  - global_id_offsets are 0
385     ///  - workgroup_offsets are 0
386     ///  - local_size is info.local_size_hint
387     Optimized,
388 }
389 
390 impl Display for NirKernelVariant {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result391     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392         // this simply prints the enum name, so that's fine
393         Debug::fmt(self, f)
394     }
395 }
396 
397 pub struct NirKernelBuilds {
398     default_build: NirKernelBuild,
399     optimized: Option<NirKernelBuild>,
400     /// merged info with worst case values
401     info: pipe_compute_state_object_info,
402 }
403 
404 impl Index<NirKernelVariant> for NirKernelBuilds {
405     type Output = NirKernelBuild;
406 
index(&self, index: NirKernelVariant) -> &Self::Output407     fn index(&self, index: NirKernelVariant) -> &Self::Output {
408         match index {
409             NirKernelVariant::Default => &self.default_build,
410             NirKernelVariant::Optimized => self.optimized.as_ref().unwrap_or(&self.default_build),
411         }
412     }
413 }
414 
415 impl NirKernelBuilds {
new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self416     fn new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self {
417         let mut info = default_build.info;
418         if let Some(build) = &optimized {
419             info.max_threads = cmp::min(info.max_threads, build.info.max_threads);
420             info.simd_sizes &= build.info.simd_sizes;
421             info.private_memory = cmp::max(info.private_memory, build.info.private_memory);
422             info.preferred_simd_size =
423                 cmp::max(info.preferred_simd_size, build.info.preferred_simd_size);
424         }
425 
426         Self {
427             default_build: default_build,
428             optimized: optimized,
429             info: info,
430         }
431     }
432 }
433 
434 struct NirKernelBuild {
435     nir_or_cso: KernelDevStateVariant,
436     constant_buffer: Option<Arc<PipeResource>>,
437     info: pipe_compute_state_object_info,
438     shared_size: u64,
439     printf_info: Option<NirPrintfInfo>,
440     compiled_args: Vec<CompiledKernelArg>,
441 }
442 
443 // SAFETY: `CSOWrapper` is only safe to use if the device supports `pipe_caps.shareable_shaders` and
444 //         we make sure to set `nir_or_cso` to `KernelDevStateVariant::Cso` only if that's the case.
445 unsafe impl Send for NirKernelBuild {}
446 unsafe impl Sync for NirKernelBuild {}
447 
448 impl NirKernelBuild {
new(dev: &'static Device, mut out: CompilationResult) -> Self449     fn new(dev: &'static Device, mut out: CompilationResult) -> Self {
450         let cso = CSOWrapper::new(dev, &out.nir);
451         let info = cso.get_cso_info();
452         let cb = Self::create_nir_constant_buffer(dev, &out.nir);
453         let shared_size = out.nir.shared_size() as u64;
454         let printf_info = out.nir.take_printf_info();
455 
456         let nir_or_cso = if !dev.shareable_shaders() {
457             KernelDevStateVariant::Nir(out.nir)
458         } else {
459             KernelDevStateVariant::Cso(cso)
460         };
461 
462         NirKernelBuild {
463             nir_or_cso: nir_or_cso,
464             constant_buffer: cb,
465             info: info,
466             shared_size: shared_size,
467             printf_info: printf_info,
468             compiled_args: out.compiled_args,
469         }
470     }
471 
create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>>472     fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>> {
473         let buf = nir.get_constant_buffer();
474         let len = buf.len() as u32;
475 
476         if len > 0 {
477             // TODO bind as constant buffer
478             let res = dev
479                 .screen()
480                 .resource_create_buffer(len, ResourceType::Normal, PIPE_BIND_GLOBAL)
481                 .unwrap();
482 
483             dev.helper_ctx()
484                 .exec(|ctx| ctx.buffer_subdata(&res, 0, buf.as_ptr().cast(), len))
485                 .wait();
486 
487             Some(Arc::new(res))
488         } else {
489             None
490         }
491     }
492 }
493 
494 pub struct Kernel {
495     pub base: CLObjectBase<CL_INVALID_KERNEL>,
496     pub prog: Arc<Program>,
497     pub name: String,
498     values: Mutex<Vec<Option<KernelArgValue>>>,
499     builds: HashMap<&'static Device, Arc<NirKernelBuilds>>,
500     pub kernel_info: Arc<KernelInfo>,
501 }
502 
503 impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
504 
create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]> where T: std::convert::TryFrom<usize> + Copy, <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,505 fn create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]>
506 where
507     T: std::convert::TryFrom<usize> + Copy,
508     <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,
509 {
510     let mut res = [val; 3];
511     for (i, v) in vals.iter().enumerate() {
512         res[i] = (*v).try_into().ok().ok_or(CL_OUT_OF_RESOURCES)?;
513     }
514 
515     Ok(res)
516 }
517 
518 #[derive(Clone)]
519 struct CompilationResult {
520     nir: NirShader,
521     compiled_args: Vec<CompiledKernelArg>,
522 }
523 
524 impl CompilationResult {
deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self>525     fn deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self> {
526         let nir = NirShader::deserialize(
527             reader,
528             d.screen()
529                 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
530         )?;
531         let compiled_args = CompiledKernelArg::deserialize(reader)?;
532 
533         Some(Self {
534             nir: nir,
535             compiled_args,
536         })
537     }
538 
serialize(&self, blob: &mut blob)539     fn serialize(&self, blob: &mut blob) {
540         self.nir.serialize(blob);
541         CompiledKernelArg::serialize(&self.compiled_args, blob);
542     }
543 }
544 
opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool)545 fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
546     let nir_options = unsafe {
547         &*dev
548             .screen
549             .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
550     };
551 
552     while {
553         let mut progress = false;
554 
555         progress |= nir_pass!(nir, nir_copy_prop);
556         progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
557         progress |= nir_pass!(nir, nir_opt_dead_write_vars);
558 
559         if nir_options.lower_to_scalar {
560             nir_pass!(
561                 nir,
562                 nir_lower_alu_to_scalar,
563                 nir_options.lower_to_scalar_filter,
564                 ptr::null(),
565             );
566             nir_pass!(nir, nir_lower_phis_to_scalar, false);
567         }
568 
569         progress |= nir_pass!(nir, nir_opt_deref);
570         if has_explicit_types {
571             progress |= nir_pass!(nir, nir_opt_memcpy);
572         }
573         progress |= nir_pass!(nir, nir_opt_dce);
574         progress |= nir_pass!(nir, nir_opt_undef);
575         progress |= nir_pass!(nir, nir_opt_constant_folding);
576         progress |= nir_pass!(nir, nir_opt_cse);
577         nir_pass!(nir, nir_split_var_copies);
578         progress |= nir_pass!(nir, nir_lower_var_copies);
579         progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
580         nir_pass!(nir, nir_lower_alu);
581         progress |= nir_pass!(nir, nir_opt_phi_precision);
582         progress |= nir_pass!(nir, nir_opt_algebraic);
583         progress |= nir_pass!(
584             nir,
585             nir_opt_if,
586             nir_opt_if_options::nir_opt_if_optimize_phi_true_false,
587         );
588         progress |= nir_pass!(nir, nir_opt_dead_cf);
589         progress |= nir_pass!(nir, nir_opt_remove_phis);
590         // we don't want to be too aggressive here, but it kills a bit of CFG
591         progress |= nir_pass!(nir, nir_opt_peephole_select, 8, true, true);
592         progress |= nir_pass!(
593             nir,
594             nir_lower_vec3_to_vec4,
595             nir_variable_mode::nir_var_mem_generic | nir_variable_mode::nir_var_uniform,
596         );
597 
598         if nir_options.max_unroll_iterations != 0 {
599             progress |= nir_pass!(nir, nir_opt_loop_unroll);
600         }
601         nir.sweep_mem();
602         progress
603     } {}
604 }
605 
606 /// # Safety
607 ///
608 /// Only safe to call when `var` is a valid pointer to a valid [`nir_variable`]
can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool609 unsafe extern "C" fn can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool {
610     // SAFETY: It is the caller's responsibility to provide a valid and aligned pointer
611     let var_type = unsafe { (*var).type_ };
612     // SAFETY: `nir_variable`'s type invariant guarantees that the `type_` field is valid and
613     // properly aligned.
614     unsafe {
615         !glsl_type_is_image(var_type)
616             && !glsl_type_is_texture(var_type)
617             && !glsl_type_is_sampler(var_type)
618     }
619 }
620 
621 const DV_OPTS: nir_remove_dead_variables_options = nir_remove_dead_variables_options {
622     can_remove_var: Some(can_remove_var),
623     can_remove_var_data: ptr::null_mut(),
624 };
625 
compile_nir_to_args( dev: &Device, mut nir: NirShader, args: &[spirv::SPIRVKernelArg], lib_clc: &NirShader, ) -> (Vec<KernelArg>, NirShader)626 fn compile_nir_to_args(
627     dev: &Device,
628     mut nir: NirShader,
629     args: &[spirv::SPIRVKernelArg],
630     lib_clc: &NirShader,
631 ) -> (Vec<KernelArg>, NirShader) {
632     // this is a hack until we support fp16 properly and check for denorms inside vstore/vload_half
633     nir.preserve_fp16_denorms();
634 
635     // Set to rtne for now until drivers are able to report their preferred rounding mode, that also
636     // matches what we report via the API.
637     nir.set_fp_rounding_mode_rtne();
638 
639     nir_pass!(nir, nir_scale_fdiv);
640     nir.set_workgroup_size_variable_if_zero();
641     nir.structurize();
642     nir_pass!(
643         nir,
644         nir_lower_variable_initializers,
645         nir_variable_mode::nir_var_function_temp
646     );
647 
648     while {
649         let mut progress = false;
650         nir_pass!(nir, nir_split_var_copies);
651         progress |= nir_pass!(nir, nir_copy_prop);
652         progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
653         progress |= nir_pass!(nir, nir_opt_dead_write_vars);
654         progress |= nir_pass!(nir, nir_opt_deref);
655         progress |= nir_pass!(nir, nir_opt_dce);
656         progress |= nir_pass!(nir, nir_opt_undef);
657         progress |= nir_pass!(nir, nir_opt_constant_folding);
658         progress |= nir_pass!(nir, nir_opt_cse);
659         progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
660         progress |= nir_pass!(nir, nir_opt_algebraic);
661         progress
662     } {}
663     nir.inline(lib_clc);
664     nir.cleanup_functions();
665     // that should free up tons of memory
666     nir.sweep_mem();
667 
668     nir_pass!(nir, nir_dedup_inline_samplers);
669 
670     let printf_opts = nir_lower_printf_options {
671         ptr_bit_size: 0,
672         use_printf_base_identifier: false,
673         hash_format_strings: false,
674         max_buffer_size: dev.printf_buffer_size() as u32,
675     };
676     nir_pass!(nir, nir_lower_printf, &printf_opts);
677 
678     opt_nir(&mut nir, dev, false);
679 
680     (KernelArg::from_spirv_nir(args, &mut nir), nir)
681 }
682 
compile_nir_prepare_for_variants( dev: &Device, nir: &mut NirShader, compiled_args: &mut Vec<CompiledKernelArg>, )683 fn compile_nir_prepare_for_variants(
684     dev: &Device,
685     nir: &mut NirShader,
686     compiled_args: &mut Vec<CompiledKernelArg>,
687 ) {
688     // assign locations for inline samplers.
689     // IMPORTANT: this needs to happen before nir_remove_dead_variables.
690     let mut last_loc = -1;
691     for v in nir
692         .variables_with_mode(nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image)
693     {
694         if unsafe { !glsl_type_is_sampler(v.type_) } {
695             last_loc = v.data.location;
696             continue;
697         }
698         let s = unsafe { v.data.anon_1.sampler };
699         if s.is_inline_sampler() != 0 {
700             last_loc += 1;
701             v.data.location = last_loc;
702 
703             compiled_args.push(CompiledKernelArg {
704                 kind: CompiledKernelArgType::InlineSampler(Sampler::nir_to_cl(
705                     s.addressing_mode(),
706                     s.filter_mode(),
707                     s.normalized_coordinates(),
708                 )),
709                 offset: 0,
710                 dead: true,
711             });
712         } else {
713             last_loc = v.data.location;
714         }
715     }
716 
717     nir_pass!(
718         nir,
719         nir_remove_dead_variables,
720         nir_variable_mode::nir_var_uniform
721             | nir_variable_mode::nir_var_image
722             | nir_variable_mode::nir_var_mem_constant
723             | nir_variable_mode::nir_var_mem_shared
724             | nir_variable_mode::nir_var_function_temp,
725         &DV_OPTS,
726     );
727 
728     nir_pass!(nir, nir_lower_readonly_images_to_tex, true);
729     nir_pass!(
730         nir,
731         nir_lower_cl_images,
732         !dev.images_as_deref(),
733         !dev.samplers_as_deref(),
734     );
735 
736     nir_pass!(
737         nir,
738         nir_lower_vars_to_explicit_types,
739         nir_variable_mode::nir_var_mem_constant,
740         Some(glsl_get_cl_type_size_align),
741     );
742 
743     // has to run before adding internal kernel arguments
744     nir.extract_constant_initializers();
745 
746     // needed to convert variables to load intrinsics
747     nir_pass!(nir, nir_lower_system_values);
748 
749     // Run here so we can decide if it makes sense to compile a variant, e.g. read system values.
750     nir.gather_info();
751 }
752 
compile_nir_variant( res: &mut CompilationResult, dev: &Device, variant: NirKernelVariant, args: &[KernelArg], name: &str, )753 fn compile_nir_variant(
754     res: &mut CompilationResult,
755     dev: &Device,
756     variant: NirKernelVariant,
757     args: &[KernelArg],
758     name: &str,
759 ) {
760     let mut lower_state = rusticl_lower_state::default();
761     let compiled_args = &mut res.compiled_args;
762     let nir = &mut res.nir;
763 
764     let address_bits_ptr_type;
765     let address_bits_base_type;
766     let global_address_format;
767     let shared_address_format;
768 
769     if dev.address_bits() == 64 {
770         address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
771         address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
772         global_address_format = nir_address_format::nir_address_format_64bit_global;
773         shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
774     } else {
775         address_bits_ptr_type = unsafe { glsl_uint_type() };
776         address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
777         global_address_format = nir_address_format::nir_address_format_32bit_global;
778         shared_address_format = nir_address_format::nir_address_format_32bit_offset;
779     }
780 
781     let nir_options = unsafe {
782         &*dev
783             .screen
784             .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
785     };
786 
787     if variant == NirKernelVariant::Optimized {
788         let wgsh = nir.workgroup_size_hint();
789         if wgsh != [0; 3] {
790             nir.set_workgroup_size(wgsh);
791         }
792     }
793 
794     let mut compute_options = nir_lower_compute_system_values_options::default();
795     compute_options.set_has_global_size(true);
796     if variant != NirKernelVariant::Optimized {
797         compute_options.set_has_base_global_invocation_id(true);
798         compute_options.set_has_base_workgroup_id(true);
799     }
800     nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
801     nir.gather_info();
802 
803     let mut add_var = |nir: &mut NirShader,
804                        var_loc: &mut usize,
805                        kind: CompiledKernelArgType,
806                        glsl_type: *const glsl_type,
807                        name| {
808         *var_loc = compiled_args.len();
809         compiled_args.push(CompiledKernelArg {
810             kind: kind,
811             offset: 0,
812             dead: true,
813         });
814         nir.add_var(
815             nir_variable_mode::nir_var_uniform,
816             glsl_type,
817             *var_loc,
818             name,
819         );
820     };
821 
822     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
823         debug_assert_ne!(variant, NirKernelVariant::Optimized);
824         add_var(
825             nir,
826             &mut lower_state.base_global_invoc_id_loc,
827             CompiledKernelArgType::GlobalWorkOffsets,
828             unsafe { glsl_vector_type(address_bits_base_type, 3) },
829             c"base_global_invocation_id",
830         )
831     }
832 
833     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_GROUP_SIZE) {
834         add_var(
835             nir,
836             &mut lower_state.global_size_loc,
837             CompiledKernelArgType::GlobalWorkSize,
838             unsafe { glsl_vector_type(address_bits_base_type, 3) },
839             c"global_size",
840         )
841     }
842 
843     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_WORKGROUP_ID) {
844         debug_assert_ne!(variant, NirKernelVariant::Optimized);
845         add_var(
846             nir,
847             &mut lower_state.base_workgroup_id_loc,
848             CompiledKernelArgType::WorkGroupOffsets,
849             unsafe { glsl_vector_type(address_bits_base_type, 3) },
850             c"base_workgroup_id",
851         );
852     }
853 
854     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_NUM_WORKGROUPS) {
855         add_var(
856             nir,
857             &mut lower_state.num_workgroups_loc,
858             CompiledKernelArgType::NumWorkgroups,
859             unsafe { glsl_vector_type(glsl_base_type::GLSL_TYPE_UINT, 3) },
860             c"num_workgroups",
861         );
862     }
863 
864     if nir.has_constant() {
865         add_var(
866             nir,
867             &mut lower_state.const_buf_loc,
868             CompiledKernelArgType::ConstantBuffer,
869             address_bits_ptr_type,
870             c"constant_buffer_addr",
871         );
872     }
873     if nir.has_printf() {
874         add_var(
875             nir,
876             &mut lower_state.printf_buf_loc,
877             CompiledKernelArgType::PrintfBuffer,
878             address_bits_ptr_type,
879             c"printf_buffer_addr",
880         );
881     }
882 
883     if nir.num_images() > 0 || nir.num_textures() > 0 {
884         let count = nir.num_images() + nir.num_textures();
885 
886         add_var(
887             nir,
888             &mut lower_state.format_arr_loc,
889             CompiledKernelArgType::FormatArray,
890             unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
891             c"image_formats",
892         );
893 
894         add_var(
895             nir,
896             &mut lower_state.order_arr_loc,
897             CompiledKernelArgType::OrderArray,
898             unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
899             c"image_orders",
900         );
901     }
902 
903     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
904         add_var(
905             nir,
906             &mut lower_state.work_dim_loc,
907             CompiledKernelArgType::WorkDim,
908             unsafe { glsl_uint8_t_type() },
909             c"work_dim",
910         );
911     }
912 
913     // need to run after first opt loop and remove_dead_variables to get rid of uneccessary scratch
914     // memory
915     nir_pass!(
916         nir,
917         nir_lower_vars_to_explicit_types,
918         nir_variable_mode::nir_var_mem_shared
919             | nir_variable_mode::nir_var_function_temp
920             | nir_variable_mode::nir_var_shader_temp
921             | nir_variable_mode::nir_var_uniform
922             | nir_variable_mode::nir_var_mem_global
923             | nir_variable_mode::nir_var_mem_generic,
924         Some(glsl_get_cl_type_size_align),
925     );
926 
927     opt_nir(nir, dev, true);
928     nir_pass!(nir, nir_lower_memcpy);
929 
930     // we might have got rid of more function_temp or shared memory
931     nir.reset_scratch_size();
932     nir.reset_shared_size();
933     nir_pass!(
934         nir,
935         nir_remove_dead_variables,
936         nir_variable_mode::nir_var_function_temp | nir_variable_mode::nir_var_mem_shared,
937         &DV_OPTS,
938     );
939     nir_pass!(
940         nir,
941         nir_lower_vars_to_explicit_types,
942         nir_variable_mode::nir_var_function_temp
943             | nir_variable_mode::nir_var_mem_shared
944             | nir_variable_mode::nir_var_mem_generic,
945         Some(glsl_get_cl_type_size_align),
946     );
947 
948     nir_pass!(
949         nir,
950         nir_lower_explicit_io,
951         nir_variable_mode::nir_var_mem_global | nir_variable_mode::nir_var_mem_constant,
952         global_address_format,
953     );
954 
955     nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
956     nir_pass!(
957         nir,
958         nir_lower_explicit_io,
959         nir_variable_mode::nir_var_mem_shared
960             | nir_variable_mode::nir_var_function_temp
961             | nir_variable_mode::nir_var_uniform,
962         shared_address_format,
963     );
964 
965     if nir_options.lower_int64_options.0 != 0 && !nir_options.late_lower_int64 {
966         nir_pass!(nir, nir_lower_int64);
967     }
968 
969     if nir_options.lower_uniforms_to_ubo {
970         nir_pass!(nir, rusticl_lower_inputs);
971     }
972 
973     nir_pass!(nir, nir_lower_convert_alu_types, None);
974 
975     opt_nir(nir, dev, true);
976 
977     /* before passing it into drivers, assign locations as drivers might remove nir_variables or
978      * other things we depend on
979      */
980     CompiledKernelArg::assign_locations(compiled_args, nir);
981 
982     /* update the has_variable_shared_mem info as we might have DCEed all of them */
983     nir.set_has_variable_shared_mem(compiled_args.iter().any(|arg| {
984         if let CompiledKernelArgType::APIArg(idx) = arg.kind {
985             args[idx as usize].kind == KernelArgType::MemLocal && !arg.dead
986         } else {
987             false
988         }
989     }));
990 
991     if Platform::dbg().nir {
992         eprintln!("=== Printing nir variant '{variant}' for '{name}' before driver finalization");
993         nir.print();
994     }
995 
996     if dev.screen.finalize_nir(nir) {
997         if Platform::dbg().nir {
998             eprintln!(
999                 "=== Printing nir variant '{variant}' for '{name}' after driver finalization"
1000             );
1001             nir.print();
1002         }
1003     }
1004 
1005     nir_pass!(nir, nir_opt_dce);
1006     nir.sweep_mem();
1007 }
1008 
compile_nir_remaining( dev: &Device, mut nir: NirShader, args: &[KernelArg], name: &str, ) -> (CompilationResult, Option<CompilationResult>)1009 fn compile_nir_remaining(
1010     dev: &Device,
1011     mut nir: NirShader,
1012     args: &[KernelArg],
1013     name: &str,
1014 ) -> (CompilationResult, Option<CompilationResult>) {
1015     // add all API kernel args
1016     let mut compiled_args: Vec<_> = (0..args.len())
1017         .map(|idx| CompiledKernelArg {
1018             kind: CompiledKernelArgType::APIArg(idx as u32),
1019             offset: 0,
1020             dead: true,
1021         })
1022         .collect();
1023 
1024     compile_nir_prepare_for_variants(dev, &mut nir, &mut compiled_args);
1025     if Platform::dbg().nir {
1026         eprintln!("=== Printing nir for '{name}' before specialization");
1027         nir.print();
1028     }
1029 
1030     let mut default_build = CompilationResult {
1031         nir: nir,
1032         compiled_args: compiled_args,
1033     };
1034 
1035     // check if we even want to compile a variant before cloning the compilation state
1036     let has_wgs_hint = default_build.nir.workgroup_size_variable()
1037         && default_build.nir.workgroup_size_hint() != [0; 3];
1038     let has_offsets = default_build
1039         .nir
1040         .reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_INVOCATION_ID);
1041 
1042     let mut optimized = (!Platform::dbg().no_variants && (has_offsets || has_wgs_hint))
1043         .then(|| default_build.clone());
1044 
1045     compile_nir_variant(
1046         &mut default_build,
1047         dev,
1048         NirKernelVariant::Default,
1049         args,
1050         name,
1051     );
1052     if let Some(optimized) = &mut optimized {
1053         compile_nir_variant(optimized, dev, NirKernelVariant::Optimized, args, name);
1054     }
1055 
1056     (default_build, optimized)
1057 }
1058 
1059 pub struct SPIRVToNirResult {
1060     pub kernel_info: KernelInfo,
1061     pub nir_kernel_builds: NirKernelBuilds,
1062 }
1063 
1064 impl SPIRVToNirResult {
new( dev: &'static Device, kernel_info: &clc_kernel_info, args: Vec<KernelArg>, default_build: CompilationResult, optimized: Option<CompilationResult>, ) -> Self1065     fn new(
1066         dev: &'static Device,
1067         kernel_info: &clc_kernel_info,
1068         args: Vec<KernelArg>,
1069         default_build: CompilationResult,
1070         optimized: Option<CompilationResult>,
1071     ) -> Self {
1072         // TODO: we _should_ be able to parse them out of the SPIR-V, but clc doesn't handle
1073         //       indirections yet.
1074         let nir = &default_build.nir;
1075         let wgs = nir.workgroup_size();
1076         let subgroup_size = nir.subgroup_size();
1077         let num_subgroups = nir.num_subgroups();
1078 
1079         let default_build = NirKernelBuild::new(dev, default_build);
1080         let optimized = optimized.map(|opt| NirKernelBuild::new(dev, opt));
1081 
1082         let kernel_info = KernelInfo {
1083             args: args,
1084             attributes_string: kernel_info.attribute_str(),
1085             work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
1086             work_group_size_hint: kernel_info.local_size_hint,
1087             subgroup_size: subgroup_size as usize,
1088             num_subgroups: num_subgroups as usize,
1089         };
1090 
1091         Self {
1092             kernel_info: kernel_info,
1093             nir_kernel_builds: NirKernelBuilds::new(default_build, optimized),
1094         }
1095     }
1096 
deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self>1097     fn deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self> {
1098         let mut reader = blob_reader::default();
1099         unsafe {
1100             blob_reader_init(&mut reader, bin.as_ptr().cast(), bin.len());
1101         }
1102 
1103         let args = KernelArg::deserialize(&mut reader)?;
1104         let default_build = CompilationResult::deserialize(&mut reader, d)?;
1105 
1106         // SAFETY: on overrun this returns 0
1107         let optimized = match unsafe { blob_read_uint8(&mut reader) } {
1108             0 => None,
1109             _ => Some(CompilationResult::deserialize(&mut reader, d)?),
1110         };
1111 
1112         reader
1113             .overrun
1114             .not()
1115             .then(|| SPIRVToNirResult::new(d, kernel_info, args, default_build, optimized))
1116     }
1117 
1118     // we can't use Self here as the nir shader might be compiled to a cso already and we can't
1119     // cache that.
serialize( blob: &mut blob, args: &[KernelArg], default_build: &CompilationResult, optimized: &Option<CompilationResult>, )1120     fn serialize(
1121         blob: &mut blob,
1122         args: &[KernelArg],
1123         default_build: &CompilationResult,
1124         optimized: &Option<CompilationResult>,
1125     ) {
1126         KernelArg::serialize(args, blob);
1127         default_build.serialize(blob);
1128         match optimized {
1129             Some(variant) => {
1130                 unsafe { blob_write_uint8(blob, 1) };
1131                 variant.serialize(blob);
1132             }
1133             None => unsafe {
1134                 blob_write_uint8(blob, 0);
1135             },
1136         }
1137     }
1138 }
1139 
convert_spirv_to_nir( build: &ProgramBuild, name: &str, args: &[spirv::SPIRVKernelArg], dev: &'static Device, ) -> SPIRVToNirResult1140 pub(super) fn convert_spirv_to_nir(
1141     build: &ProgramBuild,
1142     name: &str,
1143     args: &[spirv::SPIRVKernelArg],
1144     dev: &'static Device,
1145 ) -> SPIRVToNirResult {
1146     let cache = dev.screen().shader_cache();
1147     let key = build.hash_key(dev, name);
1148     let spirv_info = build.spirv_info(name, dev).unwrap();
1149 
1150     cache
1151         .as_ref()
1152         .and_then(|cache| cache.get(&mut key?))
1153         .and_then(|entry| SPIRVToNirResult::deserialize(&entry, dev, spirv_info))
1154         .unwrap_or_else(|| {
1155             let nir = build.to_nir(name, dev);
1156 
1157             if Platform::dbg().nir {
1158                 eprintln!("=== Printing nir for '{name}' after spirv_to_nir");
1159                 nir.print();
1160             }
1161 
1162             let (mut args, nir) = compile_nir_to_args(dev, nir, args, &dev.lib_clc);
1163             let (default_build, optimized) = compile_nir_remaining(dev, nir, &args, name);
1164 
1165             for build in [Some(&default_build), optimized.as_ref()].into_iter() {
1166                 let Some(build) = build else {
1167                     continue;
1168                 };
1169 
1170                 for arg in &build.compiled_args {
1171                     if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1172                         args[idx as usize].dead &= arg.dead;
1173                     }
1174                 }
1175             }
1176 
1177             if let Some(cache) = cache {
1178                 let mut blob = blob::default();
1179                 unsafe {
1180                     blob_init(&mut blob);
1181                     SPIRVToNirResult::serialize(&mut blob, &args, &default_build, &optimized);
1182                     let bin = slice::from_raw_parts(blob.data, blob.size);
1183                     cache.put(bin, &mut key.unwrap());
1184                     blob_finish(&mut blob);
1185                 }
1186             }
1187 
1188             SPIRVToNirResult::new(dev, spirv_info, args, default_build, optimized)
1189         })
1190 }
1191 
extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S]1192 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
1193     let val;
1194     (val, *buf) = (*buf).split_at(S);
1195     // we split of 4 bytes and convert to [u8; 4], so this should be safe
1196     // use split_array_ref once it's stable
1197     val.try_into().unwrap()
1198 }
1199 
1200 impl Kernel {
new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel>1201     pub fn new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel> {
1202         let kernel_info = Arc::clone(prog_build.kernel_info.get(&name).unwrap());
1203         let builds = prog_build
1204             .builds
1205             .iter()
1206             .filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, Arc::clone(k))))
1207             .collect();
1208 
1209         let values = vec![None; kernel_info.args.len()];
1210         Arc::new(Self {
1211             base: CLObjectBase::new(RusticlTypes::Kernel),
1212             prog: prog,
1213             name: name,
1214             values: Mutex::new(values),
1215             builds: builds,
1216             kernel_info: kernel_info,
1217         })
1218     }
1219 
suggest_local_size( &self, d: &Device, work_dim: usize, grid: &mut [usize], block: &mut [usize], )1220     pub fn suggest_local_size(
1221         &self,
1222         d: &Device,
1223         work_dim: usize,
1224         grid: &mut [usize],
1225         block: &mut [usize],
1226     ) {
1227         let mut threads = self.max_threads_per_block(d);
1228         let dim_threads = d.max_block_sizes();
1229         let subgroups = self.preferred_simd_size(d);
1230 
1231         for i in 0..work_dim {
1232             let t = cmp::min(threads, dim_threads[i]);
1233             let gcd = gcd(t, grid[i]);
1234 
1235             block[i] = gcd;
1236             grid[i] /= gcd;
1237 
1238             // update limits
1239             threads /= block[i];
1240         }
1241 
1242         // if we didn't fill the subgroup we can do a bit better if we have threads remaining
1243         let total_threads = block.iter().take(work_dim).product::<usize>();
1244         if threads != 1 && total_threads < subgroups {
1245             for i in 0..work_dim {
1246                 if grid[i] * total_threads < threads && grid[i] * block[i] <= dim_threads[i] {
1247                     block[i] *= grid[i];
1248                     grid[i] = 1;
1249                     // can only do it once as nothing is cleanly divisible
1250                     break;
1251                 }
1252             }
1253         }
1254     }
1255 
optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3])1256     fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) {
1257         if !block.contains(&0) {
1258             for i in 0..3 {
1259                 // we already made sure everything is fine
1260                 grid[i] /= block[i] as usize;
1261             }
1262             return;
1263         }
1264 
1265         let mut usize_block = [0usize; 3];
1266         for i in 0..3 {
1267             usize_block[i] = block[i] as usize;
1268         }
1269 
1270         self.suggest_local_size(d, 3, grid, &mut usize_block);
1271 
1272         for i in 0..3 {
1273             block[i] = usize_block[i] as u32;
1274         }
1275     }
1276 
1277     // the painful part is, that host threads are allowed to modify the kernel object once it was
1278     // enqueued, so return a closure with all req data included.
launch( self: &Arc<Self>, q: &Arc<Queue>, work_dim: u32, block: &[usize], grid: &[usize], offsets: &[usize], ) -> CLResult<EventSig>1279     pub fn launch(
1280         self: &Arc<Self>,
1281         q: &Arc<Queue>,
1282         work_dim: u32,
1283         block: &[usize],
1284         grid: &[usize],
1285         offsets: &[usize],
1286     ) -> CLResult<EventSig> {
1287         // Clone all the data we need to execute this kernel
1288         let kernel_info = Arc::clone(&self.kernel_info);
1289         let arg_values = self.arg_values().clone();
1290         let nir_kernel_builds = Arc::clone(&self.builds[q.device]);
1291 
1292         let mut buffer_arcs = HashMap::new();
1293         let mut image_arcs = HashMap::new();
1294 
1295         // need to preprocess buffer and image arguments so we hold a strong reference until the
1296         // event was processed.
1297         for arg in arg_values.iter() {
1298             match arg {
1299                 Some(KernelArgValue::Buffer(buffer)) => {
1300                     buffer_arcs.insert(
1301                         // we use the ptr as the key, and also cast it to usize so we don't need to
1302                         // deal with Send + Sync here.
1303                         buffer.as_ptr() as usize,
1304                         buffer.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
1305                     );
1306                 }
1307                 Some(KernelArgValue::Image(image)) => {
1308                     image_arcs.insert(
1309                         image.as_ptr() as usize,
1310                         image.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
1311                     );
1312                 }
1313                 _ => {}
1314             }
1315         }
1316 
1317         // operations we want to report errors to the clients
1318         let mut block = create_kernel_arr::<u32>(block, 1)?;
1319         let mut grid = create_kernel_arr::<usize>(grid, 1)?;
1320         let offsets = create_kernel_arr::<usize>(offsets, 0)?;
1321 
1322         let api_grid = grid;
1323 
1324         self.optimize_local_size(q.device, &mut grid, &mut block);
1325 
1326         Ok(Box::new(move |q, ctx| {
1327             let hw_max_grid: Vec<usize> = q
1328                 .device
1329                 .max_grid_size()
1330                 .into_iter()
1331                 .map(|val| val.try_into().unwrap_or(usize::MAX))
1332                 // clamped as pipe_launch_grid::grid is only u32
1333                 .map(|val| cmp::min(val, u32::MAX as usize))
1334                 .collect();
1335 
1336             let variant = if offsets == [0; 3]
1337                 && grid[0] <= hw_max_grid[0]
1338                 && grid[1] <= hw_max_grid[1]
1339                 && grid[2] <= hw_max_grid[2]
1340                 && (kernel_info.work_group_size_hint == [0; 3]
1341                     || block == kernel_info.work_group_size_hint)
1342             {
1343                 NirKernelVariant::Optimized
1344             } else {
1345                 NirKernelVariant::Default
1346             };
1347 
1348             let nir_kernel_build = &nir_kernel_builds[variant];
1349             let mut workgroup_id_offset_loc = None;
1350             let mut input = Vec::new();
1351             // Set it once so we get the alignment padding right
1352             let static_local_size: u64 = nir_kernel_build.shared_size;
1353             let mut variable_local_size: u64 = static_local_size;
1354             let printf_size = q.device.printf_buffer_size() as u32;
1355             let mut samplers = Vec::new();
1356             let mut iviews = Vec::new();
1357             let mut sviews = Vec::new();
1358             let mut tex_formats: Vec<u16> = Vec::new();
1359             let mut tex_orders: Vec<u16> = Vec::new();
1360             let mut img_formats: Vec<u16> = Vec::new();
1361             let mut img_orders: Vec<u16> = Vec::new();
1362 
1363             let null_ptr;
1364             let null_ptr_v3;
1365             if q.device.address_bits() == 64 {
1366                 null_ptr = [0u8; 8].as_slice();
1367                 null_ptr_v3 = [0u8; 24].as_slice();
1368             } else {
1369                 null_ptr = [0u8; 4].as_slice();
1370                 null_ptr_v3 = [0u8; 12].as_slice();
1371             };
1372 
1373             let mut resource_info = Vec::new();
1374             fn add_global<'a>(
1375                 q: &Queue,
1376                 input: &mut Vec<u8>,
1377                 resource_info: &mut Vec<(&'a PipeResource, usize)>,
1378                 res: &'a PipeResource,
1379                 offset: usize,
1380             ) {
1381                 resource_info.push((res, input.len()));
1382                 if q.device.address_bits() == 64 {
1383                     let offset: u64 = offset as u64;
1384                     input.extend_from_slice(&offset.to_ne_bytes());
1385                 } else {
1386                     let offset: u32 = offset as u32;
1387                     input.extend_from_slice(&offset.to_ne_bytes());
1388                 }
1389             }
1390 
1391             fn add_sysval(q: &Queue, input: &mut Vec<u8>, vals: &[usize; 3]) {
1392                 if q.device.address_bits() == 64 {
1393                     input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u64)) });
1394                 } else {
1395                     input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u32)) });
1396                 }
1397             }
1398 
1399             let mut printf_buf = None;
1400             if nir_kernel_build.printf_info.is_some() {
1401                 let buf = q
1402                     .device
1403                     .screen
1404                     .resource_create_buffer(printf_size, ResourceType::Staging, PIPE_BIND_GLOBAL)
1405                     .unwrap();
1406 
1407                 let init_data: [u8; 1] = [4];
1408                 ctx.buffer_subdata(&buf, 0, init_data.as_ptr().cast(), init_data.len() as u32);
1409 
1410                 printf_buf = Some(buf);
1411             }
1412 
1413             for arg in &nir_kernel_build.compiled_args {
1414                 let is_opaque = if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1415                     kernel_info.args[idx as usize].kind.is_opaque()
1416                 } else {
1417                     false
1418                 };
1419 
1420                 if !is_opaque && arg.offset as usize > input.len() {
1421                     input.resize(arg.offset as usize, 0);
1422                 }
1423 
1424                 match arg.kind {
1425                     CompiledKernelArgType::APIArg(idx) => {
1426                         let api_arg = &kernel_info.args[idx as usize];
1427                         if api_arg.dead {
1428                             continue;
1429                         }
1430 
1431                         let Some(value) = &arg_values[idx as usize] else {
1432                             continue;
1433                         };
1434 
1435                         match value {
1436                             KernelArgValue::Constant(c) => input.extend_from_slice(c),
1437                             KernelArgValue::Buffer(buffer) => {
1438                                 let buffer = &buffer_arcs[&(buffer.as_ptr() as usize)];
1439                                 let rw = if api_arg.spirv.address_qualifier
1440                                     == clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT
1441                                 {
1442                                     RWFlags::RD
1443                                 } else {
1444                                     RWFlags::RW
1445                                 };
1446 
1447                                 let res = buffer.get_res_for_access(ctx, rw)?;
1448                                 add_global(q, &mut input, &mut resource_info, res, buffer.offset());
1449                             }
1450                             KernelArgValue::Image(image) => {
1451                                 let image = &image_arcs[&(image.as_ptr() as usize)];
1452                                 let (formats, orders) = if api_arg.kind == KernelArgType::Image {
1453                                     iviews.push(image.image_view(ctx, false)?);
1454                                     (&mut img_formats, &mut img_orders)
1455                                 } else if api_arg.kind == KernelArgType::RWImage {
1456                                     iviews.push(image.image_view(ctx, true)?);
1457                                     (&mut img_formats, &mut img_orders)
1458                                 } else {
1459                                     sviews.push(image.sampler_view(ctx)?);
1460                                     (&mut tex_formats, &mut tex_orders)
1461                                 };
1462 
1463                                 let binding = arg.offset as usize;
1464                                 assert!(binding >= formats.len());
1465 
1466                                 formats.resize(binding, 0);
1467                                 orders.resize(binding, 0);
1468 
1469                                 formats.push(image.image_format.image_channel_data_type as u16);
1470                                 orders.push(image.image_format.image_channel_order as u16);
1471                             }
1472                             KernelArgValue::LocalMem(size) => {
1473                                 // TODO 32 bit
1474                                 let pot = cmp::min(*size, 0x80);
1475                                 variable_local_size = variable_local_size
1476                                     .next_multiple_of(pot.next_power_of_two() as u64);
1477                                 if q.device.address_bits() == 64 {
1478                                     let variable_local_size: [u8; 8] =
1479                                         variable_local_size.to_ne_bytes();
1480                                     input.extend_from_slice(&variable_local_size);
1481                                 } else {
1482                                     let variable_local_size: [u8; 4] =
1483                                         (variable_local_size as u32).to_ne_bytes();
1484                                     input.extend_from_slice(&variable_local_size);
1485                                 }
1486                                 variable_local_size += *size as u64;
1487                             }
1488                             KernelArgValue::Sampler(sampler) => {
1489                                 samplers.push(sampler.pipe());
1490                             }
1491                             KernelArgValue::None => {
1492                                 assert!(
1493                                     api_arg.kind == KernelArgType::MemGlobal
1494                                         || api_arg.kind == KernelArgType::MemConstant
1495                                 );
1496                                 input.extend_from_slice(null_ptr);
1497                             }
1498                         }
1499                     }
1500                     CompiledKernelArgType::ConstantBuffer => {
1501                         assert!(nir_kernel_build.constant_buffer.is_some());
1502                         let res = nir_kernel_build.constant_buffer.as_ref().unwrap();
1503                         add_global(q, &mut input, &mut resource_info, res, 0);
1504                     }
1505                     CompiledKernelArgType::GlobalWorkOffsets => {
1506                         add_sysval(q, &mut input, &offsets);
1507                     }
1508                     CompiledKernelArgType::WorkGroupOffsets => {
1509                         workgroup_id_offset_loc = Some(input.len());
1510                         input.extend_from_slice(null_ptr_v3);
1511                     }
1512                     CompiledKernelArgType::GlobalWorkSize => {
1513                         add_sysval(q, &mut input, &api_grid);
1514                     }
1515                     CompiledKernelArgType::PrintfBuffer => {
1516                         let res = printf_buf.as_ref().unwrap();
1517                         add_global(q, &mut input, &mut resource_info, res, 0);
1518                     }
1519                     CompiledKernelArgType::InlineSampler(cl) => {
1520                         samplers.push(Sampler::cl_to_pipe(cl));
1521                     }
1522                     CompiledKernelArgType::FormatArray => {
1523                         input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) });
1524                         input.extend_from_slice(unsafe { as_byte_slice(&img_formats) });
1525                     }
1526                     CompiledKernelArgType::OrderArray => {
1527                         input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) });
1528                         input.extend_from_slice(unsafe { as_byte_slice(&img_orders) });
1529                     }
1530                     CompiledKernelArgType::WorkDim => {
1531                         input.extend_from_slice(&[work_dim as u8; 1]);
1532                     }
1533                     CompiledKernelArgType::NumWorkgroups => {
1534                         input.extend_from_slice(unsafe {
1535                             as_byte_slice(&[grid[0] as u32, grid[1] as u32, grid[2] as u32])
1536                         });
1537                     }
1538                 }
1539             }
1540 
1541             // subtract the shader local_size as we only request something on top of that.
1542             variable_local_size -= static_local_size;
1543 
1544             let samplers: Vec<_> = samplers
1545                 .iter()
1546                 .map(|s| ctx.create_sampler_state(s))
1547                 .collect();
1548 
1549             let mut resources = Vec::with_capacity(resource_info.len());
1550             let mut globals: Vec<*mut u32> = Vec::with_capacity(resource_info.len());
1551             for (res, offset) in resource_info {
1552                 resources.push(res);
1553                 globals.push(unsafe { input.as_mut_ptr().byte_add(offset) }.cast());
1554             }
1555 
1556             let temp_cso;
1557             let cso = match &nir_kernel_build.nir_or_cso {
1558                 KernelDevStateVariant::Cso(cso) => cso,
1559                 KernelDevStateVariant::Nir(nir) => {
1560                     temp_cso = CSOWrapper::new(q.device, nir);
1561                     &temp_cso
1562                 }
1563             };
1564 
1565             let sviews_len = sviews.len();
1566             ctx.bind_compute_state(cso.cso_ptr);
1567             ctx.bind_sampler_states(&samplers);
1568             ctx.set_sampler_views(sviews);
1569             ctx.set_shader_images(&iviews);
1570             ctx.set_global_binding(resources.as_slice(), &mut globals);
1571 
1572             for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
1573                 for y in 0..grid[1].div_ceil(hw_max_grid[1]) {
1574                     for x in 0..grid[0].div_ceil(hw_max_grid[0]) {
1575                         if let Some(workgroup_id_offset_loc) = workgroup_id_offset_loc {
1576                             let this_offsets =
1577                                 [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]];
1578 
1579                             if q.device.address_bits() == 64 {
1580                                 let val = this_offsets.map(|v| v as u64);
1581                                 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24]
1582                                     .copy_from_slice(unsafe { as_byte_slice(&val) });
1583                             } else {
1584                                 let val = this_offsets.map(|v| v as u32);
1585                                 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12]
1586                                     .copy_from_slice(unsafe { as_byte_slice(&val) });
1587                             }
1588                         }
1589 
1590                         let this_grid = [
1591                             cmp::min(hw_max_grid[0], grid[0] - hw_max_grid[0] * x) as u32,
1592                             cmp::min(hw_max_grid[1], grid[1] - hw_max_grid[1] * y) as u32,
1593                             cmp::min(hw_max_grid[2], grid[2] - hw_max_grid[2] * z) as u32,
1594                         ];
1595 
1596                         ctx.update_cb0(&input)?;
1597                         ctx.launch_grid(work_dim, block, this_grid, variable_local_size as u32);
1598 
1599                         if Platform::dbg().sync_every_event {
1600                             ctx.flush().wait();
1601                         }
1602                     }
1603                 }
1604             }
1605 
1606             ctx.clear_global_binding(globals.len() as u32);
1607             ctx.clear_shader_images(iviews.len() as u32);
1608             ctx.clear_sampler_views(sviews_len as u32);
1609             ctx.clear_sampler_states(samplers.len() as u32);
1610 
1611             ctx.bind_compute_state(ptr::null_mut());
1612 
1613             ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
1614 
1615             samplers.iter().for_each(|s| ctx.delete_sampler_state(*s));
1616 
1617             if let Some(printf_buf) = &printf_buf {
1618                 let tx = ctx
1619                     .buffer_map(printf_buf, 0, printf_size as i32, RWFlags::RD)
1620                     .ok_or(CL_OUT_OF_RESOURCES)?;
1621                 let mut buf: &[u8] =
1622                     unsafe { slice::from_raw_parts(tx.ptr().cast(), printf_size as usize) };
1623                 let length = u32::from_ne_bytes(*extract(&mut buf));
1624 
1625                 // update our slice to make sure we don't go out of bounds
1626                 buf = &buf[0..(length - 4) as usize];
1627                 if let Some(pf) = &nir_kernel_build.printf_info {
1628                     pf.u_printf(buf)
1629                 }
1630             }
1631 
1632             Ok(())
1633         }))
1634     }
1635 
arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>>1636     pub fn arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>> {
1637         self.values.lock().unwrap()
1638     }
1639 
set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()>1640     pub fn set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()> {
1641         self.values
1642             .lock()
1643             .unwrap()
1644             .get_mut(idx)
1645             .ok_or(CL_INVALID_ARG_INDEX)?
1646             .replace(arg);
1647         Ok(())
1648     }
1649 
access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier1650     pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
1651         let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier;
1652 
1653         if aq
1654             == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
1655                 | clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE
1656         {
1657             CL_KERNEL_ARG_ACCESS_READ_WRITE
1658         } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ {
1659             CL_KERNEL_ARG_ACCESS_READ_ONLY
1660         } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE {
1661             CL_KERNEL_ARG_ACCESS_WRITE_ONLY
1662         } else {
1663             CL_KERNEL_ARG_ACCESS_NONE
1664         }
1665     }
1666 
address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier1667     pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
1668         match self.kernel_info.args[idx as usize].spirv.address_qualifier {
1669             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
1670                 CL_KERNEL_ARG_ADDRESS_PRIVATE
1671             }
1672             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
1673                 CL_KERNEL_ARG_ADDRESS_CONSTANT
1674             }
1675             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
1676                 CL_KERNEL_ARG_ADDRESS_LOCAL
1677             }
1678             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
1679                 CL_KERNEL_ARG_ADDRESS_GLOBAL
1680             }
1681         }
1682     }
1683 
type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier1684     pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
1685         let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier;
1686         let zero = clc_kernel_arg_type_qualifier(0);
1687         let mut res = CL_KERNEL_ARG_TYPE_NONE;
1688 
1689         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_CONST != zero {
1690             res |= CL_KERNEL_ARG_TYPE_CONST;
1691         }
1692 
1693         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_RESTRICT != zero {
1694             res |= CL_KERNEL_ARG_TYPE_RESTRICT;
1695         }
1696 
1697         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_VOLATILE != zero {
1698             res |= CL_KERNEL_ARG_TYPE_VOLATILE;
1699         }
1700 
1701         res.into()
1702     }
1703 
work_group_size(&self) -> [usize; 3]1704     pub fn work_group_size(&self) -> [usize; 3] {
1705         self.kernel_info.work_group_size
1706     }
1707 
num_subgroups(&self) -> usize1708     pub fn num_subgroups(&self) -> usize {
1709         self.kernel_info.num_subgroups
1710     }
1711 
subgroup_size(&self) -> usize1712     pub fn subgroup_size(&self) -> usize {
1713         self.kernel_info.subgroup_size
1714     }
1715 
arg_name(&self, idx: cl_uint) -> Option<&CStr>1716     pub fn arg_name(&self, idx: cl_uint) -> Option<&CStr> {
1717         let name = &self.kernel_info.args[idx as usize].spirv.name;
1718         name.is_empty().not().then_some(name)
1719     }
1720 
arg_type_name(&self, idx: cl_uint) -> Option<&CStr>1721     pub fn arg_type_name(&self, idx: cl_uint) -> Option<&CStr> {
1722         let type_name = &self.kernel_info.args[idx as usize].spirv.type_name;
1723         type_name.is_empty().not().then_some(type_name)
1724     }
1725 
priv_mem_size(&self, dev: &Device) -> cl_ulong1726     pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
1727         self.builds.get(dev).unwrap().info.private_memory as cl_ulong
1728     }
1729 
max_threads_per_block(&self, dev: &Device) -> usize1730     pub fn max_threads_per_block(&self, dev: &Device) -> usize {
1731         self.builds.get(dev).unwrap().info.max_threads as usize
1732     }
1733 
preferred_simd_size(&self, dev: &Device) -> usize1734     pub fn preferred_simd_size(&self, dev: &Device) -> usize {
1735         self.builds.get(dev).unwrap().info.preferred_simd_size as usize
1736     }
1737 
local_mem_size(&self, dev: &Device) -> cl_ulong1738     pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
1739         // TODO: take alignment into account?
1740         // this is purely informational so it shouldn't even matter
1741         let local =
1742             self.builds.get(dev).unwrap()[NirKernelVariant::Default].shared_size as cl_ulong;
1743         let args: cl_ulong = self
1744             .arg_values()
1745             .iter()
1746             .map(|arg| match arg {
1747                 Some(KernelArgValue::LocalMem(val)) => *val as cl_ulong,
1748                 // If the local memory size, for any pointer argument to the kernel declared with
1749                 // the __local address qualifier, is not specified, its size is assumed to be 0.
1750                 _ => 0,
1751             })
1752             .sum();
1753 
1754         local + args
1755     }
1756 
has_svm_devs(&self) -> bool1757     pub fn has_svm_devs(&self) -> bool {
1758         self.prog.devs.iter().any(|dev| dev.svm_supported())
1759     }
1760 
subgroup_sizes(&self, dev: &Device) -> Vec<usize>1761     pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
1762         SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
1763             .map(|bit| 1 << bit)
1764             .collect()
1765     }
1766 
subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize1767     pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1768         let subgroup_size = self.subgroup_size_for_block(dev, block);
1769         if subgroup_size == 0 {
1770             return 0;
1771         }
1772 
1773         let threads: usize = block.iter().product();
1774         threads.div_ceil(subgroup_size)
1775     }
1776 
subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize1777     pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1778         let subgroup_sizes = self.subgroup_sizes(dev);
1779         if subgroup_sizes.is_empty() {
1780             return 0;
1781         }
1782 
1783         if subgroup_sizes.len() == 1 {
1784             return subgroup_sizes[0];
1785         }
1786 
1787         let block = [
1788             *block.first().unwrap_or(&1) as u32,
1789             *block.get(1).unwrap_or(&1) as u32,
1790             *block.get(2).unwrap_or(&1) as u32,
1791         ];
1792 
1793         // TODO: this _might_ bite us somewhere, but I think it probably doesn't matter
1794         match &self.builds.get(dev).unwrap()[NirKernelVariant::Default].nir_or_cso {
1795             KernelDevStateVariant::Cso(cso) => {
1796                 dev.helper_ctx()
1797                     .compute_state_subgroup_size(cso.cso_ptr, &block) as usize
1798             }
1799             _ => {
1800                 panic!()
1801             }
1802         }
1803     }
1804 }
1805 
1806 impl Clone for Kernel {
clone(&self) -> Self1807     fn clone(&self) -> Self {
1808         Self {
1809             base: CLObjectBase::new(RusticlTypes::Kernel),
1810             prog: Arc::clone(&self.prog),
1811             name: self.name.clone(),
1812             values: Mutex::new(self.arg_values().clone()),
1813             builds: self.builds.clone(),
1814             kernel_info: Arc::clone(&self.kernel_info),
1815         }
1816     }
1817 }
1818