• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // system_utils_win32.cpp: Implementation of OS-specific functions for Windows.
7 
8 #include "common/FastVector.h"
9 #include "system_utils.h"
10 
11 #include <array>
12 
13 // Must be included in this order.
14 // clang-format off
15 #include <windows.h>
16 #include <psapi.h>
17 // clang-format on
18 
19 namespace angle
20 {
UnsetEnvironmentVar(const char * variableName)21 bool UnsetEnvironmentVar(const char *variableName)
22 {
23     return (SetEnvironmentVariableW(Widen(variableName).c_str(), nullptr) == TRUE);
24 }
25 
SetEnvironmentVar(const char * variableName,const char * value)26 bool SetEnvironmentVar(const char *variableName, const char *value)
27 {
28     return (SetEnvironmentVariableW(Widen(variableName).c_str(), Widen(value).c_str()) == TRUE);
29 }
30 
GetEnvironmentVar(const char * variableName)31 std::string GetEnvironmentVar(const char *variableName)
32 {
33     std::wstring variableNameUtf16 = Widen(variableName);
34     FastVector<wchar_t, MAX_PATH> value;
35 
36     DWORD result;
37 
38     // First get the length of the variable, including the null terminator
39     result = GetEnvironmentVariableW(variableNameUtf16.c_str(), nullptr, 0);
40 
41     // Zero means the variable was not found, so return now.
42     if (result == 0)
43     {
44         return std::string();
45     }
46 
47     // Now size the vector to fit the data, and read the environment variable.
48     value.resize(result, 0);
49     result = GetEnvironmentVariableW(variableNameUtf16.c_str(), value.data(), result);
50 
51     return Narrow(value.data());
52 }
53 
OpenSystemLibraryWithExtensionAndGetError(const char * libraryName,SearchType searchType,std::string * errorOut)54 void *OpenSystemLibraryWithExtensionAndGetError(const char *libraryName,
55                                                 SearchType searchType,
56                                                 std::string *errorOut)
57 {
58     char buffer[MAX_PATH];
59     int ret = snprintf(buffer, MAX_PATH, "%s.%s", libraryName, GetSharedLibraryExtension());
60     if (ret <= 0 || ret >= MAX_PATH)
61     {
62         fprintf(stderr, "Error loading shared library: 0x%x", ret);
63         return nullptr;
64     }
65 
66     HMODULE libraryModule = nullptr;
67 
68     switch (searchType)
69     {
70         case SearchType::ModuleDir:
71         {
72             std::string moduleRelativePath = ConcatenatePath(GetModuleDirectory(), libraryName);
73             if (errorOut)
74             {
75                 *errorOut = moduleRelativePath;
76             }
77             libraryModule = LoadLibraryW(Widen(moduleRelativePath).c_str());
78             break;
79         }
80 
81         case SearchType::SystemDir:
82         {
83             if (errorOut)
84             {
85                 *errorOut = libraryName;
86             }
87             libraryModule =
88                 LoadLibraryExW(Widen(libraryName).c_str(), nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32);
89             break;
90         }
91 
92         case SearchType::AlreadyLoaded:
93         {
94             if (errorOut)
95             {
96                 *errorOut = libraryName;
97             }
98             libraryModule = GetModuleHandleW(Widen(libraryName).c_str());
99             break;
100         }
101     }
102 
103     return reinterpret_cast<void *>(libraryModule);
104 }
105 
106 namespace
107 {
108 class Win32PageFaultHandler : public PageFaultHandler
109 {
110   public:
Win32PageFaultHandler(PageFaultCallback callback)111     Win32PageFaultHandler(PageFaultCallback callback) : PageFaultHandler(callback) {}
~Win32PageFaultHandler()112     ~Win32PageFaultHandler() override {}
113 
114     bool enable() override;
115     bool disable() override;
116 
117     LONG handle(PEXCEPTION_POINTERS pExceptionInfo);
118 
119   private:
120     void *mVectoredExceptionHandler = nullptr;
121 };
122 
123 Win32PageFaultHandler *gWin32PageFaultHandler = nullptr;
VectoredExceptionHandler(PEXCEPTION_POINTERS info)124 static LONG CALLBACK VectoredExceptionHandler(PEXCEPTION_POINTERS info)
125 {
126     return gWin32PageFaultHandler->handle(info);
127 }
128 
SetMemoryProtection(uintptr_t start,size_t size,DWORD protections)129 bool SetMemoryProtection(uintptr_t start, size_t size, DWORD protections)
130 {
131     DWORD oldProtect;
132     BOOL res = VirtualProtect(reinterpret_cast<LPVOID>(start), size, protections, &oldProtect);
133     if (!res)
134     {
135         DWORD lastError = GetLastError();
136         fprintf(stderr, "VirtualProtect failed: 0x%lx\n", lastError);
137         return false;
138     }
139 
140     return true;
141 }
142 
handle(PEXCEPTION_POINTERS info)143 LONG Win32PageFaultHandler::handle(PEXCEPTION_POINTERS info)
144 {
145     bool found = false;
146 
147     if (info->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION &&
148         info->ExceptionRecord->NumberParameters >= 2 &&
149         info->ExceptionRecord->ExceptionInformation[0] == 1)
150     {
151         found = mCallback(static_cast<uintptr_t>(info->ExceptionRecord->ExceptionInformation[1])) ==
152                 PageFaultHandlerRangeType::InRange;
153     }
154 
155     if (found)
156     {
157         return EXCEPTION_CONTINUE_EXECUTION;
158     }
159     else
160     {
161         return EXCEPTION_CONTINUE_SEARCH;
162     }
163 }
164 
disable()165 bool Win32PageFaultHandler::disable()
166 {
167     if (mVectoredExceptionHandler)
168     {
169         ULONG res                 = RemoveVectoredExceptionHandler(mVectoredExceptionHandler);
170         mVectoredExceptionHandler = nullptr;
171         if (res == 0)
172         {
173             DWORD lastError = GetLastError();
174             fprintf(stderr, "RemoveVectoredExceptionHandler failed: 0x%lx\n", lastError);
175             return false;
176         }
177     }
178     return true;
179 }
180 
enable()181 bool Win32PageFaultHandler::enable()
182 {
183     if (mVectoredExceptionHandler)
184     {
185         return true;
186     }
187 
188     PVECTORED_EXCEPTION_HANDLER handler =
189         reinterpret_cast<PVECTORED_EXCEPTION_HANDLER>(&VectoredExceptionHandler);
190 
191     mVectoredExceptionHandler = AddVectoredExceptionHandler(1, handler);
192 
193     if (!mVectoredExceptionHandler)
194     {
195         DWORD lastError = GetLastError();
196         fprintf(stderr, "AddVectoredExceptionHandler failed: 0x%lx\n", lastError);
197         return false;
198     }
199     return true;
200 }
201 }  // namespace
202 
203 // Set write protection
ProtectMemory(uintptr_t start,size_t size)204 bool ProtectMemory(uintptr_t start, size_t size)
205 {
206     return SetMemoryProtection(start, size, PAGE_READONLY);
207 }
208 
209 // Allow reading and writing
UnprotectMemory(uintptr_t start,size_t size)210 bool UnprotectMemory(uintptr_t start, size_t size)
211 {
212     return SetMemoryProtection(start, size, PAGE_READWRITE);
213 }
214 
GetPageSize()215 size_t GetPageSize()
216 {
217     SYSTEM_INFO info;
218     GetSystemInfo(&info);
219     return static_cast<size_t>(info.dwPageSize);
220 }
221 
CreatePageFaultHandler(PageFaultCallback callback)222 PageFaultHandler *CreatePageFaultHandler(PageFaultCallback callback)
223 {
224     gWin32PageFaultHandler = new Win32PageFaultHandler(callback);
225     return gWin32PageFaultHandler;
226 }
227 
GetProcessMemoryUsageKB()228 uint64_t GetProcessMemoryUsageKB()
229 {
230     PROCESS_MEMORY_COUNTERS_EX pmc;
231     ::GetProcessMemoryInfo(::GetCurrentProcess(), reinterpret_cast<PROCESS_MEMORY_COUNTERS *>(&pmc),
232                            sizeof(pmc));
233     return static_cast<uint64_t>(pmc.PrivateUsage) / 1024ull;
234 }
235 }  // namespace angle
236