• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::compiler::nir::*;
2 use crate::pipe::screen::*;
3 use crate::util::disk_cache::*;
4 
5 use libc_rust_gen::malloc;
6 use mesa_rust_gen::*;
7 use mesa_rust_util::serialize::*;
8 use mesa_rust_util::string::*;
9 
10 use std::ffi::CStr;
11 use std::ffi::CString;
12 use std::fmt::Debug;
13 use std::ops::Not;
14 use std::os::raw::c_char;
15 use std::os::raw::c_void;
16 use std::ptr;
17 use std::slice;
18 
19 const INPUT_STR: &CStr = c"input.cl";
20 
21 pub enum SpecConstant {
22     None,
23 }
24 
25 pub struct SPIRVBin {
26     spirv: clc_binary,
27     info: Option<clc_parsed_spirv>,
28 }
29 
30 // Safety: SPIRVBin is not mutable and is therefore Send and Sync, needed due to `clc_binary::data`
31 unsafe impl Send for SPIRVBin {}
32 unsafe impl Sync for SPIRVBin {}
33 
34 #[derive(PartialEq, Eq, Hash, Clone)]
35 pub struct SPIRVKernelArg {
36     pub name: CString,
37     pub type_name: CString,
38     pub access_qualifier: clc_kernel_arg_access_qualifier,
39     pub address_qualifier: clc_kernel_arg_address_qualifier,
40     pub type_qualifier: clc_kernel_arg_type_qualifier,
41 }
42 
43 pub struct CLCHeader<'a> {
44     pub name: CString,
45     pub source: &'a CString,
46 }
47 
48 impl Debug for CLCHeader<'_> {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result49     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50         let name = self.name.to_string_lossy();
51         let source = self.source.to_string_lossy();
52 
53         f.write_fmt(format_args!("[{name}]:\n{source}"))
54     }
55 }
56 
callback_impl(data: *mut c_void, msg: *const c_char)57 unsafe fn callback_impl(data: *mut c_void, msg: *const c_char) {
58     let data = data as *mut Vec<String>;
59     let msgs = unsafe { data.as_mut() }.unwrap();
60     msgs.push(c_string_to_string(msg));
61 }
62 
spirv_msg_callback(data: *mut c_void, msg: *const c_char)63 unsafe extern "C" fn spirv_msg_callback(data: *mut c_void, msg: *const c_char) {
64     unsafe {
65         callback_impl(data, msg);
66     }
67 }
68 
spirv_to_nir_msg_callback( data: *mut c_void, dbg_level: nir_spirv_debug_level, _offset: usize, msg: *const c_char, )69 unsafe extern "C" fn spirv_to_nir_msg_callback(
70     data: *mut c_void,
71     dbg_level: nir_spirv_debug_level,
72     _offset: usize,
73     msg: *const c_char,
74 ) {
75     if dbg_level >= nir_spirv_debug_level::NIR_SPIRV_DEBUG_LEVEL_WARNING {
76         unsafe {
77             callback_impl(data, msg);
78         }
79     }
80 }
81 
create_clc_logger(msgs: &mut Vec<String>) -> clc_logger82 fn create_clc_logger(msgs: &mut Vec<String>) -> clc_logger {
83     clc_logger {
84         priv_: ptr::from_mut(msgs).cast(),
85         error: Some(spirv_msg_callback),
86         warning: Some(spirv_msg_callback),
87     }
88 }
89 
90 impl SPIRVBin {
from_clc( source: &CString, args: &[CString], headers: &[CLCHeader], cache: &Option<DiskCache>, features: clc_optional_features, spirv_extensions: &[&CStr], address_bits: u32, ) -> (Option<Self>, String)91     pub fn from_clc(
92         source: &CString,
93         args: &[CString],
94         headers: &[CLCHeader],
95         cache: &Option<DiskCache>,
96         features: clc_optional_features,
97         spirv_extensions: &[&CStr],
98         address_bits: u32,
99     ) -> (Option<Self>, String) {
100         let mut hash_key = None;
101         let has_includes = args.iter().any(|a| a.as_bytes()[0..2] == *b"-I");
102 
103         let mut spirv_extensions: Vec<_> = spirv_extensions.iter().map(|s| s.as_ptr()).collect();
104         spirv_extensions.push(ptr::null());
105 
106         if let Some(cache) = cache {
107             if !has_includes {
108                 let mut key = Vec::new();
109 
110                 key.extend_from_slice(source.as_bytes());
111                 args.iter()
112                     .for_each(|a| key.extend_from_slice(a.as_bytes()));
113                 headers.iter().for_each(|h| {
114                     key.extend_from_slice(h.name.as_bytes());
115                     key.extend_from_slice(h.source.as_bytes());
116                 });
117 
118                 // Safety: clc_optional_features is a struct of bools and contains no padding.
119                 // Sadly we can't guarentee this.
120                 key.extend(unsafe { as_byte_slice(slice::from_ref(&features)) });
121 
122                 let mut key = cache.gen_key(&key);
123                 if let Some(data) = cache.get(&mut key) {
124                     return (Some(Self::from_bin(&data)), String::from(""));
125                 }
126 
127                 hash_key = Some(key);
128             }
129         }
130 
131         let c_headers: Vec<_> = headers
132             .iter()
133             .map(|h| clc_named_value {
134                 name: h.name.as_ptr(),
135                 value: h.source.as_ptr(),
136             })
137             .collect();
138 
139         let c_args: Vec<_> = args.iter().map(|a| a.as_ptr()).collect();
140 
141         let args = clc_compile_args {
142             headers: c_headers.as_ptr(),
143             num_headers: c_headers.len() as u32,
144             source: clc_named_value {
145                 name: INPUT_STR.as_ptr(),
146                 value: source.as_ptr(),
147             },
148             args: c_args.as_ptr(),
149             num_args: c_args.len() as u32,
150             spirv_version: clc_spirv_version::CLC_SPIRV_VERSION_MAX,
151             features: features,
152             use_llvm_spirv_target: false,
153             allowed_spirv_extensions: spirv_extensions.as_ptr(),
154             address_bits: address_bits,
155         };
156         let mut msgs: Vec<String> = Vec::new();
157         let logger = create_clc_logger(&mut msgs);
158         let mut out = clc_binary::default();
159 
160         let res = unsafe { clc_compile_c_to_spirv(&args, &logger, &mut out, ptr::null_mut()) };
161 
162         let res = if res {
163             let spirv = SPIRVBin {
164                 spirv: out,
165                 info: None,
166             };
167 
168             // add cache entry
169             if !has_includes {
170                 if let Some(mut key) = hash_key {
171                     cache.as_ref().unwrap().put(spirv.to_bin(), &mut key);
172                 }
173             }
174 
175             Some(spirv)
176         } else {
177             None
178         };
179 
180         (res, msgs.join(""))
181     }
182 
183     // TODO cache linking, parsing is around 25% of link time
link(spirvs: &[&SPIRVBin], library: bool) -> (Option<Self>, String)184     pub fn link(spirvs: &[&SPIRVBin], library: bool) -> (Option<Self>, String) {
185         let bins: Vec<_> = spirvs.iter().map(|s| ptr::from_ref(&s.spirv)).collect();
186 
187         let linker_args = clc_linker_args {
188             in_objs: bins.as_ptr(),
189             num_in_objs: bins.len() as u32,
190             create_library: library as u32,
191         };
192 
193         let mut msgs: Vec<String> = Vec::new();
194         let logger = create_clc_logger(&mut msgs);
195 
196         let mut out = clc_binary::default();
197         let res = unsafe { clc_link_spirv(&linker_args, &logger, &mut out) };
198 
199         let info = if !library && res {
200             let mut pspirv = clc_parsed_spirv::default();
201             let res = unsafe { clc_parse_spirv(&out, &logger, &mut pspirv) };
202             res.then_some(pspirv)
203         } else {
204             None
205         };
206 
207         let res = res.then_some(SPIRVBin {
208             spirv: out,
209             info: info,
210         });
211         (res, msgs.join(""))
212     }
213 
validate(&self, options: &clc_validator_options) -> (bool, String)214     pub fn validate(&self, options: &clc_validator_options) -> (bool, String) {
215         let mut msgs: Vec<String> = Vec::new();
216         let logger = create_clc_logger(&mut msgs);
217         let res = unsafe { clc_validate_spirv(&self.spirv, &logger, options) };
218 
219         (res, msgs.join(""))
220     }
221 
clone_on_validate(&self, options: &clc_validator_options) -> (Option<Self>, String)222     pub fn clone_on_validate(&self, options: &clc_validator_options) -> (Option<Self>, String) {
223         let (res, msgs) = self.validate(options);
224         (res.then(|| self.clone()), msgs)
225     }
226 
kernel_infos(&self) -> &[clc_kernel_info]227     fn kernel_infos(&self) -> &[clc_kernel_info] {
228         match self.info {
229             Some(info) if info.num_kernels > 0 => unsafe {
230                 slice::from_raw_parts(info.kernels, info.num_kernels as usize)
231             },
232             _ => &[],
233         }
234     }
235 
kernel_info(&self, name: &str) -> Option<&clc_kernel_info>236     pub fn kernel_info(&self, name: &str) -> Option<&clc_kernel_info> {
237         self.kernel_infos()
238             .iter()
239             .find(|i| c_string_to_string(i.name) == name)
240     }
241 
kernels(&self) -> Vec<String>242     pub fn kernels(&self) -> Vec<String> {
243         self.kernel_infos()
244             .iter()
245             .map(|i| i.name)
246             .map(c_string_to_string)
247             .collect()
248     }
249 
args(&self, name: &str) -> Vec<SPIRVKernelArg>250     pub fn args(&self, name: &str) -> Vec<SPIRVKernelArg> {
251         match self.kernel_info(name) {
252             Some(info) if info.num_args > 0 => {
253                 unsafe { slice::from_raw_parts(info.args, info.num_args) }
254                     .iter()
255                     .map(|a| SPIRVKernelArg {
256                         // SAFETY: we have a valid C string pointer here
257                         name: a
258                             .name
259                             .is_null()
260                             .not()
261                             .then(|| unsafe { CStr::from_ptr(a.name) }.to_owned())
262                             .unwrap_or_default(),
263                         type_name: a
264                             .type_name
265                             .is_null()
266                             .not()
267                             .then(|| unsafe { CStr::from_ptr(a.type_name) }.to_owned())
268                             .unwrap_or_default(),
269                         access_qualifier: clc_kernel_arg_access_qualifier(a.access_qualifier),
270                         address_qualifier: a.address_qualifier,
271                         type_qualifier: clc_kernel_arg_type_qualifier(a.type_qualifier),
272                     })
273                     .collect()
274             }
275             _ => Vec::new(),
276         }
277     }
278 
get_spirv_capabilities() -> spirv_capabilities279     fn get_spirv_capabilities() -> spirv_capabilities {
280         spirv_capabilities {
281             Addresses: true,
282             Float16: true,
283             Float16Buffer: true,
284             Float64: true,
285             GenericPointer: true,
286             Groups: true,
287             GroupNonUniformShuffle: true,
288             GroupNonUniformShuffleRelative: true,
289             Int8: true,
290             Int16: true,
291             Int64: true,
292             Kernel: true,
293             ImageBasic: true,
294             ImageReadWrite: true,
295             Linkage: true,
296             LiteralSampler: true,
297             SampledBuffer: true,
298             Sampled1D: true,
299             Vector16: true,
300             ..Default::default()
301         }
302     }
303 
get_spirv_options( library: bool, clc_shader: *const nir_shader, address_bits: u32, caps: &spirv_capabilities, log: Option<&mut Vec<String>>, ) -> spirv_to_nir_options304     fn get_spirv_options(
305         library: bool,
306         clc_shader: *const nir_shader,
307         address_bits: u32,
308         caps: &spirv_capabilities,
309         log: Option<&mut Vec<String>>,
310     ) -> spirv_to_nir_options {
311         let global_addr_format;
312         let offset_addr_format;
313 
314         if address_bits == 32 {
315             global_addr_format = nir_address_format::nir_address_format_32bit_global;
316             offset_addr_format = nir_address_format::nir_address_format_32bit_offset;
317         } else {
318             global_addr_format = nir_address_format::nir_address_format_64bit_global;
319             offset_addr_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
320         }
321 
322         let debug = log.map(|log| spirv_to_nir_options__bindgen_ty_1 {
323             func: Some(spirv_to_nir_msg_callback),
324             private_data: ptr::from_mut(log).cast(),
325         });
326 
327         spirv_to_nir_options {
328             create_library: library,
329             environment: nir_spirv_execution_environment::NIR_SPIRV_OPENCL,
330             clc_shader: clc_shader,
331             float_controls_execution_mode: float_controls::FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP32
332                 as u32,
333 
334             printf: true,
335             capabilities: caps,
336             constant_addr_format: global_addr_format,
337             global_addr_format: global_addr_format,
338             shared_addr_format: offset_addr_format,
339             temp_addr_format: offset_addr_format,
340             debug: debug.unwrap_or_default(),
341 
342             ..Default::default()
343         }
344     }
345 
to_nir( &self, entry_point: &str, nir_options: *const nir_shader_compiler_options, libclc: &NirShader, spec_constants: &mut [nir_spirv_specialization], address_bits: u32, log: Option<&mut Vec<String>>, ) -> Option<NirShader>346     pub fn to_nir(
347         &self,
348         entry_point: &str,
349         nir_options: *const nir_shader_compiler_options,
350         libclc: &NirShader,
351         spec_constants: &mut [nir_spirv_specialization],
352         address_bits: u32,
353         log: Option<&mut Vec<String>>,
354     ) -> Option<NirShader> {
355         let c_entry = CString::new(entry_point.as_bytes()).unwrap();
356         let spirv_caps = Self::get_spirv_capabilities();
357         let spirv_options =
358             Self::get_spirv_options(false, libclc.get_nir(), address_bits, &spirv_caps, log);
359 
360         let nir = unsafe {
361             spirv_to_nir(
362                 self.spirv.data.cast(),
363                 self.spirv.size / 4,
364                 spec_constants.as_mut_ptr(),
365                 spec_constants.len() as u32,
366                 gl_shader_stage::MESA_SHADER_KERNEL,
367                 c_entry.as_ptr(),
368                 &spirv_options,
369                 nir_options,
370             )
371         };
372 
373         NirShader::new(nir)
374     }
375 
get_lib_clc(screen: &PipeScreen) -> Option<NirShader>376     pub fn get_lib_clc(screen: &PipeScreen) -> Option<NirShader> {
377         let nir_options = screen.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE);
378         let address_bits = screen.compute_param(pipe_compute_cap::PIPE_COMPUTE_CAP_ADDRESS_BITS);
379         let spirv_caps = Self::get_spirv_capabilities();
380         let spirv_options =
381             Self::get_spirv_options(false, ptr::null(), address_bits, &spirv_caps, None);
382         let shader_cache = DiskCacheBorrowed::as_ptr(&screen.shader_cache());
383 
384         NirShader::new(unsafe {
385             nir_load_libclc_shader(
386                 address_bits,
387                 shader_cache,
388                 &spirv_options,
389                 nir_options,
390                 true,
391             )
392         })
393     }
394 
to_bin(&self) -> &[u8]395     pub fn to_bin(&self) -> &[u8] {
396         unsafe { slice::from_raw_parts(self.spirv.data.cast(), self.spirv.size) }
397     }
398 
from_bin(bin: &[u8]) -> Self399     pub fn from_bin(bin: &[u8]) -> Self {
400         unsafe {
401             let ptr = malloc(bin.len());
402             ptr::copy_nonoverlapping(bin.as_ptr(), ptr.cast(), bin.len());
403             let spirv = clc_binary {
404                 data: ptr,
405                 size: bin.len(),
406             };
407 
408             let mut pspirv = clc_parsed_spirv::default();
409 
410             let info = if clc_parse_spirv(&spirv, ptr::null(), &mut pspirv) {
411                 Some(pspirv)
412             } else {
413                 None
414             };
415 
416             SPIRVBin {
417                 spirv: spirv,
418                 info: info,
419             }
420         }
421     }
422 
spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type>423     pub fn spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type> {
424         let info = self.info?;
425         if info.num_spec_constants == 0 {
426             return None;
427         }
428 
429         let spec_constants =
430             unsafe { slice::from_raw_parts(info.spec_constants, info.num_spec_constants as usize) };
431 
432         spec_constants
433             .iter()
434             .find(|sc| sc.id == spec_id)
435             .map(|sc| sc.type_)
436     }
437 
print(&self)438     pub fn print(&self) {
439         unsafe {
440             clc_dump_spirv(&self.spirv, stderr_ptr());
441         }
442     }
443 }
444 
445 impl Clone for SPIRVBin {
clone(&self) -> Self446     fn clone(&self) -> Self {
447         Self::from_bin(self.to_bin())
448     }
449 }
450 
451 impl Drop for SPIRVBin {
drop(&mut self)452     fn drop(&mut self) {
453         unsafe {
454             clc_free_spirv(&mut self.spirv);
455             if let Some(info) = &mut self.info {
456                 clc_free_parsed_spirv(info);
457             }
458         }
459     }
460 }
461 
462 impl SPIRVKernelArg {
serialize(&self, blob: &mut blob)463     pub fn serialize(&self, blob: &mut blob) {
464         unsafe {
465             blob_write_uint32(blob, self.access_qualifier.0);
466             blob_write_uint32(blob, self.type_qualifier.0);
467 
468             blob_write_string(blob, self.name.as_ptr());
469             blob_write_string(blob, self.type_name.as_ptr());
470 
471             blob_write_uint8(blob, self.address_qualifier as u8);
472         }
473     }
474 
deserialize(blob: &mut blob_reader) -> Option<Self>475     pub fn deserialize(blob: &mut blob_reader) -> Option<Self> {
476         unsafe {
477             let access_qualifier = blob_read_uint32(blob);
478             let type_qualifier = blob_read_uint32(blob);
479 
480             let name = blob_read_string(blob);
481             let type_name = blob_read_string(blob);
482 
483             let address_qualifier = match blob_read_uint8(blob) {
484                 0 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE,
485                 1 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT,
486                 2 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL,
487                 3 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL,
488                 _ => return None,
489             };
490 
491             // check overrun to ensure nothing went wrong
492             blob.overrun.not().then(|| Self {
493                 // SAFETY: blob_read_string checks for a valid nul character already and sets the
494                 //         blob to overrun state if none was found.
495                 name: CStr::from_ptr(name).to_owned(),
496                 type_name: CStr::from_ptr(type_name).to_owned(),
497                 access_qualifier: clc_kernel_arg_access_qualifier(access_qualifier),
498                 address_qualifier: address_qualifier,
499                 type_qualifier: clc_kernel_arg_type_qualifier(type_qualifier),
500             })
501         }
502     }
503 }
504 
505 pub trait CLCSpecConstantType {
size(self) -> u8506     fn size(self) -> u8;
507 }
508 
509 impl CLCSpecConstantType for clc_spec_constant_type {
size(self) -> u8510     fn size(self) -> u8 {
511         match self {
512             Self::CLC_SPEC_CONSTANT_INT64
513             | Self::CLC_SPEC_CONSTANT_UINT64
514             | Self::CLC_SPEC_CONSTANT_DOUBLE => 8,
515             Self::CLC_SPEC_CONSTANT_INT32
516             | Self::CLC_SPEC_CONSTANT_UINT32
517             | Self::CLC_SPEC_CONSTANT_FLOAT => 4,
518             Self::CLC_SPEC_CONSTANT_INT16 | Self::CLC_SPEC_CONSTANT_UINT16 => 2,
519             Self::CLC_SPEC_CONSTANT_INT8
520             | Self::CLC_SPEC_CONSTANT_UINT8
521             | Self::CLC_SPEC_CONSTANT_BOOL => 1,
522             Self::CLC_SPEC_CONSTANT_UNKNOWN => 0,
523         }
524     }
525 }
526 
527 pub trait SpirvKernelInfo {
vec_type_hint(&self) -> Option<String>528     fn vec_type_hint(&self) -> Option<String>;
local_size(&self) -> Option<String>529     fn local_size(&self) -> Option<String>;
local_size_hint(&self) -> Option<String>530     fn local_size_hint(&self) -> Option<String>;
531 
attribute_str(&self) -> String532     fn attribute_str(&self) -> String {
533         let attributes_strings = [
534             self.vec_type_hint(),
535             self.local_size(),
536             self.local_size_hint(),
537         ];
538 
539         let attributes_strings: Vec<_> = attributes_strings.into_iter().flatten().collect();
540         attributes_strings.join(",")
541     }
542 }
543 
544 impl SpirvKernelInfo for clc_kernel_info {
vec_type_hint(&self) -> Option<String>545     fn vec_type_hint(&self) -> Option<String> {
546         if ![1, 2, 3, 4, 8, 16].contains(&self.vec_hint_size) {
547             return None;
548         }
549         let cltype = match self.vec_hint_type {
550             clc_vec_hint_type::CLC_VEC_HINT_TYPE_CHAR => "uchar",
551             clc_vec_hint_type::CLC_VEC_HINT_TYPE_SHORT => "ushort",
552             clc_vec_hint_type::CLC_VEC_HINT_TYPE_INT => "uint",
553             clc_vec_hint_type::CLC_VEC_HINT_TYPE_LONG => "ulong",
554             clc_vec_hint_type::CLC_VEC_HINT_TYPE_HALF => "half",
555             clc_vec_hint_type::CLC_VEC_HINT_TYPE_FLOAT => "float",
556             clc_vec_hint_type::CLC_VEC_HINT_TYPE_DOUBLE => "double",
557         };
558 
559         Some(format!("vec_type_hint({}{})", cltype, self.vec_hint_size))
560     }
561 
local_size(&self) -> Option<String>562     fn local_size(&self) -> Option<String> {
563         if self.local_size == [0; 3] {
564             return None;
565         }
566         Some(format!(
567             "reqd_work_group_size({},{},{})",
568             self.local_size[0], self.local_size[1], self.local_size[2]
569         ))
570     }
571 
local_size_hint(&self) -> Option<String>572     fn local_size_hint(&self) -> Option<String> {
573         if self.local_size_hint == [0; 3] {
574             return None;
575         }
576         Some(format!(
577             "work_group_size_hint({},{},{})",
578             self.local_size_hint[0], self.local_size_hint[1], self.local_size_hint[2]
579         ))
580     }
581 }
582