• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::compiler::nir::*;
2 use crate::pipe::screen::*;
3 use crate::util::disk_cache::*;
4 
5 use libc_rust_gen::malloc;
6 use mesa_rust_gen::*;
7 use mesa_rust_util::serialize::*;
8 use mesa_rust_util::string::*;
9 
10 use std::ffi::CString;
11 use std::fmt::Debug;
12 use std::os::raw::c_char;
13 use std::os::raw::c_void;
14 use std::ptr;
15 use std::slice;
16 
17 const INPUT_STR: *const c_char = b"input.cl\0" as *const u8 as *const c_char;
18 
19 pub enum SpecConstant {
20     None,
21 }
22 
23 pub struct SPIRVBin {
24     spirv: clc_binary,
25     info: Option<clc_parsed_spirv>,
26 }
27 
28 // Safety: SPIRVBin is not mutable and is therefore Send and Sync, needed due to `clc_binary::data`
29 unsafe impl Send for SPIRVBin {}
30 unsafe impl Sync for SPIRVBin {}
31 
32 #[derive(PartialEq, Eq, Hash, Clone)]
33 pub struct SPIRVKernelArg {
34     pub name: String,
35     pub type_name: String,
36     pub access_qualifier: clc_kernel_arg_access_qualifier,
37     pub address_qualifier: clc_kernel_arg_address_qualifier,
38     pub type_qualifier: clc_kernel_arg_type_qualifier,
39 }
40 
41 pub struct CLCHeader<'a> {
42     pub name: CString,
43     pub source: &'a CString,
44 }
45 
46 impl<'a> Debug for CLCHeader<'a> {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result47     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48         let name = self.name.to_string_lossy();
49         let source = self.source.to_string_lossy();
50 
51         f.write_fmt(format_args!("[{name}]:\n{source}"))
52     }
53 }
54 
callback_impl(data: *mut c_void, msg: *const c_char)55 unsafe fn callback_impl(data: *mut c_void, msg: *const c_char) {
56     let data = data as *mut Vec<String>;
57     let msgs = unsafe { data.as_mut() }.unwrap();
58     msgs.push(c_string_to_string(msg));
59 }
60 
spirv_msg_callback(data: *mut c_void, msg: *const c_char)61 unsafe extern "C" fn spirv_msg_callback(data: *mut c_void, msg: *const c_char) {
62     unsafe {
63         callback_impl(data, msg);
64     }
65 }
66 
spirv_to_nir_msg_callback( data: *mut c_void, dbg_level: nir_spirv_debug_level, _offset: usize, msg: *const c_char, )67 unsafe extern "C" fn spirv_to_nir_msg_callback(
68     data: *mut c_void,
69     dbg_level: nir_spirv_debug_level,
70     _offset: usize,
71     msg: *const c_char,
72 ) {
73     if dbg_level >= nir_spirv_debug_level::NIR_SPIRV_DEBUG_LEVEL_WARNING {
74         unsafe {
75             callback_impl(data, msg);
76         }
77     }
78 }
79 
create_clc_logger(msgs: &mut Vec<String>) -> clc_logger80 fn create_clc_logger(msgs: &mut Vec<String>) -> clc_logger {
81     clc_logger {
82         priv_: msgs as *mut Vec<String> as *mut c_void,
83         error: Some(spirv_msg_callback),
84         warning: Some(spirv_msg_callback),
85     }
86 }
87 
88 impl SPIRVBin {
from_clc( source: &CString, args: &[CString], headers: &[CLCHeader], cache: &Option<DiskCache>, features: clc_optional_features, spirv_extensions: &[CString], address_bits: u32, ) -> (Option<Self>, String)89     pub fn from_clc(
90         source: &CString,
91         args: &[CString],
92         headers: &[CLCHeader],
93         cache: &Option<DiskCache>,
94         features: clc_optional_features,
95         spirv_extensions: &[CString],
96         address_bits: u32,
97     ) -> (Option<Self>, String) {
98         let mut hash_key = None;
99         let has_includes = args.iter().any(|a| a.as_bytes()[0..2] == *b"-I");
100 
101         let mut spirv_extensions: Vec<_> = spirv_extensions.iter().map(|s| s.as_ptr()).collect();
102         spirv_extensions.push(ptr::null());
103 
104         if let Some(cache) = cache {
105             if !has_includes {
106                 let mut key = Vec::new();
107 
108                 key.extend_from_slice(source.as_bytes());
109                 args.iter()
110                     .for_each(|a| key.extend_from_slice(a.as_bytes()));
111                 headers.iter().for_each(|h| {
112                     key.extend_from_slice(h.name.as_bytes());
113                     key.extend_from_slice(h.source.as_bytes());
114                 });
115 
116                 // Safety: clc_optional_features is a struct of bools and contains no padding.
117                 // Sadly we can't guarentee this.
118                 key.extend(unsafe { as_byte_slice(slice::from_ref(&features)) });
119 
120                 let mut key = cache.gen_key(&key);
121                 if let Some(data) = cache.get(&mut key) {
122                     return (Some(Self::from_bin(&data)), String::from(""));
123                 }
124 
125                 hash_key = Some(key);
126             }
127         }
128 
129         let c_headers: Vec<_> = headers
130             .iter()
131             .map(|h| clc_named_value {
132                 name: h.name.as_ptr(),
133                 value: h.source.as_ptr(),
134             })
135             .collect();
136 
137         let c_args: Vec<_> = args.iter().map(|a| a.as_ptr()).collect();
138 
139         let args = clc_compile_args {
140             headers: c_headers.as_ptr(),
141             num_headers: c_headers.len() as u32,
142             source: clc_named_value {
143                 name: INPUT_STR,
144                 value: source.as_ptr(),
145             },
146             args: c_args.as_ptr(),
147             num_args: c_args.len() as u32,
148             spirv_version: clc_spirv_version::CLC_SPIRV_VERSION_MAX,
149             features: features,
150             use_llvm_spirv_target: false,
151             allowed_spirv_extensions: spirv_extensions.as_ptr(),
152             address_bits: address_bits,
153         };
154         let mut msgs: Vec<String> = Vec::new();
155         let logger = create_clc_logger(&mut msgs);
156         let mut out = clc_binary::default();
157 
158         let res = unsafe { clc_compile_c_to_spirv(&args, &logger, &mut out) };
159 
160         let res = if res {
161             let spirv = SPIRVBin {
162                 spirv: out,
163                 info: None,
164             };
165 
166             // add cache entry
167             if !has_includes {
168                 if let Some(mut key) = hash_key {
169                     cache.as_ref().unwrap().put(spirv.to_bin(), &mut key);
170                 }
171             }
172 
173             Some(spirv)
174         } else {
175             None
176         };
177 
178         (res, msgs.join("\n"))
179     }
180 
181     // TODO cache linking, parsing is around 25% of link time
link(spirvs: &[&SPIRVBin], library: bool) -> (Option<Self>, String)182     pub fn link(spirvs: &[&SPIRVBin], library: bool) -> (Option<Self>, String) {
183         let bins: Vec<_> = spirvs.iter().map(|s| &s.spirv as *const _).collect();
184 
185         let linker_args = clc_linker_args {
186             in_objs: bins.as_ptr(),
187             num_in_objs: bins.len() as u32,
188             create_library: library as u32,
189         };
190 
191         let mut msgs: Vec<String> = Vec::new();
192         let logger = create_clc_logger(&mut msgs);
193 
194         let mut out = clc_binary::default();
195         let res = unsafe { clc_link_spirv(&linker_args, &logger, &mut out) };
196 
197         let info = if !library && res {
198             let mut pspirv = clc_parsed_spirv::default();
199             let res = unsafe { clc_parse_spirv(&out, &logger, &mut pspirv) };
200             res.then_some(pspirv)
201         } else {
202             None
203         };
204 
205         let res = res.then_some(SPIRVBin {
206             spirv: out,
207             info: info,
208         });
209         (res, msgs.join("\n"))
210     }
211 
validate(&self, options: &clc_validator_options) -> (bool, String)212     pub fn validate(&self, options: &clc_validator_options) -> (bool, String) {
213         let mut msgs: Vec<String> = Vec::new();
214         let logger = create_clc_logger(&mut msgs);
215         let res = unsafe { clc_validate_spirv(&self.spirv, &logger, options) };
216 
217         (res, msgs.join("\n"))
218     }
219 
clone_on_validate(&self, options: &clc_validator_options) -> (Option<Self>, String)220     pub fn clone_on_validate(&self, options: &clc_validator_options) -> (Option<Self>, String) {
221         let (res, msgs) = self.validate(options);
222         (res.then(|| self.clone()), msgs)
223     }
224 
kernel_infos(&self) -> &[clc_kernel_info]225     fn kernel_infos(&self) -> &[clc_kernel_info] {
226         match self.info {
227             None => &[],
228             Some(info) => unsafe { slice::from_raw_parts(info.kernels, info.num_kernels as usize) },
229         }
230     }
231 
kernel_info(&self, name: &str) -> Option<&clc_kernel_info>232     fn kernel_info(&self, name: &str) -> Option<&clc_kernel_info> {
233         self.kernel_infos()
234             .iter()
235             .find(|i| c_string_to_string(i.name) == name)
236     }
237 
kernels(&self) -> Vec<String>238     pub fn kernels(&self) -> Vec<String> {
239         self.kernel_infos()
240             .iter()
241             .map(|i| i.name)
242             .map(c_string_to_string)
243             .collect()
244     }
245 
vec_type_hint(&self, name: &str) -> Option<String>246     pub fn vec_type_hint(&self, name: &str) -> Option<String> {
247         self.kernel_info(name)
248             .filter(|info| [1, 2, 3, 4, 8, 16].contains(&info.vec_hint_size))
249             .map(|info| {
250                 let cltype = match info.vec_hint_type {
251                     clc_vec_hint_type::CLC_VEC_HINT_TYPE_CHAR => "uchar",
252                     clc_vec_hint_type::CLC_VEC_HINT_TYPE_SHORT => "ushort",
253                     clc_vec_hint_type::CLC_VEC_HINT_TYPE_INT => "uint",
254                     clc_vec_hint_type::CLC_VEC_HINT_TYPE_LONG => "ulong",
255                     clc_vec_hint_type::CLC_VEC_HINT_TYPE_HALF => "half",
256                     clc_vec_hint_type::CLC_VEC_HINT_TYPE_FLOAT => "float",
257                     clc_vec_hint_type::CLC_VEC_HINT_TYPE_DOUBLE => "double",
258                 };
259 
260                 format!("vec_type_hint({}{})", cltype, info.vec_hint_size)
261             })
262     }
263 
local_size(&self, name: &str) -> Option<String>264     pub fn local_size(&self, name: &str) -> Option<String> {
265         self.kernel_info(name)
266             .filter(|info| info.local_size != [0; 3])
267             .map(|info| {
268                 format!(
269                     "reqd_work_group_size({},{},{})",
270                     info.local_size[0], info.local_size[1], info.local_size[2]
271                 )
272             })
273     }
274 
local_size_hint(&self, name: &str) -> Option<String>275     pub fn local_size_hint(&self, name: &str) -> Option<String> {
276         self.kernel_info(name)
277             .filter(|info| info.local_size_hint != [0; 3])
278             .map(|info| {
279                 format!(
280                     "work_group_size_hint({},{},{})",
281                     info.local_size_hint[0], info.local_size_hint[1], info.local_size_hint[2]
282                 )
283             })
284     }
285 
args(&self, name: &str) -> Vec<SPIRVKernelArg>286     pub fn args(&self, name: &str) -> Vec<SPIRVKernelArg> {
287         match self.kernel_info(name) {
288             None => Vec::new(),
289             Some(info) => unsafe { slice::from_raw_parts(info.args, info.num_args) }
290                 .iter()
291                 .map(|a| SPIRVKernelArg {
292                     name: c_string_to_string(a.name),
293                     type_name: c_string_to_string(a.type_name),
294                     access_qualifier: clc_kernel_arg_access_qualifier(a.access_qualifier),
295                     address_qualifier: a.address_qualifier,
296                     type_qualifier: clc_kernel_arg_type_qualifier(a.type_qualifier),
297                 })
298                 .collect(),
299         }
300     }
301 
get_spirv_options( library: bool, clc_shader: *const nir_shader, address_bits: u32, log: Option<&mut Vec<String>>, ) -> spirv_to_nir_options302     fn get_spirv_options(
303         library: bool,
304         clc_shader: *const nir_shader,
305         address_bits: u32,
306         log: Option<&mut Vec<String>>,
307     ) -> spirv_to_nir_options {
308         let global_addr_format;
309         let offset_addr_format;
310 
311         if address_bits == 32 {
312             global_addr_format = nir_address_format::nir_address_format_32bit_global;
313             offset_addr_format = nir_address_format::nir_address_format_32bit_offset;
314         } else {
315             global_addr_format = nir_address_format::nir_address_format_64bit_global;
316             offset_addr_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
317         }
318 
319         let debug = log.map(|log| spirv_to_nir_options__bindgen_ty_1 {
320             func: Some(spirv_to_nir_msg_callback),
321             private_data: (log as *mut Vec<String>).cast(),
322         });
323 
324         spirv_to_nir_options {
325             create_library: library,
326             environment: nir_spirv_execution_environment::NIR_SPIRV_OPENCL,
327             clc_shader: clc_shader,
328             float_controls_execution_mode: float_controls::FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP32
329                 as u32,
330 
331             caps: spirv_supported_capabilities {
332                 address: true,
333                 float16: true,
334                 float64: true,
335                 generic_pointers: true,
336                 groups: true,
337                 subgroup_shuffle: true,
338                 int8: true,
339                 int16: true,
340                 int64: true,
341                 kernel: true,
342                 kernel_image: true,
343                 kernel_image_read_write: true,
344                 linkage: true,
345                 literal_sampler: true,
346                 printf: true,
347                 ..Default::default()
348             },
349 
350             constant_addr_format: global_addr_format,
351             global_addr_format: global_addr_format,
352             shared_addr_format: offset_addr_format,
353             temp_addr_format: offset_addr_format,
354             debug: debug.unwrap_or_default(),
355 
356             ..Default::default()
357         }
358     }
359 
to_nir( &self, entry_point: &str, nir_options: *const nir_shader_compiler_options, libclc: &NirShader, spec_constants: &mut [nir_spirv_specialization], address_bits: u32, log: Option<&mut Vec<String>>, ) -> Option<NirShader>360     pub fn to_nir(
361         &self,
362         entry_point: &str,
363         nir_options: *const nir_shader_compiler_options,
364         libclc: &NirShader,
365         spec_constants: &mut [nir_spirv_specialization],
366         address_bits: u32,
367         log: Option<&mut Vec<String>>,
368     ) -> Option<NirShader> {
369         let c_entry = CString::new(entry_point.as_bytes()).unwrap();
370         let spirv_options = Self::get_spirv_options(false, libclc.get_nir(), address_bits, log);
371 
372         let nir = unsafe {
373             spirv_to_nir(
374                 self.spirv.data.cast(),
375                 self.spirv.size / 4,
376                 spec_constants.as_mut_ptr(),
377                 spec_constants.len() as u32,
378                 gl_shader_stage::MESA_SHADER_KERNEL,
379                 c_entry.as_ptr(),
380                 &spirv_options,
381                 nir_options,
382             )
383         };
384 
385         NirShader::new(nir)
386     }
387 
get_lib_clc(screen: &PipeScreen) -> Option<NirShader>388     pub fn get_lib_clc(screen: &PipeScreen) -> Option<NirShader> {
389         let nir_options = screen.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE);
390         let address_bits = screen.compute_param(pipe_compute_cap::PIPE_COMPUTE_CAP_ADDRESS_BITS);
391         let spirv_options = Self::get_spirv_options(true, ptr::null(), address_bits, None);
392         let shader_cache = DiskCacheBorrowed::as_ptr(&screen.shader_cache());
393 
394         NirShader::new(unsafe {
395             nir_load_libclc_shader(
396                 address_bits,
397                 shader_cache,
398                 &spirv_options,
399                 nir_options,
400                 true,
401             )
402         })
403     }
404 
to_bin(&self) -> &[u8]405     pub fn to_bin(&self) -> &[u8] {
406         unsafe { slice::from_raw_parts(self.spirv.data.cast(), self.spirv.size) }
407     }
408 
from_bin(bin: &[u8]) -> Self409     pub fn from_bin(bin: &[u8]) -> Self {
410         unsafe {
411             let ptr = malloc(bin.len());
412             ptr::copy_nonoverlapping(bin.as_ptr(), ptr.cast(), bin.len());
413             let spirv = clc_binary {
414                 data: ptr,
415                 size: bin.len(),
416             };
417 
418             let mut pspirv = clc_parsed_spirv::default();
419 
420             let info = if clc_parse_spirv(&spirv, ptr::null(), &mut pspirv) {
421                 Some(pspirv)
422             } else {
423                 None
424             };
425 
426             SPIRVBin {
427                 spirv: spirv,
428                 info: info,
429             }
430         }
431     }
432 
spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type>433     pub fn spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type> {
434         let info = self.info?;
435         let spec_constants =
436             unsafe { slice::from_raw_parts(info.spec_constants, info.num_spec_constants as usize) };
437 
438         spec_constants
439             .iter()
440             .find(|sc| sc.id == spec_id)
441             .map(|sc| sc.type_)
442     }
443 
print(&self)444     pub fn print(&self) {
445         unsafe {
446             clc_dump_spirv(&self.spirv, stderr_ptr());
447         }
448     }
449 }
450 
451 impl Clone for SPIRVBin {
clone(&self) -> Self452     fn clone(&self) -> Self {
453         Self::from_bin(self.to_bin())
454     }
455 }
456 
457 impl Drop for SPIRVBin {
drop(&mut self)458     fn drop(&mut self) {
459         unsafe {
460             clc_free_spirv(&mut self.spirv);
461             if let Some(info) = &mut self.info {
462                 clc_free_parsed_spirv(info);
463             }
464         }
465     }
466 }
467 
468 impl SPIRVKernelArg {
serialize(&self) -> Vec<u8>469     pub fn serialize(&self) -> Vec<u8> {
470         let mut res = Vec::new();
471 
472         let name_arr = self.name.as_bytes();
473         let type_name_arr = self.type_name.as_bytes();
474 
475         res.extend_from_slice(&name_arr.len().to_ne_bytes());
476         res.extend_from_slice(name_arr);
477         res.extend_from_slice(&type_name_arr.len().to_ne_bytes());
478         res.extend_from_slice(type_name_arr);
479         res.extend_from_slice(&u32::to_ne_bytes(self.access_qualifier.0));
480         res.extend_from_slice(&u32::to_ne_bytes(self.type_qualifier.0));
481         res.push(self.address_qualifier as u8);
482 
483         res
484     }
485 
deserialize(bin: &mut &[u8]) -> Option<Self>486     pub fn deserialize(bin: &mut &[u8]) -> Option<Self> {
487         let name_len = read_ne_usize(bin);
488         let name = read_string(bin, name_len)?;
489         let type_len = read_ne_usize(bin);
490         let type_name = read_string(bin, type_len)?;
491         let access_qualifier = read_ne_u32(bin);
492         let type_qualifier = read_ne_u32(bin);
493 
494         let address_qualifier = match read_ne_u8(bin) {
495             0 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE,
496             1 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT,
497             2 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL,
498             3 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL,
499             _ => return None,
500         };
501 
502         Some(Self {
503             name: name,
504             type_name: type_name,
505             access_qualifier: clc_kernel_arg_access_qualifier(access_qualifier),
506             address_qualifier: address_qualifier,
507             type_qualifier: clc_kernel_arg_type_qualifier(type_qualifier),
508         })
509     }
510 }
511 
512 pub trait CLCSpecConstantType {
size(self) -> u8513     fn size(self) -> u8;
514 }
515 
516 impl CLCSpecConstantType for clc_spec_constant_type {
size(self) -> u8517     fn size(self) -> u8 {
518         match self {
519             Self::CLC_SPEC_CONSTANT_INT64
520             | Self::CLC_SPEC_CONSTANT_UINT64
521             | Self::CLC_SPEC_CONSTANT_DOUBLE => 8,
522             Self::CLC_SPEC_CONSTANT_INT32
523             | Self::CLC_SPEC_CONSTANT_UINT32
524             | Self::CLC_SPEC_CONSTANT_FLOAT => 4,
525             Self::CLC_SPEC_CONSTANT_INT16 | Self::CLC_SPEC_CONSTANT_UINT16 => 2,
526             Self::CLC_SPEC_CONSTANT_INT8
527             | Self::CLC_SPEC_CONSTANT_UINT8
528             | Self::CLC_SPEC_CONSTANT_BOOL => 1,
529             Self::CLC_SPEC_CONSTANT_UNKNOWN => 0,
530         }
531     }
532 }
533