1 /**************************************************************************
2 *
3 * Copyright 2009-2013 VMware, Inc.
4 * All Rights Reserved.
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a
7 * copy of this software and associated documentation files (the
8 * "Software"), to deal in the Software without restriction, including
9 * without limitation the rights to use, copy, modify, merge, publish,
10 * distribute, sub license, and/or sell copies of the Software, and to
11 * permit persons to whom the Software is furnished to do so, subject to
12 * the following conditions:
13 *
14 * The above copyright notice and this permission notice (including the
15 * next paragraph) shall be included in all copies or substantial portions
16 * of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
19 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
21 * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
22 * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25 *
26 **************************************************************************/
27
28 #include <windows.h>
29 #include <tlhelp32.h>
30
31 #include "pipe/p_compiler.h"
32 #include "util/u_debug.h"
33 #include "stw_tls.h"
34
35 static DWORD tlsIndex = TLS_OUT_OF_INDEXES;
36
37
38 /**
39 * Static mutex to protect the access to g_pendingTlsData global and
40 * stw_tls_data::next member.
41 */
42 static CRITICAL_SECTION g_mutex = {
43 (PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0
44 };
45
46 /**
47 * There is no way to invoke TlsSetValue for a different thread, so we
48 * temporarily put the thread data for non-current threads here.
49 */
50 static struct stw_tls_data *g_pendingTlsData = NULL;
51
52
53 static struct stw_tls_data *
54 stw_tls_data_create(DWORD dwThreadId);
55
56 static struct stw_tls_data *
57 stw_tls_lookup_pending_data(DWORD dwThreadId);
58
59
60 boolean
stw_tls_init(void)61 stw_tls_init(void)
62 {
63 tlsIndex = TlsAlloc();
64 if (tlsIndex == TLS_OUT_OF_INDEXES) {
65 return FALSE;
66 }
67
68 /*
69 * DllMain is called with DLL_THREAD_ATTACH only for threads created after
70 * the DLL is loaded by the process. So enumerate and add our hook to all
71 * previously existing threads.
72 *
73 * XXX: Except for the current thread since it there is an explicit
74 * stw_tls_init_thread() call for it later on.
75 */
76 if (1) {
77 DWORD dwCurrentProcessId = GetCurrentProcessId();
78 DWORD dwCurrentThreadId = GetCurrentThreadId();
79 HANDLE hSnapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, dwCurrentProcessId);
80 if (hSnapshot != INVALID_HANDLE_VALUE) {
81 THREADENTRY32 te;
82 te.dwSize = sizeof te;
83 if (Thread32First(hSnapshot, &te)) {
84 do {
85 if (te.dwSize >= FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID) +
86 sizeof te.th32OwnerProcessID) {
87 if (te.th32OwnerProcessID == dwCurrentProcessId) {
88 if (te.th32ThreadID != dwCurrentThreadId) {
89 struct stw_tls_data *data;
90 data = stw_tls_data_create(te.th32ThreadID);
91 if (data) {
92 EnterCriticalSection(&g_mutex);
93 data->next = g_pendingTlsData;
94 g_pendingTlsData = data;
95 LeaveCriticalSection(&g_mutex);
96 }
97 }
98 }
99 }
100 te.dwSize = sizeof te;
101 } while (Thread32Next(hSnapshot, &te));
102 }
103 CloseHandle(hSnapshot);
104 }
105 }
106
107 return TRUE;
108 }
109
110
111 /**
112 * Install windows hook for a given thread (not necessarily the current one).
113 */
114 static struct stw_tls_data *
stw_tls_data_create(DWORD dwThreadId)115 stw_tls_data_create(DWORD dwThreadId)
116 {
117 struct stw_tls_data *data;
118
119 if (0) {
120 debug_printf("%s(0x%04lx)\n", __FUNCTION__, dwThreadId);
121 }
122
123 data = calloc(1, sizeof *data);
124 if (!data) {
125 goto no_data;
126 }
127
128 data->dwThreadId = dwThreadId;
129
130 data->hCallWndProcHook = SetWindowsHookEx(WH_CALLWNDPROC,
131 stw_call_window_proc,
132 NULL,
133 dwThreadId);
134 if (data->hCallWndProcHook == NULL) {
135 goto no_hook;
136 }
137
138 return data;
139
140 no_hook:
141 free(data);
142 no_data:
143 return NULL;
144 }
145
146 /**
147 * Destroy the per-thread data/hook.
148 *
149 * It is important to remove all hooks when unloading our DLL, otherwise our
150 * hook function might be called after it is no longer there.
151 */
152 static void
stw_tls_data_destroy(struct stw_tls_data * data)153 stw_tls_data_destroy(struct stw_tls_data *data)
154 {
155 assert(data);
156 if (!data) {
157 return;
158 }
159
160 if (0) {
161 debug_printf("%s(0x%04lx)\n", __FUNCTION__, data->dwThreadId);
162 }
163
164 if (data->hCallWndProcHook) {
165 UnhookWindowsHookEx(data->hCallWndProcHook);
166 data->hCallWndProcHook = NULL;
167 }
168
169 free(data);
170 }
171
172 boolean
stw_tls_init_thread(void)173 stw_tls_init_thread(void)
174 {
175 struct stw_tls_data *data;
176
177 if (tlsIndex == TLS_OUT_OF_INDEXES) {
178 return FALSE;
179 }
180
181 data = stw_tls_data_create(GetCurrentThreadId());
182 if (!data) {
183 return FALSE;
184 }
185
186 TlsSetValue(tlsIndex, data);
187
188 return TRUE;
189 }
190
191 void
stw_tls_cleanup_thread(void)192 stw_tls_cleanup_thread(void)
193 {
194 struct stw_tls_data *data;
195
196 if (tlsIndex == TLS_OUT_OF_INDEXES) {
197 return;
198 }
199
200 data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
201 if (data) {
202 TlsSetValue(tlsIndex, NULL);
203 } else {
204 /* See if there this thread's data in on the pending list */
205 data = stw_tls_lookup_pending_data(GetCurrentThreadId());
206 }
207
208 if (data) {
209 stw_tls_data_destroy(data);
210 }
211 }
212
213 void
stw_tls_cleanup(void)214 stw_tls_cleanup(void)
215 {
216 if (tlsIndex != TLS_OUT_OF_INDEXES) {
217 /*
218 * Destroy all items in g_pendingTlsData linked list.
219 */
220 EnterCriticalSection(&g_mutex);
221 while (g_pendingTlsData) {
222 struct stw_tls_data * data = g_pendingTlsData;
223 g_pendingTlsData = data->next;
224 stw_tls_data_destroy(data);
225 }
226 LeaveCriticalSection(&g_mutex);
227
228 TlsFree(tlsIndex);
229 tlsIndex = TLS_OUT_OF_INDEXES;
230 }
231 }
232
233 /*
234 * Search for the current thread in the g_pendingTlsData linked list.
235 *
236 * It will remove and return the node on success, or return NULL on failure.
237 */
238 static struct stw_tls_data *
stw_tls_lookup_pending_data(DWORD dwThreadId)239 stw_tls_lookup_pending_data(DWORD dwThreadId)
240 {
241 struct stw_tls_data ** p_data;
242 struct stw_tls_data *data = NULL;
243
244 EnterCriticalSection(&g_mutex);
245 for (p_data = &g_pendingTlsData; *p_data; p_data = &(*p_data)->next) {
246 if ((*p_data)->dwThreadId == dwThreadId) {
247 data = *p_data;
248
249 /*
250 * Unlink the node.
251 */
252 *p_data = data->next;
253 data->next = NULL;
254
255 break;
256 }
257 }
258 LeaveCriticalSection(&g_mutex);
259
260 return data;
261 }
262
263 struct stw_tls_data *
stw_tls_get_data(void)264 stw_tls_get_data(void)
265 {
266 struct stw_tls_data *data;
267
268 if (tlsIndex == TLS_OUT_OF_INDEXES) {
269 return NULL;
270 }
271
272 data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
273 if (!data) {
274 DWORD dwCurrentThreadId = GetCurrentThreadId();
275
276 /*
277 * Search for the current thread in the g_pendingTlsData linked list.
278 */
279 data = stw_tls_lookup_pending_data(dwCurrentThreadId);
280
281 if (!data) {
282 /*
283 * This should be impossible now.
284 */
285 assert(!"Failed to find thread data for thread id");
286
287 /*
288 * DllMain is called with DLL_THREAD_ATTACH only by threads created
289 * after the DLL is loaded by the process
290 */
291 data = stw_tls_data_create(dwCurrentThreadId);
292 if (!data) {
293 return NULL;
294 }
295 }
296
297 TlsSetValue(tlsIndex, data);
298 }
299
300 assert(data);
301 assert(data->dwThreadId = GetCurrentThreadId());
302 assert(data->next == NULL);
303
304 return data;
305 }
306