1 // Copyright 2017 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "base/win/com_init_util.h" 6 7 #include <stdint.h> 8 #include <windows.h> 9 #include <winternl.h> 10 11 #include "base/logging.h" 12 #include "base/notreached.h" 13 14 namespace base { 15 namespace win { 16 17 namespace { 18 19 #if DCHECK_IS_ON() 20 const char kComNotInitialized[] = "COM is not initialized on this thread."; 21 #endif // DCHECK_IS_ON() 22 23 // Derived from combase.dll. 24 struct OleTlsData { 25 enum ApartmentFlags { 26 LOGICAL_THREAD_REGISTERED = 0x2, 27 STA = 0x80, 28 MTA = 0x140, 29 }; 30 31 uintptr_t thread_base; 32 uintptr_t sm_allocator; 33 DWORD apartment_id; 34 DWORD apartment_flags; 35 // There are many more fields than this, but for our purposes, we only care 36 // about |apartment_flags|. Correctly declaring the previous types allows this 37 // to work between x86 and x64 builds. 38 }; 39 GetOleTlsData()40OleTlsData* GetOleTlsData() { 41 TEB* teb = NtCurrentTeb(); 42 return reinterpret_cast<OleTlsData*>(teb->ReservedForOle); 43 } 44 45 } // namespace 46 GetComApartmentTypeForThread()47ComApartmentType GetComApartmentTypeForThread() { 48 OleTlsData* ole_tls_data = GetOleTlsData(); 49 if (!ole_tls_data) 50 return ComApartmentType::NONE; 51 52 if (ole_tls_data->apartment_flags & OleTlsData::ApartmentFlags::STA) 53 return ComApartmentType::STA; 54 55 if ((ole_tls_data->apartment_flags & OleTlsData::ApartmentFlags::MTA) == 56 OleTlsData::ApartmentFlags::MTA) { 57 return ComApartmentType::MTA; 58 } 59 60 return ComApartmentType::NONE; 61 } 62 63 #if DCHECK_IS_ON() 64 AssertComInitialized(const char * message)65void AssertComInitialized(const char* message) { 66 if (GetComApartmentTypeForThread() != ComApartmentType::NONE) 67 return; 68 69 // COM worker threads don't always set up the apartment, but they do perform 70 // some thread registration, so we allow those. 71 OleTlsData* ole_tls_data = GetOleTlsData(); 72 if (ole_tls_data && (ole_tls_data->apartment_flags & 73 OleTlsData::ApartmentFlags::LOGICAL_THREAD_REGISTERED)) { 74 return; 75 } 76 77 NOTREACHED() << (message ? message : kComNotInitialized); 78 } 79 AssertComApartmentType(ComApartmentType apartment_type)80void AssertComApartmentType(ComApartmentType apartment_type) { 81 DCHECK_EQ(apartment_type, GetComApartmentTypeForThread()); 82 } 83 84 #endif // DCHECK_IS_ON() 85 86 } // namespace win 87 } // namespace base 88