1 use crate::api::icd::CLResult;
2 use crate::api::icd::DISPATCH;
3 use crate::core::device::*;
4 use crate::core::version::*;
5
6 use mesa_rust_gen::*;
7 use mesa_rust_util::string::char_arr_to_cstr;
8 use rusticl_opencl_gen::*;
9
10 use std::env;
11 use std::ptr;
12 use std::ptr::addr_of;
13 use std::ptr::addr_of_mut;
14 use std::sync::Once;
15
16 /// Maximum size a pixel can be across all supported image formats.
17 pub const MAX_PIXEL_SIZE_BYTES: u64 = 4 * 4;
18
19 #[repr(C)]
20 pub struct Platform {
21 dispatch: &'static cl_icd_dispatch,
22 pub devs: Vec<Device>,
23 pub extension_string: String,
24 pub extensions: Vec<cl_name_version>,
25 }
26
27 pub enum PerfDebugLevel {
28 None,
29 Once,
30 Spam,
31 }
32
33 pub struct PlatformDebug {
34 pub allow_invalid_spirv: bool,
35 pub clc: bool,
36 pub max_grid_size: u64,
37 pub nir: bool,
38 pub no_variants: bool,
39 pub perf: PerfDebugLevel,
40 pub program: bool,
41 pub reuse_context: bool,
42 pub sync_every_event: bool,
43 pub validate_spirv: bool,
44 }
45
46 pub struct PlatformFeatures {
47 pub fp16: bool,
48 pub fp64: bool,
49 }
50
51 static PLATFORM_ENV_ONCE: Once = Once::new();
52 static PLATFORM_ONCE: Once = Once::new();
53
54 static mut PLATFORM: Platform = Platform {
55 dispatch: &DISPATCH,
56 devs: Vec::new(),
57 extension_string: String::new(),
58 extensions: Vec::new(),
59 };
60 static mut PLATFORM_DBG: PlatformDebug = PlatformDebug {
61 allow_invalid_spirv: false,
62 clc: false,
63 max_grid_size: 0,
64 nir: false,
65 no_variants: false,
66 perf: PerfDebugLevel::None,
67 program: false,
68 reuse_context: true,
69 sync_every_event: false,
70 validate_spirv: false,
71 };
72 static mut PLATFORM_FEATURES: PlatformFeatures = PlatformFeatures {
73 fp16: false,
74 fp64: false,
75 };
76
load_env()77 fn load_env() {
78 // SAFETY: no other references exist at this point
79 let debug = unsafe { &mut *addr_of_mut!(PLATFORM_DBG) };
80 if let Ok(debug_flags) = env::var("RUSTICL_DEBUG") {
81 for flag in debug_flags.split(',') {
82 match flag {
83 "allow_invalid_spirv" => debug.allow_invalid_spirv = true,
84 "clc" => debug.clc = true,
85 "nir" => debug.nir = true,
86 "no_reuse_context" => debug.reuse_context = false,
87 "no_variants" => debug.no_variants = true,
88 "perf" => debug.perf = PerfDebugLevel::Once,
89 "perfspam" => debug.perf = PerfDebugLevel::Spam,
90 "program" => debug.program = true,
91 "sync" => debug.sync_every_event = true,
92 "validate" => debug.validate_spirv = true,
93 "" => (),
94 _ => eprintln!("Unknown RUSTICL_DEBUG flag found: {}", flag),
95 }
96 }
97 }
98
99 debug.max_grid_size = env::var("RUSTICL_MAX_WORK_GROUPS")
100 .ok()
101 .and_then(|s| s.parse().ok())
102 .unwrap_or(u64::MAX);
103
104 // SAFETY: no other references exist at this point
105 let features = unsafe { &mut *addr_of_mut!(PLATFORM_FEATURES) };
106 if let Ok(feature_flags) = env::var("RUSTICL_FEATURES") {
107 for flag in feature_flags.split(',') {
108 match flag {
109 "fp16" => features.fp16 = true,
110 "fp64" => features.fp64 = true,
111 "" => (),
112 _ => eprintln!("Unknown RUSTICL_FEATURES flag found: {}", flag),
113 }
114 }
115 }
116 }
117
118 impl Platform {
as_ptr(&self) -> cl_platform_id119 pub fn as_ptr(&self) -> cl_platform_id {
120 ptr::from_ref(self) as cl_platform_id
121 }
122
get() -> &'static Self123 pub fn get() -> &'static Self {
124 debug_assert!(PLATFORM_ONCE.is_completed());
125 // SAFETY: no mut references exist at this point
126 unsafe { &*addr_of!(PLATFORM) }
127 }
128
dbg() -> &'static PlatformDebug129 pub fn dbg() -> &'static PlatformDebug {
130 debug_assert!(PLATFORM_ENV_ONCE.is_completed());
131 unsafe { &*addr_of!(PLATFORM_DBG) }
132 }
133
features() -> &'static PlatformFeatures134 pub fn features() -> &'static PlatformFeatures {
135 debug_assert!(PLATFORM_ENV_ONCE.is_completed());
136 unsafe { &*addr_of!(PLATFORM_FEATURES) }
137 }
138
init(&mut self)139 fn init(&mut self) {
140 unsafe {
141 glsl_type_singleton_init_or_ref();
142 }
143
144 self.devs = Device::all();
145
146 let mut exts_str: Vec<&str> = Vec::new();
147 let mut add_ext = |major, minor, patch, ext: &'static str| {
148 self.extensions
149 .push(mk_cl_version_ext(major, minor, patch, ext));
150 exts_str.push(ext);
151 };
152
153 // Add all platform extensions we don't expect devices to advertise.
154 add_ext(1, 0, 0, "cl_khr_icd");
155
156 let mut exts;
157 if let Some((first, rest)) = self.devs.split_first() {
158 exts = first.extensions.clone();
159
160 for dev in rest {
161 // This isn't fast, but the lists are small, so it doesn't really matter.
162 exts.retain(|ext| dev.extensions.contains(ext));
163 }
164
165 // Now that we found all extensions supported by all devices, we push them to the
166 // platform.
167 for ext in &exts {
168 exts_str.push(
169 // SAFETY: ext.name contains a nul terminated string.
170 unsafe { char_arr_to_cstr(&ext.name) }.to_str().unwrap(),
171 );
172 self.extensions.push(*ext);
173 }
174 }
175
176 self.extension_string = exts_str.join(" ");
177 }
178
init_once()179 pub fn init_once() {
180 PLATFORM_ENV_ONCE.call_once(load_env);
181 // SAFETY: no concurrent static mut access due to std::Once
182 #[allow(static_mut_refs)]
183 PLATFORM_ONCE.call_once(|| unsafe { PLATFORM.init() });
184 }
185 }
186
187 impl Drop for Platform {
drop(&mut self)188 fn drop(&mut self) {
189 unsafe {
190 glsl_type_singleton_decref();
191 }
192 }
193 }
194
195 pub trait GetPlatformRef {
get_ref(&self) -> CLResult<&'static Platform>196 fn get_ref(&self) -> CLResult<&'static Platform>;
197 }
198
199 impl GetPlatformRef for cl_platform_id {
get_ref(&self) -> CLResult<&'static Platform>200 fn get_ref(&self) -> CLResult<&'static Platform> {
201 if !self.is_null() && *self == Platform::get().as_ptr() {
202 Ok(Platform::get())
203 } else {
204 Err(CL_INVALID_PLATFORM)
205 }
206 }
207 }
208
209 #[macro_export]
210 macro_rules! perf_warning {
211 (@PRINT $format:tt, $($arg:tt)*) => {
212 eprintln!(std::concat!("=== Rusticl perf warning: ", $format, " ==="), $($arg)*)
213 };
214
215 ($format:tt $(, $arg:tt)*) => {
216 match $crate::core::platform::Platform::dbg().perf {
217 $crate::core::platform::PerfDebugLevel::Once => {
218 static PERF_WARN_ONCE: std::sync::Once = std::sync::Once::new();
219 PERF_WARN_ONCE.call_once(|| {
220 perf_warning!(@PRINT $format, $($arg)*);
221 })
222 },
223 $crate::core::platform::PerfDebugLevel::Spam => perf_warning!(@PRINT $format, $($arg)*),
224 _ => (),
225 }
226 };
227 }
228