• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::io;
6 use std::mem::MaybeUninit;
7 use std::sync::Once;
8 use std::thread::sleep;
9 use std::time::Duration;
10 use std::time::Instant;
11 
12 use win_util::win32_string;
13 use win_util::win32_wide_string;
14 use winapi::shared::minwindef;
15 use winapi::shared::minwindef::HINSTANCE;
16 use winapi::shared::minwindef::HMODULE;
17 use winapi::shared::minwindef::PULONG;
18 use winapi::shared::ntdef::NTSTATUS;
19 use winapi::shared::ntdef::ULONG;
20 use winapi::shared::ntstatus::STATUS_SUCCESS;
21 use winapi::um::libloaderapi;
22 use winapi::um::mmsystem::TIMERR_NOERROR;
23 use winapi::um::timeapi::timeBeginPeriod;
24 use winapi::um::timeapi::timeEndPeriod;
25 use winapi::um::winnt::BOOLEAN;
26 
27 use super::super::Error;
28 use super::super::Result;
29 use crate::warn;
30 
31 static NT_INIT: Once = Once::new();
32 static mut NT_LIBRARY: MaybeUninit<HMODULE> = MaybeUninit::uninit();
33 
34 #[inline]
init_ntdll() -> Result<HINSTANCE>35 fn init_ntdll() -> Result<HINSTANCE> {
36     NT_INIT.call_once(|| {
37         unsafe {
38             *NT_LIBRARY.as_mut_ptr() =
39                 libloaderapi::LoadLibraryW(win32_wide_string("ntdll").as_ptr());
40 
41             if NT_LIBRARY.assume_init().is_null() {
42                 warn!("Failed to load ntdll: {}", Error::last());
43             }
44         };
45     });
46 
47     let handle = unsafe { NT_LIBRARY.assume_init() };
48     if handle.is_null() {
49         Err(Error::from(io::Error::new(
50             io::ErrorKind::NotFound,
51             "ntdll failed to load",
52         )))
53     } else {
54         Ok(handle)
55     }
56 }
57 
get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function>58 fn get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function> {
59     let symbol = unsafe { libloaderapi::GetProcAddress(handle, win32_string(proc_name).as_ptr()) };
60     if symbol.is_null() {
61         Err(Error::last())
62     } else {
63         Ok(symbol)
64     }
65 }
66 
67 /// Returns the resolution of timers on the host (current_res, max_res).
nt_query_timer_resolution() -> Result<(Duration, Duration)>68 pub fn nt_query_timer_resolution() -> Result<(Duration, Duration)> {
69     let handle = init_ntdll()?;
70 
71     let func = unsafe {
72         std::mem::transmute::<
73             *mut minwindef::__some_function,
74             extern "system" fn(PULONG, PULONG, PULONG) -> NTSTATUS,
75         >(get_symbol(handle, "NtQueryTimerResolution")?)
76     };
77 
78     let mut min_res: u32 = 0;
79     let mut max_res: u32 = 0;
80     let mut current_res: u32 = 0;
81     let ret = func(
82         &mut min_res as *mut u32,
83         &mut max_res as *mut u32,
84         &mut current_res as *mut u32,
85     );
86 
87     if ret != STATUS_SUCCESS {
88         Err(Error::from(io::Error::new(
89             io::ErrorKind::Other,
90             "NtQueryTimerResolution failed",
91         )))
92     } else {
93         Ok((
94             Duration::from_nanos((current_res as u64) * 100),
95             Duration::from_nanos((max_res as u64) * 100),
96         ))
97     }
98 }
99 
nt_set_timer_resolution(resolution: Duration) -> Result<()>100 pub fn nt_set_timer_resolution(resolution: Duration) -> Result<()> {
101     let handle = init_ntdll()?;
102     let func = unsafe {
103         std::mem::transmute::<
104             *mut minwindef::__some_function,
105             extern "system" fn(ULONG, BOOLEAN, PULONG) -> NTSTATUS,
106         >(get_symbol(handle, "NtSetTimerResolution")?)
107     };
108 
109     let requested_res: u32 = (resolution.as_nanos() / 100) as u32;
110     let mut current_res: u32 = 0;
111     let ret = func(
112         requested_res,
113         1, /* true */
114         &mut current_res as *mut u32,
115     );
116 
117     if ret != STATUS_SUCCESS {
118         Err(Error::from(io::Error::new(
119             io::ErrorKind::Other,
120             "NtSetTimerResolution failed",
121         )))
122     } else {
123         Ok(())
124     }
125 }
126 
127 /// Measures the timer resolution by taking the 90th percentile wall time of 1ms sleeps.
measure_timer_resolution() -> Duration128 pub fn measure_timer_resolution() -> Duration {
129     let mut durations = Vec::with_capacity(100);
130     for _ in 0..100 {
131         let start = Instant::now();
132         // Windows cannot support sleeps shorter than 1ms.
133         sleep(Duration::from_millis(1));
134         durations.push(Instant::now() - start);
135     }
136 
137     durations.sort();
138     durations[89]
139 }
140 
141 /// Note that Durations below 1ms are not supported and will panic.
set_time_period(res: Duration, begin: bool) -> Result<()>142 pub fn set_time_period(res: Duration, begin: bool) -> Result<()> {
143     if res.as_millis() < 1 {
144         panic!(
145             "time(Begin|End)Period does not support values below 1ms, but {:?} was requested.",
146             res
147         );
148     }
149     if res.as_millis() > u32::MAX as u128 {
150         panic!("time(Begin|End)Period does not support values above u32::MAX.",);
151     }
152 
153     // Trivially safe. Note that the casts are safe because we know res is within u32's range.
154     let ret = if begin {
155         unsafe { timeBeginPeriod(res.as_millis() as u32) }
156     } else {
157         unsafe { timeEndPeriod(res.as_millis() as u32) }
158     };
159     if ret != TIMERR_NOERROR {
160         // These functions only have two return codes: NOERROR and NOCANDO.
161         Err(Error::from(io::Error::new(
162             io::ErrorKind::InvalidInput,
163             "timeBegin/EndPeriod failed",
164         )))
165     } else {
166         Ok(())
167     }
168 }
169 
170 /// Note that these tests cannot run on Kokoro due to random slowness in that environment.
171 #[cfg(test)]
172 mod tests {
173     use super::*;
174 
175     /// We're testing whether NtSetTimerResolution does what it says on the tin.
176     #[test]
177     #[ignore]
setting_nt_timer_resolution_changes_resolution()178     fn setting_nt_timer_resolution_changes_resolution() {
179         let (old_res, _) = nt_query_timer_resolution().unwrap();
180 
181         nt_set_timer_resolution(Duration::from_millis(1)).unwrap();
182         assert_res_within_bound(measure_timer_resolution());
183         nt_set_timer_resolution(old_res).unwrap();
184     }
185 
186     #[test]
187     #[ignore]
setting_timer_resolution_changes_resolution()188     fn setting_timer_resolution_changes_resolution() {
189         let res = Duration::from_millis(1);
190 
191         set_time_period(res, true).unwrap();
192         assert_res_within_bound(measure_timer_resolution());
193         set_time_period(res, false).unwrap();
194     }
195 
assert_res_within_bound(actual_res: Duration)196     fn assert_res_within_bound(actual_res: Duration) {
197         assert!(
198             actual_res <= Duration::from_millis(2),
199             "actual_res was {:?}, expected <= 2ms",
200             actual_res
201         );
202     }
203 }
204