1 // Copyright 2011 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/win/iat_patch_function.h"
6 
7 #include "base/check_op.h"
8 #include "base/memory/raw_ptr_exclusion.h"
9 #include "base/notreached.h"
10 #include "base/win/patch_util.h"
11 #include "base/win/pe_image.h"
12 
13 namespace base {
14 namespace win {
15 
16 namespace {
17 
18 struct InterceptFunctionInformation {
19   bool finished_operation;
20   const char* imported_from_module;
21   const char* function_name;
22   // This field is not a raw_ptr<> because it was filtered by the rewriter for:
23   // #reinterpret-cast-trivial-type
24   RAW_PTR_EXCLUSION void* new_function;
25   // This field is not a raw_ptr<> because it was filtered by the rewriter for:
26   // #reinterpret-cast-trivial-type
27   RAW_PTR_EXCLUSION void** old_function;
28   // This field is not a raw_ptr<> because it was filtered by the rewriter for:
29   // #reinterpret-cast-trivial-type
30   RAW_PTR_EXCLUSION IMAGE_THUNK_DATA** iat_thunk;
31   DWORD return_code;
32 };
33 
GetIATFunction(IMAGE_THUNK_DATA * iat_thunk)34 void* GetIATFunction(IMAGE_THUNK_DATA* iat_thunk) {
35   if (!iat_thunk) {
36     NOTREACHED();
37     return nullptr;
38   }
39 
40   // Works around the 64 bit portability warning:
41   // The Function member inside IMAGE_THUNK_DATA is really a pointer
42   // to the IAT function. IMAGE_THUNK_DATA correctly maps to IMAGE_THUNK_DATA32
43   // or IMAGE_THUNK_DATA64 for correct pointer size.
44   union FunctionThunk {
45     IMAGE_THUNK_DATA thunk;
46     // This field is not a raw_ptr<> because it was filtered by the rewriter
47     // for: #union
48     RAW_PTR_EXCLUSION void* pointer;
49   } iat_function;
50 
51   iat_function.thunk = *iat_thunk;
52   return iat_function.pointer;
53 }
54 
InterceptEnumCallback(const base::win::PEImage & image,const char * module,DWORD ordinal,const char * name,DWORD hint,IMAGE_THUNK_DATA * iat,void * cookie)55 bool InterceptEnumCallback(const base::win::PEImage& image,
56                            const char* module,
57                            DWORD ordinal,
58                            const char* name,
59                            DWORD hint,
60                            IMAGE_THUNK_DATA* iat,
61                            void* cookie) {
62   InterceptFunctionInformation* intercept_information =
63       reinterpret_cast<InterceptFunctionInformation*>(cookie);
64 
65   if (!intercept_information) {
66     NOTREACHED();
67     return false;
68   }
69 
70   DCHECK(module);
71 
72   if (name && (0 == lstrcmpiA(name, intercept_information->function_name))) {
73     // Save the old pointer.
74     if (intercept_information->old_function) {
75       *(intercept_information->old_function) = GetIATFunction(iat);
76     }
77 
78     if (intercept_information->iat_thunk) {
79       *(intercept_information->iat_thunk) = iat;
80     }
81 
82     // portability check
83     static_assert(
84         sizeof(iat->u1.Function) == sizeof(intercept_information->new_function),
85         "unknown IAT thunk format");
86 
87     // Patch the function.
88     intercept_information->return_code = internal::ModifyCode(
89         &(iat->u1.Function), &(intercept_information->new_function),
90         sizeof(intercept_information->new_function));
91 
92     // Terminate further enumeration.
93     intercept_information->finished_operation = true;
94     return false;
95   }
96 
97   return true;
98 }
99 
100 // Helper to intercept a function in an import table of a specific
101 // module.
102 //
103 // Arguments:
104 // module_handle          Module to be intercepted
105 // imported_from_module   Module that exports the symbol
106 // function_name          Name of the API to be intercepted
107 // new_function           Interceptor function
108 // old_function           Receives the original function pointer
109 // iat_thunk              Receives pointer to IAT_THUNK_DATA
110 //                        for the API from the import table.
111 //
112 // Returns: Returns NO_ERROR on success or Windows error code
113 //          as defined in winerror.h
InterceptImportedFunction(HMODULE module_handle,const char * imported_from_module,const char * function_name,void * new_function,void ** old_function,IMAGE_THUNK_DATA ** iat_thunk)114 DWORD InterceptImportedFunction(HMODULE module_handle,
115                                 const char* imported_from_module,
116                                 const char* function_name,
117                                 void* new_function,
118                                 void** old_function,
119                                 IMAGE_THUNK_DATA** iat_thunk) {
120   if (!module_handle || !imported_from_module || !function_name ||
121       !new_function) {
122     NOTREACHED();
123     return ERROR_INVALID_PARAMETER;
124   }
125 
126   base::win::PEImage target_image(module_handle);
127   if (!target_image.VerifyMagic()) {
128     NOTREACHED();
129     return ERROR_INVALID_PARAMETER;
130   }
131 
132   InterceptFunctionInformation intercept_information = {false,
133                                                         imported_from_module,
134                                                         function_name,
135                                                         new_function,
136                                                         old_function,
137                                                         iat_thunk,
138                                                         ERROR_GEN_FAILURE};
139 
140   // First go through the IAT. If we don't find the import we are looking
141   // for in IAT, search delay import table.
142   target_image.EnumAllImports(InterceptEnumCallback, &intercept_information,
143                               imported_from_module);
144   if (!intercept_information.finished_operation) {
145     target_image.EnumAllDelayImports(
146         InterceptEnumCallback, &intercept_information, imported_from_module);
147   }
148 
149   return intercept_information.return_code;
150 }
151 
152 // Restore intercepted IAT entry with the original function.
153 //
154 // Arguments:
155 // intercept_function     Interceptor function
156 // original_function      Receives the original function pointer
157 //
158 // Returns: Returns NO_ERROR on success or Windows error code
159 //          as defined in winerror.h
RestoreImportedFunction(void * intercept_function,void * original_function,IMAGE_THUNK_DATA * iat_thunk)160 DWORD RestoreImportedFunction(void* intercept_function,
161                               void* original_function,
162                               IMAGE_THUNK_DATA* iat_thunk) {
163   if (!intercept_function || !original_function || !iat_thunk) {
164     NOTREACHED();
165     return ERROR_INVALID_PARAMETER;
166   }
167 
168   if (GetIATFunction(iat_thunk) != intercept_function) {
169     // Check if someone else has intercepted on top of us.
170     // We cannot unpatch in this case, just raise a red flag.
171     NOTREACHED();
172     return ERROR_INVALID_FUNCTION;
173   }
174 
175   return internal::ModifyCode(&(iat_thunk->u1.Function), &original_function,
176                               sizeof(original_function));
177 }
178 
179 }  // namespace
180 
181 IATPatchFunction::IATPatchFunction() = default;
182 
~IATPatchFunction()183 IATPatchFunction::~IATPatchFunction() {
184   if (intercept_function_) {
185     DWORD error = Unpatch();
186     DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
187   }
188 }
189 
Patch(const wchar_t * module,const char * imported_from_module,const char * function_name,void * new_function)190 DWORD IATPatchFunction::Patch(const wchar_t* module,
191                               const char* imported_from_module,
192                               const char* function_name,
193                               void* new_function) {
194   HMODULE module_handle = LoadLibraryW(module);
195   if (!module_handle) {
196     NOTREACHED();
197     return GetLastError();
198   }
199 
200   DWORD error = PatchFromModule(module_handle, imported_from_module,
201                                 function_name, new_function);
202   if (NO_ERROR == error) {
203     module_handle_ = module_handle;
204   } else {
205     FreeLibrary(module_handle);
206   }
207 
208   return error;
209 }
210 
PatchFromModule(HMODULE module,const char * imported_from_module,const char * function_name,void * new_function)211 DWORD IATPatchFunction::PatchFromModule(HMODULE module,
212                                         const char* imported_from_module,
213                                         const char* function_name,
214                                         void* new_function) {
215   DCHECK_EQ(nullptr, original_function_);
216   DCHECK_EQ(nullptr, iat_thunk_);
217   DCHECK_EQ(nullptr, intercept_function_);
218   DCHECK(module);
219 
220   DWORD error =
221       InterceptImportedFunction(module, imported_from_module, function_name,
222                                 new_function, &original_function_, &iat_thunk_);
223 
224   if (NO_ERROR == error) {
225     DCHECK_NE(original_function_, intercept_function_);
226     intercept_function_ = new_function;
227   }
228 
229   return error;
230 }
231 
Unpatch()232 DWORD IATPatchFunction::Unpatch() {
233   DWORD error = RestoreImportedFunction(intercept_function_, original_function_,
234                                         iat_thunk_);
235   DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
236 
237   // Hands off the intercept if we fail to unpatch.
238   // If IATPatchFunction::Unpatch fails during RestoreImportedFunction
239   // it means that we cannot safely unpatch the import address table
240   // patch. In this case its better to be hands off the intercept as
241   // trying to unpatch again in the destructor of IATPatchFunction is
242   // not going to be any safer
243   if (module_handle_)
244     FreeLibrary(module_handle_);
245   module_handle_ = nullptr;
246   intercept_function_ = nullptr;
247   original_function_ = nullptr;
248   iat_thunk_ = nullptr;
249 
250   return error;
251 }
252 
original_function() const253 void* IATPatchFunction::original_function() const {
254   DCHECK(is_patched());
255   return original_function_;
256 }
257 
258 }  // namespace win
259 }  // namespace base
260