1 // Copyright 2022 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::{
6 io,
7 mem::MaybeUninit,
8 sync::Once,
9 time::{Duration, Instant},
10 };
11
12 use winapi::shared::minwindef::{
13 HINSTANCE, HMODULE, PULONG, {self},
14 };
15
16 use winapi::{
17 shared::{
18 ntdef::{NTSTATUS, ULONG},
19 ntstatus::STATUS_SUCCESS,
20 },
21 um::{libloaderapi, mmsystem::TIMERR_NOERROR},
22 };
23
24 use std::thread::sleep;
25 use win_util::{win32_string, win32_wide_string};
26 use winapi::um::{
27 timeapi::{timeBeginPeriod, timeEndPeriod},
28 winnt::BOOLEAN,
29 };
30
31 use super::super::{Error, Result};
32 use crate::warn;
33
34 static NT_INIT: Once = Once::new();
35 static mut NT_LIBRARY: MaybeUninit<HMODULE> = MaybeUninit::uninit();
36
37 #[inline]
init_ntdll() -> Result<HINSTANCE>38 fn init_ntdll() -> Result<HINSTANCE> {
39 NT_INIT.call_once(|| {
40 unsafe {
41 *NT_LIBRARY.as_mut_ptr() =
42 libloaderapi::LoadLibraryW(win32_wide_string("ntdll").as_ptr());
43
44 if NT_LIBRARY.assume_init().is_null() {
45 warn!("Failed to load ntdll: {}", Error::last());
46 }
47 };
48 });
49
50 let handle = unsafe { NT_LIBRARY.assume_init() };
51 if handle.is_null() {
52 Err(Error::from(io::Error::new(
53 io::ErrorKind::NotFound,
54 "ntdll failed to load",
55 )))
56 } else {
57 Ok(handle)
58 }
59 }
60
get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function>61 fn get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function> {
62 let symbol = unsafe { libloaderapi::GetProcAddress(handle, win32_string(proc_name).as_ptr()) };
63 if symbol.is_null() {
64 Err(Error::last())
65 } else {
66 Ok(symbol)
67 }
68 }
69
70 /// Returns the resolution of timers on the host (current_res, max_res).
nt_query_timer_resolution() -> Result<(Duration, Duration)>71 pub fn nt_query_timer_resolution() -> Result<(Duration, Duration)> {
72 let handle = init_ntdll()?;
73
74 let func = unsafe {
75 std::mem::transmute::<
76 *mut minwindef::__some_function,
77 extern "system" fn(PULONG, PULONG, PULONG) -> NTSTATUS,
78 >(get_symbol(handle, "NtQueryTimerResolution")?)
79 };
80
81 let mut min_res: u32 = 0;
82 let mut max_res: u32 = 0;
83 let mut current_res: u32 = 0;
84 let ret = func(
85 &mut min_res as *mut u32,
86 &mut max_res as *mut u32,
87 &mut current_res as *mut u32,
88 );
89
90 if ret != STATUS_SUCCESS {
91 Err(Error::from(io::Error::new(
92 io::ErrorKind::Other,
93 "NtQueryTimerResolution failed",
94 )))
95 } else {
96 Ok((
97 Duration::from_nanos((current_res as u64) * 100),
98 Duration::from_nanos((max_res as u64) * 100),
99 ))
100 }
101 }
102
nt_set_timer_resolution(resolution: Duration) -> Result<()>103 pub fn nt_set_timer_resolution(resolution: Duration) -> Result<()> {
104 let handle = init_ntdll()?;
105 let func = unsafe {
106 std::mem::transmute::<
107 *mut minwindef::__some_function,
108 extern "system" fn(ULONG, BOOLEAN, PULONG) -> NTSTATUS,
109 >(get_symbol(handle, "NtSetTimerResolution")?)
110 };
111
112 let requested_res: u32 = (resolution.as_nanos() / 100) as u32;
113 let mut current_res: u32 = 0;
114 let ret = func(
115 requested_res,
116 1, /* true */
117 &mut current_res as *mut u32,
118 );
119
120 if ret != STATUS_SUCCESS {
121 Err(Error::from(io::Error::new(
122 io::ErrorKind::Other,
123 "NtSetTimerResolution failed",
124 )))
125 } else {
126 Ok(())
127 }
128 }
129
130 /// Measures the timer resolution by taking the 90th percentile wall time of 1ms sleeps.
measure_timer_resolution() -> Duration131 pub fn measure_timer_resolution() -> Duration {
132 let mut durations = Vec::with_capacity(100);
133 for _ in 0..100 {
134 let start = Instant::now();
135 // Windows cannot support sleeps shorter than 1ms.
136 sleep(Duration::from_millis(1));
137 durations.push(Instant::now() - start);
138 }
139
140 durations.sort();
141 durations[89]
142 }
143
144 /// Note that Durations below 1ms are not supported and will panic.
set_time_period(res: Duration, begin: bool) -> Result<()>145 pub fn set_time_period(res: Duration, begin: bool) -> Result<()> {
146 if res.as_millis() < 1 {
147 panic!(
148 "time(Begin|End)Period does not support values below 1ms, but {:?} was requested.",
149 res
150 );
151 }
152 if res.as_millis() > u32::MAX as u128 {
153 panic!("time(Begin|End)Period does not support values above u32::MAX.",);
154 }
155
156 // Trivially safe. Note that the casts are safe because we know res is within u32's range.
157 let ret = if begin {
158 unsafe { timeBeginPeriod(res.as_millis() as u32) }
159 } else {
160 unsafe { timeEndPeriod(res.as_millis() as u32) }
161 };
162 if ret != TIMERR_NOERROR {
163 // These functions only have two return codes: NOERROR and NOCANDO.
164 Err(Error::from(io::Error::new(
165 io::ErrorKind::InvalidInput,
166 "timeBegin/EndPeriod failed",
167 )))
168 } else {
169 Ok(())
170 }
171 }
172
173 /// Note that these tests cannot run on Kokoro due to random slowness in that environment.
174 #[cfg(test)]
175 mod tests {
176 use super::*;
177
178 /// We're testing whether NtSetTimerResolution does what it says on the tin.
179 #[test]
180 #[ignore]
setting_nt_timer_resolution_changes_resolution()181 fn setting_nt_timer_resolution_changes_resolution() {
182 let (old_res, _) = nt_query_timer_resolution().unwrap();
183
184 nt_set_timer_resolution(Duration::from_millis(1)).unwrap();
185 assert_res_within_bound(measure_timer_resolution());
186 nt_set_timer_resolution(old_res).unwrap();
187 }
188
189 #[test]
190 #[ignore]
setting_timer_resolution_changes_resolution()191 fn setting_timer_resolution_changes_resolution() {
192 let res = Duration::from_millis(1);
193
194 set_time_period(res, true).unwrap();
195 assert_res_within_bound(measure_timer_resolution());
196 set_time_period(res, false).unwrap();
197 }
198
assert_res_within_bound(actual_res: Duration)199 fn assert_res_within_bound(actual_res: Duration) {
200 assert!(
201 actual_res <= Duration::from_millis(2),
202 "actual_res was {:?}, expected <= 2ms",
203 actual_res
204 );
205 }
206 }
207