• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Test the internal _wmi module on Windows
2# This is used by the platform module, and potentially others
3
4import time
5import unittest
6from test.support import import_helper, requires_resource, LOOPBACK_TIMEOUT
7
8
9# Do this first so test will be skipped if module doesn't exist
10_wmi = import_helper.import_module('_wmi', required_on=['win'])
11
12
13def wmi_exec_query(query):
14    # gh-112278: WMI maybe slow response when first call.
15    try:
16        return _wmi.exec_query(query)
17    except BrokenPipeError:
18        pass
19    except WindowsError as e:
20        if e.winerror != 258:
21            raise
22    time.sleep(LOOPBACK_TIMEOUT)
23    return _wmi.exec_query(query)
24
25
26class WmiTests(unittest.TestCase):
27    def test_wmi_query_os_version(self):
28        r = wmi_exec_query("SELECT Version FROM Win32_OperatingSystem").split("\0")
29        self.assertEqual(1, len(r))
30        k, eq, v = r[0].partition("=")
31        self.assertEqual("=", eq, r[0])
32        self.assertEqual("Version", k, r[0])
33        # Best we can check for the version is that it's digits, dot, digits, anything
34        # Otherwise, we are likely checking the result of the query against itself
35        self.assertRegex(v, r"\d+\.\d+.+$", r[0])
36
37    def test_wmi_query_repeated(self):
38        # Repeated queries should not break
39        for _ in range(10):
40            self.test_wmi_query_os_version()
41
42    def test_wmi_query_error(self):
43        # Invalid queries fail with OSError
44        try:
45            wmi_exec_query("SELECT InvalidColumnName FROM InvalidTableName")
46        except OSError as ex:
47            if ex.winerror & 0xFFFFFFFF == 0x80041010:
48                # This is the expected error code. All others should fail the test
49                return
50        self.fail("Expected OSError")
51
52    def test_wmi_query_repeated_error(self):
53        for _ in range(10):
54            self.test_wmi_query_error()
55
56    def test_wmi_query_not_select(self):
57        # Queries other than SELECT are blocked to avoid potential exploits
58        with self.assertRaises(ValueError):
59            wmi_exec_query("not select, just in case someone tries something")
60
61    @requires_resource('cpu')
62    def test_wmi_query_overflow(self):
63        # Ensure very big queries fail
64        # Test multiple times to ensure consistency
65        for _ in range(2):
66            with self.assertRaises(OSError):
67                wmi_exec_query("SELECT * FROM CIM_DataFile")
68
69    def test_wmi_query_multiple_rows(self):
70        # Multiple instances should have an extra null separator
71        r = wmi_exec_query("SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000")
72        self.assertFalse(r.startswith("\0"), r)
73        self.assertFalse(r.endswith("\0"), r)
74        it = iter(r.split("\0"))
75        try:
76            while True:
77                self.assertRegex(next(it), r"ProcessId=\d+")
78                self.assertEqual("", next(it))
79        except StopIteration:
80            pass
81
82    def test_wmi_query_threads(self):
83        from concurrent.futures import ThreadPoolExecutor
84        query = "SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000"
85        with ThreadPoolExecutor(4) as pool:
86            task = [pool.submit(wmi_exec_query, query) for _ in range(32)]
87            for t in task:
88                self.assertRegex(t.result(), "ProcessId=")
89