1 // Copyright 2011 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "base/win/iat_patch_function.h"
6
7 #include "base/check_op.h"
8 #include "base/memory/raw_ptr_exclusion.h"
9 #include "base/notreached.h"
10 #include "base/win/patch_util.h"
11 #include "base/win/pe_image.h"
12
13 namespace base {
14 namespace win {
15
16 namespace {
17
18 struct InterceptFunctionInformation {
19 bool finished_operation;
20 const char* imported_from_module;
21 const char* function_name;
22 // RAW_PTR_EXCLUSION: #reinterpret-cast-trivial-type
23 RAW_PTR_EXCLUSION void* new_function;
24 RAW_PTR_EXCLUSION void** old_function;
25 RAW_PTR_EXCLUSION IMAGE_THUNK_DATA** iat_thunk;
26 DWORD return_code;
27 };
28
GetIATFunction(IMAGE_THUNK_DATA * iat_thunk)29 void* GetIATFunction(IMAGE_THUNK_DATA* iat_thunk) {
30 if (!iat_thunk) {
31 NOTREACHED();
32 }
33
34 // Works around the 64 bit portability warning:
35 // The Function member inside IMAGE_THUNK_DATA is really a pointer
36 // to the IAT function. IMAGE_THUNK_DATA correctly maps to IMAGE_THUNK_DATA32
37 // or IMAGE_THUNK_DATA64 for correct pointer size.
38 union FunctionThunk {
39 IMAGE_THUNK_DATA thunk;
40 // This field is not a raw_ptr<> because it was filtered by the rewriter
41 // for: #union
42 RAW_PTR_EXCLUSION void* pointer;
43 } iat_function;
44
45 iat_function.thunk = *iat_thunk;
46 return iat_function.pointer;
47 }
48
InterceptEnumCallback(const base::win::PEImage & image,const char * module,DWORD ordinal,const char * name,DWORD hint,IMAGE_THUNK_DATA * iat,void * cookie)49 bool InterceptEnumCallback(const base::win::PEImage& image,
50 const char* module,
51 DWORD ordinal,
52 const char* name,
53 DWORD hint,
54 IMAGE_THUNK_DATA* iat,
55 void* cookie) {
56 InterceptFunctionInformation* intercept_information =
57 reinterpret_cast<InterceptFunctionInformation*>(cookie);
58
59 if (!intercept_information) {
60 NOTREACHED();
61 }
62
63 DCHECK(module);
64
65 if (name && (0 == lstrcmpiA(name, intercept_information->function_name))) {
66 // Save the old pointer.
67 if (intercept_information->old_function) {
68 *(intercept_information->old_function) = GetIATFunction(iat);
69 }
70
71 if (intercept_information->iat_thunk) {
72 *(intercept_information->iat_thunk) = iat;
73 }
74
75 // portability check
76 static_assert(
77 sizeof(iat->u1.Function) == sizeof(intercept_information->new_function),
78 "unknown IAT thunk format");
79
80 // Patch the function.
81 intercept_information->return_code = internal::ModifyCode(
82 &(iat->u1.Function), &(intercept_information->new_function),
83 sizeof(intercept_information->new_function));
84
85 // Terminate further enumeration.
86 intercept_information->finished_operation = true;
87 return false;
88 }
89
90 return true;
91 }
92
93 // Helper to intercept a function in an import table of a specific
94 // module.
95 //
96 // Arguments:
97 // module_handle Module to be intercepted
98 // imported_from_module Module that exports the symbol
99 // function_name Name of the API to be intercepted
100 // new_function Interceptor function
101 // old_function Receives the original function pointer
102 // iat_thunk Receives pointer to IAT_THUNK_DATA
103 // for the API from the import table.
104 //
105 // Returns: Returns NO_ERROR on success or Windows error code
106 // as defined in winerror.h
InterceptImportedFunction(HMODULE module_handle,const char * imported_from_module,const char * function_name,void * new_function,void ** old_function,IMAGE_THUNK_DATA ** iat_thunk)107 DWORD InterceptImportedFunction(HMODULE module_handle,
108 const char* imported_from_module,
109 const char* function_name,
110 void* new_function,
111 void** old_function,
112 IMAGE_THUNK_DATA** iat_thunk) {
113 if (!module_handle || !imported_from_module || !function_name ||
114 !new_function) {
115 NOTREACHED();
116 }
117
118 base::win::PEImage target_image(module_handle);
119 if (!target_image.VerifyMagic()) {
120 NOTREACHED();
121 }
122
123 InterceptFunctionInformation intercept_information = {false,
124 imported_from_module,
125 function_name,
126 new_function,
127 old_function,
128 iat_thunk,
129 ERROR_GEN_FAILURE};
130
131 // First go through the IAT. If we don't find the import we are looking
132 // for in IAT, search delay import table.
133 target_image.EnumAllImports(InterceptEnumCallback, &intercept_information,
134 imported_from_module);
135 if (!intercept_information.finished_operation) {
136 target_image.EnumAllDelayImports(
137 InterceptEnumCallback, &intercept_information, imported_from_module);
138 }
139
140 return intercept_information.return_code;
141 }
142
143 // Restore intercepted IAT entry with the original function.
144 //
145 // Arguments:
146 // intercept_function Interceptor function
147 // original_function Receives the original function pointer
148 //
149 // Returns: Returns NO_ERROR on success or Windows error code
150 // as defined in winerror.h
RestoreImportedFunction(void * intercept_function,void * original_function,IMAGE_THUNK_DATA * iat_thunk)151 DWORD RestoreImportedFunction(void* intercept_function,
152 void* original_function,
153 IMAGE_THUNK_DATA* iat_thunk) {
154 if (!intercept_function || !original_function || !iat_thunk) {
155 NOTREACHED();
156 }
157
158 if (GetIATFunction(iat_thunk) != intercept_function) {
159 // Check if someone else has intercepted on top of us.
160 // We cannot unpatch in this case, just raise a red flag.
161 NOTREACHED();
162 }
163
164 return internal::ModifyCode(&(iat_thunk->u1.Function), &original_function,
165 sizeof(original_function));
166 }
167
168 } // namespace
169
170 IATPatchFunction::IATPatchFunction() = default;
171
~IATPatchFunction()172 IATPatchFunction::~IATPatchFunction() {
173 if (intercept_function_) {
174 DWORD error = Unpatch();
175 DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
176 }
177 }
178
Patch(const wchar_t * module,const char * imported_from_module,const char * function_name,void * new_function)179 DWORD IATPatchFunction::Patch(const wchar_t* module,
180 const char* imported_from_module,
181 const char* function_name,
182 void* new_function) {
183 HMODULE module_handle = LoadLibraryW(module);
184 if (!module_handle) {
185 NOTREACHED();
186 }
187
188 DWORD error = PatchFromModule(module_handle, imported_from_module,
189 function_name, new_function);
190 if (NO_ERROR == error) {
191 module_handle_ = module_handle;
192 } else {
193 FreeLibrary(module_handle);
194 }
195
196 return error;
197 }
198
PatchFromModule(HMODULE module,const char * imported_from_module,const char * function_name,void * new_function)199 DWORD IATPatchFunction::PatchFromModule(HMODULE module,
200 const char* imported_from_module,
201 const char* function_name,
202 void* new_function) {
203 DCHECK_EQ(nullptr, original_function_);
204 DCHECK_EQ(nullptr, iat_thunk_);
205 DCHECK_EQ(nullptr, intercept_function_);
206 DCHECK(module);
207
208 DWORD error = InterceptImportedFunction(
209 module, imported_from_module, function_name, new_function,
210 &original_function_.AsEphemeralRawAddr(),
211 &iat_thunk_.AsEphemeralRawAddr());
212
213 if (NO_ERROR == error) {
214 DCHECK_NE(original_function_, intercept_function_);
215 intercept_function_ = new_function;
216 }
217
218 return error;
219 }
220
Unpatch()221 DWORD IATPatchFunction::Unpatch() {
222 DWORD error = RestoreImportedFunction(intercept_function_, original_function_,
223 iat_thunk_);
224 DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
225
226 // Hands off the intercept if we fail to unpatch.
227 // If IATPatchFunction::Unpatch fails during RestoreImportedFunction
228 // it means that we cannot safely unpatch the import address table
229 // patch. In this case its better to be hands off the intercept as
230 // trying to unpatch again in the destructor of IATPatchFunction is
231 // not going to be any safer
232 if (module_handle_)
233 FreeLibrary(module_handle_);
234 module_handle_ = nullptr;
235 intercept_function_ = nullptr;
236 original_function_ = nullptr;
237 iat_thunk_ = nullptr;
238
239 return error;
240 }
241
original_function() const242 void* IATPatchFunction::original_function() const {
243 DCHECK(is_patched());
244 return original_function_;
245 }
246
247 } // namespace win
248 } // namespace base
249