• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::{
6     io,
7     mem::MaybeUninit,
8     sync::Once,
9     time::{Duration, Instant},
10 };
11 
12 use winapi::shared::minwindef::{
13     HINSTANCE, HMODULE, PULONG, {self},
14 };
15 
16 use winapi::{
17     shared::{
18         ntdef::{NTSTATUS, ULONG},
19         ntstatus::STATUS_SUCCESS,
20     },
21     um::{libloaderapi, mmsystem::TIMERR_NOERROR},
22 };
23 
24 use std::thread::sleep;
25 use win_util::{win32_string, win32_wide_string};
26 use winapi::um::{
27     timeapi::{timeBeginPeriod, timeEndPeriod},
28     winnt::BOOLEAN,
29 };
30 
31 use super::super::{Error, Result};
32 use crate::warn;
33 
34 static NT_INIT: Once = Once::new();
35 static mut NT_LIBRARY: MaybeUninit<HMODULE> = MaybeUninit::uninit();
36 
37 #[inline]
init_ntdll() -> Result<HINSTANCE>38 fn init_ntdll() -> Result<HINSTANCE> {
39     NT_INIT.call_once(|| {
40         unsafe {
41             *NT_LIBRARY.as_mut_ptr() =
42                 libloaderapi::LoadLibraryW(win32_wide_string("ntdll").as_ptr());
43 
44             if NT_LIBRARY.assume_init().is_null() {
45                 warn!("Failed to load ntdll: {}", Error::last());
46             }
47         };
48     });
49 
50     let handle = unsafe { NT_LIBRARY.assume_init() };
51     if handle.is_null() {
52         Err(Error::from(io::Error::new(
53             io::ErrorKind::NotFound,
54             "ntdll failed to load",
55         )))
56     } else {
57         Ok(handle)
58     }
59 }
60 
get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function>61 fn get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function> {
62     let symbol = unsafe { libloaderapi::GetProcAddress(handle, win32_string(proc_name).as_ptr()) };
63     if symbol.is_null() {
64         Err(Error::last())
65     } else {
66         Ok(symbol)
67     }
68 }
69 
70 /// Returns the resolution of timers on the host (current_res, max_res).
nt_query_timer_resolution() -> Result<(Duration, Duration)>71 pub fn nt_query_timer_resolution() -> Result<(Duration, Duration)> {
72     let handle = init_ntdll()?;
73 
74     let func = unsafe {
75         std::mem::transmute::<
76             *mut minwindef::__some_function,
77             extern "system" fn(PULONG, PULONG, PULONG) -> NTSTATUS,
78         >(get_symbol(handle, "NtQueryTimerResolution")?)
79     };
80 
81     let mut min_res: u32 = 0;
82     let mut max_res: u32 = 0;
83     let mut current_res: u32 = 0;
84     let ret = func(
85         &mut min_res as *mut u32,
86         &mut max_res as *mut u32,
87         &mut current_res as *mut u32,
88     );
89 
90     if ret != STATUS_SUCCESS {
91         Err(Error::from(io::Error::new(
92             io::ErrorKind::Other,
93             "NtQueryTimerResolution failed",
94         )))
95     } else {
96         Ok((
97             Duration::from_nanos((current_res as u64) * 100),
98             Duration::from_nanos((max_res as u64) * 100),
99         ))
100     }
101 }
102 
nt_set_timer_resolution(resolution: Duration) -> Result<()>103 pub fn nt_set_timer_resolution(resolution: Duration) -> Result<()> {
104     let handle = init_ntdll()?;
105     let func = unsafe {
106         std::mem::transmute::<
107             *mut minwindef::__some_function,
108             extern "system" fn(ULONG, BOOLEAN, PULONG) -> NTSTATUS,
109         >(get_symbol(handle, "NtSetTimerResolution")?)
110     };
111 
112     let requested_res: u32 = (resolution.as_nanos() / 100) as u32;
113     let mut current_res: u32 = 0;
114     let ret = func(
115         requested_res,
116         1, /* true */
117         &mut current_res as *mut u32,
118     );
119 
120     if ret != STATUS_SUCCESS {
121         Err(Error::from(io::Error::new(
122             io::ErrorKind::Other,
123             "NtSetTimerResolution failed",
124         )))
125     } else {
126         Ok(())
127     }
128 }
129 
130 /// Measures the timer resolution by taking the 90th percentile wall time of 1ms sleeps.
measure_timer_resolution() -> Duration131 pub fn measure_timer_resolution() -> Duration {
132     let mut durations = Vec::with_capacity(100);
133     for _ in 0..100 {
134         let start = Instant::now();
135         // Windows cannot support sleeps shorter than 1ms.
136         sleep(Duration::from_millis(1));
137         durations.push(Instant::now() - start);
138     }
139 
140     durations.sort();
141     durations[89]
142 }
143 
144 /// Note that Durations below 1ms are not supported and will panic.
set_time_period(res: Duration, begin: bool) -> Result<()>145 pub fn set_time_period(res: Duration, begin: bool) -> Result<()> {
146     if res.as_millis() < 1 {
147         panic!(
148             "time(Begin|End)Period does not support values below 1ms, but {:?} was requested.",
149             res
150         );
151     }
152     if res.as_millis() > u32::MAX as u128 {
153         panic!("time(Begin|End)Period does not support values above u32::MAX.",);
154     }
155 
156     // Trivially safe. Note that the casts are safe because we know res is within u32's range.
157     let ret = if begin {
158         unsafe { timeBeginPeriod(res.as_millis() as u32) }
159     } else {
160         unsafe { timeEndPeriod(res.as_millis() as u32) }
161     };
162     if ret != TIMERR_NOERROR {
163         // These functions only have two return codes: NOERROR and NOCANDO.
164         Err(Error::from(io::Error::new(
165             io::ErrorKind::InvalidInput,
166             "timeBegin/EndPeriod failed",
167         )))
168     } else {
169         Ok(())
170     }
171 }
172 
173 /// Note that these tests cannot run on Kokoro due to random slowness in that environment.
174 #[cfg(test)]
175 mod tests {
176     use super::*;
177 
178     /// We're testing whether NtSetTimerResolution does what it says on the tin.
179     #[test]
180     #[ignore]
setting_nt_timer_resolution_changes_resolution()181     fn setting_nt_timer_resolution_changes_resolution() {
182         let (old_res, _) = nt_query_timer_resolution().unwrap();
183 
184         nt_set_timer_resolution(Duration::from_millis(1)).unwrap();
185         assert_res_within_bound(measure_timer_resolution());
186         nt_set_timer_resolution(old_res).unwrap();
187     }
188 
189     #[test]
190     #[ignore]
setting_timer_resolution_changes_resolution()191     fn setting_timer_resolution_changes_resolution() {
192         let res = Duration::from_millis(1);
193 
194         set_time_period(res, true).unwrap();
195         assert_res_within_bound(measure_timer_resolution());
196         set_time_period(res, false).unwrap();
197     }
198 
assert_res_within_bound(actual_res: Duration)199     fn assert_res_within_bound(actual_res: Duration) {
200         assert!(
201             actual_res <= Duration::from_millis(2),
202             "actual_res was {:?}, expected <= 2ms",
203             actual_res
204         );
205     }
206 }
207