• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::api::icd::*;
2 use crate::core::context::*;
3 use crate::core::device::*;
4 use crate::core::kernel::*;
5 use crate::core::platform::Platform;
6 use crate::impl_cl_type_trait;
7 
8 use mesa_rust::compiler::clc::spirv::SPIRVBin;
9 use mesa_rust::compiler::clc::*;
10 use mesa_rust::compiler::nir::*;
11 use mesa_rust::pipe::resource::*;
12 use mesa_rust::pipe::screen::ResourceType;
13 use mesa_rust::util::disk_cache::*;
14 use mesa_rust_gen::*;
15 use rusticl_llvm_gen::*;
16 use rusticl_opencl_gen::*;
17 
18 use std::collections::HashMap;
19 use std::collections::HashSet;
20 use std::ffi::CString;
21 use std::mem::size_of;
22 use std::ptr;
23 use std::slice;
24 use std::sync::Arc;
25 use std::sync::Mutex;
26 use std::sync::MutexGuard;
27 use std::sync::Once;
28 
29 const BIN_HEADER_SIZE_V1: usize =
30     // 1. format version
31     size_of::<u32>() +
32     // 2. spirv len
33     size_of::<u32>() +
34     // 3. binary_type
35     size_of::<cl_program_binary_type>();
36 
37 const BIN_HEADER_SIZE: usize = BIN_HEADER_SIZE_V1;
38 
39 // kernel cache
40 static mut DISK_CACHE: Option<DiskCache> = None;
41 static DISK_CACHE_ONCE: Once = Once::new();
42 
get_disk_cache() -> &'static Option<DiskCache>43 fn get_disk_cache() -> &'static Option<DiskCache> {
44     let func_ptrs = [
45         // ourselves
46         get_disk_cache as _,
47         // LLVM
48         llvm_LLVMContext_LLVMContext as _,
49         // clang
50         clang_getClangFullVersion as _,
51         // SPIRV-LLVM-Translator
52         llvm_writeSpirv1 as _,
53     ];
54     unsafe {
55         DISK_CACHE_ONCE.call_once(|| {
56             DISK_CACHE = DiskCache::new("rusticl", &func_ptrs, 0);
57         });
58         &DISK_CACHE
59     }
60 }
61 
clc_validator_options(dev: &Device) -> clc_validator_options62 fn clc_validator_options(dev: &Device) -> clc_validator_options {
63     clc_validator_options {
64         // has to match CL_DEVICE_MAX_PARAMETER_SIZE
65         limit_max_function_arg: dev.param_max_size() as u32,
66     }
67 }
68 
69 pub enum ProgramSourceType {
70     Binary,
71     Linked,
72     Src(CString),
73     Il(spirv::SPIRVBin),
74 }
75 
76 pub struct Program {
77     pub base: CLObjectBase<CL_INVALID_PROGRAM>,
78     pub context: Arc<Context>,
79     pub devs: Vec<&'static Device>,
80     pub src: ProgramSourceType,
81     build: Mutex<ProgramBuild>,
82 }
83 
84 impl_cl_type_trait!(cl_program, Program, CL_INVALID_PROGRAM);
85 
86 pub struct NirKernelBuild {
87     pub nir_or_cso: KernelDevStateVariant,
88     pub constant_buffer: Option<Arc<PipeResource>>,
89     pub info: pipe_compute_state_object_info,
90     pub shared_size: u64,
91     pub printf_info: Option<NirPrintfInfo>,
92 }
93 
94 // SAFETY: `CSOWrapper` is only safe to use if the device supports `PIPE_CAP_SHAREABLE_SHADERS` and
95 //         we make sure to set `nir_or_cso` to `KernelDevStateVariant::Cso` only if that's the case.
96 unsafe impl Send for NirKernelBuild {}
97 unsafe impl Sync for NirKernelBuild {}
98 
99 pub struct ProgramBuild {
100     pub builds: HashMap<&'static Device, ProgramDevBuild>,
101     pub kernel_info: HashMap<String, KernelInfo>,
102     spec_constants: HashMap<u32, nir_const_value>,
103     kernels: Vec<String>,
104 }
105 
106 impl NirKernelBuild {
new(dev: &'static Device, mut nir: NirShader) -> Self107     pub fn new(dev: &'static Device, mut nir: NirShader) -> Self {
108         let cso = CSOWrapper::new(dev, &nir);
109         let info = cso.get_cso_info();
110         let cb = Self::create_nir_constant_buffer(dev, &nir);
111         let shared_size = nir.shared_size() as u64;
112         let printf_info = nir.take_printf_info();
113 
114         let nir_or_cso = if !dev.shareable_shaders() {
115             KernelDevStateVariant::Nir(nir)
116         } else {
117             KernelDevStateVariant::Cso(cso)
118         };
119 
120         NirKernelBuild {
121             nir_or_cso: nir_or_cso,
122             constant_buffer: cb,
123             info: info,
124             shared_size: shared_size,
125             printf_info: printf_info,
126         }
127     }
128 
create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>>129     fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>> {
130         let buf = nir.get_constant_buffer();
131         let len = buf.len() as u32;
132 
133         if len > 0 {
134             // TODO bind as constant buffer
135             let res = dev
136                 .screen()
137                 .resource_create_buffer(len, ResourceType::Normal, PIPE_BIND_GLOBAL)
138                 .unwrap();
139 
140             dev.helper_ctx()
141                 .exec(|ctx| ctx.buffer_subdata(&res, 0, buf.as_ptr().cast(), len))
142                 .wait();
143 
144             Some(Arc::new(res))
145         } else {
146             None
147         }
148     }
149 }
150 
151 impl ProgramBuild {
attribute_str(&self, kernel: &str, d: &Device) -> String152     pub fn attribute_str(&self, kernel: &str, d: &Device) -> String {
153         let info = self.dev_build(d);
154 
155         let attributes_strings = [
156             info.spirv.as_ref().unwrap().vec_type_hint(kernel),
157             info.spirv.as_ref().unwrap().local_size(kernel),
158             info.spirv.as_ref().unwrap().local_size_hint(kernel),
159         ];
160 
161         let attributes_strings: Vec<_> = attributes_strings
162             .iter()
163             .flatten()
164             .map(String::as_str)
165             .collect();
166         attributes_strings.join(",")
167     }
168 
args(&self, dev: &Device, kernel: &str) -> Vec<spirv::SPIRVKernelArg>169     fn args(&self, dev: &Device, kernel: &str) -> Vec<spirv::SPIRVKernelArg> {
170         self.dev_build(dev).spirv.as_ref().unwrap().args(kernel)
171     }
172 
build_nirs(&mut self, is_src: bool)173     fn build_nirs(&mut self, is_src: bool) {
174         for kernel_name in &self.kernels.clone() {
175             let kernel_args: HashSet<_> = self
176                 .devs_with_build()
177                 .iter()
178                 .map(|d| self.args(d, kernel_name))
179                 .collect();
180 
181             let args = kernel_args.into_iter().next().unwrap();
182             let mut kernel_info_set = HashSet::new();
183 
184             // TODO: we could run this in parallel?
185             for dev in self.devs_with_build() {
186                 let (kernel_info, nir) = convert_spirv_to_nir(self, kernel_name, &args, dev);
187                 kernel_info_set.insert(kernel_info);
188 
189                 self.builds
190                     .get_mut(dev)
191                     .unwrap()
192                     .kernels
193                     .insert(kernel_name.clone(), Arc::new(NirKernelBuild::new(dev, nir)));
194             }
195 
196             // we want the same (internal) args for every compiled kernel, for now
197             assert!(kernel_info_set.len() == 1);
198             let mut kernel_info = kernel_info_set.into_iter().next().unwrap();
199 
200             // spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource
201             // API call the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty.
202             if !is_src {
203                 kernel_info.attributes_string = String::new();
204             }
205 
206             self.kernel_info.insert(kernel_name.clone(), kernel_info);
207         }
208     }
209 
dev_build(&self, dev: &Device) -> &ProgramDevBuild210     fn dev_build(&self, dev: &Device) -> &ProgramDevBuild {
211         self.builds.get(dev).unwrap()
212     }
213 
dev_build_mut(&mut self, dev: &Device) -> &mut ProgramDevBuild214     fn dev_build_mut(&mut self, dev: &Device) -> &mut ProgramDevBuild {
215         self.builds.get_mut(dev).unwrap()
216     }
217 
devs_with_build(&self) -> Vec<&'static Device>218     fn devs_with_build(&self) -> Vec<&'static Device> {
219         self.builds
220             .iter()
221             .filter(|(_, build)| build.status == CL_BUILD_SUCCESS as cl_build_status)
222             .map(|(&d, _)| d)
223             .collect()
224     }
225 
hash_key(&self, dev: &Device, name: &str) -> Option<cache_key>226     pub fn hash_key(&self, dev: &Device, name: &str) -> Option<cache_key> {
227         if let Some(cache) = dev.screen().shader_cache() {
228             let info = self.dev_build(dev);
229             assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status);
230 
231             let spirv = info.spirv.as_ref().unwrap();
232             let mut bin = spirv.to_bin().to_vec();
233             bin.extend_from_slice(name.as_bytes());
234 
235             for (k, v) in &self.spec_constants {
236                 bin.extend_from_slice(&k.to_ne_bytes());
237                 unsafe {
238                     // SAFETY: we fully initialize this union
239                     bin.extend_from_slice(&v.u64_.to_ne_bytes());
240                 }
241             }
242 
243             Some(cache.gen_key(&bin))
244         } else {
245             None
246         }
247     }
248 
to_nir(&self, kernel: &str, d: &Device) -> NirShader249     pub fn to_nir(&self, kernel: &str, d: &Device) -> NirShader {
250         let mut spec_constants: Vec<_> = self
251             .spec_constants
252             .iter()
253             .map(|(&id, &value)| nir_spirv_specialization {
254                 id: id,
255                 value: value,
256                 defined_on_module: true,
257             })
258             .collect();
259 
260         let info = self.dev_build(d);
261         assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status);
262 
263         let mut log = Platform::dbg().program.then(Vec::new);
264         let nir = info.spirv.as_ref().unwrap().to_nir(
265             kernel,
266             d.screen
267                 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
268             &d.lib_clc,
269             &mut spec_constants,
270             d.address_bits(),
271             log.as_mut(),
272         );
273 
274         if let Some(log) = log {
275             for line in log {
276                 eprintln!("{}", line);
277             }
278         };
279 
280         nir.unwrap()
281     }
282 }
283 
284 pub struct ProgramDevBuild {
285     spirv: Option<spirv::SPIRVBin>,
286     status: cl_build_status,
287     options: String,
288     log: String,
289     bin_type: cl_program_binary_type,
290     pub kernels: HashMap<String, Arc<NirKernelBuild>>,
291 }
292 
prepare_options(options: &str, dev: &Device) -> Vec<CString>293 fn prepare_options(options: &str, dev: &Device) -> Vec<CString> {
294     let mut options = options.to_owned();
295     if !options.contains("-cl-std=CL") {
296         options.push_str(" -cl-std=CL");
297         options.push_str(dev.clc_version.api_str());
298     }
299     options.push_str(" -D__OPENCL_VERSION__=");
300     options.push_str(dev.cl_version.clc_str());
301 
302     let mut res = Vec::new();
303 
304     // we seperate on a ' ' unless we hit a "
305     let mut sep = ' ';
306     let mut old = 0;
307     for (i, c) in options.char_indices() {
308         if c == '"' {
309             if sep == ' ' {
310                 sep = '"';
311             } else {
312                 sep = ' ';
313             }
314         }
315 
316         if c == '"' || c == sep {
317             // beware of double seps
318             if old != i {
319                 res.push(&options[old..i]);
320             }
321             old = i + c.len_utf8();
322         }
323     }
324     // add end of the string
325     res.push(&options[old..]);
326 
327     res.iter()
328         .map(|&a| match a {
329             "-cl-denorms-are-zero" => "-fdenormal-fp-math=positive-zero",
330             _ => a,
331         })
332         .map(CString::new)
333         .map(Result::unwrap)
334         .collect()
335 }
336 
337 impl Program {
create_default_builds( devs: &[&'static Device], ) -> HashMap<&'static Device, ProgramDevBuild>338     fn create_default_builds(
339         devs: &[&'static Device],
340     ) -> HashMap<&'static Device, ProgramDevBuild> {
341         devs.iter()
342             .map(|&d| {
343                 (
344                     d,
345                     ProgramDevBuild {
346                         spirv: None,
347                         status: CL_BUILD_NONE,
348                         log: String::from(""),
349                         options: String::from(""),
350                         bin_type: CL_PROGRAM_BINARY_TYPE_NONE,
351                         kernels: HashMap::new(),
352                     },
353                 )
354             })
355             .collect()
356     }
357 
new(context: Arc<Context>, src: CString) -> Arc<Program>358     pub fn new(context: Arc<Context>, src: CString) -> Arc<Program> {
359         Arc::new(Self {
360             base: CLObjectBase::new(RusticlTypes::Program),
361             build: Mutex::new(ProgramBuild {
362                 builds: Self::create_default_builds(&context.devs),
363                 spec_constants: HashMap::new(),
364                 kernels: Vec::new(),
365                 kernel_info: HashMap::new(),
366             }),
367             devs: context.devs.to_vec(),
368             context: context,
369             src: ProgramSourceType::Src(src),
370         })
371     }
372 
from_bins( context: Arc<Context>, devs: Vec<&'static Device>, bins: &[&[u8]], ) -> Arc<Program>373     pub fn from_bins(
374         context: Arc<Context>,
375         devs: Vec<&'static Device>,
376         bins: &[&[u8]],
377     ) -> Arc<Program> {
378         let mut builds = HashMap::new();
379         let mut kernels = HashSet::new();
380 
381         for (&d, b) in devs.iter().zip(bins) {
382             let mut ptr = b.as_ptr();
383             let bin_type;
384             let spirv;
385 
386             unsafe {
387                 // 1. version
388                 let version = ptr.cast::<u32>().read();
389                 ptr = ptr.add(size_of::<u32>());
390 
391                 match version {
392                     1 => {
393                         // 2. size of the spirv
394                         let spirv_size = ptr.cast::<u32>().read();
395                         ptr = ptr.add(size_of::<u32>());
396 
397                         // 3. binary_type
398                         bin_type = ptr.cast::<cl_program_binary_type>().read();
399                         ptr = ptr.add(size_of::<cl_program_binary_type>());
400 
401                         // 4. the spirv
402                         assert!(b.as_ptr().add(BIN_HEADER_SIZE_V1) == ptr);
403                         assert!(b.len() == BIN_HEADER_SIZE_V1 + spirv_size as usize);
404                         spirv = Some(spirv::SPIRVBin::from_bin(slice::from_raw_parts(
405                             ptr,
406                             spirv_size as usize,
407                         )));
408                     }
409                     _ => panic!("unknown version"),
410                 }
411             }
412 
413             if let Some(spirv) = &spirv {
414                 for k in spirv.kernels() {
415                     kernels.insert(k);
416                 }
417             }
418 
419             builds.insert(
420                 d,
421                 ProgramDevBuild {
422                     spirv: spirv,
423                     status: CL_BUILD_SUCCESS as cl_build_status,
424                     log: String::from(""),
425                     options: String::from(""),
426                     bin_type: bin_type,
427                     kernels: HashMap::new(),
428                 },
429             );
430         }
431 
432         let mut build = ProgramBuild {
433             builds: builds,
434             spec_constants: HashMap::new(),
435             kernels: kernels.into_iter().collect(),
436             kernel_info: HashMap::new(),
437         };
438         build.build_nirs(false);
439 
440         Arc::new(Self {
441             base: CLObjectBase::new(RusticlTypes::Program),
442             context: context,
443             devs: devs,
444             src: ProgramSourceType::Binary,
445             build: Mutex::new(build),
446         })
447     }
448 
from_spirv(context: Arc<Context>, spirv: &[u8]) -> Arc<Program>449     pub fn from_spirv(context: Arc<Context>, spirv: &[u8]) -> Arc<Program> {
450         let builds = Self::create_default_builds(&context.devs);
451         Arc::new(Self {
452             base: CLObjectBase::new(RusticlTypes::Program),
453             devs: context.devs.clone(),
454             context: context,
455             src: ProgramSourceType::Il(SPIRVBin::from_bin(spirv)),
456             build: Mutex::new(ProgramBuild {
457                 builds: builds,
458                 spec_constants: HashMap::new(),
459                 kernels: Vec::new(),
460                 kernel_info: HashMap::new(),
461             }),
462         })
463     }
464 
build_info(&self) -> MutexGuard<ProgramBuild>465     pub fn build_info(&self) -> MutexGuard<ProgramBuild> {
466         self.build.lock().unwrap()
467     }
468 
status(&self, dev: &Device) -> cl_build_status469     pub fn status(&self, dev: &Device) -> cl_build_status {
470         self.build_info().dev_build(dev).status
471     }
472 
log(&self, dev: &Device) -> String473     pub fn log(&self, dev: &Device) -> String {
474         self.build_info().dev_build(dev).log.clone()
475     }
476 
bin_type(&self, dev: &Device) -> cl_program_binary_type477     pub fn bin_type(&self, dev: &Device) -> cl_program_binary_type {
478         self.build_info().dev_build(dev).bin_type
479     }
480 
options(&self, dev: &Device) -> String481     pub fn options(&self, dev: &Device) -> String {
482         self.build_info().dev_build(dev).options.clone()
483     }
484 
485     // we need to precalculate the size
bin_sizes(&self) -> Vec<usize>486     pub fn bin_sizes(&self) -> Vec<usize> {
487         let lock = self.build_info();
488         let mut res = Vec::new();
489         for d in &self.devs {
490             let info = lock.dev_build(d);
491 
492             res.push(
493                 info.spirv
494                     .as_ref()
495                     .map_or(0, |s| s.to_bin().len() + BIN_HEADER_SIZE),
496             );
497         }
498         res
499     }
500 
binaries(&self, vals: &[u8]) -> Vec<*mut u8>501     pub fn binaries(&self, vals: &[u8]) -> Vec<*mut u8> {
502         // if the application didn't provide any pointers, just return the length of devices
503         if vals.is_empty() {
504             return vec![std::ptr::null_mut(); self.devs.len()];
505         }
506 
507         // vals is an array of pointers where we should write the device binaries into
508         if vals.len() != self.devs.len() * size_of::<*const u8>() {
509             panic!("wrong size")
510         }
511 
512         let ptrs: &[*mut u8] = unsafe {
513             slice::from_raw_parts(vals.as_ptr().cast(), vals.len() / size_of::<*mut u8>())
514         };
515 
516         let lock = self.build_info();
517         for (i, d) in self.devs.iter().enumerate() {
518             let mut ptr = ptrs[i];
519             let info = lock.dev_build(d);
520 
521             // no spirv means nothing to write
522             let Some(spirv) = info.spirv.as_ref() else {
523                 continue;
524             };
525             let spirv = spirv.to_bin();
526 
527             unsafe {
528                 // 1. binary format version
529                 ptr.cast::<u32>().write(1);
530                 ptr = ptr.add(size_of::<u32>());
531 
532                 // 2. size of the spirv
533                 ptr.cast::<u32>().write(spirv.len() as u32);
534                 ptr = ptr.add(size_of::<u32>());
535 
536                 // 3. binary_type
537                 ptr.cast::<cl_program_binary_type>().write(info.bin_type);
538                 ptr = ptr.add(size_of::<cl_program_binary_type>());
539 
540                 // 4. the spirv
541                 assert!(ptrs[i].add(BIN_HEADER_SIZE) == ptr);
542                 ptr::copy_nonoverlapping(spirv.as_ptr(), ptr, spirv.len());
543             }
544         }
545 
546         ptrs.to_vec()
547     }
548 
kernel_signatures(&self, kernel_name: &str) -> HashSet<Vec<spirv::SPIRVKernelArg>>549     pub fn kernel_signatures(&self, kernel_name: &str) -> HashSet<Vec<spirv::SPIRVKernelArg>> {
550         let build = self.build_info();
551         let devs = build.devs_with_build();
552 
553         if devs.is_empty() {
554             return HashSet::new();
555         }
556 
557         devs.iter().map(|d| build.args(d, kernel_name)).collect()
558     }
559 
kernels(&self) -> Vec<String>560     pub fn kernels(&self) -> Vec<String> {
561         self.build_info().kernels.clone()
562     }
563 
active_kernels(&self) -> bool564     pub fn active_kernels(&self) -> bool {
565         self.build_info()
566             .builds
567             .values()
568             .any(|b| b.kernels.values().any(|b| Arc::strong_count(b) > 1))
569     }
570 
build(&self, dev: &Device, options: String) -> bool571     pub fn build(&self, dev: &Device, options: String) -> bool {
572         let lib = options.contains("-create-library");
573         let mut info = self.build_info();
574         if !self.do_compile(dev, options, &Vec::new(), &mut info) {
575             return false;
576         }
577 
578         let d = info.dev_build_mut(dev);
579 
580         // skip compilation if we already have the right thing.
581         if self.is_bin() {
582             if d.bin_type == CL_PROGRAM_BINARY_TYPE_EXECUTABLE && !lib
583                 || d.bin_type == CL_PROGRAM_BINARY_TYPE_LIBRARY && lib
584             {
585                 return true;
586             }
587         }
588 
589         let spirvs = [d.spirv.as_ref().unwrap()];
590         let (spirv, log) = spirv::SPIRVBin::link(&spirvs, lib);
591 
592         d.log.push_str(&log);
593         d.spirv = spirv;
594         if let Some(spirv) = &d.spirv {
595             d.bin_type = if lib {
596                 CL_PROGRAM_BINARY_TYPE_LIBRARY
597             } else {
598                 CL_PROGRAM_BINARY_TYPE_EXECUTABLE
599             };
600             d.status = CL_BUILD_SUCCESS as cl_build_status;
601             let mut kernels = spirv.kernels();
602             info.kernels.append(&mut kernels);
603             info.build_nirs(self.is_src());
604             true
605         } else {
606             d.status = CL_BUILD_ERROR;
607             d.bin_type = CL_PROGRAM_BINARY_TYPE_NONE;
608             false
609         }
610     }
611 
do_compile( &self, dev: &Device, options: String, headers: &[spirv::CLCHeader], info: &mut MutexGuard<ProgramBuild>, ) -> bool612     fn do_compile(
613         &self,
614         dev: &Device,
615         options: String,
616         headers: &[spirv::CLCHeader],
617         info: &mut MutexGuard<ProgramBuild>,
618     ) -> bool {
619         let d = info.dev_build_mut(dev);
620 
621         let val_options = clc_validator_options(dev);
622         let (spirv, log) = match &self.src {
623             ProgramSourceType::Il(spirv) => {
624                 if Platform::dbg().allow_invalid_spirv {
625                     (Some(spirv.clone()), String::new())
626                 } else {
627                     spirv.clone_on_validate(&val_options)
628                 }
629             }
630             ProgramSourceType::Src(src) => {
631                 let args = prepare_options(&options, dev);
632 
633                 if Platform::dbg().clc {
634                     let src = src.to_string_lossy();
635                     eprintln!("dumping compilation inputs:");
636                     eprintln!("compilation arguments: {args:?}");
637                     if !headers.is_empty() {
638                         eprintln!("headers: {headers:#?}");
639                     }
640                     eprintln!("source code:\n{src}");
641                 }
642 
643                 let (spirv, msgs) = spirv::SPIRVBin::from_clc(
644                     src,
645                     &args,
646                     headers,
647                     get_disk_cache(),
648                     dev.cl_features(),
649                     &dev.spirv_extensions,
650                     dev.address_bits(),
651                 );
652 
653                 if Platform::dbg().validate_spirv {
654                     if let Some(spirv) = spirv {
655                         let (res, spirv_msgs) = spirv.validate(&val_options);
656                         (res.then_some(spirv), format!("{}\n{}", msgs, spirv_msgs))
657                     } else {
658                         (None, msgs)
659                     }
660                 } else {
661                     (spirv, msgs)
662                 }
663             }
664             // do nothing if we got a library or binary
665             _ => {
666                 return true;
667             }
668         };
669 
670         d.spirv = spirv;
671         d.log = log;
672         d.options = options;
673 
674         if d.spirv.is_some() {
675             d.status = CL_BUILD_SUCCESS as cl_build_status;
676             d.bin_type = CL_PROGRAM_BINARY_TYPE_COMPILED_OBJECT;
677             true
678         } else {
679             d.status = CL_BUILD_ERROR;
680             false
681         }
682     }
683 
compile(&self, dev: &Device, options: String, headers: &[spirv::CLCHeader]) -> bool684     pub fn compile(&self, dev: &Device, options: String, headers: &[spirv::CLCHeader]) -> bool {
685         self.do_compile(dev, options, headers, &mut self.build_info())
686     }
687 
link( context: Arc<Context>, devs: &[&'static Device], progs: &[Arc<Program>], options: String, ) -> Arc<Program>688     pub fn link(
689         context: Arc<Context>,
690         devs: &[&'static Device],
691         progs: &[Arc<Program>],
692         options: String,
693     ) -> Arc<Program> {
694         let mut builds = HashMap::new();
695         let mut kernels = HashSet::new();
696         let mut locks: Vec<_> = progs.iter().map(|p| p.build_info()).collect();
697         let lib = options.contains("-create-library");
698 
699         for &d in devs {
700             let bins: Vec<_> = locks
701                 .iter_mut()
702                 .map(|l| l.dev_build(d).spirv.as_ref().unwrap())
703                 .collect();
704 
705             let (spirv, log) = spirv::SPIRVBin::link(&bins, lib);
706             let (spirv, log) = if Platform::dbg().validate_spirv {
707                 if let Some(spirv) = spirv {
708                     let val_options = clc_validator_options(d);
709                     let (res, spirv_msgs) = spirv.validate(&val_options);
710                     (res.then_some(spirv), format!("{}\n{}", log, spirv_msgs))
711                 } else {
712                     (None, log)
713                 }
714             } else {
715                 (spirv, log)
716             };
717 
718             let status;
719             let bin_type;
720             if let Some(spirv) = &spirv {
721                 for k in spirv.kernels() {
722                     kernels.insert(k);
723                 }
724                 status = CL_BUILD_SUCCESS as cl_build_status;
725                 bin_type = if lib {
726                     CL_PROGRAM_BINARY_TYPE_LIBRARY
727                 } else {
728                     CL_PROGRAM_BINARY_TYPE_EXECUTABLE
729                 };
730             } else {
731                 status = CL_BUILD_ERROR;
732                 bin_type = CL_PROGRAM_BINARY_TYPE_NONE;
733             };
734 
735             builds.insert(
736                 d,
737                 ProgramDevBuild {
738                     spirv: spirv,
739                     status: status,
740                     log: log,
741                     options: String::from(""),
742                     bin_type: bin_type,
743                     kernels: HashMap::new(),
744                 },
745             );
746         }
747 
748         let mut build = ProgramBuild {
749             builds: builds,
750             spec_constants: HashMap::new(),
751             kernels: kernels.into_iter().collect(),
752             kernel_info: HashMap::new(),
753         };
754 
755         // Pre build nir kernels
756         build.build_nirs(false);
757 
758         Arc::new(Self {
759             base: CLObjectBase::new(RusticlTypes::Program),
760             context: context,
761             devs: devs.to_owned(),
762             src: ProgramSourceType::Linked,
763             build: Mutex::new(build),
764         })
765     }
766 
is_bin(&self) -> bool767     pub fn is_bin(&self) -> bool {
768         matches!(self.src, ProgramSourceType::Binary)
769     }
770 
is_il(&self) -> bool771     pub fn is_il(&self) -> bool {
772         matches!(self.src, ProgramSourceType::Il(_))
773     }
774 
is_src(&self) -> bool775     pub fn is_src(&self) -> bool {
776         matches!(self.src, ProgramSourceType::Src(_))
777     }
778 
get_spec_constant_size(&self, spec_id: u32) -> u8779     pub fn get_spec_constant_size(&self, spec_id: u32) -> u8 {
780         match &self.src {
781             ProgramSourceType::Il(il) => il
782                 .spec_constant(spec_id)
783                 .map_or(0, spirv::CLCSpecConstantType::size),
784             _ => unreachable!(),
785         }
786     }
787 
set_spec_constant(&self, spec_id: u32, data: &[u8])788     pub fn set_spec_constant(&self, spec_id: u32, data: &[u8]) {
789         let mut lock = self.build_info();
790         let mut val = nir_const_value::default();
791 
792         match data.len() {
793             1 => val.u8_ = u8::from_ne_bytes(data.try_into().unwrap()),
794             2 => val.u16_ = u16::from_ne_bytes(data.try_into().unwrap()),
795             4 => val.u32_ = u32::from_ne_bytes(data.try_into().unwrap()),
796             8 => val.u64_ = u64::from_ne_bytes(data.try_into().unwrap()),
797             _ => unreachable!("Spec constant with invalid size!"),
798         };
799 
800         lock.spec_constants.insert(spec_id, val);
801     }
802 }
803