• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright © 2022 Collabora, Ltd.
2 // SPDX-License-Identifier: MIT
3 
4 use crate::from_nir::*;
5 use crate::ir::{ShaderInfo, ShaderIoInfo, ShaderModel, ShaderStageInfo};
6 use crate::sm50::ShaderModel50;
7 use crate::sm70::ShaderModel70;
8 use crate::sph;
9 
10 use compiler::bindings::*;
11 use nak_bindings::*;
12 
13 use std::cmp::max;
14 use std::env;
15 use std::ffi::{CStr, CString};
16 use std::fmt::Write;
17 use std::os::raw::c_void;
18 use std::panic;
19 use std::sync::OnceLock;
20 
21 #[repr(u8)]
22 enum DebugFlags {
23     Print,
24     Serial,
25     Spill,
26     Annotate,
27     NoUgpr,
28 }
29 
30 pub struct Debug {
31     flags: u32,
32 }
33 
34 impl Debug {
new() -> Debug35     fn new() -> Debug {
36         let debug_var = "NAK_DEBUG";
37         let debug_str = match env::var(debug_var) {
38             Ok(s) => s,
39             Err(_) => {
40                 return Debug { flags: 0 };
41             }
42         };
43 
44         let mut flags = 0;
45         for flag in debug_str.split(',') {
46             match flag.trim() {
47                 "print" => flags |= 1 << DebugFlags::Print as u8,
48                 "serial" => flags |= 1 << DebugFlags::Serial as u8,
49                 "spill" => flags |= 1 << DebugFlags::Spill as u8,
50                 "annotate" => flags |= 1 << DebugFlags::Annotate as u8,
51                 "nougpr" => flags |= 1 << DebugFlags::NoUgpr as u8,
52                 unk => eprintln!("Unknown NAK_DEBUG flag \"{}\"", unk),
53             }
54         }
55         Debug { flags: flags }
56     }
57 }
58 
59 pub trait GetDebugFlags {
debug_flags(&self) -> u3260     fn debug_flags(&self) -> u32;
61 
print(&self) -> bool62     fn print(&self) -> bool {
63         self.debug_flags() & (1 << DebugFlags::Print as u8) != 0
64     }
65 
serial(&self) -> bool66     fn serial(&self) -> bool {
67         self.debug_flags() & (1 << DebugFlags::Serial as u8) != 0
68     }
69 
spill(&self) -> bool70     fn spill(&self) -> bool {
71         self.debug_flags() & (1 << DebugFlags::Spill as u8) != 0
72     }
73 
annotate(&self) -> bool74     fn annotate(&self) -> bool {
75         self.debug_flags() & (1 << DebugFlags::Annotate as u8) != 0
76     }
77 
no_ugpr(&self) -> bool78     fn no_ugpr(&self) -> bool {
79         self.debug_flags() & (1 << DebugFlags::NoUgpr as u8) != 0
80     }
81 }
82 
83 pub static DEBUG: OnceLock<Debug> = OnceLock::new();
84 
85 impl GetDebugFlags for OnceLock<Debug> {
debug_flags(&self) -> u3286     fn debug_flags(&self) -> u32 {
87         self.get_or_init(Debug::new).flags
88     }
89 }
90 
91 #[no_mangle]
nak_should_print_nir() -> bool92 pub extern "C" fn nak_should_print_nir() -> bool {
93     DEBUG.print()
94 }
95 
nir_options(dev: &nv_device_info) -> nir_shader_compiler_options96 fn nir_options(dev: &nv_device_info) -> nir_shader_compiler_options {
97     let mut op: nir_shader_compiler_options = unsafe { std::mem::zeroed() };
98 
99     op.lower_fdiv = true;
100     op.fuse_ffma16 = true;
101     op.fuse_ffma32 = true;
102     op.fuse_ffma64 = true;
103     op.lower_flrp16 = true;
104     op.lower_flrp32 = true;
105     op.lower_flrp64 = true;
106     op.lower_fsqrt = dev.sm < 52;
107     op.lower_bitfield_extract = dev.sm >= 70;
108     op.lower_bitfield_insert = true;
109     op.lower_pack_half_2x16 = true;
110     op.lower_pack_unorm_2x16 = true;
111     op.lower_pack_snorm_2x16 = true;
112     op.lower_pack_unorm_4x8 = true;
113     op.lower_pack_snorm_4x8 = true;
114     op.lower_unpack_half_2x16 = true;
115     op.lower_unpack_unorm_2x16 = true;
116     op.lower_unpack_snorm_2x16 = true;
117     op.lower_unpack_unorm_4x8 = true;
118     op.lower_unpack_snorm_4x8 = true;
119     op.lower_insert_byte = true;
120     op.lower_insert_word = true;
121     op.lower_cs_local_index_to_id = true;
122     op.lower_device_index_to_zero = true;
123     op.lower_isign = true;
124     op.lower_uadd_sat = dev.sm < 70;
125     op.lower_usub_sat = dev.sm < 70;
126     op.lower_iadd_sat = true; // TODO
127     op.lower_doubles_options = nir_lower_drcp
128         | nir_lower_dsqrt
129         | nir_lower_drsq
130         | nir_lower_dtrunc
131         | nir_lower_dfloor
132         | nir_lower_dceil
133         | nir_lower_dfract
134         | nir_lower_dround_even
135         | nir_lower_dsat;
136     if dev.sm >= 70 {
137         op.lower_doubles_options |= nir_lower_dminmax;
138     }
139     op.lower_int64_options = !(nir_lower_icmp64
140         | nir_lower_iadd64
141         | nir_lower_ineg64
142         | nir_lower_shift64
143         | nir_lower_imul_2x32_64
144         | nir_lower_conv64);
145     op.lower_ldexp = true;
146     op.lower_fmod = true;
147     op.lower_ffract = true;
148     op.lower_fpow = true;
149     op.lower_scmp = true;
150     op.lower_uadd_carry = true;
151     op.lower_usub_borrow = true;
152     op.has_iadd3 = dev.sm >= 70;
153     op.has_imad32 = dev.sm >= 70;
154     op.has_sdot_4x8 = dev.sm >= 70;
155     op.has_udot_4x8 = dev.sm >= 70;
156     op.has_sudot_4x8 = dev.sm >= 70;
157     // We set .ftz on f32 by default so we can support fmulz whenever the client
158     // doesn't explicitly request denorms.
159     op.has_fmulz_no_denorms = true;
160     op.has_find_msb_rev = true;
161     op.has_pack_half_2x16_rtz = true;
162     op.has_bfm = dev.sm >= 70;
163     op.discard_is_demote = true;
164 
165     op.max_unroll_iterations = 32;
166     op.scalarize_ddx = true;
167 
168     op
169 }
170 
171 #[no_mangle]
nak_compiler_create( dev: *const nv_device_info, ) -> *mut nak_compiler172 pub extern "C" fn nak_compiler_create(
173     dev: *const nv_device_info,
174 ) -> *mut nak_compiler {
175     assert!(!dev.is_null());
176     let dev = unsafe { &*dev };
177 
178     let nak = Box::new(nak_compiler {
179         sm: dev.sm,
180         warps_per_sm: dev.max_warps_per_mp,
181         nir_options: nir_options(dev),
182     });
183 
184     Box::into_raw(nak)
185 }
186 
187 #[no_mangle]
nak_compiler_destroy(nak: *mut nak_compiler)188 pub extern "C" fn nak_compiler_destroy(nak: *mut nak_compiler) {
189     unsafe { drop(Box::from_raw(nak)) };
190 }
191 
192 #[no_mangle]
nak_debug_flags(_nak: *const nak_compiler) -> u64193 pub extern "C" fn nak_debug_flags(_nak: *const nak_compiler) -> u64 {
194     DEBUG.debug_flags().into()
195 }
196 
197 #[no_mangle]
nak_nir_options( nak: *const nak_compiler, ) -> *const nir_shader_compiler_options198 pub extern "C" fn nak_nir_options(
199     nak: *const nak_compiler,
200 ) -> *const nir_shader_compiler_options {
201     assert!(!nak.is_null());
202     let nak = unsafe { &*nak };
203     &nak.nir_options
204 }
205 
206 #[repr(C)]
207 pub struct ShaderBin {
208     pub bin: nak_shader_bin,
209     code: Vec<u32>,
210     asm: CString,
211 }
212 
213 impl ShaderBin {
new( sm: &dyn ShaderModel, info: &ShaderInfo, fs_key: Option<&nak_fs_key>, code: Vec<u32>, asm: &str, ) -> ShaderBin214     pub fn new(
215         sm: &dyn ShaderModel,
216         info: &ShaderInfo,
217         fs_key: Option<&nak_fs_key>,
218         code: Vec<u32>,
219         asm: &str,
220     ) -> ShaderBin {
221         let asm = CString::new(asm)
222             .expect("NAK assembly has unexpected null characters");
223 
224         let c_info = nak_shader_info {
225             stage: match info.stage {
226                 ShaderStageInfo::Compute(_) => MESA_SHADER_COMPUTE,
227                 ShaderStageInfo::Vertex => MESA_SHADER_VERTEX,
228                 ShaderStageInfo::Fragment(_) => MESA_SHADER_FRAGMENT,
229                 ShaderStageInfo::Geometry(_) => MESA_SHADER_GEOMETRY,
230                 ShaderStageInfo::TessellationInit(_) => MESA_SHADER_TESS_CTRL,
231                 ShaderStageInfo::Tessellation(_) => MESA_SHADER_TESS_EVAL,
232             },
233             sm: sm.sm(),
234             num_gprs: {
235                 max(4, info.num_gprs as u32 + sm.hw_reserved_gprs())
236                     .try_into()
237                     .unwrap()
238             },
239             num_control_barriers: info.num_control_barriers,
240             _pad0: Default::default(),
241             num_instrs: info.num_instrs,
242             slm_size: info.slm_size,
243             crs_size: sm.crs_size(info.max_crs_depth),
244             __bindgen_anon_1: match &info.stage {
245                 ShaderStageInfo::Compute(cs_info) => {
246                     nak_shader_info__bindgen_ty_1 {
247                         cs: nak_shader_info__bindgen_ty_1__bindgen_ty_1 {
248                             local_size: [
249                                 cs_info.local_size[0],
250                                 cs_info.local_size[1],
251                                 cs_info.local_size[2],
252                             ],
253                             smem_size: cs_info.smem_size,
254                             _pad: Default::default(),
255                         },
256                     }
257                 }
258                 ShaderStageInfo::Fragment(fs_info) => {
259                     let fs_io_info = match &info.io {
260                         ShaderIoInfo::Fragment(io) => io,
261                         _ => unreachable!(),
262                     };
263                     nak_shader_info__bindgen_ty_1 {
264                         fs: nak_shader_info__bindgen_ty_1__bindgen_ty_2 {
265                             writes_depth: fs_io_info.writes_depth,
266                             reads_sample_mask: fs_io_info.reads_sample_mask,
267                             post_depth_coverage: fs_info.post_depth_coverage,
268                             uses_sample_shading: fs_info.uses_sample_shading,
269                             early_fragment_tests: fs_info.early_fragment_tests,
270                             _pad: Default::default(),
271                         },
272                     }
273                 }
274                 ShaderStageInfo::Tessellation(ts_info) => {
275                     nak_shader_info__bindgen_ty_1 {
276                         ts: nak_shader_info__bindgen_ty_1__bindgen_ty_3 {
277                             domain: ts_info.domain as u8,
278                             spacing: ts_info.spacing as u8,
279                             prims: ts_info.primitives as u8,
280                             _pad: Default::default(),
281                         },
282                     }
283                 }
284                 _ => nak_shader_info__bindgen_ty_1 {
285                     _pad: Default::default(),
286                 },
287             },
288             vtg: match &info.io {
289                 ShaderIoInfo::Vtg(io) => nak_shader_info__bindgen_ty_2 {
290                     writes_layer: io.attr_written(NAK_ATTR_RT_ARRAY_INDEX),
291                     writes_point_size: io.attr_written(NAK_ATTR_POINT_SIZE),
292                     writes_vprs_table_index: io
293                         .attr_written(NAK_ATTR_VPRS_TABLE_INDEX),
294                     clip_enable: io.clip_enable.try_into().unwrap(),
295                     cull_enable: io.cull_enable.try_into().unwrap(),
296                     xfb: if let Some(xfb) = &io.xfb {
297                         **xfb
298                     } else {
299                         unsafe { std::mem::zeroed() }
300                     },
301                     _pad: Default::default(),
302                 },
303                 _ => unsafe { std::mem::zeroed() },
304             },
305             hdr: sph::encode_header(sm, &info, fs_key),
306         };
307 
308         if DEBUG.print() {
309             let stage_name = unsafe {
310                 let c_name = _mesa_shader_stage_to_string(c_info.stage as u32);
311                 CStr::from_ptr(c_name).to_str().expect("Invalid UTF-8")
312             };
313 
314             eprintln!("Stage: {}", stage_name);
315             eprintln!("Instruction count: {}", c_info.num_instrs);
316             eprintln!("Num GPRs: {}", c_info.num_gprs);
317             eprintln!("SLM size: {}", c_info.slm_size);
318 
319             if c_info.stage != MESA_SHADER_COMPUTE {
320                 eprint_hex("Header", &c_info.hdr);
321             }
322 
323             eprint_hex("Encoded shader", &code);
324         }
325 
326         let bin = nak_shader_bin {
327             info: c_info,
328             code_size: (code.len() * 4).try_into().unwrap(),
329             code: code.as_ptr() as *const c_void,
330             asm_str: if asm.is_empty() {
331                 std::ptr::null()
332             } else {
333                 asm.as_ptr()
334             },
335         };
336         ShaderBin {
337             bin: bin,
338             code: code,
339             asm: asm,
340         }
341     }
342 }
343 
344 impl std::ops::Deref for ShaderBin {
345     type Target = nak_shader_bin;
346 
deref(&self) -> &nak_shader_bin347     fn deref(&self) -> &nak_shader_bin {
348         &self.bin
349     }
350 }
351 
352 #[no_mangle]
nak_shader_bin_destroy(bin: *mut nak_shader_bin)353 pub extern "C" fn nak_shader_bin_destroy(bin: *mut nak_shader_bin) {
354     unsafe {
355         _ = Box::from_raw(bin as *mut ShaderBin);
356     };
357 }
358 
eprint_hex(label: &str, data: &[u32])359 fn eprint_hex(label: &str, data: &[u32]) {
360     eprint!("{}:", label);
361     for i in 0..data.len() {
362         if (i % 8) == 0 {
363             eprintln!("");
364             eprint!(" ");
365         }
366         eprint!(" {:08x}", data[i]);
367     }
368     eprintln!("");
369 }
370 
371 macro_rules! pass {
372     ($s: expr, $pass: ident) => {
373         $s.$pass();
374         if DEBUG.print() {
375             eprintln!("NAK IR after {}:\n{}", stringify!($pass), $s);
376         }
377     };
378 }
379 
nak_compile_shader_internal( nir: *mut nir_shader, dump_asm: bool, nak: *const nak_compiler, robust2_modes: nir_variable_mode, fs_key: *const nak_fs_key, ) -> *mut nak_shader_bin380 fn nak_compile_shader_internal(
381     nir: *mut nir_shader,
382     dump_asm: bool,
383     nak: *const nak_compiler,
384     robust2_modes: nir_variable_mode,
385     fs_key: *const nak_fs_key,
386 ) -> *mut nak_shader_bin {
387     unsafe { nak_postprocess_nir(nir, nak, robust2_modes, fs_key) };
388     let nak = unsafe { &*nak };
389     let nir = unsafe { &*nir };
390     let fs_key = if fs_key.is_null() {
391         None
392     } else {
393         Some(unsafe { &*fs_key })
394     };
395 
396     let sm: Box<dyn ShaderModel> = if nak.sm >= 70 {
397         Box::new(ShaderModel70::new(nak.sm))
398     } else if nak.sm >= 50 {
399         Box::new(ShaderModel50::new(nak.sm))
400     } else {
401         panic!("Unsupported shader model");
402     };
403 
404     let mut s = nak_shader_from_nir(nak, nir, sm.as_ref());
405 
406     if DEBUG.print() {
407         eprintln!("NAK IR:\n{}", &s);
408     }
409 
410     pass!(s, opt_bar_prop);
411     pass!(s, opt_uniform_instrs);
412     pass!(s, opt_copy_prop);
413     pass!(s, opt_prmt);
414     pass!(s, opt_lop);
415     pass!(s, opt_copy_prop);
416     pass!(s, opt_dce);
417     pass!(s, opt_out);
418     pass!(s, legalize);
419     pass!(s, assign_regs);
420     pass!(s, lower_par_copies);
421     pass!(s, lower_copy_swap);
422     if nak.sm >= 70 {
423         pass!(s, opt_jump_thread);
424     } else {
425         pass!(s, opt_crs);
426     }
427 
428     s.remove_annotations();
429 
430     pass!(s, calc_instr_deps);
431 
432     s.gather_info();
433 
434     let mut asm = String::new();
435     if dump_asm {
436         write!(asm, "{}", s).expect("Failed to dump assembly");
437     }
438 
439     let code = sm.encode_shader(&s);
440     let bin =
441         Box::new(ShaderBin::new(sm.as_ref(), &s.info, fs_key, code, &asm));
442     Box::into_raw(bin) as *mut nak_shader_bin
443 }
444 
445 #[no_mangle]
nak_compile_shader( nir: *mut nir_shader, dump_asm: bool, nak: *const nak_compiler, robust2_modes: nir_variable_mode, fs_key: *const nak_fs_key, ) -> *mut nak_shader_bin446 pub extern "C" fn nak_compile_shader(
447     nir: *mut nir_shader,
448     dump_asm: bool,
449     nak: *const nak_compiler,
450     robust2_modes: nir_variable_mode,
451     fs_key: *const nak_fs_key,
452 ) -> *mut nak_shader_bin {
453     panic::catch_unwind(|| {
454         nak_compile_shader_internal(nir, dump_asm, nak, robust2_modes, fs_key)
455     })
456     .unwrap_or(std::ptr::null_mut())
457 }
458