1 // Copyright 2017 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "base/win/com_init_check_hook.h"
6
7 #include <objbase.h>
8 #include <shlobj.h>
9 #include <wrl/client.h>
10
11 #include "base/test/gtest_util.h"
12 #include "base/win/com_init_util.h"
13 #include "base/win/patch_util.h"
14 #include "base/win/scoped_com_initializer.h"
15 #include "testing/gtest/include/gtest/gtest.h"
16
17 namespace base {
18 namespace win {
19
20 using Microsoft::WRL::ComPtr;
21
TEST(ComInitCheckHook,AssertNotInitialized)22 TEST(ComInitCheckHook, AssertNotInitialized) {
23 ComInitCheckHook com_check_hook;
24 AssertComApartmentType(ComApartmentType::NONE);
25 ComPtr<IUnknown> shell_link;
26 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
27 EXPECT_DCHECK_DEATH(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
28 IID_PPV_ARGS(&shell_link)));
29 #else
30 EXPECT_EQ(CO_E_NOTINITIALIZED,
31 ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
32 IID_PPV_ARGS(&shell_link)));
33 #endif
34 }
35
TEST(ComInitCheckHook,HookRemoval)36 TEST(ComInitCheckHook, HookRemoval) {
37 AssertComApartmentType(ComApartmentType::NONE);
38 { ComInitCheckHook com_check_hook; }
39 ComPtr<IUnknown> shell_link;
40 EXPECT_EQ(CO_E_NOTINITIALIZED,
41 ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
42 IID_PPV_ARGS(&shell_link)));
43 }
44
TEST(ComInitCheckHook,NoAssertComInitialized)45 TEST(ComInitCheckHook, NoAssertComInitialized) {
46 ComInitCheckHook com_check_hook;
47 ScopedCOMInitializer com_initializer;
48 ComPtr<IUnknown> shell_link;
49 EXPECT_TRUE(SUCCEEDED(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
50 IID_PPV_ARGS(&shell_link))));
51 }
52
TEST(ComInitCheckHook,MultipleHooks)53 TEST(ComInitCheckHook, MultipleHooks) {
54 ComInitCheckHook com_check_hook_1;
55 ComInitCheckHook com_check_hook_2;
56 AssertComApartmentType(ComApartmentType::NONE);
57 ComPtr<IUnknown> shell_link;
58 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
59 EXPECT_DCHECK_DEATH(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
60 IID_PPV_ARGS(&shell_link)));
61 #else
62 EXPECT_EQ(CO_E_NOTINITIALIZED,
63 ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
64 IID_PPV_ARGS(&shell_link)));
65 #endif
66 }
67
TEST(ComInitCheckHook,UnexpectedHook)68 TEST(ComInitCheckHook, UnexpectedHook) {
69 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
70 HMODULE ole32_library = ::LoadLibrary(L"ole32.dll");
71 ASSERT_TRUE(ole32_library);
72
73 uint32_t co_create_instance_padded_address =
74 reinterpret_cast<uint32_t>(
75 GetProcAddress(ole32_library, "CoCreateInstance")) -
76 5;
77 const unsigned char* co_create_instance_bytes =
78 reinterpret_cast<const unsigned char*>(co_create_instance_padded_address);
79 const unsigned char original_byte = co_create_instance_bytes[0];
80 const unsigned char unexpected_byte = 0xdb;
81 ASSERT_EQ(static_cast<DWORD>(NO_ERROR),
82 internal::ModifyCode(
83 reinterpret_cast<void*>(co_create_instance_padded_address),
84 reinterpret_cast<const void*>(&unexpected_byte),
85 sizeof(unexpected_byte)));
86
87 EXPECT_DCHECK_DEATH({ ComInitCheckHook com_check_hook; });
88
89 // If this call fails, really bad things are going to happen to other tests
90 // so CHECK here.
91 CHECK_EQ(static_cast<DWORD>(NO_ERROR),
92 internal::ModifyCode(
93 reinterpret_cast<void*>(co_create_instance_padded_address),
94 reinterpret_cast<const void*>(&original_byte),
95 sizeof(original_byte)));
96
97 ::FreeLibrary(ole32_library);
98 ole32_library = nullptr;
99 #endif
100 }
101
TEST(ComInitCheckHook,ExternallyHooked)102 TEST(ComInitCheckHook, ExternallyHooked) {
103 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
104 HMODULE ole32_library = ::LoadLibrary(L"ole32.dll");
105 ASSERT_TRUE(ole32_library);
106
107 uint32_t co_create_instance_address = reinterpret_cast<uint32_t>(
108 GetProcAddress(ole32_library, "CoCreateInstance"));
109 const unsigned char* co_create_instance_bytes =
110 reinterpret_cast<const unsigned char*>(co_create_instance_address);
111 const unsigned char original_byte = co_create_instance_bytes[0];
112 const unsigned char jmp_byte = 0xe9;
113 ASSERT_EQ(static_cast<DWORD>(NO_ERROR),
114 internal::ModifyCode(
115 reinterpret_cast<void*>(co_create_instance_address),
116 reinterpret_cast<const void*>(&jmp_byte), sizeof(jmp_byte)));
117
118 // Externally patched instances should crash so we catch these cases on bots.
119 EXPECT_DCHECK_DEATH({ ComInitCheckHook com_check_hook; });
120
121 // If this call fails, really bad things are going to happen to other tests
122 // so CHECK here.
123 CHECK_EQ(
124 static_cast<DWORD>(NO_ERROR),
125 internal::ModifyCode(reinterpret_cast<void*>(co_create_instance_address),
126 reinterpret_cast<const void*>(&original_byte),
127 sizeof(original_byte)));
128
129 ::FreeLibrary(ole32_library);
130 ole32_library = nullptr;
131 #endif
132 }
133
TEST(ComInitCheckHook,UnexpectedChangeDuringHook)134 TEST(ComInitCheckHook, UnexpectedChangeDuringHook) {
135 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
136 HMODULE ole32_library = ::LoadLibrary(L"ole32.dll");
137 ASSERT_TRUE(ole32_library);
138
139 uint32_t co_create_instance_padded_address =
140 reinterpret_cast<uint32_t>(
141 GetProcAddress(ole32_library, "CoCreateInstance")) -
142 5;
143 const unsigned char* co_create_instance_bytes =
144 reinterpret_cast<const unsigned char*>(co_create_instance_padded_address);
145 const unsigned char original_byte = co_create_instance_bytes[0];
146 const unsigned char unexpected_byte = 0xdb;
147 ASSERT_EQ(static_cast<DWORD>(NO_ERROR),
148 internal::ModifyCode(
149 reinterpret_cast<void*>(co_create_instance_padded_address),
150 reinterpret_cast<const void*>(&unexpected_byte),
151 sizeof(unexpected_byte)));
152
153 EXPECT_DCHECK_DEATH({
154 ComInitCheckHook com_check_hook;
155
156 internal::ModifyCode(
157 reinterpret_cast<void*>(co_create_instance_padded_address),
158 reinterpret_cast<const void*>(&unexpected_byte),
159 sizeof(unexpected_byte));
160 });
161
162 // If this call fails, really bad things are going to happen to other tests
163 // so CHECK here.
164 CHECK_EQ(static_cast<DWORD>(NO_ERROR),
165 internal::ModifyCode(
166 reinterpret_cast<void*>(co_create_instance_padded_address),
167 reinterpret_cast<const void*>(&original_byte),
168 sizeof(original_byte)));
169
170 ::FreeLibrary(ole32_library);
171 ole32_library = nullptr;
172 #endif
173 }
174
175 } // namespace win
176 } // namespace base
177