• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::api::icd::*;
2 use crate::core::device::*;
3 use crate::core::event::*;
4 use crate::core::memory::*;
5 use crate::core::program::*;
6 use crate::core::queue::*;
7 use crate::impl_cl_type_trait;
8 
9 use mesa_rust::compiler::clc::*;
10 use mesa_rust::compiler::nir::*;
11 use mesa_rust::nir_pass;
12 use mesa_rust::pipe::context::RWFlags;
13 use mesa_rust::pipe::context::ResourceMapType;
14 use mesa_rust::pipe::resource::*;
15 use mesa_rust::pipe::screen::ResourceType;
16 use mesa_rust_gen::*;
17 use mesa_rust_util::math::*;
18 use mesa_rust_util::serialize::*;
19 use rusticl_opencl_gen::*;
20 
21 use std::cmp;
22 use std::collections::HashMap;
23 use std::convert::TryInto;
24 use std::os::raw::c_void;
25 use std::ptr;
26 use std::slice;
27 use std::sync::Arc;
28 use std::sync::Mutex;
29 use std::sync::MutexGuard;
30 
31 // ugh, we are not allowed to take refs, so...
32 #[derive(Clone)]
33 pub enum KernelArgValue {
34     None,
35     Buffer(Arc<Buffer>),
36     Constant(Vec<u8>),
37     Image(Arc<Image>),
38     LocalMem(usize),
39     Sampler(Arc<Sampler>),
40 }
41 
42 #[derive(Hash, PartialEq, Eq, Clone, Copy)]
43 pub enum KernelArgType {
44     Constant = 0, // for anything passed by value
45     Image = 1,
46     RWImage = 2,
47     Sampler = 3,
48     Texture = 4,
49     MemGlobal = 5,
50     MemConstant = 6,
51     MemLocal = 7,
52 }
53 
54 #[derive(Hash, PartialEq, Eq, Clone)]
55 pub enum InternalKernelArgType {
56     ConstantBuffer,
57     GlobalWorkOffsets,
58     PrintfBuffer,
59     InlineSampler((cl_addressing_mode, cl_filter_mode, bool)),
60     FormatArray,
61     OrderArray,
62     WorkDim,
63 }
64 
65 #[derive(Hash, PartialEq, Eq, Clone)]
66 pub struct KernelArg {
67     spirv: spirv::SPIRVKernelArg,
68     pub kind: KernelArgType,
69     pub size: usize,
70     /// The offset into the input buffer
71     pub offset: usize,
72     /// The actual binding slot
73     pub binding: u32,
74     pub dead: bool,
75 }
76 
77 #[derive(Hash, PartialEq, Eq, Clone)]
78 pub struct InternalKernelArg {
79     pub kind: InternalKernelArgType,
80     pub size: usize,
81     pub offset: usize,
82 }
83 
84 impl KernelArg {
from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self>85     fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
86         let nir_arg_map: HashMap<_, _> = nir
87             .variables_with_mode(
88                 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
89             )
90             .map(|v| (v.data.location, v))
91             .collect();
92         let mut res = Vec::new();
93 
94         for (i, s) in spirv.iter().enumerate() {
95             let nir = nir_arg_map.get(&(i as i32)).unwrap();
96             let kind = match s.address_qualifier {
97                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
98                     if unsafe { glsl_type_is_sampler(nir.type_) } {
99                         KernelArgType::Sampler
100                     } else {
101                         KernelArgType::Constant
102                     }
103                 }
104                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
105                     KernelArgType::MemConstant
106                 }
107                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
108                     KernelArgType::MemLocal
109                 }
110                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
111                     if unsafe { glsl_type_is_image(nir.type_) } {
112                         let access = nir.data.access();
113                         if access == gl_access_qualifier::ACCESS_NON_WRITEABLE.0 {
114                             KernelArgType::Texture
115                         } else if access == gl_access_qualifier::ACCESS_NON_READABLE.0 {
116                             KernelArgType::Image
117                         } else {
118                             KernelArgType::RWImage
119                         }
120                     } else {
121                         KernelArgType::MemGlobal
122                     }
123                 }
124             };
125 
126             res.push(Self {
127                 spirv: s.clone(),
128                 size: unsafe { glsl_get_cl_size(nir.type_) } as usize,
129                 // we'll update it later in the 2nd pass
130                 kind: kind,
131                 offset: 0,
132                 binding: 0,
133                 dead: true,
134             });
135         }
136         res
137     }
138 
assign_locations( args: &mut [Self], internal_args: &mut [InternalKernelArg], nir: &mut NirShader, )139     fn assign_locations(
140         args: &mut [Self],
141         internal_args: &mut [InternalKernelArg],
142         nir: &mut NirShader,
143     ) {
144         for var in nir.variables_with_mode(
145             nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
146         ) {
147             if let Some(arg) = args.get_mut(var.data.location as usize) {
148                 arg.offset = var.data.driver_location as usize;
149                 arg.binding = var.data.binding;
150                 arg.dead = false;
151             } else {
152                 internal_args
153                     .get_mut(var.data.location as usize - args.len())
154                     .unwrap()
155                     .offset = var.data.driver_location as usize;
156             }
157         }
158     }
159 
serialize(&self) -> Vec<u8>160     fn serialize(&self) -> Vec<u8> {
161         let mut bin = Vec::new();
162 
163         bin.append(&mut self.spirv.serialize());
164         bin.extend_from_slice(&self.size.to_ne_bytes());
165         bin.extend_from_slice(&self.offset.to_ne_bytes());
166         bin.extend_from_slice(&self.binding.to_ne_bytes());
167         bin.extend_from_slice(&(self.dead as u8).to_ne_bytes());
168         bin.extend_from_slice(&(self.kind as u8).to_ne_bytes());
169 
170         bin
171     }
172 
deserialize(bin: &mut &[u8]) -> Option<Self>173     fn deserialize(bin: &mut &[u8]) -> Option<Self> {
174         let spirv = spirv::SPIRVKernelArg::deserialize(bin)?;
175         let size = read_ne_usize(bin);
176         let offset = read_ne_usize(bin);
177         let binding = read_ne_u32(bin);
178         let dead = read_ne_u8(bin) == 1;
179 
180         let kind = match read_ne_u8(bin) {
181             0 => KernelArgType::Constant,
182             1 => KernelArgType::Image,
183             2 => KernelArgType::RWImage,
184             3 => KernelArgType::Sampler,
185             4 => KernelArgType::Texture,
186             5 => KernelArgType::MemGlobal,
187             6 => KernelArgType::MemConstant,
188             7 => KernelArgType::MemLocal,
189             _ => return None,
190         };
191 
192         Some(Self {
193             spirv: spirv,
194             kind: kind,
195             size: size,
196             offset: offset,
197             binding: binding,
198             dead: dead,
199         })
200     }
201 }
202 
203 impl InternalKernelArg {
serialize(&self) -> Vec<u8>204     fn serialize(&self) -> Vec<u8> {
205         let mut bin = Vec::new();
206 
207         bin.extend_from_slice(&self.size.to_ne_bytes());
208         bin.extend_from_slice(&self.offset.to_ne_bytes());
209 
210         match self.kind {
211             InternalKernelArgType::ConstantBuffer => bin.push(0),
212             InternalKernelArgType::GlobalWorkOffsets => bin.push(1),
213             InternalKernelArgType::PrintfBuffer => bin.push(2),
214             InternalKernelArgType::InlineSampler((addr_mode, filter_mode, norm)) => {
215                 bin.push(3);
216                 bin.extend_from_slice(&addr_mode.to_ne_bytes());
217                 bin.extend_from_slice(&filter_mode.to_ne_bytes());
218                 bin.push(norm as u8);
219             }
220             InternalKernelArgType::FormatArray => bin.push(4),
221             InternalKernelArgType::OrderArray => bin.push(5),
222             InternalKernelArgType::WorkDim => bin.push(6),
223         }
224 
225         bin
226     }
227 
deserialize(bin: &mut &[u8]) -> Option<Self>228     fn deserialize(bin: &mut &[u8]) -> Option<Self> {
229         let size = read_ne_usize(bin);
230         let offset = read_ne_usize(bin);
231 
232         let kind = match read_ne_u8(bin) {
233             0 => InternalKernelArgType::ConstantBuffer,
234             1 => InternalKernelArgType::GlobalWorkOffsets,
235             2 => InternalKernelArgType::PrintfBuffer,
236             3 => {
237                 let addr_mode = read_ne_u32(bin);
238                 let filter_mode = read_ne_u32(bin);
239                 let norm = read_ne_u8(bin) == 1;
240                 InternalKernelArgType::InlineSampler((addr_mode, filter_mode, norm))
241             }
242             4 => InternalKernelArgType::FormatArray,
243             5 => InternalKernelArgType::OrderArray,
244             6 => InternalKernelArgType::WorkDim,
245             _ => return None,
246         };
247 
248         Some(Self {
249             kind: kind,
250             size: size,
251             offset: offset,
252         })
253     }
254 }
255 
256 #[derive(Clone, PartialEq, Eq, Hash)]
257 pub struct KernelInfo {
258     pub args: Vec<KernelArg>,
259     pub internal_args: Vec<InternalKernelArg>,
260     pub attributes_string: String,
261     pub work_group_size: [usize; 3],
262     pub subgroup_size: usize,
263     pub num_subgroups: usize,
264 }
265 
266 pub struct CSOWrapper {
267     pub cso_ptr: *mut c_void,
268     dev: &'static Device,
269 }
270 
271 impl CSOWrapper {
new(dev: &'static Device, nir: &NirShader) -> Self272     pub fn new(dev: &'static Device, nir: &NirShader) -> Self {
273         let cso_ptr = dev
274             .helper_ctx()
275             .create_compute_state(nir, nir.shared_size());
276 
277         Self {
278             cso_ptr: cso_ptr,
279             dev: dev,
280         }
281     }
282 
get_cso_info(&self) -> pipe_compute_state_object_info283     pub fn get_cso_info(&self) -> pipe_compute_state_object_info {
284         self.dev.helper_ctx().compute_state_info(self.cso_ptr)
285     }
286 }
287 
288 impl Drop for CSOWrapper {
drop(&mut self)289     fn drop(&mut self) {
290         self.dev.helper_ctx().delete_compute_state(self.cso_ptr);
291     }
292 }
293 
294 pub enum KernelDevStateVariant {
295     Cso(CSOWrapper),
296     Nir(NirShader),
297 }
298 
299 pub struct Kernel {
300     pub base: CLObjectBase<CL_INVALID_KERNEL>,
301     pub prog: Arc<Program>,
302     pub name: String,
303     values: Mutex<Vec<Option<KernelArgValue>>>,
304     builds: HashMap<&'static Device, Arc<NirKernelBuild>>,
305     pub kernel_info: KernelInfo,
306 }
307 
308 impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
309 
create_kernel_arr<T>(vals: &[usize], val: T) -> [T; 3] where T: std::convert::TryFrom<usize> + Copy, <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,310 fn create_kernel_arr<T>(vals: &[usize], val: T) -> [T; 3]
311 where
312     T: std::convert::TryFrom<usize> + Copy,
313     <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,
314 {
315     let mut res = [val; 3];
316     for (i, v) in vals.iter().enumerate() {
317         res[i] = (*v).try_into().expect("64 bit work groups not supported");
318     }
319     res
320 }
321 
opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool)322 fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
323     let nir_options = unsafe {
324         &*dev
325             .screen
326             .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
327     };
328 
329     while {
330         let mut progress = false;
331 
332         progress |= nir_pass!(nir, nir_copy_prop);
333         progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
334         progress |= nir_pass!(nir, nir_opt_dead_write_vars);
335 
336         if nir_options.lower_to_scalar {
337             nir_pass!(
338                 nir,
339                 nir_lower_alu_to_scalar,
340                 nir_options.lower_to_scalar_filter,
341                 ptr::null(),
342             );
343             nir_pass!(nir, nir_lower_phis_to_scalar, false);
344         }
345 
346         progress |= nir_pass!(nir, nir_opt_deref);
347         if has_explicit_types {
348             progress |= nir_pass!(nir, nir_opt_memcpy);
349         }
350         progress |= nir_pass!(nir, nir_opt_dce);
351         progress |= nir_pass!(nir, nir_opt_undef);
352         progress |= nir_pass!(nir, nir_opt_constant_folding);
353         progress |= nir_pass!(nir, nir_opt_cse);
354         nir_pass!(nir, nir_split_var_copies);
355         progress |= nir_pass!(nir, nir_lower_var_copies);
356         progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
357         nir_pass!(nir, nir_lower_alu);
358         progress |= nir_pass!(nir, nir_opt_phi_precision);
359         progress |= nir_pass!(nir, nir_opt_algebraic);
360         progress |= nir_pass!(
361             nir,
362             nir_opt_if,
363             nir_opt_if_options::nir_opt_if_optimize_phi_true_false,
364         );
365         progress |= nir_pass!(nir, nir_opt_dead_cf);
366         progress |= nir_pass!(nir, nir_opt_remove_phis);
367         // we don't want to be too aggressive here, but it kills a bit of CFG
368         progress |= nir_pass!(nir, nir_opt_peephole_select, 8, true, true);
369         progress |= nir_pass!(
370             nir,
371             nir_lower_vec3_to_vec4,
372             nir_variable_mode::nir_var_mem_generic | nir_variable_mode::nir_var_uniform,
373         );
374 
375         if nir_options.max_unroll_iterations != 0 {
376             progress |= nir_pass!(nir, nir_opt_loop_unroll);
377         }
378         nir.sweep_mem();
379         progress
380     } {}
381 }
382 
383 /// # Safety
384 ///
385 /// Only safe to call when `var` is a valid pointer to a valid [`nir_variable`]
can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool386 unsafe extern "C" fn can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool {
387     // SAFETY: It is the caller's responsibility to provide a valid and aligned pointer
388     let var_type = unsafe { (*var).type_ };
389     // SAFETY: `nir_variable`'s type invariant guarantees that the `type_` field is valid and
390     // properly aligned.
391     unsafe {
392         !glsl_type_is_image(var_type)
393             && !glsl_type_is_texture(var_type)
394             && !glsl_type_is_sampler(var_type)
395     }
396 }
397 
lower_and_optimize_nir( dev: &Device, nir: &mut NirShader, args: &[spirv::SPIRVKernelArg], lib_clc: &NirShader, ) -> (Vec<KernelArg>, Vec<InternalKernelArg>)398 fn lower_and_optimize_nir(
399     dev: &Device,
400     nir: &mut NirShader,
401     args: &[spirv::SPIRVKernelArg],
402     lib_clc: &NirShader,
403 ) -> (Vec<KernelArg>, Vec<InternalKernelArg>) {
404     let address_bits_base_type;
405     let address_bits_ptr_type;
406     let global_address_format;
407     let shared_address_format;
408 
409     if dev.address_bits() == 64 {
410         address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
411         address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
412         global_address_format = nir_address_format::nir_address_format_64bit_global;
413         shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
414     } else {
415         address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
416         address_bits_ptr_type = unsafe { glsl_uint_type() };
417         global_address_format = nir_address_format::nir_address_format_32bit_global;
418         shared_address_format = nir_address_format::nir_address_format_32bit_offset;
419     }
420 
421     let mut lower_state = rusticl_lower_state::default();
422     let nir_options = unsafe {
423         &*dev
424             .screen
425             .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
426     };
427 
428     nir_pass!(nir, nir_scale_fdiv);
429     nir.set_workgroup_size_variable_if_zero();
430     nir.structurize();
431     while {
432         let mut progress = false;
433         nir_pass!(nir, nir_split_var_copies);
434         progress |= nir_pass!(nir, nir_copy_prop);
435         progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
436         progress |= nir_pass!(nir, nir_opt_dead_write_vars);
437         progress |= nir_pass!(nir, nir_opt_deref);
438         progress |= nir_pass!(nir, nir_opt_dce);
439         progress |= nir_pass!(nir, nir_opt_undef);
440         progress |= nir_pass!(nir, nir_opt_constant_folding);
441         progress |= nir_pass!(nir, nir_opt_cse);
442         progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
443         progress |= nir_pass!(nir, nir_opt_algebraic);
444         progress
445     } {}
446     nir.inline(lib_clc);
447     nir.cleanup_functions();
448     // that should free up tons of memory
449     nir.sweep_mem();
450 
451     nir_pass!(nir, nir_dedup_inline_samplers);
452 
453     let printf_opts = nir_lower_printf_options {
454         max_buffer_size: dev.printf_buffer_size() as u32,
455     };
456     nir_pass!(nir, nir_lower_printf, &printf_opts);
457 
458     opt_nir(nir, dev, false);
459 
460     let mut args = KernelArg::from_spirv_nir(args, nir);
461     let mut internal_args = Vec::new();
462 
463     let dv_opts = nir_remove_dead_variables_options {
464         can_remove_var: Some(can_remove_var),
465         can_remove_var_data: ptr::null_mut(),
466     };
467     nir_pass!(
468         nir,
469         nir_remove_dead_variables,
470         nir_variable_mode::nir_var_uniform
471             | nir_variable_mode::nir_var_image
472             | nir_variable_mode::nir_var_mem_constant
473             | nir_variable_mode::nir_var_mem_shared
474             | nir_variable_mode::nir_var_function_temp,
475         &dv_opts,
476     );
477 
478     // asign locations for inline samplers
479     let mut last_loc = -1;
480     for v in nir
481         .variables_with_mode(nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image)
482     {
483         if unsafe { !glsl_type_is_sampler(v.type_) } {
484             last_loc = v.data.location;
485             continue;
486         }
487         let s = unsafe { v.data.anon_1.sampler };
488         if s.is_inline_sampler() != 0 {
489             last_loc += 1;
490             v.data.location = last_loc;
491 
492             internal_args.push(InternalKernelArg {
493                 kind: InternalKernelArgType::InlineSampler(Sampler::nir_to_cl(
494                     s.addressing_mode(),
495                     s.filter_mode(),
496                     s.normalized_coordinates(),
497                 )),
498                 offset: 0,
499                 size: 0,
500             });
501         } else {
502             last_loc = v.data.location;
503         }
504     }
505 
506     nir_pass!(nir, nir_lower_readonly_images_to_tex, true);
507     nir_pass!(
508         nir,
509         nir_lower_cl_images,
510         !dev.images_as_deref(),
511         !dev.samplers_as_deref(),
512     );
513 
514     nir_pass!(
515         nir,
516         nir_lower_vars_to_explicit_types,
517         nir_variable_mode::nir_var_mem_constant,
518         Some(glsl_get_cl_type_size_align),
519     );
520 
521     // has to run before adding internal kernel arguments
522     nir.extract_constant_initializers();
523 
524     // run before gather info
525     nir_pass!(nir, nir_lower_system_values);
526     let mut compute_options = nir_lower_compute_system_values_options::default();
527     compute_options.set_has_base_global_invocation_id(true);
528     nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
529     nir.gather_info();
530 
531     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
532         internal_args.push(InternalKernelArg {
533             kind: InternalKernelArgType::GlobalWorkOffsets,
534             offset: 0,
535             size: (3 * dev.address_bits() / 8) as usize,
536         });
537         lower_state.base_global_invoc_id_loc = args.len() + internal_args.len() - 1;
538         nir.add_var(
539             nir_variable_mode::nir_var_uniform,
540             unsafe { glsl_vector_type(address_bits_base_type, 3) },
541             lower_state.base_global_invoc_id_loc,
542             "base_global_invocation_id",
543         );
544     }
545 
546     if nir.has_constant() {
547         internal_args.push(InternalKernelArg {
548             kind: InternalKernelArgType::ConstantBuffer,
549             offset: 0,
550             size: (dev.address_bits() / 8) as usize,
551         });
552         lower_state.const_buf_loc = args.len() + internal_args.len() - 1;
553         nir.add_var(
554             nir_variable_mode::nir_var_uniform,
555             address_bits_ptr_type,
556             lower_state.const_buf_loc,
557             "constant_buffer_addr",
558         );
559     }
560     if nir.has_printf() {
561         internal_args.push(InternalKernelArg {
562             kind: InternalKernelArgType::PrintfBuffer,
563             offset: 0,
564             size: (dev.address_bits() / 8) as usize,
565         });
566         lower_state.printf_buf_loc = args.len() + internal_args.len() - 1;
567         nir.add_var(
568             nir_variable_mode::nir_var_uniform,
569             address_bits_ptr_type,
570             lower_state.printf_buf_loc,
571             "printf_buffer_addr",
572         );
573     }
574 
575     if nir.num_images() > 0 || nir.num_textures() > 0 {
576         let count = nir.num_images() + nir.num_textures();
577         internal_args.push(InternalKernelArg {
578             kind: InternalKernelArgType::FormatArray,
579             offset: 0,
580             size: 2 * count as usize,
581         });
582 
583         internal_args.push(InternalKernelArg {
584             kind: InternalKernelArgType::OrderArray,
585             offset: 0,
586             size: 2 * count as usize,
587         });
588 
589         lower_state.format_arr_loc = args.len() + internal_args.len() - 2;
590         nir.add_var(
591             nir_variable_mode::nir_var_uniform,
592             unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
593             lower_state.format_arr_loc,
594             "image_formats",
595         );
596 
597         lower_state.order_arr_loc = args.len() + internal_args.len() - 1;
598         nir.add_var(
599             nir_variable_mode::nir_var_uniform,
600             unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
601             lower_state.order_arr_loc,
602             "image_orders",
603         );
604     }
605 
606     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
607         internal_args.push(InternalKernelArg {
608             kind: InternalKernelArgType::WorkDim,
609             size: 1,
610             offset: 0,
611         });
612         lower_state.work_dim_loc = args.len() + internal_args.len() - 1;
613         nir.add_var(
614             nir_variable_mode::nir_var_uniform,
615             unsafe { glsl_uint8_t_type() },
616             lower_state.work_dim_loc,
617             "work_dim",
618         );
619     }
620 
621     // need to run after first opt loop and remove_dead_variables to get rid of uneccessary scratch
622     // memory
623     nir_pass!(
624         nir,
625         nir_lower_vars_to_explicit_types,
626         nir_variable_mode::nir_var_mem_shared
627             | nir_variable_mode::nir_var_function_temp
628             | nir_variable_mode::nir_var_shader_temp
629             | nir_variable_mode::nir_var_uniform
630             | nir_variable_mode::nir_var_mem_global
631             | nir_variable_mode::nir_var_mem_generic,
632         Some(glsl_get_cl_type_size_align),
633     );
634 
635     opt_nir(nir, dev, true);
636     nir_pass!(nir, nir_lower_memcpy);
637 
638     // we might have got rid of more function_temp or shared memory
639     nir.reset_scratch_size();
640     nir.reset_shared_size();
641     nir_pass!(
642         nir,
643         nir_remove_dead_variables,
644         nir_variable_mode::nir_var_function_temp | nir_variable_mode::nir_var_mem_shared,
645         &dv_opts,
646     );
647     nir_pass!(
648         nir,
649         nir_lower_vars_to_explicit_types,
650         nir_variable_mode::nir_var_function_temp
651             | nir_variable_mode::nir_var_mem_shared
652             | nir_variable_mode::nir_var_mem_generic,
653         Some(glsl_get_cl_type_size_align),
654     );
655 
656     nir_pass!(
657         nir,
658         nir_lower_explicit_io,
659         nir_variable_mode::nir_var_mem_global | nir_variable_mode::nir_var_mem_constant,
660         global_address_format,
661     );
662 
663     nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
664     nir_pass!(
665         nir,
666         nir_lower_explicit_io,
667         nir_variable_mode::nir_var_mem_shared
668             | nir_variable_mode::nir_var_function_temp
669             | nir_variable_mode::nir_var_uniform,
670         shared_address_format,
671     );
672 
673     if nir_options.lower_int64_options.0 != 0 {
674         nir_pass!(nir, nir_lower_int64);
675     }
676 
677     if nir_options.lower_uniforms_to_ubo {
678         nir_pass!(nir, rusticl_lower_inputs);
679     }
680 
681     nir_pass!(nir, nir_lower_convert_alu_types, None);
682 
683     opt_nir(nir, dev, true);
684 
685     /* before passing it into drivers, assign locations as drivers might remove nir_variables or
686      * other things we depend on
687      */
688     KernelArg::assign_locations(&mut args, &mut internal_args, nir);
689 
690     /* update the has_variable_shared_mem info as we might have DCEed all of them */
691     nir.set_has_variable_shared_mem(
692         args.iter()
693             .any(|arg| arg.kind == KernelArgType::MemLocal && !arg.dead),
694     );
695     dev.screen.finalize_nir(nir);
696 
697     nir_pass!(nir, nir_opt_dce);
698     nir.sweep_mem();
699 
700     (args, internal_args)
701 }
702 
deserialize_nir( bin: &mut &[u8], d: &Device, ) -> Option<(NirShader, Vec<KernelArg>, Vec<InternalKernelArg>)>703 fn deserialize_nir(
704     bin: &mut &[u8],
705     d: &Device,
706 ) -> Option<(NirShader, Vec<KernelArg>, Vec<InternalKernelArg>)> {
707     let nir_len = read_ne_usize(bin);
708 
709     let nir = NirShader::deserialize(
710         bin,
711         nir_len,
712         d.screen()
713             .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
714     )?;
715 
716     let arg_len = read_ne_usize(bin);
717     let mut args = Vec::with_capacity(arg_len);
718     for _ in 0..arg_len {
719         args.push(KernelArg::deserialize(bin)?);
720     }
721 
722     let arg_len = read_ne_usize(bin);
723     let mut internal_args = Vec::with_capacity(arg_len);
724     for _ in 0..arg_len {
725         internal_args.push(InternalKernelArg::deserialize(bin)?);
726     }
727 
728     assert!(bin.is_empty());
729 
730     Some((nir, args, internal_args))
731 }
732 
convert_spirv_to_nir( build: &ProgramBuild, name: &str, args: &[spirv::SPIRVKernelArg], dev: &Device, ) -> (KernelInfo, NirShader)733 pub(super) fn convert_spirv_to_nir(
734     build: &ProgramBuild,
735     name: &str,
736     args: &[spirv::SPIRVKernelArg],
737     dev: &Device,
738 ) -> (KernelInfo, NirShader) {
739     let cache = dev.screen().shader_cache();
740     let key = build.hash_key(dev, name);
741 
742     let res = if let Some(cache) = &cache {
743         cache.get(&mut key.unwrap()).and_then(|entry| {
744             let mut bin: &[u8] = &entry;
745             deserialize_nir(&mut bin, dev)
746         })
747     } else {
748         None
749     };
750 
751     let (nir, args, internal_args) = if let Some(res) = res {
752         res
753     } else {
754         let mut nir = build.to_nir(name, dev);
755 
756         /* this is a hack until we support fp16 properly and check for denorms inside
757          * vstore/vload_half
758          */
759         nir.preserve_fp16_denorms();
760 
761         // Set to rtne for now until drivers are able to report their prefered rounding mode, that
762         // also matches what we report via the API.
763         nir.set_fp_rounding_mode_rtne();
764 
765         let (args, internal_args) = lower_and_optimize_nir(dev, &mut nir, args, &dev.lib_clc);
766 
767         if let Some(cache) = cache {
768             let mut bin = Vec::new();
769             let mut nir = nir.serialize();
770 
771             bin.extend_from_slice(&nir.len().to_ne_bytes());
772             bin.append(&mut nir);
773 
774             bin.extend_from_slice(&args.len().to_ne_bytes());
775             for arg in &args {
776                 bin.append(&mut arg.serialize());
777             }
778 
779             bin.extend_from_slice(&internal_args.len().to_ne_bytes());
780             for arg in &internal_args {
781                 bin.append(&mut arg.serialize());
782             }
783 
784             cache.put(&bin, &mut key.unwrap());
785         }
786 
787         (nir, args, internal_args)
788     };
789 
790     let attributes_string = build.attribute_str(name, dev);
791     let wgs = nir.workgroup_size();
792     let kernel_info = KernelInfo {
793         args: args,
794         internal_args: internal_args,
795         attributes_string: attributes_string,
796         work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
797         subgroup_size: nir.subgroup_size() as usize,
798         num_subgroups: nir.num_subgroups() as usize,
799     };
800 
801     (kernel_info, nir)
802 }
803 
extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S]804 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
805     let val;
806     (val, *buf) = (*buf).split_at(S);
807     // we split of 4 bytes and convert to [u8; 4], so this should be safe
808     // use split_array_ref once it's stable
809     val.try_into().unwrap()
810 }
811 
812 impl Kernel {
new(name: String, prog: Arc<Program>) -> Arc<Kernel>813     pub fn new(name: String, prog: Arc<Program>) -> Arc<Kernel> {
814         let prog_build = prog.build_info();
815         let kernel_info = prog_build.kernel_info.get(&name).unwrap().clone();
816         let builds = prog_build
817             .builds
818             .iter()
819             .filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, k.clone())))
820             .collect();
821 
822         let values = vec![None; kernel_info.args.len()];
823         Arc::new(Self {
824             base: CLObjectBase::new(RusticlTypes::Kernel),
825             prog: prog.clone(),
826             name: name,
827             values: Mutex::new(values),
828             builds: builds,
829             kernel_info: kernel_info,
830         })
831     }
832 
optimize_local_size(&self, d: &Device, grid: &mut [u32; 3], block: &mut [u32; 3])833     fn optimize_local_size(&self, d: &Device, grid: &mut [u32; 3], block: &mut [u32; 3]) {
834         let mut threads = self.max_threads_per_block(d) as u32;
835         let dim_threads = d.max_block_sizes();
836         let subgroups = self.preferred_simd_size(d) as u32;
837 
838         if !block.contains(&0) {
839             for i in 0..3 {
840                 // we already made sure everything is fine
841                 grid[i] /= block[i];
842             }
843             return;
844         }
845 
846         for i in 0..3 {
847             let t = cmp::min(threads, dim_threads[i] as u32);
848             let gcd = gcd(t, grid[i]);
849 
850             block[i] = gcd;
851             grid[i] /= gcd;
852 
853             // update limits
854             threads /= block[i];
855         }
856 
857         // if we didn't fill the subgroup we can do a bit better if we have threads remaining
858         let total_threads = block[0] * block[1] * block[2];
859         if threads != 1 && total_threads < subgroups {
860             for i in 0..3 {
861                 if grid[i] * total_threads < threads {
862                     block[i] *= grid[i];
863                     grid[i] = 1;
864                     // can only do it once as nothing is cleanly divisible
865                     break;
866                 }
867             }
868         }
869     }
870 
871     // the painful part is, that host threads are allowed to modify the kernel object once it was
872     // enqueued, so return a closure with all req data included.
launch( self: &Arc<Self>, q: &Arc<Queue>, work_dim: u32, block: &[usize], grid: &[usize], offsets: &[usize], ) -> CLResult<EventSig>873     pub fn launch(
874         self: &Arc<Self>,
875         q: &Arc<Queue>,
876         work_dim: u32,
877         block: &[usize],
878         grid: &[usize],
879         offsets: &[usize],
880     ) -> CLResult<EventSig> {
881         let nir_kernel_build = self.builds.get(q.device).unwrap().clone();
882         let mut block = create_kernel_arr::<u32>(block, 1);
883         let mut grid = create_kernel_arr::<u32>(grid, 1);
884         let offsets = create_kernel_arr::<u64>(offsets, 0);
885         let mut input: Vec<u8> = Vec::new();
886         let mut resource_info = Vec::new();
887         // Set it once so we get the alignment padding right
888         let static_local_size: u64 = nir_kernel_build.shared_size;
889         let mut variable_local_size: u64 = static_local_size;
890         let printf_size = q.device.printf_buffer_size() as u32;
891         let mut samplers = Vec::new();
892         let mut iviews = Vec::new();
893         let mut sviews = Vec::new();
894         let mut tex_formats: Vec<u16> = Vec::new();
895         let mut tex_orders: Vec<u16> = Vec::new();
896         let mut img_formats: Vec<u16> = Vec::new();
897         let mut img_orders: Vec<u16> = Vec::new();
898         let null_ptr: &[u8] = if q.device.address_bits() == 64 {
899             &[0; 8]
900         } else {
901             &[0; 4]
902         };
903 
904         self.optimize_local_size(q.device, &mut grid, &mut block);
905 
906         let arg_values = self.arg_values();
907         for (arg, val) in self.kernel_info.args.iter().zip(arg_values.iter()) {
908             if arg.dead {
909                 continue;
910             }
911 
912             if arg.kind != KernelArgType::Image
913                 && arg.kind != KernelArgType::RWImage
914                 && arg.kind != KernelArgType::Texture
915                 && arg.kind != KernelArgType::Sampler
916             {
917                 input.resize(arg.offset, 0);
918             }
919             match val.as_ref().unwrap() {
920                 KernelArgValue::Constant(c) => input.extend_from_slice(c),
921                 KernelArgValue::Buffer(buffer) => {
922                     let res = buffer.get_res_of_dev(q.device)?;
923                     if q.device.address_bits() == 64 {
924                         input.extend_from_slice(&buffer.offset.to_ne_bytes());
925                     } else {
926                         input.extend_from_slice(&(buffer.offset as u32).to_ne_bytes());
927                     }
928                     resource_info.push((res.clone(), arg.offset));
929                 }
930                 KernelArgValue::Image(image) => {
931                     let res = image.get_res_of_dev(q.device)?;
932 
933                     // If resource is a buffer, the image was created from a buffer. Use strides and
934                     // dimensions of the image then.
935                     let app_img_info =
936                         if res.as_ref().is_buffer() && image.mem_type == CL_MEM_OBJECT_IMAGE2D {
937                             Some(AppImgInfo::new(
938                                 image.image_desc.row_pitch()? / image.image_elem_size as u32,
939                                 image.image_desc.width()?,
940                                 image.image_desc.height()?,
941                             ))
942                         } else {
943                             None
944                         };
945 
946                     let format = image.pipe_format;
947                     let (formats, orders) = if arg.kind == KernelArgType::Image {
948                         iviews.push(res.pipe_image_view(
949                             format,
950                             false,
951                             image.pipe_image_host_access(),
952                             app_img_info.as_ref(),
953                         ));
954                         (&mut img_formats, &mut img_orders)
955                     } else if arg.kind == KernelArgType::RWImage {
956                         iviews.push(res.pipe_image_view(
957                             format,
958                             true,
959                             image.pipe_image_host_access(),
960                             app_img_info.as_ref(),
961                         ));
962                         (&mut img_formats, &mut img_orders)
963                     } else {
964                         sviews.push((res.clone(), format, app_img_info));
965                         (&mut tex_formats, &mut tex_orders)
966                     };
967 
968                     let binding = arg.binding as usize;
969                     assert!(binding >= formats.len());
970 
971                     formats.resize(binding, 0);
972                     orders.resize(binding, 0);
973 
974                     formats.push(image.image_format.image_channel_data_type as u16);
975                     orders.push(image.image_format.image_channel_order as u16);
976                 }
977                 KernelArgValue::LocalMem(size) => {
978                     // TODO 32 bit
979                     let pot = cmp::min(*size, 0x80);
980                     variable_local_size =
981                         align(variable_local_size, pot.next_power_of_two() as u64);
982                     if q.device.address_bits() == 64 {
983                         input.extend_from_slice(&variable_local_size.to_ne_bytes());
984                     } else {
985                         input.extend_from_slice(&(variable_local_size as u32).to_ne_bytes());
986                     }
987                     variable_local_size += *size as u64;
988                 }
989                 KernelArgValue::Sampler(sampler) => {
990                     samplers.push(sampler.pipe());
991                 }
992                 KernelArgValue::None => {
993                     assert!(
994                         arg.kind == KernelArgType::MemGlobal
995                             || arg.kind == KernelArgType::MemConstant
996                     );
997                     input.extend_from_slice(null_ptr);
998                 }
999             }
1000         }
1001 
1002         // subtract the shader local_size as we only request something on top of that.
1003         variable_local_size -= static_local_size;
1004 
1005         let mut printf_buf = None;
1006         for arg in &self.kernel_info.internal_args {
1007             if arg.offset > input.len() {
1008                 input.resize(arg.offset, 0);
1009             }
1010             match arg.kind {
1011                 InternalKernelArgType::ConstantBuffer => {
1012                     assert!(nir_kernel_build.constant_buffer.is_some());
1013                     input.extend_from_slice(null_ptr);
1014                     resource_info.push((
1015                         nir_kernel_build.constant_buffer.clone().unwrap(),
1016                         arg.offset,
1017                     ));
1018                 }
1019                 InternalKernelArgType::GlobalWorkOffsets => {
1020                     if q.device.address_bits() == 64 {
1021                         input.extend_from_slice(unsafe { as_byte_slice(&offsets) });
1022                     } else {
1023                         input.extend_from_slice(unsafe {
1024                             as_byte_slice(&[
1025                                 offsets[0] as u32,
1026                                 offsets[1] as u32,
1027                                 offsets[2] as u32,
1028                             ])
1029                         });
1030                     }
1031                 }
1032                 InternalKernelArgType::PrintfBuffer => {
1033                     let buf = Arc::new(
1034                         q.device
1035                             .screen
1036                             .resource_create_buffer(
1037                                 printf_size,
1038                                 ResourceType::Staging,
1039                                 PIPE_BIND_GLOBAL,
1040                             )
1041                             .unwrap(),
1042                     );
1043 
1044                     input.extend_from_slice(null_ptr);
1045                     resource_info.push((buf.clone(), arg.offset));
1046 
1047                     printf_buf = Some(buf);
1048                 }
1049                 InternalKernelArgType::InlineSampler(cl) => {
1050                     samplers.push(Sampler::cl_to_pipe(cl));
1051                 }
1052                 InternalKernelArgType::FormatArray => {
1053                     input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) });
1054                     input.extend_from_slice(unsafe { as_byte_slice(&img_formats) });
1055                 }
1056                 InternalKernelArgType::OrderArray => {
1057                     input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) });
1058                     input.extend_from_slice(unsafe { as_byte_slice(&img_orders) });
1059                 }
1060                 InternalKernelArgType::WorkDim => {
1061                     input.extend_from_slice(&[work_dim as u8; 1]);
1062                 }
1063             }
1064         }
1065 
1066         Ok(Box::new(move |q, ctx| {
1067             let mut input = input.clone();
1068             let mut resources = Vec::with_capacity(resource_info.len());
1069             let mut globals: Vec<*mut u32> = Vec::new();
1070             let printf_format = &nir_kernel_build.printf_info;
1071 
1072             let mut sviews: Vec<_> = sviews
1073                 .iter()
1074                 .map(|(s, f, aii)| ctx.create_sampler_view(s, *f, aii.as_ref()))
1075                 .collect();
1076             let samplers: Vec<_> = samplers
1077                 .iter()
1078                 .map(|s| ctx.create_sampler_state(s))
1079                 .collect();
1080 
1081             for (res, offset) in &resource_info {
1082                 resources.push(res);
1083                 globals.push(unsafe { input.as_mut_ptr().add(*offset) }.cast());
1084             }
1085 
1086             if let Some(printf_buf) = &printf_buf {
1087                 let init_data: [u8; 1] = [4];
1088                 ctx.buffer_subdata(
1089                     printf_buf,
1090                     0,
1091                     init_data.as_ptr().cast(),
1092                     init_data.len() as u32,
1093                 );
1094             }
1095 
1096             let temp_cso;
1097             let cso = match &nir_kernel_build.nir_or_cso {
1098                 KernelDevStateVariant::Cso(cso) => cso,
1099                 KernelDevStateVariant::Nir(nir) => {
1100                     temp_cso = CSOWrapper::new(q.device, nir);
1101                     &temp_cso
1102                 }
1103             };
1104 
1105             ctx.bind_compute_state(cso.cso_ptr);
1106             ctx.bind_sampler_states(&samplers);
1107             ctx.set_sampler_views(&mut sviews);
1108             ctx.set_shader_images(&iviews);
1109             ctx.set_global_binding(resources.as_slice(), &mut globals);
1110             ctx.update_cb0(&input);
1111 
1112             ctx.launch_grid(work_dim, block, grid, variable_local_size as u32);
1113 
1114             ctx.clear_global_binding(globals.len() as u32);
1115             ctx.clear_shader_images(iviews.len() as u32);
1116             ctx.clear_sampler_views(sviews.len() as u32);
1117             ctx.clear_sampler_states(samplers.len() as u32);
1118 
1119             ctx.bind_compute_state(ptr::null_mut());
1120 
1121             ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
1122 
1123             samplers.iter().for_each(|s| ctx.delete_sampler_state(*s));
1124             sviews.iter().for_each(|v| ctx.sampler_view_destroy(*v));
1125 
1126             if let Some(printf_buf) = &printf_buf {
1127                 let tx = ctx
1128                     .buffer_map(
1129                         printf_buf,
1130                         0,
1131                         printf_size as i32,
1132                         RWFlags::RD,
1133                         ResourceMapType::Normal,
1134                     )
1135                     .ok_or(CL_OUT_OF_RESOURCES)?
1136                     .with_ctx(ctx);
1137                 let mut buf: &[u8] =
1138                     unsafe { slice::from_raw_parts(tx.ptr().cast(), printf_size as usize) };
1139                 let length = u32::from_ne_bytes(*extract(&mut buf));
1140 
1141                 // update our slice to make sure we don't go out of bounds
1142                 buf = &buf[0..(length - 4) as usize];
1143                 if let Some(pf) = printf_format.as_ref() {
1144                     pf.u_printf(buf)
1145                 }
1146             }
1147 
1148             Ok(())
1149         }))
1150     }
1151 
arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>>1152     pub fn arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>> {
1153         self.values.lock().unwrap()
1154     }
1155 
set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()>1156     pub fn set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()> {
1157         self.values
1158             .lock()
1159             .unwrap()
1160             .get_mut(idx)
1161             .ok_or(CL_INVALID_ARG_INDEX)?
1162             .replace(arg);
1163         Ok(())
1164     }
1165 
access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier1166     pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
1167         let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier;
1168 
1169         if aq
1170             == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
1171                 | clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE
1172         {
1173             CL_KERNEL_ARG_ACCESS_READ_WRITE
1174         } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ {
1175             CL_KERNEL_ARG_ACCESS_READ_ONLY
1176         } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE {
1177             CL_KERNEL_ARG_ACCESS_WRITE_ONLY
1178         } else {
1179             CL_KERNEL_ARG_ACCESS_NONE
1180         }
1181     }
1182 
address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier1183     pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
1184         match self.kernel_info.args[idx as usize].spirv.address_qualifier {
1185             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
1186                 CL_KERNEL_ARG_ADDRESS_PRIVATE
1187             }
1188             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
1189                 CL_KERNEL_ARG_ADDRESS_CONSTANT
1190             }
1191             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
1192                 CL_KERNEL_ARG_ADDRESS_LOCAL
1193             }
1194             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
1195                 CL_KERNEL_ARG_ADDRESS_GLOBAL
1196             }
1197         }
1198     }
1199 
type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier1200     pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
1201         let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier;
1202         let zero = clc_kernel_arg_type_qualifier(0);
1203         let mut res = CL_KERNEL_ARG_TYPE_NONE;
1204 
1205         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_CONST != zero {
1206             res |= CL_KERNEL_ARG_TYPE_CONST;
1207         }
1208 
1209         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_RESTRICT != zero {
1210             res |= CL_KERNEL_ARG_TYPE_RESTRICT;
1211         }
1212 
1213         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_VOLATILE != zero {
1214             res |= CL_KERNEL_ARG_TYPE_VOLATILE;
1215         }
1216 
1217         res.into()
1218     }
1219 
work_group_size(&self) -> [usize; 3]1220     pub fn work_group_size(&self) -> [usize; 3] {
1221         self.kernel_info.work_group_size
1222     }
1223 
num_subgroups(&self) -> usize1224     pub fn num_subgroups(&self) -> usize {
1225         self.kernel_info.num_subgroups
1226     }
1227 
subgroup_size(&self) -> usize1228     pub fn subgroup_size(&self) -> usize {
1229         self.kernel_info.subgroup_size
1230     }
1231 
arg_name(&self, idx: cl_uint) -> &String1232     pub fn arg_name(&self, idx: cl_uint) -> &String {
1233         &self.kernel_info.args[idx as usize].spirv.name
1234     }
1235 
arg_type_name(&self, idx: cl_uint) -> &String1236     pub fn arg_type_name(&self, idx: cl_uint) -> &String {
1237         &self.kernel_info.args[idx as usize].spirv.type_name
1238     }
1239 
priv_mem_size(&self, dev: &Device) -> cl_ulong1240     pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
1241         self.builds.get(dev).unwrap().info.private_memory as cl_ulong
1242     }
1243 
max_threads_per_block(&self, dev: &Device) -> usize1244     pub fn max_threads_per_block(&self, dev: &Device) -> usize {
1245         self.builds.get(dev).unwrap().info.max_threads as usize
1246     }
1247 
preferred_simd_size(&self, dev: &Device) -> usize1248     pub fn preferred_simd_size(&self, dev: &Device) -> usize {
1249         self.builds.get(dev).unwrap().info.preferred_simd_size as usize
1250     }
1251 
local_mem_size(&self, dev: &Device) -> cl_ulong1252     pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
1253         // TODO include args
1254         self.builds.get(dev).unwrap().shared_size as cl_ulong
1255     }
1256 
has_svm_devs(&self) -> bool1257     pub fn has_svm_devs(&self) -> bool {
1258         self.prog.devs.iter().any(|dev| dev.svm_supported())
1259     }
1260 
subgroup_sizes(&self, dev: &Device) -> Vec<usize>1261     pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
1262         SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
1263             .map(|bit| 1 << bit)
1264             .collect()
1265     }
1266 
subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize1267     pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1268         let subgroup_size = self.subgroup_size_for_block(dev, block);
1269         if subgroup_size == 0 {
1270             return 0;
1271         }
1272 
1273         let threads = block.iter().product();
1274         div_round_up(threads, subgroup_size)
1275     }
1276 
subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize1277     pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1278         let subgroup_sizes = self.subgroup_sizes(dev);
1279         if subgroup_sizes.is_empty() {
1280             return 0;
1281         }
1282 
1283         if subgroup_sizes.len() == 1 {
1284             return subgroup_sizes[0];
1285         }
1286 
1287         let block = [
1288             *block.first().unwrap_or(&1) as u32,
1289             *block.get(1).unwrap_or(&1) as u32,
1290             *block.get(2).unwrap_or(&1) as u32,
1291         ];
1292 
1293         match &self.builds.get(dev).unwrap().nir_or_cso {
1294             KernelDevStateVariant::Cso(cso) => {
1295                 dev.helper_ctx()
1296                     .compute_state_subgroup_size(cso.cso_ptr, &block) as usize
1297             }
1298             _ => {
1299                 panic!()
1300             }
1301         }
1302     }
1303 }
1304 
1305 impl Clone for Kernel {
clone(&self) -> Self1306     fn clone(&self) -> Self {
1307         Self {
1308             base: CLObjectBase::new(RusticlTypes::Kernel),
1309             prog: self.prog.clone(),
1310             name: self.name.clone(),
1311             values: Mutex::new(self.arg_values().clone()),
1312             builds: self.builds.clone(),
1313             kernel_info: self.kernel_info.clone(),
1314         }
1315     }
1316 }
1317