• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::api::icd::CLResult;
2 use crate::api::icd::DISPATCH;
3 use crate::core::device::*;
4 use crate::core::version::*;
5 
6 use mesa_rust_gen::*;
7 use mesa_rust_util::string::char_arr_to_cstr;
8 use rusticl_opencl_gen::*;
9 
10 use std::env;
11 use std::ptr;
12 use std::ptr::addr_of;
13 use std::ptr::addr_of_mut;
14 use std::sync::Once;
15 
16 /// Maximum size a pixel can be across all supported image formats.
17 pub const MAX_PIXEL_SIZE_BYTES: u64 = 4 * 4;
18 
19 #[repr(C)]
20 pub struct Platform {
21     dispatch: &'static cl_icd_dispatch,
22     pub devs: Vec<Device>,
23     pub extension_string: String,
24     pub extensions: Vec<cl_name_version>,
25 }
26 
27 pub enum PerfDebugLevel {
28     None,
29     Once,
30     Spam,
31 }
32 
33 pub struct PlatformDebug {
34     pub allow_invalid_spirv: bool,
35     pub clc: bool,
36     pub max_grid_size: u64,
37     pub nir: bool,
38     pub no_variants: bool,
39     pub perf: PerfDebugLevel,
40     pub program: bool,
41     pub reuse_context: bool,
42     pub sync_every_event: bool,
43     pub validate_spirv: bool,
44 }
45 
46 pub struct PlatformFeatures {
47     pub fp16: bool,
48     pub fp64: bool,
49 }
50 
51 static PLATFORM_ENV_ONCE: Once = Once::new();
52 static PLATFORM_ONCE: Once = Once::new();
53 
54 static mut PLATFORM: Platform = Platform {
55     dispatch: &DISPATCH,
56     devs: Vec::new(),
57     extension_string: String::new(),
58     extensions: Vec::new(),
59 };
60 static mut PLATFORM_DBG: PlatformDebug = PlatformDebug {
61     allow_invalid_spirv: false,
62     clc: false,
63     max_grid_size: 0,
64     nir: false,
65     no_variants: false,
66     perf: PerfDebugLevel::None,
67     program: false,
68     reuse_context: true,
69     sync_every_event: false,
70     validate_spirv: false,
71 };
72 static mut PLATFORM_FEATURES: PlatformFeatures = PlatformFeatures {
73     fp16: false,
74     fp64: false,
75 };
76 
load_env()77 fn load_env() {
78     // SAFETY: no other references exist at this point
79     let debug = unsafe { &mut *addr_of_mut!(PLATFORM_DBG) };
80     if let Ok(debug_flags) = env::var("RUSTICL_DEBUG") {
81         for flag in debug_flags.split(',') {
82             match flag {
83                 "allow_invalid_spirv" => debug.allow_invalid_spirv = true,
84                 "clc" => debug.clc = true,
85                 "nir" => debug.nir = true,
86                 "no_reuse_context" => debug.reuse_context = false,
87                 "no_variants" => debug.no_variants = true,
88                 "perf" => debug.perf = PerfDebugLevel::Once,
89                 "perfspam" => debug.perf = PerfDebugLevel::Spam,
90                 "program" => debug.program = true,
91                 "sync" => debug.sync_every_event = true,
92                 "validate" => debug.validate_spirv = true,
93                 "" => (),
94                 _ => eprintln!("Unknown RUSTICL_DEBUG flag found: {}", flag),
95             }
96         }
97     }
98 
99     debug.max_grid_size = env::var("RUSTICL_MAX_WORK_GROUPS")
100         .ok()
101         .and_then(|s| s.parse().ok())
102         .unwrap_or(u64::MAX);
103 
104     // SAFETY: no other references exist at this point
105     let features = unsafe { &mut *addr_of_mut!(PLATFORM_FEATURES) };
106     if let Ok(feature_flags) = env::var("RUSTICL_FEATURES") {
107         for flag in feature_flags.split(',') {
108             match flag {
109                 "fp16" => features.fp16 = true,
110                 "fp64" => features.fp64 = true,
111                 "" => (),
112                 _ => eprintln!("Unknown RUSTICL_FEATURES flag found: {}", flag),
113             }
114         }
115     }
116 }
117 
118 impl Platform {
as_ptr(&self) -> cl_platform_id119     pub fn as_ptr(&self) -> cl_platform_id {
120         ptr::from_ref(self) as cl_platform_id
121     }
122 
get() -> &'static Self123     pub fn get() -> &'static Self {
124         debug_assert!(PLATFORM_ONCE.is_completed());
125         // SAFETY: no mut references exist at this point
126         unsafe { &*addr_of!(PLATFORM) }
127     }
128 
dbg() -> &'static PlatformDebug129     pub fn dbg() -> &'static PlatformDebug {
130         debug_assert!(PLATFORM_ENV_ONCE.is_completed());
131         unsafe { &*addr_of!(PLATFORM_DBG) }
132     }
133 
features() -> &'static PlatformFeatures134     pub fn features() -> &'static PlatformFeatures {
135         debug_assert!(PLATFORM_ENV_ONCE.is_completed());
136         unsafe { &*addr_of!(PLATFORM_FEATURES) }
137     }
138 
init(&mut self)139     fn init(&mut self) {
140         unsafe {
141             glsl_type_singleton_init_or_ref();
142         }
143 
144         self.devs = Device::all();
145 
146         let mut exts_str: Vec<&str> = Vec::new();
147         let mut add_ext = |major, minor, patch, ext: &'static str| {
148             self.extensions
149                 .push(mk_cl_version_ext(major, minor, patch, ext));
150             exts_str.push(ext);
151         };
152 
153         // Add all platform extensions we don't expect devices to advertise.
154         add_ext(1, 0, 0, "cl_khr_icd");
155 
156         let mut exts;
157         if let Some((first, rest)) = self.devs.split_first() {
158             exts = first.extensions.clone();
159 
160             for dev in rest {
161                 // This isn't fast, but the lists are small, so it doesn't really matter.
162                 exts.retain(|ext| dev.extensions.contains(ext));
163             }
164 
165             // Now that we found all extensions supported by all devices, we push them to the
166             // platform.
167             for ext in &exts {
168                 exts_str.push(
169                     // SAFETY: ext.name contains a nul terminated string.
170                     unsafe { char_arr_to_cstr(&ext.name) }.to_str().unwrap(),
171                 );
172                 self.extensions.push(*ext);
173             }
174         }
175 
176         self.extension_string = exts_str.join(" ");
177     }
178 
init_once()179     pub fn init_once() {
180         PLATFORM_ENV_ONCE.call_once(load_env);
181         // SAFETY: no concurrent static mut access due to std::Once
182         #[allow(static_mut_refs)]
183         PLATFORM_ONCE.call_once(|| unsafe { PLATFORM.init() });
184     }
185 }
186 
187 impl Drop for Platform {
drop(&mut self)188     fn drop(&mut self) {
189         unsafe {
190             glsl_type_singleton_decref();
191         }
192     }
193 }
194 
195 pub trait GetPlatformRef {
get_ref(&self) -> CLResult<&'static Platform>196     fn get_ref(&self) -> CLResult<&'static Platform>;
197 }
198 
199 impl GetPlatformRef for cl_platform_id {
get_ref(&self) -> CLResult<&'static Platform>200     fn get_ref(&self) -> CLResult<&'static Platform> {
201         if !self.is_null() && *self == Platform::get().as_ptr() {
202             Ok(Platform::get())
203         } else {
204             Err(CL_INVALID_PLATFORM)
205         }
206     }
207 }
208 
209 #[macro_export]
210 macro_rules! perf_warning {
211     (@PRINT $format:tt, $($arg:tt)*) => {
212         eprintln!(std::concat!("=== Rusticl perf warning: ", $format, " ==="), $($arg)*)
213     };
214 
215     ($format:tt $(, $arg:tt)*) => {
216         match $crate::core::platform::Platform::dbg().perf {
217             $crate::core::platform::PerfDebugLevel::Once => {
218                 static PERF_WARN_ONCE: std::sync::Once = std::sync::Once::new();
219                 PERF_WARN_ONCE.call_once(|| {
220                     perf_warning!(@PRINT $format, $($arg)*);
221                 })
222             },
223             $crate::core::platform::PerfDebugLevel::Spam => perf_warning!(@PRINT $format, $($arg)*),
224             _ => (),
225         }
226     };
227 }
228