1 use crate::api::icd::*;
2 use crate::core::device::*;
3 use crate::core::event::*;
4 use crate::core::memory::*;
5 use crate::core::program::*;
6 use crate::core::queue::*;
7 use crate::impl_cl_type_trait;
8
9 use mesa_rust::compiler::clc::*;
10 use mesa_rust::compiler::nir::*;
11 use mesa_rust::nir_pass;
12 use mesa_rust::pipe::context::RWFlags;
13 use mesa_rust::pipe::context::ResourceMapType;
14 use mesa_rust::pipe::resource::*;
15 use mesa_rust::pipe::screen::ResourceType;
16 use mesa_rust_gen::*;
17 use mesa_rust_util::math::*;
18 use mesa_rust_util::serialize::*;
19 use rusticl_opencl_gen::*;
20
21 use std::cmp;
22 use std::collections::HashMap;
23 use std::convert::TryInto;
24 use std::os::raw::c_void;
25 use std::ptr;
26 use std::slice;
27 use std::sync::Arc;
28 use std::sync::Mutex;
29 use std::sync::MutexGuard;
30
31 // ugh, we are not allowed to take refs, so...
32 #[derive(Clone)]
33 pub enum KernelArgValue {
34 None,
35 Buffer(Arc<Buffer>),
36 Constant(Vec<u8>),
37 Image(Arc<Image>),
38 LocalMem(usize),
39 Sampler(Arc<Sampler>),
40 }
41
42 #[derive(Hash, PartialEq, Eq, Clone, Copy)]
43 pub enum KernelArgType {
44 Constant = 0, // for anything passed by value
45 Image = 1,
46 RWImage = 2,
47 Sampler = 3,
48 Texture = 4,
49 MemGlobal = 5,
50 MemConstant = 6,
51 MemLocal = 7,
52 }
53
54 #[derive(Hash, PartialEq, Eq, Clone)]
55 pub enum InternalKernelArgType {
56 ConstantBuffer,
57 GlobalWorkOffsets,
58 PrintfBuffer,
59 InlineSampler((cl_addressing_mode, cl_filter_mode, bool)),
60 FormatArray,
61 OrderArray,
62 WorkDim,
63 }
64
65 #[derive(Hash, PartialEq, Eq, Clone)]
66 pub struct KernelArg {
67 spirv: spirv::SPIRVKernelArg,
68 pub kind: KernelArgType,
69 pub size: usize,
70 /// The offset into the input buffer
71 pub offset: usize,
72 /// The actual binding slot
73 pub binding: u32,
74 pub dead: bool,
75 }
76
77 #[derive(Hash, PartialEq, Eq, Clone)]
78 pub struct InternalKernelArg {
79 pub kind: InternalKernelArgType,
80 pub size: usize,
81 pub offset: usize,
82 }
83
84 impl KernelArg {
from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self>85 fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
86 let nir_arg_map: HashMap<_, _> = nir
87 .variables_with_mode(
88 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
89 )
90 .map(|v| (v.data.location, v))
91 .collect();
92 let mut res = Vec::new();
93
94 for (i, s) in spirv.iter().enumerate() {
95 let nir = nir_arg_map.get(&(i as i32)).unwrap();
96 let kind = match s.address_qualifier {
97 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
98 if unsafe { glsl_type_is_sampler(nir.type_) } {
99 KernelArgType::Sampler
100 } else {
101 KernelArgType::Constant
102 }
103 }
104 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
105 KernelArgType::MemConstant
106 }
107 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
108 KernelArgType::MemLocal
109 }
110 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
111 if unsafe { glsl_type_is_image(nir.type_) } {
112 let access = nir.data.access();
113 if access == gl_access_qualifier::ACCESS_NON_WRITEABLE.0 {
114 KernelArgType::Texture
115 } else if access == gl_access_qualifier::ACCESS_NON_READABLE.0 {
116 KernelArgType::Image
117 } else {
118 KernelArgType::RWImage
119 }
120 } else {
121 KernelArgType::MemGlobal
122 }
123 }
124 };
125
126 res.push(Self {
127 spirv: s.clone(),
128 size: unsafe { glsl_get_cl_size(nir.type_) } as usize,
129 // we'll update it later in the 2nd pass
130 kind: kind,
131 offset: 0,
132 binding: 0,
133 dead: true,
134 });
135 }
136 res
137 }
138
assign_locations( args: &mut [Self], internal_args: &mut [InternalKernelArg], nir: &mut NirShader, )139 fn assign_locations(
140 args: &mut [Self],
141 internal_args: &mut [InternalKernelArg],
142 nir: &mut NirShader,
143 ) {
144 for var in nir.variables_with_mode(
145 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
146 ) {
147 if let Some(arg) = args.get_mut(var.data.location as usize) {
148 arg.offset = var.data.driver_location as usize;
149 arg.binding = var.data.binding;
150 arg.dead = false;
151 } else {
152 internal_args
153 .get_mut(var.data.location as usize - args.len())
154 .unwrap()
155 .offset = var.data.driver_location as usize;
156 }
157 }
158 }
159
serialize(&self) -> Vec<u8>160 fn serialize(&self) -> Vec<u8> {
161 let mut bin = Vec::new();
162
163 bin.append(&mut self.spirv.serialize());
164 bin.extend_from_slice(&self.size.to_ne_bytes());
165 bin.extend_from_slice(&self.offset.to_ne_bytes());
166 bin.extend_from_slice(&self.binding.to_ne_bytes());
167 bin.extend_from_slice(&(self.dead as u8).to_ne_bytes());
168 bin.extend_from_slice(&(self.kind as u8).to_ne_bytes());
169
170 bin
171 }
172
deserialize(bin: &mut &[u8]) -> Option<Self>173 fn deserialize(bin: &mut &[u8]) -> Option<Self> {
174 let spirv = spirv::SPIRVKernelArg::deserialize(bin)?;
175 let size = read_ne_usize(bin);
176 let offset = read_ne_usize(bin);
177 let binding = read_ne_u32(bin);
178 let dead = read_ne_u8(bin) == 1;
179
180 let kind = match read_ne_u8(bin) {
181 0 => KernelArgType::Constant,
182 1 => KernelArgType::Image,
183 2 => KernelArgType::RWImage,
184 3 => KernelArgType::Sampler,
185 4 => KernelArgType::Texture,
186 5 => KernelArgType::MemGlobal,
187 6 => KernelArgType::MemConstant,
188 7 => KernelArgType::MemLocal,
189 _ => return None,
190 };
191
192 Some(Self {
193 spirv: spirv,
194 kind: kind,
195 size: size,
196 offset: offset,
197 binding: binding,
198 dead: dead,
199 })
200 }
201 }
202
203 impl InternalKernelArg {
serialize(&self) -> Vec<u8>204 fn serialize(&self) -> Vec<u8> {
205 let mut bin = Vec::new();
206
207 bin.extend_from_slice(&self.size.to_ne_bytes());
208 bin.extend_from_slice(&self.offset.to_ne_bytes());
209
210 match self.kind {
211 InternalKernelArgType::ConstantBuffer => bin.push(0),
212 InternalKernelArgType::GlobalWorkOffsets => bin.push(1),
213 InternalKernelArgType::PrintfBuffer => bin.push(2),
214 InternalKernelArgType::InlineSampler((addr_mode, filter_mode, norm)) => {
215 bin.push(3);
216 bin.extend_from_slice(&addr_mode.to_ne_bytes());
217 bin.extend_from_slice(&filter_mode.to_ne_bytes());
218 bin.push(norm as u8);
219 }
220 InternalKernelArgType::FormatArray => bin.push(4),
221 InternalKernelArgType::OrderArray => bin.push(5),
222 InternalKernelArgType::WorkDim => bin.push(6),
223 }
224
225 bin
226 }
227
deserialize(bin: &mut &[u8]) -> Option<Self>228 fn deserialize(bin: &mut &[u8]) -> Option<Self> {
229 let size = read_ne_usize(bin);
230 let offset = read_ne_usize(bin);
231
232 let kind = match read_ne_u8(bin) {
233 0 => InternalKernelArgType::ConstantBuffer,
234 1 => InternalKernelArgType::GlobalWorkOffsets,
235 2 => InternalKernelArgType::PrintfBuffer,
236 3 => {
237 let addr_mode = read_ne_u32(bin);
238 let filter_mode = read_ne_u32(bin);
239 let norm = read_ne_u8(bin) == 1;
240 InternalKernelArgType::InlineSampler((addr_mode, filter_mode, norm))
241 }
242 4 => InternalKernelArgType::FormatArray,
243 5 => InternalKernelArgType::OrderArray,
244 6 => InternalKernelArgType::WorkDim,
245 _ => return None,
246 };
247
248 Some(Self {
249 kind: kind,
250 size: size,
251 offset: offset,
252 })
253 }
254 }
255
256 #[derive(Clone, PartialEq, Eq, Hash)]
257 pub struct KernelInfo {
258 pub args: Vec<KernelArg>,
259 pub internal_args: Vec<InternalKernelArg>,
260 pub attributes_string: String,
261 pub work_group_size: [usize; 3],
262 pub subgroup_size: usize,
263 pub num_subgroups: usize,
264 }
265
266 pub struct CSOWrapper {
267 pub cso_ptr: *mut c_void,
268 dev: &'static Device,
269 }
270
271 impl CSOWrapper {
new(dev: &'static Device, nir: &NirShader) -> Self272 pub fn new(dev: &'static Device, nir: &NirShader) -> Self {
273 let cso_ptr = dev
274 .helper_ctx()
275 .create_compute_state(nir, nir.shared_size());
276
277 Self {
278 cso_ptr: cso_ptr,
279 dev: dev,
280 }
281 }
282
get_cso_info(&self) -> pipe_compute_state_object_info283 pub fn get_cso_info(&self) -> pipe_compute_state_object_info {
284 self.dev.helper_ctx().compute_state_info(self.cso_ptr)
285 }
286 }
287
288 impl Drop for CSOWrapper {
drop(&mut self)289 fn drop(&mut self) {
290 self.dev.helper_ctx().delete_compute_state(self.cso_ptr);
291 }
292 }
293
294 pub enum KernelDevStateVariant {
295 Cso(CSOWrapper),
296 Nir(NirShader),
297 }
298
299 pub struct Kernel {
300 pub base: CLObjectBase<CL_INVALID_KERNEL>,
301 pub prog: Arc<Program>,
302 pub name: String,
303 values: Mutex<Vec<Option<KernelArgValue>>>,
304 builds: HashMap<&'static Device, Arc<NirKernelBuild>>,
305 pub kernel_info: KernelInfo,
306 }
307
308 impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
309
create_kernel_arr<T>(vals: &[usize], val: T) -> [T; 3] where T: std::convert::TryFrom<usize> + Copy, <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,310 fn create_kernel_arr<T>(vals: &[usize], val: T) -> [T; 3]
311 where
312 T: std::convert::TryFrom<usize> + Copy,
313 <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,
314 {
315 let mut res = [val; 3];
316 for (i, v) in vals.iter().enumerate() {
317 res[i] = (*v).try_into().expect("64 bit work groups not supported");
318 }
319 res
320 }
321
opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool)322 fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
323 let nir_options = unsafe {
324 &*dev
325 .screen
326 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
327 };
328
329 while {
330 let mut progress = false;
331
332 progress |= nir_pass!(nir, nir_copy_prop);
333 progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
334 progress |= nir_pass!(nir, nir_opt_dead_write_vars);
335
336 if nir_options.lower_to_scalar {
337 nir_pass!(
338 nir,
339 nir_lower_alu_to_scalar,
340 nir_options.lower_to_scalar_filter,
341 ptr::null(),
342 );
343 nir_pass!(nir, nir_lower_phis_to_scalar, false);
344 }
345
346 progress |= nir_pass!(nir, nir_opt_deref);
347 if has_explicit_types {
348 progress |= nir_pass!(nir, nir_opt_memcpy);
349 }
350 progress |= nir_pass!(nir, nir_opt_dce);
351 progress |= nir_pass!(nir, nir_opt_undef);
352 progress |= nir_pass!(nir, nir_opt_constant_folding);
353 progress |= nir_pass!(nir, nir_opt_cse);
354 nir_pass!(nir, nir_split_var_copies);
355 progress |= nir_pass!(nir, nir_lower_var_copies);
356 progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
357 nir_pass!(nir, nir_lower_alu);
358 progress |= nir_pass!(nir, nir_opt_phi_precision);
359 progress |= nir_pass!(nir, nir_opt_algebraic);
360 progress |= nir_pass!(
361 nir,
362 nir_opt_if,
363 nir_opt_if_options::nir_opt_if_optimize_phi_true_false,
364 );
365 progress |= nir_pass!(nir, nir_opt_dead_cf);
366 progress |= nir_pass!(nir, nir_opt_remove_phis);
367 // we don't want to be too aggressive here, but it kills a bit of CFG
368 progress |= nir_pass!(nir, nir_opt_peephole_select, 8, true, true);
369 progress |= nir_pass!(
370 nir,
371 nir_lower_vec3_to_vec4,
372 nir_variable_mode::nir_var_mem_generic | nir_variable_mode::nir_var_uniform,
373 );
374
375 if nir_options.max_unroll_iterations != 0 {
376 progress |= nir_pass!(nir, nir_opt_loop_unroll);
377 }
378 nir.sweep_mem();
379 progress
380 } {}
381 }
382
383 /// # Safety
384 ///
385 /// Only safe to call when `var` is a valid pointer to a valid [`nir_variable`]
can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool386 unsafe extern "C" fn can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool {
387 // SAFETY: It is the caller's responsibility to provide a valid and aligned pointer
388 let var_type = unsafe { (*var).type_ };
389 // SAFETY: `nir_variable`'s type invariant guarantees that the `type_` field is valid and
390 // properly aligned.
391 unsafe {
392 !glsl_type_is_image(var_type)
393 && !glsl_type_is_texture(var_type)
394 && !glsl_type_is_sampler(var_type)
395 }
396 }
397
lower_and_optimize_nir( dev: &Device, nir: &mut NirShader, args: &[spirv::SPIRVKernelArg], lib_clc: &NirShader, ) -> (Vec<KernelArg>, Vec<InternalKernelArg>)398 fn lower_and_optimize_nir(
399 dev: &Device,
400 nir: &mut NirShader,
401 args: &[spirv::SPIRVKernelArg],
402 lib_clc: &NirShader,
403 ) -> (Vec<KernelArg>, Vec<InternalKernelArg>) {
404 let address_bits_base_type;
405 let address_bits_ptr_type;
406 let global_address_format;
407 let shared_address_format;
408
409 if dev.address_bits() == 64 {
410 address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
411 address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
412 global_address_format = nir_address_format::nir_address_format_64bit_global;
413 shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
414 } else {
415 address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
416 address_bits_ptr_type = unsafe { glsl_uint_type() };
417 global_address_format = nir_address_format::nir_address_format_32bit_global;
418 shared_address_format = nir_address_format::nir_address_format_32bit_offset;
419 }
420
421 let mut lower_state = rusticl_lower_state::default();
422 let nir_options = unsafe {
423 &*dev
424 .screen
425 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
426 };
427
428 nir_pass!(nir, nir_scale_fdiv);
429 nir.set_workgroup_size_variable_if_zero();
430 nir.structurize();
431 while {
432 let mut progress = false;
433 nir_pass!(nir, nir_split_var_copies);
434 progress |= nir_pass!(nir, nir_copy_prop);
435 progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
436 progress |= nir_pass!(nir, nir_opt_dead_write_vars);
437 progress |= nir_pass!(nir, nir_opt_deref);
438 progress |= nir_pass!(nir, nir_opt_dce);
439 progress |= nir_pass!(nir, nir_opt_undef);
440 progress |= nir_pass!(nir, nir_opt_constant_folding);
441 progress |= nir_pass!(nir, nir_opt_cse);
442 progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
443 progress |= nir_pass!(nir, nir_opt_algebraic);
444 progress
445 } {}
446 nir.inline(lib_clc);
447 nir.cleanup_functions();
448 // that should free up tons of memory
449 nir.sweep_mem();
450
451 nir_pass!(nir, nir_dedup_inline_samplers);
452
453 let printf_opts = nir_lower_printf_options {
454 max_buffer_size: dev.printf_buffer_size() as u32,
455 };
456 nir_pass!(nir, nir_lower_printf, &printf_opts);
457
458 opt_nir(nir, dev, false);
459
460 let mut args = KernelArg::from_spirv_nir(args, nir);
461 let mut internal_args = Vec::new();
462
463 let dv_opts = nir_remove_dead_variables_options {
464 can_remove_var: Some(can_remove_var),
465 can_remove_var_data: ptr::null_mut(),
466 };
467 nir_pass!(
468 nir,
469 nir_remove_dead_variables,
470 nir_variable_mode::nir_var_uniform
471 | nir_variable_mode::nir_var_image
472 | nir_variable_mode::nir_var_mem_constant
473 | nir_variable_mode::nir_var_mem_shared
474 | nir_variable_mode::nir_var_function_temp,
475 &dv_opts,
476 );
477
478 // asign locations for inline samplers
479 let mut last_loc = -1;
480 for v in nir
481 .variables_with_mode(nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image)
482 {
483 if unsafe { !glsl_type_is_sampler(v.type_) } {
484 last_loc = v.data.location;
485 continue;
486 }
487 let s = unsafe { v.data.anon_1.sampler };
488 if s.is_inline_sampler() != 0 {
489 last_loc += 1;
490 v.data.location = last_loc;
491
492 internal_args.push(InternalKernelArg {
493 kind: InternalKernelArgType::InlineSampler(Sampler::nir_to_cl(
494 s.addressing_mode(),
495 s.filter_mode(),
496 s.normalized_coordinates(),
497 )),
498 offset: 0,
499 size: 0,
500 });
501 } else {
502 last_loc = v.data.location;
503 }
504 }
505
506 nir_pass!(nir, nir_lower_readonly_images_to_tex, true);
507 nir_pass!(
508 nir,
509 nir_lower_cl_images,
510 !dev.images_as_deref(),
511 !dev.samplers_as_deref(),
512 );
513
514 nir_pass!(
515 nir,
516 nir_lower_vars_to_explicit_types,
517 nir_variable_mode::nir_var_mem_constant,
518 Some(glsl_get_cl_type_size_align),
519 );
520
521 // has to run before adding internal kernel arguments
522 nir.extract_constant_initializers();
523
524 // run before gather info
525 nir_pass!(nir, nir_lower_system_values);
526 let mut compute_options = nir_lower_compute_system_values_options::default();
527 compute_options.set_has_base_global_invocation_id(true);
528 nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
529 nir.gather_info();
530
531 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
532 internal_args.push(InternalKernelArg {
533 kind: InternalKernelArgType::GlobalWorkOffsets,
534 offset: 0,
535 size: (3 * dev.address_bits() / 8) as usize,
536 });
537 lower_state.base_global_invoc_id_loc = args.len() + internal_args.len() - 1;
538 nir.add_var(
539 nir_variable_mode::nir_var_uniform,
540 unsafe { glsl_vector_type(address_bits_base_type, 3) },
541 lower_state.base_global_invoc_id_loc,
542 "base_global_invocation_id",
543 );
544 }
545
546 if nir.has_constant() {
547 internal_args.push(InternalKernelArg {
548 kind: InternalKernelArgType::ConstantBuffer,
549 offset: 0,
550 size: (dev.address_bits() / 8) as usize,
551 });
552 lower_state.const_buf_loc = args.len() + internal_args.len() - 1;
553 nir.add_var(
554 nir_variable_mode::nir_var_uniform,
555 address_bits_ptr_type,
556 lower_state.const_buf_loc,
557 "constant_buffer_addr",
558 );
559 }
560 if nir.has_printf() {
561 internal_args.push(InternalKernelArg {
562 kind: InternalKernelArgType::PrintfBuffer,
563 offset: 0,
564 size: (dev.address_bits() / 8) as usize,
565 });
566 lower_state.printf_buf_loc = args.len() + internal_args.len() - 1;
567 nir.add_var(
568 nir_variable_mode::nir_var_uniform,
569 address_bits_ptr_type,
570 lower_state.printf_buf_loc,
571 "printf_buffer_addr",
572 );
573 }
574
575 if nir.num_images() > 0 || nir.num_textures() > 0 {
576 let count = nir.num_images() + nir.num_textures();
577 internal_args.push(InternalKernelArg {
578 kind: InternalKernelArgType::FormatArray,
579 offset: 0,
580 size: 2 * count as usize,
581 });
582
583 internal_args.push(InternalKernelArg {
584 kind: InternalKernelArgType::OrderArray,
585 offset: 0,
586 size: 2 * count as usize,
587 });
588
589 lower_state.format_arr_loc = args.len() + internal_args.len() - 2;
590 nir.add_var(
591 nir_variable_mode::nir_var_uniform,
592 unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
593 lower_state.format_arr_loc,
594 "image_formats",
595 );
596
597 lower_state.order_arr_loc = args.len() + internal_args.len() - 1;
598 nir.add_var(
599 nir_variable_mode::nir_var_uniform,
600 unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
601 lower_state.order_arr_loc,
602 "image_orders",
603 );
604 }
605
606 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
607 internal_args.push(InternalKernelArg {
608 kind: InternalKernelArgType::WorkDim,
609 size: 1,
610 offset: 0,
611 });
612 lower_state.work_dim_loc = args.len() + internal_args.len() - 1;
613 nir.add_var(
614 nir_variable_mode::nir_var_uniform,
615 unsafe { glsl_uint8_t_type() },
616 lower_state.work_dim_loc,
617 "work_dim",
618 );
619 }
620
621 // need to run after first opt loop and remove_dead_variables to get rid of uneccessary scratch
622 // memory
623 nir_pass!(
624 nir,
625 nir_lower_vars_to_explicit_types,
626 nir_variable_mode::nir_var_mem_shared
627 | nir_variable_mode::nir_var_function_temp
628 | nir_variable_mode::nir_var_shader_temp
629 | nir_variable_mode::nir_var_uniform
630 | nir_variable_mode::nir_var_mem_global
631 | nir_variable_mode::nir_var_mem_generic,
632 Some(glsl_get_cl_type_size_align),
633 );
634
635 opt_nir(nir, dev, true);
636 nir_pass!(nir, nir_lower_memcpy);
637
638 // we might have got rid of more function_temp or shared memory
639 nir.reset_scratch_size();
640 nir.reset_shared_size();
641 nir_pass!(
642 nir,
643 nir_remove_dead_variables,
644 nir_variable_mode::nir_var_function_temp | nir_variable_mode::nir_var_mem_shared,
645 &dv_opts,
646 );
647 nir_pass!(
648 nir,
649 nir_lower_vars_to_explicit_types,
650 nir_variable_mode::nir_var_function_temp
651 | nir_variable_mode::nir_var_mem_shared
652 | nir_variable_mode::nir_var_mem_generic,
653 Some(glsl_get_cl_type_size_align),
654 );
655
656 nir_pass!(
657 nir,
658 nir_lower_explicit_io,
659 nir_variable_mode::nir_var_mem_global | nir_variable_mode::nir_var_mem_constant,
660 global_address_format,
661 );
662
663 nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
664 nir_pass!(
665 nir,
666 nir_lower_explicit_io,
667 nir_variable_mode::nir_var_mem_shared
668 | nir_variable_mode::nir_var_function_temp
669 | nir_variable_mode::nir_var_uniform,
670 shared_address_format,
671 );
672
673 if nir_options.lower_int64_options.0 != 0 {
674 nir_pass!(nir, nir_lower_int64);
675 }
676
677 if nir_options.lower_uniforms_to_ubo {
678 nir_pass!(nir, rusticl_lower_inputs);
679 }
680
681 nir_pass!(nir, nir_lower_convert_alu_types, None);
682
683 opt_nir(nir, dev, true);
684
685 /* before passing it into drivers, assign locations as drivers might remove nir_variables or
686 * other things we depend on
687 */
688 KernelArg::assign_locations(&mut args, &mut internal_args, nir);
689
690 /* update the has_variable_shared_mem info as we might have DCEed all of them */
691 nir.set_has_variable_shared_mem(
692 args.iter()
693 .any(|arg| arg.kind == KernelArgType::MemLocal && !arg.dead),
694 );
695 dev.screen.finalize_nir(nir);
696
697 nir_pass!(nir, nir_opt_dce);
698 nir.sweep_mem();
699
700 (args, internal_args)
701 }
702
deserialize_nir( bin: &mut &[u8], d: &Device, ) -> Option<(NirShader, Vec<KernelArg>, Vec<InternalKernelArg>)>703 fn deserialize_nir(
704 bin: &mut &[u8],
705 d: &Device,
706 ) -> Option<(NirShader, Vec<KernelArg>, Vec<InternalKernelArg>)> {
707 let nir_len = read_ne_usize(bin);
708
709 let nir = NirShader::deserialize(
710 bin,
711 nir_len,
712 d.screen()
713 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
714 )?;
715
716 let arg_len = read_ne_usize(bin);
717 let mut args = Vec::with_capacity(arg_len);
718 for _ in 0..arg_len {
719 args.push(KernelArg::deserialize(bin)?);
720 }
721
722 let arg_len = read_ne_usize(bin);
723 let mut internal_args = Vec::with_capacity(arg_len);
724 for _ in 0..arg_len {
725 internal_args.push(InternalKernelArg::deserialize(bin)?);
726 }
727
728 assert!(bin.is_empty());
729
730 Some((nir, args, internal_args))
731 }
732
convert_spirv_to_nir( build: &ProgramBuild, name: &str, args: &[spirv::SPIRVKernelArg], dev: &Device, ) -> (KernelInfo, NirShader)733 pub(super) fn convert_spirv_to_nir(
734 build: &ProgramBuild,
735 name: &str,
736 args: &[spirv::SPIRVKernelArg],
737 dev: &Device,
738 ) -> (KernelInfo, NirShader) {
739 let cache = dev.screen().shader_cache();
740 let key = build.hash_key(dev, name);
741
742 let res = if let Some(cache) = &cache {
743 cache.get(&mut key.unwrap()).and_then(|entry| {
744 let mut bin: &[u8] = &entry;
745 deserialize_nir(&mut bin, dev)
746 })
747 } else {
748 None
749 };
750
751 let (nir, args, internal_args) = if let Some(res) = res {
752 res
753 } else {
754 let mut nir = build.to_nir(name, dev);
755
756 /* this is a hack until we support fp16 properly and check for denorms inside
757 * vstore/vload_half
758 */
759 nir.preserve_fp16_denorms();
760
761 // Set to rtne for now until drivers are able to report their prefered rounding mode, that
762 // also matches what we report via the API.
763 nir.set_fp_rounding_mode_rtne();
764
765 let (args, internal_args) = lower_and_optimize_nir(dev, &mut nir, args, &dev.lib_clc);
766
767 if let Some(cache) = cache {
768 let mut bin = Vec::new();
769 let mut nir = nir.serialize();
770
771 bin.extend_from_slice(&nir.len().to_ne_bytes());
772 bin.append(&mut nir);
773
774 bin.extend_from_slice(&args.len().to_ne_bytes());
775 for arg in &args {
776 bin.append(&mut arg.serialize());
777 }
778
779 bin.extend_from_slice(&internal_args.len().to_ne_bytes());
780 for arg in &internal_args {
781 bin.append(&mut arg.serialize());
782 }
783
784 cache.put(&bin, &mut key.unwrap());
785 }
786
787 (nir, args, internal_args)
788 };
789
790 let attributes_string = build.attribute_str(name, dev);
791 let wgs = nir.workgroup_size();
792 let kernel_info = KernelInfo {
793 args: args,
794 internal_args: internal_args,
795 attributes_string: attributes_string,
796 work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
797 subgroup_size: nir.subgroup_size() as usize,
798 num_subgroups: nir.num_subgroups() as usize,
799 };
800
801 (kernel_info, nir)
802 }
803
extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S]804 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
805 let val;
806 (val, *buf) = (*buf).split_at(S);
807 // we split of 4 bytes and convert to [u8; 4], so this should be safe
808 // use split_array_ref once it's stable
809 val.try_into().unwrap()
810 }
811
812 impl Kernel {
new(name: String, prog: Arc<Program>) -> Arc<Kernel>813 pub fn new(name: String, prog: Arc<Program>) -> Arc<Kernel> {
814 let prog_build = prog.build_info();
815 let kernel_info = prog_build.kernel_info.get(&name).unwrap().clone();
816 let builds = prog_build
817 .builds
818 .iter()
819 .filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, k.clone())))
820 .collect();
821
822 let values = vec![None; kernel_info.args.len()];
823 Arc::new(Self {
824 base: CLObjectBase::new(RusticlTypes::Kernel),
825 prog: prog.clone(),
826 name: name,
827 values: Mutex::new(values),
828 builds: builds,
829 kernel_info: kernel_info,
830 })
831 }
832
optimize_local_size(&self, d: &Device, grid: &mut [u32; 3], block: &mut [u32; 3])833 fn optimize_local_size(&self, d: &Device, grid: &mut [u32; 3], block: &mut [u32; 3]) {
834 let mut threads = self.max_threads_per_block(d) as u32;
835 let dim_threads = d.max_block_sizes();
836 let subgroups = self.preferred_simd_size(d) as u32;
837
838 if !block.contains(&0) {
839 for i in 0..3 {
840 // we already made sure everything is fine
841 grid[i] /= block[i];
842 }
843 return;
844 }
845
846 for i in 0..3 {
847 let t = cmp::min(threads, dim_threads[i] as u32);
848 let gcd = gcd(t, grid[i]);
849
850 block[i] = gcd;
851 grid[i] /= gcd;
852
853 // update limits
854 threads /= block[i];
855 }
856
857 // if we didn't fill the subgroup we can do a bit better if we have threads remaining
858 let total_threads = block[0] * block[1] * block[2];
859 if threads != 1 && total_threads < subgroups {
860 for i in 0..3 {
861 if grid[i] * total_threads < threads {
862 block[i] *= grid[i];
863 grid[i] = 1;
864 // can only do it once as nothing is cleanly divisible
865 break;
866 }
867 }
868 }
869 }
870
871 // the painful part is, that host threads are allowed to modify the kernel object once it was
872 // enqueued, so return a closure with all req data included.
launch( self: &Arc<Self>, q: &Arc<Queue>, work_dim: u32, block: &[usize], grid: &[usize], offsets: &[usize], ) -> CLResult<EventSig>873 pub fn launch(
874 self: &Arc<Self>,
875 q: &Arc<Queue>,
876 work_dim: u32,
877 block: &[usize],
878 grid: &[usize],
879 offsets: &[usize],
880 ) -> CLResult<EventSig> {
881 let nir_kernel_build = self.builds.get(q.device).unwrap().clone();
882 let mut block = create_kernel_arr::<u32>(block, 1);
883 let mut grid = create_kernel_arr::<u32>(grid, 1);
884 let offsets = create_kernel_arr::<u64>(offsets, 0);
885 let mut input: Vec<u8> = Vec::new();
886 let mut resource_info = Vec::new();
887 // Set it once so we get the alignment padding right
888 let static_local_size: u64 = nir_kernel_build.shared_size;
889 let mut variable_local_size: u64 = static_local_size;
890 let printf_size = q.device.printf_buffer_size() as u32;
891 let mut samplers = Vec::new();
892 let mut iviews = Vec::new();
893 let mut sviews = Vec::new();
894 let mut tex_formats: Vec<u16> = Vec::new();
895 let mut tex_orders: Vec<u16> = Vec::new();
896 let mut img_formats: Vec<u16> = Vec::new();
897 let mut img_orders: Vec<u16> = Vec::new();
898 let null_ptr: &[u8] = if q.device.address_bits() == 64 {
899 &[0; 8]
900 } else {
901 &[0; 4]
902 };
903
904 self.optimize_local_size(q.device, &mut grid, &mut block);
905
906 let arg_values = self.arg_values();
907 for (arg, val) in self.kernel_info.args.iter().zip(arg_values.iter()) {
908 if arg.dead {
909 continue;
910 }
911
912 if arg.kind != KernelArgType::Image
913 && arg.kind != KernelArgType::RWImage
914 && arg.kind != KernelArgType::Texture
915 && arg.kind != KernelArgType::Sampler
916 {
917 input.resize(arg.offset, 0);
918 }
919 match val.as_ref().unwrap() {
920 KernelArgValue::Constant(c) => input.extend_from_slice(c),
921 KernelArgValue::Buffer(buffer) => {
922 let res = buffer.get_res_of_dev(q.device)?;
923 if q.device.address_bits() == 64 {
924 input.extend_from_slice(&buffer.offset.to_ne_bytes());
925 } else {
926 input.extend_from_slice(&(buffer.offset as u32).to_ne_bytes());
927 }
928 resource_info.push((res.clone(), arg.offset));
929 }
930 KernelArgValue::Image(image) => {
931 let res = image.get_res_of_dev(q.device)?;
932
933 // If resource is a buffer, the image was created from a buffer. Use strides and
934 // dimensions of the image then.
935 let app_img_info =
936 if res.as_ref().is_buffer() && image.mem_type == CL_MEM_OBJECT_IMAGE2D {
937 Some(AppImgInfo::new(
938 image.image_desc.row_pitch()? / image.image_elem_size as u32,
939 image.image_desc.width()?,
940 image.image_desc.height()?,
941 ))
942 } else {
943 None
944 };
945
946 let format = image.pipe_format;
947 let (formats, orders) = if arg.kind == KernelArgType::Image {
948 iviews.push(res.pipe_image_view(
949 format,
950 false,
951 image.pipe_image_host_access(),
952 app_img_info.as_ref(),
953 ));
954 (&mut img_formats, &mut img_orders)
955 } else if arg.kind == KernelArgType::RWImage {
956 iviews.push(res.pipe_image_view(
957 format,
958 true,
959 image.pipe_image_host_access(),
960 app_img_info.as_ref(),
961 ));
962 (&mut img_formats, &mut img_orders)
963 } else {
964 sviews.push((res.clone(), format, app_img_info));
965 (&mut tex_formats, &mut tex_orders)
966 };
967
968 let binding = arg.binding as usize;
969 assert!(binding >= formats.len());
970
971 formats.resize(binding, 0);
972 orders.resize(binding, 0);
973
974 formats.push(image.image_format.image_channel_data_type as u16);
975 orders.push(image.image_format.image_channel_order as u16);
976 }
977 KernelArgValue::LocalMem(size) => {
978 // TODO 32 bit
979 let pot = cmp::min(*size, 0x80);
980 variable_local_size =
981 align(variable_local_size, pot.next_power_of_two() as u64);
982 if q.device.address_bits() == 64 {
983 input.extend_from_slice(&variable_local_size.to_ne_bytes());
984 } else {
985 input.extend_from_slice(&(variable_local_size as u32).to_ne_bytes());
986 }
987 variable_local_size += *size as u64;
988 }
989 KernelArgValue::Sampler(sampler) => {
990 samplers.push(sampler.pipe());
991 }
992 KernelArgValue::None => {
993 assert!(
994 arg.kind == KernelArgType::MemGlobal
995 || arg.kind == KernelArgType::MemConstant
996 );
997 input.extend_from_slice(null_ptr);
998 }
999 }
1000 }
1001
1002 // subtract the shader local_size as we only request something on top of that.
1003 variable_local_size -= static_local_size;
1004
1005 let mut printf_buf = None;
1006 for arg in &self.kernel_info.internal_args {
1007 if arg.offset > input.len() {
1008 input.resize(arg.offset, 0);
1009 }
1010 match arg.kind {
1011 InternalKernelArgType::ConstantBuffer => {
1012 assert!(nir_kernel_build.constant_buffer.is_some());
1013 input.extend_from_slice(null_ptr);
1014 resource_info.push((
1015 nir_kernel_build.constant_buffer.clone().unwrap(),
1016 arg.offset,
1017 ));
1018 }
1019 InternalKernelArgType::GlobalWorkOffsets => {
1020 if q.device.address_bits() == 64 {
1021 input.extend_from_slice(unsafe { as_byte_slice(&offsets) });
1022 } else {
1023 input.extend_from_slice(unsafe {
1024 as_byte_slice(&[
1025 offsets[0] as u32,
1026 offsets[1] as u32,
1027 offsets[2] as u32,
1028 ])
1029 });
1030 }
1031 }
1032 InternalKernelArgType::PrintfBuffer => {
1033 let buf = Arc::new(
1034 q.device
1035 .screen
1036 .resource_create_buffer(
1037 printf_size,
1038 ResourceType::Staging,
1039 PIPE_BIND_GLOBAL,
1040 )
1041 .unwrap(),
1042 );
1043
1044 input.extend_from_slice(null_ptr);
1045 resource_info.push((buf.clone(), arg.offset));
1046
1047 printf_buf = Some(buf);
1048 }
1049 InternalKernelArgType::InlineSampler(cl) => {
1050 samplers.push(Sampler::cl_to_pipe(cl));
1051 }
1052 InternalKernelArgType::FormatArray => {
1053 input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) });
1054 input.extend_from_slice(unsafe { as_byte_slice(&img_formats) });
1055 }
1056 InternalKernelArgType::OrderArray => {
1057 input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) });
1058 input.extend_from_slice(unsafe { as_byte_slice(&img_orders) });
1059 }
1060 InternalKernelArgType::WorkDim => {
1061 input.extend_from_slice(&[work_dim as u8; 1]);
1062 }
1063 }
1064 }
1065
1066 Ok(Box::new(move |q, ctx| {
1067 let mut input = input.clone();
1068 let mut resources = Vec::with_capacity(resource_info.len());
1069 let mut globals: Vec<*mut u32> = Vec::new();
1070 let printf_format = &nir_kernel_build.printf_info;
1071
1072 let mut sviews: Vec<_> = sviews
1073 .iter()
1074 .map(|(s, f, aii)| ctx.create_sampler_view(s, *f, aii.as_ref()))
1075 .collect();
1076 let samplers: Vec<_> = samplers
1077 .iter()
1078 .map(|s| ctx.create_sampler_state(s))
1079 .collect();
1080
1081 for (res, offset) in &resource_info {
1082 resources.push(res);
1083 globals.push(unsafe { input.as_mut_ptr().add(*offset) }.cast());
1084 }
1085
1086 if let Some(printf_buf) = &printf_buf {
1087 let init_data: [u8; 1] = [4];
1088 ctx.buffer_subdata(
1089 printf_buf,
1090 0,
1091 init_data.as_ptr().cast(),
1092 init_data.len() as u32,
1093 );
1094 }
1095
1096 let temp_cso;
1097 let cso = match &nir_kernel_build.nir_or_cso {
1098 KernelDevStateVariant::Cso(cso) => cso,
1099 KernelDevStateVariant::Nir(nir) => {
1100 temp_cso = CSOWrapper::new(q.device, nir);
1101 &temp_cso
1102 }
1103 };
1104
1105 ctx.bind_compute_state(cso.cso_ptr);
1106 ctx.bind_sampler_states(&samplers);
1107 ctx.set_sampler_views(&mut sviews);
1108 ctx.set_shader_images(&iviews);
1109 ctx.set_global_binding(resources.as_slice(), &mut globals);
1110 ctx.update_cb0(&input);
1111
1112 ctx.launch_grid(work_dim, block, grid, variable_local_size as u32);
1113
1114 ctx.clear_global_binding(globals.len() as u32);
1115 ctx.clear_shader_images(iviews.len() as u32);
1116 ctx.clear_sampler_views(sviews.len() as u32);
1117 ctx.clear_sampler_states(samplers.len() as u32);
1118
1119 ctx.bind_compute_state(ptr::null_mut());
1120
1121 ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
1122
1123 samplers.iter().for_each(|s| ctx.delete_sampler_state(*s));
1124 sviews.iter().for_each(|v| ctx.sampler_view_destroy(*v));
1125
1126 if let Some(printf_buf) = &printf_buf {
1127 let tx = ctx
1128 .buffer_map(
1129 printf_buf,
1130 0,
1131 printf_size as i32,
1132 RWFlags::RD,
1133 ResourceMapType::Normal,
1134 )
1135 .ok_or(CL_OUT_OF_RESOURCES)?
1136 .with_ctx(ctx);
1137 let mut buf: &[u8] =
1138 unsafe { slice::from_raw_parts(tx.ptr().cast(), printf_size as usize) };
1139 let length = u32::from_ne_bytes(*extract(&mut buf));
1140
1141 // update our slice to make sure we don't go out of bounds
1142 buf = &buf[0..(length - 4) as usize];
1143 if let Some(pf) = printf_format.as_ref() {
1144 pf.u_printf(buf)
1145 }
1146 }
1147
1148 Ok(())
1149 }))
1150 }
1151
arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>>1152 pub fn arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>> {
1153 self.values.lock().unwrap()
1154 }
1155
set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()>1156 pub fn set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()> {
1157 self.values
1158 .lock()
1159 .unwrap()
1160 .get_mut(idx)
1161 .ok_or(CL_INVALID_ARG_INDEX)?
1162 .replace(arg);
1163 Ok(())
1164 }
1165
access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier1166 pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
1167 let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier;
1168
1169 if aq
1170 == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
1171 | clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE
1172 {
1173 CL_KERNEL_ARG_ACCESS_READ_WRITE
1174 } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ {
1175 CL_KERNEL_ARG_ACCESS_READ_ONLY
1176 } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE {
1177 CL_KERNEL_ARG_ACCESS_WRITE_ONLY
1178 } else {
1179 CL_KERNEL_ARG_ACCESS_NONE
1180 }
1181 }
1182
address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier1183 pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
1184 match self.kernel_info.args[idx as usize].spirv.address_qualifier {
1185 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
1186 CL_KERNEL_ARG_ADDRESS_PRIVATE
1187 }
1188 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
1189 CL_KERNEL_ARG_ADDRESS_CONSTANT
1190 }
1191 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
1192 CL_KERNEL_ARG_ADDRESS_LOCAL
1193 }
1194 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
1195 CL_KERNEL_ARG_ADDRESS_GLOBAL
1196 }
1197 }
1198 }
1199
type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier1200 pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
1201 let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier;
1202 let zero = clc_kernel_arg_type_qualifier(0);
1203 let mut res = CL_KERNEL_ARG_TYPE_NONE;
1204
1205 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_CONST != zero {
1206 res |= CL_KERNEL_ARG_TYPE_CONST;
1207 }
1208
1209 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_RESTRICT != zero {
1210 res |= CL_KERNEL_ARG_TYPE_RESTRICT;
1211 }
1212
1213 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_VOLATILE != zero {
1214 res |= CL_KERNEL_ARG_TYPE_VOLATILE;
1215 }
1216
1217 res.into()
1218 }
1219
work_group_size(&self) -> [usize; 3]1220 pub fn work_group_size(&self) -> [usize; 3] {
1221 self.kernel_info.work_group_size
1222 }
1223
num_subgroups(&self) -> usize1224 pub fn num_subgroups(&self) -> usize {
1225 self.kernel_info.num_subgroups
1226 }
1227
subgroup_size(&self) -> usize1228 pub fn subgroup_size(&self) -> usize {
1229 self.kernel_info.subgroup_size
1230 }
1231
arg_name(&self, idx: cl_uint) -> &String1232 pub fn arg_name(&self, idx: cl_uint) -> &String {
1233 &self.kernel_info.args[idx as usize].spirv.name
1234 }
1235
arg_type_name(&self, idx: cl_uint) -> &String1236 pub fn arg_type_name(&self, idx: cl_uint) -> &String {
1237 &self.kernel_info.args[idx as usize].spirv.type_name
1238 }
1239
priv_mem_size(&self, dev: &Device) -> cl_ulong1240 pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
1241 self.builds.get(dev).unwrap().info.private_memory as cl_ulong
1242 }
1243
max_threads_per_block(&self, dev: &Device) -> usize1244 pub fn max_threads_per_block(&self, dev: &Device) -> usize {
1245 self.builds.get(dev).unwrap().info.max_threads as usize
1246 }
1247
preferred_simd_size(&self, dev: &Device) -> usize1248 pub fn preferred_simd_size(&self, dev: &Device) -> usize {
1249 self.builds.get(dev).unwrap().info.preferred_simd_size as usize
1250 }
1251
local_mem_size(&self, dev: &Device) -> cl_ulong1252 pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
1253 // TODO include args
1254 self.builds.get(dev).unwrap().shared_size as cl_ulong
1255 }
1256
has_svm_devs(&self) -> bool1257 pub fn has_svm_devs(&self) -> bool {
1258 self.prog.devs.iter().any(|dev| dev.svm_supported())
1259 }
1260
subgroup_sizes(&self, dev: &Device) -> Vec<usize>1261 pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
1262 SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
1263 .map(|bit| 1 << bit)
1264 .collect()
1265 }
1266
subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize1267 pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1268 let subgroup_size = self.subgroup_size_for_block(dev, block);
1269 if subgroup_size == 0 {
1270 return 0;
1271 }
1272
1273 let threads = block.iter().product();
1274 div_round_up(threads, subgroup_size)
1275 }
1276
subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize1277 pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1278 let subgroup_sizes = self.subgroup_sizes(dev);
1279 if subgroup_sizes.is_empty() {
1280 return 0;
1281 }
1282
1283 if subgroup_sizes.len() == 1 {
1284 return subgroup_sizes[0];
1285 }
1286
1287 let block = [
1288 *block.first().unwrap_or(&1) as u32,
1289 *block.get(1).unwrap_or(&1) as u32,
1290 *block.get(2).unwrap_or(&1) as u32,
1291 ];
1292
1293 match &self.builds.get(dev).unwrap().nir_or_cso {
1294 KernelDevStateVariant::Cso(cso) => {
1295 dev.helper_ctx()
1296 .compute_state_subgroup_size(cso.cso_ptr, &block) as usize
1297 }
1298 _ => {
1299 panic!()
1300 }
1301 }
1302 }
1303 }
1304
1305 impl Clone for Kernel {
clone(&self) -> Self1306 fn clone(&self) -> Self {
1307 Self {
1308 base: CLObjectBase::new(RusticlTypes::Kernel),
1309 prog: self.prog.clone(),
1310 name: self.name.clone(),
1311 values: Mutex::new(self.arg_values().clone()),
1312 builds: self.builds.clone(),
1313 kernel_info: self.kernel_info.clone(),
1314 }
1315 }
1316 }
1317