1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::io;
6 use std::mem::MaybeUninit;
7 use std::sync::Once;
8 use std::thread::sleep;
9 use std::time::Duration;
10 use std::time::Instant;
11
12 use win_util::win32_string;
13 use win_util::win32_wide_string;
14 use winapi::shared::minwindef;
15 use winapi::shared::minwindef::HINSTANCE;
16 use winapi::shared::minwindef::HMODULE;
17 use winapi::shared::minwindef::PULONG;
18 use winapi::shared::ntdef::NTSTATUS;
19 use winapi::shared::ntdef::ULONG;
20 use winapi::shared::ntstatus::STATUS_SUCCESS;
21 use winapi::um::libloaderapi;
22 use winapi::um::mmsystem::TIMERR_NOERROR;
23 use winapi::um::timeapi::timeBeginPeriod;
24 use winapi::um::timeapi::timeEndPeriod;
25 use winapi::um::winnt::BOOLEAN;
26
27 use super::super::Error;
28 use super::super::Result;
29 use crate::warn;
30
31 static NT_INIT: Once = Once::new();
32 static mut NT_LIBRARY: MaybeUninit<HMODULE> = MaybeUninit::uninit();
33
34 #[inline]
init_ntdll() -> Result<HINSTANCE>35 fn init_ntdll() -> Result<HINSTANCE> {
36 NT_INIT.call_once(|| {
37 unsafe {
38 *NT_LIBRARY.as_mut_ptr() =
39 libloaderapi::LoadLibraryW(win32_wide_string("ntdll").as_ptr());
40
41 if NT_LIBRARY.assume_init().is_null() {
42 warn!("Failed to load ntdll: {}", Error::last());
43 }
44 };
45 });
46
47 let handle = unsafe { NT_LIBRARY.assume_init() };
48 if handle.is_null() {
49 Err(Error::from(io::Error::new(
50 io::ErrorKind::NotFound,
51 "ntdll failed to load",
52 )))
53 } else {
54 Ok(handle)
55 }
56 }
57
get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function>58 fn get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function> {
59 let symbol = unsafe { libloaderapi::GetProcAddress(handle, win32_string(proc_name).as_ptr()) };
60 if symbol.is_null() {
61 Err(Error::last())
62 } else {
63 Ok(symbol)
64 }
65 }
66
67 /// Returns the resolution of timers on the host (current_res, max_res).
nt_query_timer_resolution() -> Result<(Duration, Duration)>68 pub fn nt_query_timer_resolution() -> Result<(Duration, Duration)> {
69 let handle = init_ntdll()?;
70
71 let func = unsafe {
72 std::mem::transmute::<
73 *mut minwindef::__some_function,
74 extern "system" fn(PULONG, PULONG, PULONG) -> NTSTATUS,
75 >(get_symbol(handle, "NtQueryTimerResolution")?)
76 };
77
78 let mut min_res: u32 = 0;
79 let mut max_res: u32 = 0;
80 let mut current_res: u32 = 0;
81 let ret = func(
82 &mut min_res as *mut u32,
83 &mut max_res as *mut u32,
84 &mut current_res as *mut u32,
85 );
86
87 if ret != STATUS_SUCCESS {
88 Err(Error::from(io::Error::new(
89 io::ErrorKind::Other,
90 "NtQueryTimerResolution failed",
91 )))
92 } else {
93 Ok((
94 Duration::from_nanos((current_res as u64) * 100),
95 Duration::from_nanos((max_res as u64) * 100),
96 ))
97 }
98 }
99
nt_set_timer_resolution(resolution: Duration) -> Result<()>100 pub fn nt_set_timer_resolution(resolution: Duration) -> Result<()> {
101 let handle = init_ntdll()?;
102 let func = unsafe {
103 std::mem::transmute::<
104 *mut minwindef::__some_function,
105 extern "system" fn(ULONG, BOOLEAN, PULONG) -> NTSTATUS,
106 >(get_symbol(handle, "NtSetTimerResolution")?)
107 };
108
109 let requested_res: u32 = (resolution.as_nanos() / 100) as u32;
110 let mut current_res: u32 = 0;
111 let ret = func(
112 requested_res,
113 1, /* true */
114 &mut current_res as *mut u32,
115 );
116
117 if ret != STATUS_SUCCESS {
118 Err(Error::from(io::Error::new(
119 io::ErrorKind::Other,
120 "NtSetTimerResolution failed",
121 )))
122 } else {
123 Ok(())
124 }
125 }
126
127 /// Measures the timer resolution by taking the 90th percentile wall time of 1ms sleeps.
measure_timer_resolution() -> Duration128 pub fn measure_timer_resolution() -> Duration {
129 let mut durations = Vec::with_capacity(100);
130 for _ in 0..100 {
131 let start = Instant::now();
132 // Windows cannot support sleeps shorter than 1ms.
133 sleep(Duration::from_millis(1));
134 durations.push(Instant::now() - start);
135 }
136
137 durations.sort();
138 durations[89]
139 }
140
141 /// Note that Durations below 1ms are not supported and will panic.
set_time_period(res: Duration, begin: bool) -> Result<()>142 pub fn set_time_period(res: Duration, begin: bool) -> Result<()> {
143 if res.as_millis() < 1 {
144 panic!(
145 "time(Begin|End)Period does not support values below 1ms, but {:?} was requested.",
146 res
147 );
148 }
149 if res.as_millis() > u32::MAX as u128 {
150 panic!("time(Begin|End)Period does not support values above u32::MAX.",);
151 }
152
153 // Trivially safe. Note that the casts are safe because we know res is within u32's range.
154 let ret = if begin {
155 unsafe { timeBeginPeriod(res.as_millis() as u32) }
156 } else {
157 unsafe { timeEndPeriod(res.as_millis() as u32) }
158 };
159 if ret != TIMERR_NOERROR {
160 // These functions only have two return codes: NOERROR and NOCANDO.
161 Err(Error::from(io::Error::new(
162 io::ErrorKind::InvalidInput,
163 "timeBegin/EndPeriod failed",
164 )))
165 } else {
166 Ok(())
167 }
168 }
169
170 /// Note that these tests cannot run on Kokoro due to random slowness in that environment.
171 #[cfg(test)]
172 mod tests {
173 use super::*;
174
175 /// We're testing whether NtSetTimerResolution does what it says on the tin.
176 #[test]
177 #[ignore]
setting_nt_timer_resolution_changes_resolution()178 fn setting_nt_timer_resolution_changes_resolution() {
179 let (old_res, _) = nt_query_timer_resolution().unwrap();
180
181 nt_set_timer_resolution(Duration::from_millis(1)).unwrap();
182 assert_res_within_bound(measure_timer_resolution());
183 nt_set_timer_resolution(old_res).unwrap();
184 }
185
186 #[test]
187 #[ignore]
setting_timer_resolution_changes_resolution()188 fn setting_timer_resolution_changes_resolution() {
189 let res = Duration::from_millis(1);
190
191 set_time_period(res, true).unwrap();
192 assert_res_within_bound(measure_timer_resolution());
193 set_time_period(res, false).unwrap();
194 }
195
assert_res_within_bound(actual_res: Duration)196 fn assert_res_within_bound(actual_res: Duration) {
197 assert!(
198 actual_res <= Duration::from_millis(2),
199 "actual_res was {:?}, expected <= 2ms",
200 actual_res
201 );
202 }
203 }
204