• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 The Khronos Group Inc.
3  * Copyright (c) 2021 Valve Corporation
4  * Copyright (c) 2021 LunarG, Inc.
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and/or associated documentation files (the "Materials"), to
8  * deal in the Materials without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Materials, and to permit persons to whom the Materials are
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice(s) and this permission notice shall be included in
14  * all copies or substantial portions of the Materials.
15  *
16  * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19  *
20  * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
21  * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
22  * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE
23  * USE OR OTHER DEALINGS IN THE MATERIALS.
24  *
25  * Author: Charles Giessen <charles@lunarg.com>
26  */
27 
28 // This needs to be defined first, or else we'll get redefinitions on NTSTATUS values
29 #ifdef _WIN32
30 #define UMDF_USING_NTSTATUS
31 #include <ntstatus.h>
32 #endif
33 
34 #include "shim.h"
35 
36 #include "detours.h"
37 
38 static PlatformShim platform_shim;
39 
40 extern "C" {
41 
42 static LibraryWrapper gdi32_dll;
43 
44 using PFN_GetSidSubAuthority = PDWORD(__stdcall *)(PSID pSid, DWORD nSubAuthority);
45 static PFN_GetSidSubAuthority fpGetSidSubAuthority = GetSidSubAuthority;
46 
ShimGetSidSubAuthority(PSID pSid,DWORD nSubAuthority)47 PDWORD __stdcall ShimGetSidSubAuthority(PSID pSid, DWORD nSubAuthority) { return &platform_shim.elevation_level; }
48 
49 static PFN_LoaderEnumAdapters2 fpEnumAdapters2 = nullptr;
50 static PFN_LoaderQueryAdapterInfo fpQueryAdapterInfo = nullptr;
51 
ShimEnumAdapters2(LoaderEnumAdapters2 * adapters)52 NTSTATUS APIENTRY ShimEnumAdapters2(LoaderEnumAdapters2 *adapters) {
53     if (adapters == nullptr) {
54         return STATUS_INVALID_PARAMETER;
55     }
56     if (platform_shim.d3dkmt_adapters.size() == 0) {
57         if (adapters->adapters != nullptr) adapters->adapter_count = 0;
58         return STATUS_SUCCESS;
59     }
60     if (adapters->adapters != nullptr) {
61         for (size_t i = 0; i < platform_shim.d3dkmt_adapters.size(); i++) {
62             adapters->adapters[i].handle = platform_shim.d3dkmt_adapters[i].hAdapter;
63             adapters->adapters[i].luid = platform_shim.d3dkmt_adapters[i].adapter_luid;
64         }
65         adapters->adapter_count = static_cast<ULONG>(platform_shim.d3dkmt_adapters.size());
66     } else {
67         adapters->adapter_count = static_cast<ULONG>(platform_shim.d3dkmt_adapters.size());
68     }
69     return STATUS_SUCCESS;
70 }
ShimQueryAdapterInfo(const LoaderQueryAdapterInfo * query_info)71 NTSTATUS APIENTRY ShimQueryAdapterInfo(const LoaderQueryAdapterInfo *query_info) {
72     if (query_info == nullptr || query_info->private_data == nullptr) {
73         return STATUS_INVALID_PARAMETER;
74     }
75     auto handle = query_info->handle;
76     auto it = std::find_if(platform_shim.d3dkmt_adapters.begin(), platform_shim.d3dkmt_adapters.end(),
77                            [handle](D3DKMT_Adapter const &adapter) { return handle == adapter.hAdapter; });
78     if (it == platform_shim.d3dkmt_adapters.end()) {
79         return STATUS_INVALID_PARAMETER;
80     }
81     auto &adapter = *it;
82     auto *reg_info = reinterpret_cast<LoaderQueryRegistryInfo *>(query_info->private_data);
83 
84     std::vector<std::wstring> *paths = nullptr;
85     if (reg_info->value_name[6] == L'D') {  // looking for drivers
86         paths = &adapter.driver_paths;
87     } else if (reg_info->value_name[6] == L'I') {  // looking for implicit layers
88         paths = &adapter.implicit_layer_paths;
89     } else if (reg_info->value_name[6] == L'E') {  // looking for explicit layers
90         paths = &adapter.explicit_layer_paths;
91     }
92 
93     reg_info->status = LOADER_QUERY_REGISTRY_STATUS_SUCCESS;
94     if (reg_info->output_value_size == 0) {
95         ULONG size = 2;  // final null terminator
96         for (auto const &path : *paths) size = static_cast<ULONG>(path.length() * sizeof(wchar_t));
97         // size in bytes, so multiply path size by two and add 2 for the null terminator
98         reg_info->output_value_size = size;
99         if (size != 2) {
100             // only want to write data if there is path data to write
101             reg_info->status = LOADER_QUERY_REGISTRY_STATUS_BUFFER_OVERFLOW;
102         }
103     } else if (reg_info->output_value_size > 2) {
104         size_t index = 0;
105         for (auto const &path : *paths) {
106             for (auto w : path) {
107                 reg_info->output_string[index++] = w;
108             }
109             reg_info->output_string[index++] = L'\0';
110         }
111         // make sure there is a null terminator
112         reg_info->output_string[index++] = L'\0';
113 
114         reg_info->status = LOADER_QUERY_REGISTRY_STATUS_SUCCESS;
115     }
116 
117     return STATUS_SUCCESS;
118 }
119 
120 // clang-format off
121 static CONFIGRET(WINAPI *REAL_CM_Get_Device_ID_List_SizeW)(PULONG pulLen, PCWSTR pszFilter, ULONG ulFlags) = CM_Get_Device_ID_List_SizeW;
122 static CONFIGRET(WINAPI *REAL_CM_Get_Device_ID_ListW)(PCWSTR pszFilter, PZZWSTR Buffer, ULONG BufferLen, ULONG ulFlags) = CM_Get_Device_ID_ListW;
123 static CONFIGRET(WINAPI *REAL_CM_Locate_DevNodeW)(PDEVINST pdnDevInst, DEVINSTID_W pDeviceID, ULONG ulFlags) =  CM_Locate_DevNodeW;
124 static CONFIGRET(WINAPI *REAL_CM_Get_DevNode_Status)(PULONG pulStatus, PULONG pulProblemNumber, DEVINST dnDevInst, ULONG ulFlags) =  CM_Get_DevNode_Status;
125 static CONFIGRET(WINAPI *REAL_CM_Get_Device_IDW)(DEVINST dnDevInst, PWSTR Buffer, ULONG BufferLen, ULONG ulFlags) =  CM_Get_Device_IDW;
126 static CONFIGRET(WINAPI *REAL_CM_Get_Child)(PDEVINST pdnDevInst, DEVINST dnDevInst, ULONG ulFlags) =  CM_Get_Child;
127 static CONFIGRET(WINAPI *REAL_CM_Get_DevNode_Registry_PropertyW)(DEVINST dnDevInst, ULONG ulProperty, PULONG pulRegDataType, PVOID Buffer, PULONG pulLength, ULONG ulFlags) =  CM_Get_DevNode_Registry_PropertyW;
128 static CONFIGRET(WINAPI *REAL_CM_Get_Sibling)(PDEVINST pdnDevInst, DEVINST dnDevInst, ULONG ulFlags) = CM_Get_Sibling;
129 // clang-format on
130 
SHIM_CM_Get_Device_ID_List_SizeW(PULONG pulLen,PCWSTR pszFilter,ULONG ulFlags)131 CONFIGRET WINAPI SHIM_CM_Get_Device_ID_List_SizeW(PULONG pulLen, PCWSTR pszFilter, ULONG ulFlags) {
132     if (pulLen == nullptr) {
133         return CR_INVALID_POINTER;
134     }
135     *pulLen = static_cast<ULONG>(platform_shim.CM_device_ID_list.size());
136     return CR_SUCCESS;
137 }
SHIM_CM_Get_Device_ID_ListW(PCWSTR pszFilter,PZZWSTR Buffer,ULONG BufferLen,ULONG ulFlags)138 CONFIGRET WINAPI SHIM_CM_Get_Device_ID_ListW(PCWSTR pszFilter, PZZWSTR Buffer, ULONG BufferLen, ULONG ulFlags) {
139     if (Buffer != NULL) {
140         if (BufferLen < platform_shim.CM_device_ID_list.size()) return CR_BUFFER_SMALL;
141         for (size_t i = 0; i < BufferLen; i++) {
142             Buffer[i] = platform_shim.CM_device_ID_list[i];
143         }
144     }
145     return CR_SUCCESS;
146 }
147 // TODO
SHIM_CM_Locate_DevNodeW(PDEVINST pdnDevInst,DEVINSTID_W pDeviceID,ULONG ulFlags)148 CONFIGRET WINAPI SHIM_CM_Locate_DevNodeW(PDEVINST pdnDevInst, DEVINSTID_W pDeviceID, ULONG ulFlags) { return CR_FAILURE; }
149 // TODO
SHIM_CM_Get_DevNode_Status(PULONG pulStatus,PULONG pulProblemNumber,DEVINST dnDevInst,ULONG ulFlags)150 CONFIGRET WINAPI SHIM_CM_Get_DevNode_Status(PULONG pulStatus, PULONG pulProblemNumber, DEVINST dnDevInst, ULONG ulFlags) {
151     return CR_FAILURE;
152 }
153 // TODO
SHIM_CM_Get_Device_IDW(DEVINST dnDevInst,PWSTR Buffer,ULONG BufferLen,ULONG ulFlags)154 CONFIGRET WINAPI SHIM_CM_Get_Device_IDW(DEVINST dnDevInst, PWSTR Buffer, ULONG BufferLen, ULONG ulFlags) { return CR_FAILURE; }
155 // TODO
SHIM_CM_Get_Child(PDEVINST pdnDevInst,DEVINST dnDevInst,ULONG ulFlags)156 CONFIGRET WINAPI SHIM_CM_Get_Child(PDEVINST pdnDevInst, DEVINST dnDevInst, ULONG ulFlags) { return CR_FAILURE; }
157 // TODO
SHIM_CM_Get_DevNode_Registry_PropertyW(DEVINST dnDevInst,ULONG ulProperty,PULONG pulRegDataType,PVOID Buffer,PULONG pulLength,ULONG ulFlags)158 CONFIGRET WINAPI SHIM_CM_Get_DevNode_Registry_PropertyW(DEVINST dnDevInst, ULONG ulProperty, PULONG pulRegDataType, PVOID Buffer,
159                                                         PULONG pulLength, ULONG ulFlags) {
160     return CR_FAILURE;
161 }
162 // TODO
SHIM_CM_Get_Sibling(PDEVINST pdnDevInst,DEVINST dnDevInst,ULONG ulFlags)163 CONFIGRET WINAPI SHIM_CM_Get_Sibling(PDEVINST pdnDevInst, DEVINST dnDevInst, ULONG ulFlags) { return CR_FAILURE; }
164 
165 static LibraryWrapper dxgi_module;
166 typedef HRESULT(APIENTRY *PFN_CreateDXGIFactory1)(REFIID riid, void **ppFactory);
167 
168 PFN_CreateDXGIFactory1 RealCreateDXGIFactory1;
169 
ShimGetDesc1(IDXGIAdapter1 * pAdapter,_Out_ DXGI_ADAPTER_DESC1 * pDesc)170 HRESULT __stdcall ShimGetDesc1(IDXGIAdapter1 *pAdapter,
171                                /* [annotation][out] */
172                                _Out_ DXGI_ADAPTER_DESC1 *pDesc) {
173     if (pAdapter == nullptr || pDesc == nullptr) return DXGI_ERROR_INVALID_CALL;
174     auto it = platform_shim.dxgi_adapter_map.find(pAdapter);
175     if (it == platform_shim.dxgi_adapter_map.end()) {
176         return DXGI_ERROR_INVALID_CALL;
177     }
178     *pDesc = platform_shim.dxgi_adapters[it->second].desc1;
179     return S_OK;
180 }
ShimIDXGIFactory1Release(IDXGIFactory1 * factory)181 ULONG __stdcall ShimIDXGIFactory1Release(IDXGIFactory1 *factory) {
182     if (factory != nullptr) {
183         if (factory->lpVtbl != nullptr) {
184             delete factory->lpVtbl;
185         }
186         delete factory;
187     }
188     return S_OK;
189 }
ShimIDXGIFactory6Release(IDXGIFactory6 * factory)190 ULONG __stdcall ShimIDXGIFactory6Release(IDXGIFactory6 *factory) {
191     if (factory != nullptr) {
192         if (factory->lpVtbl != nullptr) {
193             delete factory->lpVtbl;
194         }
195         delete factory;
196     }
197     return S_OK;
198 }
199 
ShimRelease(IDXGIAdapter1 * pAdapter)200 ULONG __stdcall ShimRelease(IDXGIAdapter1 *pAdapter) {
201     if (pAdapter != nullptr) {
202         if (pAdapter->lpVtbl != nullptr) {
203             delete pAdapter->lpVtbl;
204         }
205         delete pAdapter;
206     }
207     return S_OK;
208 }
209 
create_IDXGIAdapter1()210 IDXGIAdapter1 *create_IDXGIAdapter1() {
211     IDXGIAdapter1Vtbl *vtbl = new IDXGIAdapter1Vtbl();
212     vtbl->GetDesc1 = ShimGetDesc1;
213     vtbl->Release = ShimRelease;
214     IDXGIAdapter1 *adapter = new IDXGIAdapter1();
215     adapter->lpVtbl = vtbl;
216     return adapter;
217 }
218 
ShimEnumAdapters1_1(IDXGIFactory1 * This,UINT Adapter,_COM_Outptr_ IDXGIAdapter1 ** ppAdapter)219 HRESULT __stdcall ShimEnumAdapters1_1(IDXGIFactory1 *This,
220                                       /* [in] */ UINT Adapter,
221                                       /* [annotation][out] */
222                                       _COM_Outptr_ IDXGIAdapter1 **ppAdapter) {
223     if (Adapter >= platform_shim.dxgi_adapters.size()) {
224         return DXGI_ERROR_INVALID_CALL;
225     }
226     if (ppAdapter != nullptr) {
227         auto *pAdapter = create_IDXGIAdapter1();
228         *ppAdapter = pAdapter;
229         platform_shim.dxgi_adapter_map[pAdapter] = Adapter;
230     }
231     return S_OK;
232 }
233 
ShimEnumAdapters1_6(IDXGIFactory6 * This,UINT Adapter,_COM_Outptr_ IDXGIAdapter1 ** ppAdapter)234 HRESULT __stdcall ShimEnumAdapters1_6(IDXGIFactory6 *This,
235                                       /* [in] */ UINT Adapter,
236                                       /* [annotation][out] */
237                                       _COM_Outptr_ IDXGIAdapter1 **ppAdapter) {
238     if (Adapter >= platform_shim.dxgi_adapters.size()) {
239         return DXGI_ERROR_INVALID_CALL;
240     }
241     if (ppAdapter != nullptr) {
242         auto *pAdapter = create_IDXGIAdapter1();
243         *ppAdapter = pAdapter;
244         platform_shim.dxgi_adapter_map[pAdapter] = Adapter;
245     }
246     return S_OK;
247 }
248 
ShimEnumAdapterByGpuPreference(IDXGIFactory6 * This,_In_ UINT Adapter,_In_ DXGI_GPU_PREFERENCE GpuPreference,_In_ REFIID riid,_COM_Outptr_ void ** ppvAdapter)249 HRESULT __stdcall ShimEnumAdapterByGpuPreference(IDXGIFactory6 *This, _In_ UINT Adapter, _In_ DXGI_GPU_PREFERENCE GpuPreference,
250                                                  _In_ REFIID riid, _COM_Outptr_ void **ppvAdapter) {
251     if (Adapter >= platform_shim.dxgi_adapters.size()) {
252         return DXGI_ERROR_NOT_FOUND;
253     }
254     // loader always uses DXGI_GPU_PREFERENCE_UNSPECIFIED
255     // Update the shim if this isn't the case
256     assert(GpuPreference == DXGI_GPU_PREFERENCE::DXGI_GPU_PREFERENCE_UNSPECIFIED &&
257            "Test shim assumes the GpuPreference is unspecified.");
258     if (ppvAdapter != nullptr) {
259         auto *pAdapter = create_IDXGIAdapter1();
260         *ppvAdapter = pAdapter;
261         platform_shim.dxgi_adapter_map[pAdapter] = Adapter;
262     }
263     return S_OK;
264 }
265 
create_IDXGIFactory1()266 IDXGIFactory1 *create_IDXGIFactory1() {
267     IDXGIFactory1Vtbl *vtbl = new IDXGIFactory1Vtbl();
268     vtbl->EnumAdapters1 = ShimEnumAdapters1_1;
269     vtbl->Release = ShimIDXGIFactory1Release;
270     IDXGIFactory1 *factory = new IDXGIFactory1();
271     factory->lpVtbl = vtbl;
272     return factory;
273 }
274 
create_IDXGIFactory6()275 IDXGIFactory6 *create_IDXGIFactory6() {
276     IDXGIFactory6Vtbl *vtbl = new IDXGIFactory6Vtbl();
277     vtbl->EnumAdapters1 = ShimEnumAdapters1_6;
278     vtbl->EnumAdapterByGpuPreference = ShimEnumAdapterByGpuPreference;
279     vtbl->Release = ShimIDXGIFactory6Release;
280     IDXGIFactory6 *factory = new IDXGIFactory6();
281     factory->lpVtbl = vtbl;
282     return factory;
283 }
284 
ShimCreateDXGIFactory1(REFIID riid,void ** ppFactory)285 HRESULT __stdcall ShimCreateDXGIFactory1(REFIID riid, void **ppFactory) {
286     if (riid == IID_IDXGIFactory1) {
287         auto *factory = create_IDXGIFactory1();
288         *ppFactory = factory;
289         return S_OK;
290     }
291     if (riid == IID_IDXGIFactory6) {
292         auto *factory = create_IDXGIFactory6();
293         *ppFactory = factory;
294         return S_OK;
295     }
296     assert(false && "new riid, update shim code to handle");
297     return S_FALSE;
298 }
299 
300 // Windows Registry shims
301 using PFN_RegOpenKeyExA = LSTATUS(__stdcall *)(HKEY hKey, LPCSTR lpSubKey, DWORD ulOptions, REGSAM samDesired, PHKEY phkResult);
302 static PFN_RegOpenKeyExA fpRegOpenKeyExA = RegOpenKeyExA;
303 using PFN_RegQueryValueExA = LSTATUS(__stdcall *)(HKEY hKey, LPCSTR lpValueName, LPDWORD lpReserved, LPDWORD lpType, LPBYTE lpData,
304                                                   LPDWORD lpcbData);
305 static PFN_RegQueryValueExA fpRegQueryValueExA = RegQueryValueExA;
306 using PFN_RegEnumValueA = LSTATUS(__stdcall *)(HKEY hKey, DWORD dwIndex, LPSTR lpValueName, LPDWORD lpcchValueName,
307                                                LPDWORD lpReserved, LPDWORD lpType, LPBYTE lpData, LPDWORD lpcbData);
308 static PFN_RegEnumValueA fpRegEnumValueA = RegEnumValueA;
309 
310 using PFN_RegCloseKey = LSTATUS(__stdcall *)(HKEY hKey);
311 static PFN_RegCloseKey fpRegCloseKey = RegCloseKey;
312 
ShimRegOpenKeyExA(HKEY hKey,LPCSTR lpSubKey,DWORD ulOptions,REGSAM samDesired,PHKEY phkResult)313 LSTATUS __stdcall ShimRegOpenKeyExA(HKEY hKey, LPCSTR lpSubKey, DWORD ulOptions, REGSAM samDesired, PHKEY phkResult) {
314     if (HKEY_LOCAL_MACHINE != hKey && HKEY_CURRENT_USER != hKey) return ERROR_BADKEY;
315     std::string hive = "";
316     if (HKEY_LOCAL_MACHINE == hKey)
317         hive = "HKEY_LOCAL_MACHINE";
318     else if (HKEY_CURRENT_USER == hKey)
319         hive = "HKEY_CURRENT_USER";
320 
321     platform_shim.created_keys.emplace_back(platform_shim.created_key_count++, hive + "\\" + lpSubKey);
322     *phkResult = platform_shim.created_keys.back().get();
323     return 0;
324 }
get_path_of_created_key(HKEY hKey)325 const std::string *get_path_of_created_key(HKEY hKey) {
326     for (const auto &key : platform_shim.created_keys) {
327         if (key.key == hKey) {
328             return &key.path;
329         }
330     }
331     return nullptr;
332 }
get_registry_vector(std::string const & path)333 std::vector<RegistryEntry> *get_registry_vector(std::string const &path) {
334     if (path == "HKEY_LOCAL_MACHINE\\SOFTWARE\\Khronos\\Vulkan\\Drivers") return &platform_shim.hkey_local_machine_drivers;
335     if (path == "HKEY_LOCAL_MACHINE\\SOFTWARE\\Khronos\\Vulkan\\ExplicitLayers")
336         return &platform_shim.hkey_local_machine_explicit_layers;
337     if (path == "HKEY_LOCAL_MACHINE\\SOFTWARE\\Khronos\\Vulkan\\ImplicitLayers")
338         return &platform_shim.hkey_local_machine_implicit_layers;
339     if (path == "HKEY_CURRENT_USER\\SOFTWARE\\Khronos\\Vulkan\\ExplicitLayers")
340         return &platform_shim.hkey_current_user_explicit_layers;
341     if (path == "HKEY_CURRENT_USER\\SOFTWARE\\Khronos\\Vulkan\\ImplicitLayers")
342         return &platform_shim.hkey_current_user_implicit_layers;
343     return nullptr;
344 }
ShimRegQueryValueExA(HKEY hKey,LPCSTR lpValueName,LPDWORD lpReserved,LPDWORD lpType,LPBYTE lpData,LPDWORD lpcbData)345 LSTATUS __stdcall ShimRegQueryValueExA(HKEY hKey, LPCSTR lpValueName, LPDWORD lpReserved, LPDWORD lpType, LPBYTE lpData,
346                                        LPDWORD lpcbData) {
347     // TODO:
348     return ERROR_SUCCESS;
349 }
ShimRegEnumValueA(HKEY hKey,DWORD dwIndex,LPSTR lpValueName,LPDWORD lpcchValueName,LPDWORD lpReserved,LPDWORD lpType,LPBYTE lpData,LPDWORD lpcbData)350 LSTATUS __stdcall ShimRegEnumValueA(HKEY hKey, DWORD dwIndex, LPSTR lpValueName, LPDWORD lpcchValueName, LPDWORD lpReserved,
351                                     LPDWORD lpType, LPBYTE lpData, LPDWORD lpcbData) {
352     const std::string *path = get_path_of_created_key(hKey);
353     if (path == nullptr) return ERROR_NO_MORE_ITEMS;
354 
355     const auto *location_ptr = get_registry_vector(*path);
356     if (location_ptr == nullptr) return ERROR_NO_MORE_ITEMS;
357     const auto &location = *location_ptr;
358     if (dwIndex >= location.size()) return ERROR_NO_MORE_ITEMS;
359 
360     if (*lpcchValueName < location[dwIndex].name.size()) return ERROR_NO_MORE_ITEMS;
361     for (size_t i = 0; i < location[dwIndex].name.size(); i++) {
362         lpValueName[i] = location[dwIndex].name[i];
363     }
364     lpValueName[location[dwIndex].name.size()] = '\0';
365     *lpcchValueName = static_cast<DWORD>(location[dwIndex].name.size() + 1);
366     if (*lpcbData < sizeof(DWORD)) return ERROR_NO_MORE_ITEMS;
367     DWORD *lpcbData_dword = reinterpret_cast<DWORD *>(lpData);
368     *lpcbData_dword = location[dwIndex].value;
369     *lpcbData = sizeof(DWORD);
370     return ERROR_SUCCESS;
371 }
ShimRegCloseKey(HKEY hKey)372 LSTATUS __stdcall ShimRegCloseKey(HKEY hKey) {
373     for (size_t i = 0; i < platform_shim.created_keys.size(); i++) {
374         if (platform_shim.created_keys[i].get() == hKey) {
375             platform_shim.created_keys.erase(platform_shim.created_keys.begin() + i);
376             return ERROR_SUCCESS;
377         }
378     }
379     return ERROR_SUCCESS;
380 }
381 
382 // Windows app package shims
383 using PFN_GetPackagesByPackageFamily = LONG(WINAPI *)(PCWSTR, UINT32 *, PWSTR *, UINT32 *, WCHAR *);
384 static PFN_GetPackagesByPackageFamily fpGetPackagesByPackageFamily = GetPackagesByPackageFamily;
385 using PFN_GetPackagePathByFullName = LONG(WINAPI *)(PCWSTR, UINT32 *, PWSTR);
386 static PFN_GetPackagePathByFullName fpGetPackagePathByFullName = GetPackagePathByFullName;
387 
388 static constexpr wchar_t package_full_name[] = L"ThisIsARandomStringSinceTheNameDoesn'tMatter";
ShimGetPackagesByPackageFamily(_In_ PCWSTR packageFamilyName,_Inout_ UINT32 * count,_Out_writes_opt_ (* count)PWSTR * packageFullNames,_Inout_ UINT32 * bufferLength,_Out_writes_opt_ (* bufferLength)WCHAR * buffer)389 LONG WINAPI ShimGetPackagesByPackageFamily(_In_ PCWSTR packageFamilyName, _Inout_ UINT32 *count,
390                                            _Out_writes_opt_(*count) PWSTR *packageFullNames, _Inout_ UINT32 *bufferLength,
391                                            _Out_writes_opt_(*bufferLength) WCHAR *buffer) {
392     if (!packageFamilyName || !count || !bufferLength) return ERROR_INVALID_PARAMETER;
393     if (!platform_shim.app_package_path.empty() && wcscmp(packageFamilyName, L"Microsoft.D3DMappingLayers_8wekyb3d8bbwe") == 0) {
394         if (*count > 0 && !packageFullNames) return ERROR_INVALID_PARAMETER;
395         if (*bufferLength > 0 && !buffer) return ERROR_INVALID_PARAMETER;
396         if (*count > 1) return ERROR_INVALID_PARAMETER;
397         bool too_small = *count < 1 || *bufferLength < ARRAYSIZE(package_full_name);
398         *count = 1;
399         *bufferLength = ARRAYSIZE(package_full_name);
400         if (too_small) return ERROR_INSUFFICIENT_BUFFER;
401 
402         wcscpy(buffer, package_full_name);
403         *packageFullNames = buffer;
404         return 0;
405     }
406     *count = 0;
407     *bufferLength = 0;
408     return 0;
409 }
410 
ShimGetPackagePathByFullName(_In_ PCWSTR packageFullName,_Inout_ UINT32 * pathLength,_Out_writes_opt_ (* pathLength)PWSTR path)411 LONG WINAPI ShimGetPackagePathByFullName(_In_ PCWSTR packageFullName, _Inout_ UINT32 *pathLength,
412                                          _Out_writes_opt_(*pathLength) PWSTR path) {
413     if (!packageFullName || !pathLength) return ERROR_INVALID_PARAMETER;
414     if (*pathLength > 0 && !path) return ERROR_INVALID_PARAMETER;
415     if (wcscmp(packageFullName, package_full_name) != 0) {
416         *pathLength = 0;
417         return 0;
418     }
419     if (*pathLength < platform_shim.app_package_path.size() + 1) {
420         *pathLength = static_cast<UINT32>(platform_shim.app_package_path.size() + 1);
421         return ERROR_INSUFFICIENT_BUFFER;
422     }
423     wcscpy(path, platform_shim.app_package_path.c_str());
424     return 0;
425 }
426 
427 // Initialization
DetourFunctions()428 void WINAPI DetourFunctions() {
429     if (!gdi32_dll) {
430         gdi32_dll = LibraryWrapper("gdi32.dll");
431         fpEnumAdapters2 = gdi32_dll.get_symbol("D3DKMTEnumAdapters2");
432         if (fpEnumAdapters2 == nullptr) {
433             std::cerr << "Failed to load D3DKMTEnumAdapters2\n";
434             return;
435         }
436         fpQueryAdapterInfo = gdi32_dll.get_symbol("D3DKMTQueryAdapterInfo");
437         if (fpQueryAdapterInfo == nullptr) {
438             std::cerr << "Failed to load D3DKMTQueryAdapterInfo\n";
439             return;
440         }
441     }
442     if (!dxgi_module) {
443         TCHAR systemPath[MAX_PATH] = "";
444         GetSystemDirectory(systemPath, MAX_PATH);
445         StringCchCat(systemPath, MAX_PATH, TEXT("\\dxgi.dll"));
446         dxgi_module = LibraryWrapper(systemPath);
447         RealCreateDXGIFactory1 = dxgi_module.get_symbol("CreateDXGIFactory1");
448         if (RealCreateDXGIFactory1 == nullptr) {
449             std::cerr << "Failed to load CreateDXGIFactory1\n";
450         }
451     }
452 
453     DetourRestoreAfterWith();
454 
455     DetourTransactionBegin();
456     DetourUpdateThread(GetCurrentThread());
457     DetourAttach(&(PVOID &)fpGetSidSubAuthority, (PVOID)ShimGetSidSubAuthority);
458     DetourAttach(&(PVOID &)fpEnumAdapters2, (PVOID)ShimEnumAdapters2);
459     DetourAttach(&(PVOID &)fpQueryAdapterInfo, (PVOID)ShimQueryAdapterInfo);
460     DetourAttach(&(PVOID &)REAL_CM_Get_Device_ID_List_SizeW, (PVOID)SHIM_CM_Get_Device_ID_List_SizeW);
461     DetourAttach(&(PVOID &)REAL_CM_Get_Device_ID_ListW, (PVOID)SHIM_CM_Get_Device_ID_ListW);
462     DetourAttach(&(PVOID &)REAL_CM_Get_Device_ID_ListW, (PVOID)SHIM_CM_Get_Device_ID_ListW);
463     DetourAttach(&(PVOID &)REAL_CM_Locate_DevNodeW, (PVOID)SHIM_CM_Locate_DevNodeW);
464     DetourAttach(&(PVOID &)REAL_CM_Get_DevNode_Status, (PVOID)SHIM_CM_Get_DevNode_Status);
465     DetourAttach(&(PVOID &)REAL_CM_Get_Device_IDW, (PVOID)SHIM_CM_Get_Device_IDW);
466     DetourAttach(&(PVOID &)REAL_CM_Get_Child, (PVOID)SHIM_CM_Get_Child);
467     DetourAttach(&(PVOID &)REAL_CM_Get_DevNode_Registry_PropertyW, (PVOID)SHIM_CM_Get_DevNode_Registry_PropertyW);
468     DetourAttach(&(PVOID &)REAL_CM_Get_Sibling, (PVOID)SHIM_CM_Get_Sibling);
469     DetourAttach(&(PVOID &)RealCreateDXGIFactory1, (PVOID)ShimCreateDXGIFactory1);
470     DetourAttach(&(PVOID &)fpRegOpenKeyExA, (PVOID)ShimRegOpenKeyExA);
471     DetourAttach(&(PVOID &)fpRegQueryValueExA, (PVOID)ShimRegQueryValueExA);
472     DetourAttach(&(PVOID &)fpRegEnumValueA, (PVOID)ShimRegEnumValueA);
473     DetourAttach(&(PVOID &)fpRegCloseKey, (PVOID)ShimRegCloseKey);
474     DetourAttach(&(PVOID &)fpGetPackagesByPackageFamily, (PVOID)ShimGetPackagesByPackageFamily);
475     DetourAttach(&(PVOID &)fpGetPackagePathByFullName, (PVOID)ShimGetPackagePathByFullName);
476     LONG error = DetourTransactionCommit();
477 
478     if (error != NO_ERROR) {
479         std::cerr << "simple" << DETOURS_STRINGIFY(DETOURS_BITS) << ".dll:"
480                   << " Error detouring function(): " << error << "\n";
481     }
482 }
483 
DetachFunctions()484 void DetachFunctions() {
485     DetourTransactionBegin();
486     DetourUpdateThread(GetCurrentThread());
487     DetourDetach(&(PVOID &)fpGetSidSubAuthority, (PVOID)ShimGetSidSubAuthority);
488     DetourDetach(&(PVOID &)fpEnumAdapters2, (PVOID)ShimEnumAdapters2);
489     DetourDetach(&(PVOID &)fpQueryAdapterInfo, (PVOID)ShimQueryAdapterInfo);
490     DetourDetach(&(PVOID &)REAL_CM_Get_Device_ID_List_SizeW, (PVOID)SHIM_CM_Get_Device_ID_List_SizeW);
491     DetourDetach(&(PVOID &)REAL_CM_Get_Device_ID_ListW, (PVOID)SHIM_CM_Get_Device_ID_ListW);
492     DetourDetach(&(PVOID &)REAL_CM_Locate_DevNodeW, (PVOID)SHIM_CM_Locate_DevNodeW);
493     DetourDetach(&(PVOID &)REAL_CM_Get_DevNode_Status, (PVOID)SHIM_CM_Get_DevNode_Status);
494     DetourDetach(&(PVOID &)REAL_CM_Get_Device_IDW, (PVOID)SHIM_CM_Get_Device_IDW);
495     DetourDetach(&(PVOID &)REAL_CM_Get_Child, (PVOID)SHIM_CM_Get_Child);
496     DetourDetach(&(PVOID &)REAL_CM_Get_DevNode_Registry_PropertyW, (PVOID)SHIM_CM_Get_DevNode_Registry_PropertyW);
497     DetourDetach(&(PVOID &)REAL_CM_Get_Sibling, (PVOID)SHIM_CM_Get_Sibling);
498     DetourDetach(&(PVOID &)RealCreateDXGIFactory1, (PVOID)ShimCreateDXGIFactory1);
499     DetourDetach(&(PVOID &)fpRegOpenKeyExA, (PVOID)ShimRegOpenKeyExA);
500     DetourDetach(&(PVOID &)fpRegQueryValueExA, (PVOID)ShimRegQueryValueExA);
501     DetourDetach(&(PVOID &)fpRegEnumValueA, (PVOID)ShimRegEnumValueA);
502     DetourDetach(&(PVOID &)fpRegCloseKey, (PVOID)ShimRegCloseKey);
503     DetourDetach(&(PVOID &)fpGetPackagesByPackageFamily, (PVOID)ShimGetPackagesByPackageFamily);
504     DetourDetach(&(PVOID &)fpGetPackagePathByFullName, (PVOID)ShimGetPackagePathByFullName);
505     DetourTransactionCommit();
506 }
507 
DllMain(HINSTANCE hinst,DWORD dwReason,LPVOID reserved)508 BOOL WINAPI DllMain(HINSTANCE hinst, DWORD dwReason, LPVOID reserved) {
509     if (DetourIsHelperProcess()) {
510         return TRUE;
511     }
512 
513     if (dwReason == DLL_PROCESS_ATTACH) {
514         DetourFunctions();
515     } else if (dwReason == DLL_PROCESS_DETACH) {
516         DetachFunctions();
517     }
518     return TRUE;
519 }
get_platform_shim(std::vector<fs::FolderManager> * folders)520 FRAMEWORK_EXPORT PlatformShim *get_platform_shim(std::vector<fs::FolderManager> *folders) {
521     platform_shim = PlatformShim(folders);
522     return &platform_shim;
523 }
524 }
525