1 // 2 // Copyright (C) Microsoft Corporation 3 // All rights reserved. 4 // Modified for native C++ WRL support by Gregory Morse 5 // 6 // Code in Details namespace is for internal usage within the library code 7 // 8 9 #ifndef _PLATFORM_AGILE_H_ 10 #define _PLATFORM_AGILE_H_ 11 12 #ifdef _MSC_VER 13 #pragma once 14 #endif // _MSC_VER 15 16 #include <algorithm> 17 #include <wrl\client.h> 18 19 template <typename T, bool TIsNotAgile> class Agile; 20 21 template <typename T> 22 struct UnwrapAgile 23 { 24 static const bool _IsAgile = false; 25 }; 26 template <typename T> 27 struct UnwrapAgile<Agile<T, false>> 28 { 29 static const bool _IsAgile = true; 30 }; 31 template <typename T> 32 struct UnwrapAgile<Agile<T, true>> 33 { 34 static const bool _IsAgile = true; 35 }; 36 37 #define IS_AGILE(T) UnwrapAgile<T>::_IsAgile 38 39 #define __is_winrt_agile(T) (std::is_same<T, HSTRING__>::value || std::is_base_of<Microsoft::WRL::FtmBase, T>::value || std::is_base_of<IAgileObject, T>::value) //derived from Microsoft::WRL::FtmBase or IAgileObject 40 41 #define __is_win_interface(T) (std::is_base_of<IUnknown, T>::value || std::is_base_of<IInspectable, T>::value) //derived from IUnknown or IInspectable 42 43 #define __is_win_class(T) (std::is_same<T, HSTRING__>::value || std::is_base_of<Microsoft::WRL::Details::RuntimeClassBase, T>::value) //derived from Microsoft::WRL::RuntimeClass or HSTRING 44 45 namespace Details 46 { 47 IUnknown* __stdcall GetObjectContext(); 48 HRESULT __stdcall GetProxyImpl(IUnknown*, REFIID, IUnknown*, IUnknown**); 49 HRESULT __stdcall ReleaseInContextImpl(IUnknown*, IUnknown*); 50 51 template <typename T> 52 #if _MSC_VER >= 1800 53 __declspec(no_refcount) inline HRESULT GetProxy(T *ObjectIn, IUnknown *ContextCallBack, T **Proxy) 54 #else 55 inline HRESULT GetProxy(T *ObjectIn, IUnknown *ContextCallBack, T **Proxy) 56 #endif 57 { 58 #if _MSC_VER >= 1800 59 return GetProxyImpl(*reinterpret_cast<IUnknown**>(&ObjectIn), __uuidof(T*), ContextCallBack, reinterpret_cast<IUnknown**>(Proxy)); 60 #else 61 return GetProxyImpl(*reinterpret_cast<IUnknown**>(&const_cast<T*>(ObjectIn)), __uuidof(T*), ContextCallBack, reinterpret_cast<IUnknown**>(Proxy)); 62 #endif 63 } 64 65 template <typename T> 66 inline HRESULT ReleaseInContext(T *ObjectIn, IUnknown *ContextCallBack) 67 { 68 return ReleaseInContextImpl(ObjectIn, ContextCallBack); 69 } 70 71 template <typename T> 72 class AgileHelper 73 { 74 __abi_IUnknown* _p; 75 bool _release; 76 public: 77 AgileHelper(__abi_IUnknown* p, bool release = true) : _p(p), _release(release) 78 { 79 } 80 AgileHelper(AgileHelper&& other) : _p(other._p), _release(other._release) 81 { 82 _other._p = nullptr; 83 _other._release = true; 84 } 85 AgileHelper operator=(AgileHelper&& other) 86 { 87 _p = other._p; 88 _release = other._release; 89 _other._p = nullptr; 90 _other._release = true; 91 return *this; 92 } 93 94 ~AgileHelper() 95 { 96 if (_release && _p) 97 { 98 _p->__abi_Release(); 99 } 100 } 101 102 __declspec(no_refcount) __declspec(no_release_return) 103 T* operator->() 104 { 105 return reinterpret_cast<T*>(_p); 106 } 107 108 __declspec(no_refcount) __declspec(no_release_return) 109 operator T * () 110 { 111 return reinterpret_cast<T*>(_p); 112 } 113 private: 114 AgileHelper(const AgileHelper&); 115 AgileHelper operator=(const AgileHelper&); 116 }; 117 template <typename T> 118 struct __remove_hat 119 { 120 typedef T type; 121 }; 122 template <typename T> 123 struct __remove_hat<T*> 124 { 125 typedef T type; 126 }; 127 template <typename T> 128 struct AgileTypeHelper 129 { 130 typename typedef __remove_hat<T>::type type; 131 typename typedef __remove_hat<T>::type* agileMemberType; 132 }; 133 } // namespace Details 134 135 #pragma warning(push) 136 #pragma warning(disable: 4451) // Usage of ref class inside this context can lead to invalid marshaling of object across contexts 137 138 template < 139 typename T, 140 bool TIsNotAgile = (__is_win_class(typename Details::AgileTypeHelper<T>::type) && !__is_winrt_agile(typename Details::AgileTypeHelper<T>::type)) || 141 __is_win_interface(typename Details::AgileTypeHelper<T>::type) 142 > 143 class Agile 144 { 145 static_assert(__is_win_class(typename Details::AgileTypeHelper<T>::type) || __is_win_interface(typename Details::AgileTypeHelper<T>::type), "Agile can only be used with ref class or interface class types"); 146 typename typedef Details::AgileTypeHelper<T>::agileMemberType TypeT; 147 TypeT _object; 148 ::Microsoft::WRL::ComPtr<IUnknown> _contextCallback; 149 ULONG_PTR _contextToken; 150 151 #if _MSC_VER >= 1800 152 enum class AgileState 153 { 154 NonAgilePointer = 0, 155 AgilePointer = 1, 156 Unknown = 2 157 }; 158 AgileState _agileState; 159 #endif 160 161 void CaptureContext() 162 { 163 _contextCallback = Details::GetObjectContext(); 164 __abi_ThrowIfFailed(CoGetContextToken(&_contextToken)); 165 } 166 167 void SetObject(TypeT object) 168 { 169 // Capture context before setting the pointer 170 // If context capture fails then nothing to cleanup 171 Release(); 172 if (object != nullptr) 173 { 174 ::Microsoft::WRL::ComPtr<IAgileObject> checkIfAgile; 175 HRESULT hr = reinterpret_cast<IUnknown*>(object)->QueryInterface(__uuidof(IAgileObject), &checkIfAgile); 176 // Don't Capture context if object is agile 177 if (hr != S_OK) 178 { 179 #if _MSC_VER >= 1800 180 _agileState = AgileState::NonAgilePointer; 181 #endif 182 CaptureContext(); 183 } 184 #if _MSC_VER >= 1800 185 else 186 { 187 _agileState = AgileState::AgilePointer; 188 } 189 #endif 190 } 191 _object = object; 192 } 193 194 public: 195 Agile() throw() : _object(nullptr), _contextToken(0) 196 #if _MSC_VER >= 1800 197 , _agileState(AgileState::Unknown) 198 #endif 199 { 200 } 201 202 Agile(nullptr_t) throw() : _object(nullptr), _contextToken(0) 203 #if _MSC_VER >= 1800 204 , _agileState(AgileState::Unknown) 205 #endif 206 { 207 } 208 209 explicit Agile(TypeT object) throw() : _object(nullptr), _contextToken(0) 210 #if _MSC_VER >= 1800 211 , _agileState(AgileState::Unknown) 212 #endif 213 { 214 // Assumes that the source object is from the current context 215 SetObject(object); 216 } 217 218 Agile(const Agile& object) throw() : _object(nullptr), _contextToken(0) 219 #if _MSC_VER >= 1800 220 , _agileState(AgileState::Unknown) 221 #endif 222 { 223 // Get returns pointer valid for current context 224 SetObject(object.Get()); 225 } 226 227 Agile(Agile&& object) throw() : _object(nullptr), _contextToken(0) 228 #if _MSC_VER >= 1800 229 , _agileState(AgileState::Unknown) 230 #endif 231 { 232 // Assumes that the source object is from the current context 233 Swap(object); 234 } 235 236 ~Agile() throw() 237 { 238 Release(); 239 } 240 241 TypeT Get() const 242 { 243 // Agile object, no proxy required 244 #if _MSC_VER >= 1800 245 if (_agileState == AgileState::AgilePointer || _object == nullptr) 246 #else 247 if (_contextToken == 0 || _contextCallback == nullptr || _object == nullptr) 248 #endif 249 { 250 return _object; 251 } 252 253 // Do the check for same context 254 ULONG_PTR currentContextToken; 255 __abi_ThrowIfFailed(CoGetContextToken(¤tContextToken)); 256 if (currentContextToken == _contextToken) 257 { 258 return _object; 259 } 260 261 #if _MSC_VER >= 1800 262 // Different context and holding on to a non agile object 263 // Do the costly work of getting a proxy 264 TypeT localObject; 265 __abi_ThrowIfFailed(Details::GetProxy(_object, _contextCallback.Get(), &localObject)); 266 267 if (_agileState == AgileState::Unknown) 268 #else 269 // Object is agile if it implements IAgileObject 270 // GetAddressOf captures the context with out knowing the type of object that it will hold 271 if (_object != nullptr) 272 #endif 273 { 274 #if _MSC_VER >= 1800 275 // Object is agile if it implements IAgileObject 276 // GetAddressOf captures the context with out knowing the type of object that it will hold 277 ::Microsoft::WRL::ComPtr<IAgileObject> checkIfAgile; 278 HRESULT hr = reinterpret_cast<IUnknown*>(localObject)->QueryInterface(__uuidof(IAgileObject), &checkIfAgile); 279 #else 280 ::Microsoft::WRL::ComPtr<IAgileObject> checkIfAgile; 281 HRESULT hr = reinterpret_cast<IUnknown*>(_object)->QueryInterface(__uuidof(IAgileObject), &checkIfAgile); 282 #endif 283 if (hr == S_OK) 284 { 285 auto pThis = const_cast<Agile*>(this); 286 #if _MSC_VER >= 1800 287 pThis->_agileState = AgileState::AgilePointer; 288 #endif 289 pThis->_contextToken = 0; 290 pThis->_contextCallback = nullptr; 291 return _object; 292 } 293 #if _MSC_VER >= 1800 294 else 295 { 296 auto pThis = const_cast<Agile*>(this); 297 pThis->_agileState = AgileState::NonAgilePointer; 298 } 299 #endif 300 } 301 302 #if _MSC_VER < 1800 303 // Different context and holding on to a non agile object 304 // Do the costly work of getting a proxy 305 TypeT localObject; 306 __abi_ThrowIfFailed(Details::GetProxy(_object, _contextCallback.Get(), &localObject)); 307 #endif 308 return localObject; 309 } 310 311 TypeT* GetAddressOf() throw() 312 { 313 Release(); 314 CaptureContext(); 315 return &_object; 316 } 317 318 TypeT* GetAddressOfForInOut() throw() 319 { 320 CaptureContext(); 321 return &_object; 322 } 323 324 TypeT operator->() const throw() 325 { 326 return Get(); 327 } 328 329 Agile& operator=(nullptr_t) throw() 330 { 331 Release(); 332 return *this; 333 } 334 335 Agile& operator=(TypeT object) throw() 336 { 337 Agile(object).Swap(*this); 338 return *this; 339 } 340 341 Agile& operator=(Agile object) throw() 342 { 343 // parameter is by copy which gets pointer valid for current context 344 object.Swap(*this); 345 return *this; 346 } 347 348 #if _MSC_VER < 1800 349 Agile& operator=(IUnknown* lp) throw() 350 { 351 // bump ref count 352 ::Microsoft::WRL::ComPtr<IUnknown> spObject(lp); 353 354 // put it into Platform Object 355 Platform::Object object; 356 *(IUnknown**)(&object) = spObject.Detach(); 357 358 SetObject(object); 359 return *this; 360 } 361 #endif 362 363 void Swap(Agile& object) 364 { 365 std::swap(_object, object._object); 366 std::swap(_contextCallback, object._contextCallback); 367 std::swap(_contextToken, object._contextToken); 368 #if _MSC_VER >= 1800 369 std::swap(_agileState, object._agileState); 370 #endif 371 } 372 373 // Release the interface and set to NULL 374 void Release() throw() 375 { 376 if (_object) 377 { 378 // Cast to IInspectable (no QI) 379 IUnknown* pObject = *(IUnknown**)(&_object); 380 // Set * to null without release 381 *(IUnknown**)(&_object) = nullptr; 382 383 ULONG_PTR currentContextToken; 384 __abi_ThrowIfFailed(CoGetContextToken(¤tContextToken)); 385 if (_contextToken == 0 || _contextCallback == nullptr || _contextToken == currentContextToken) 386 { 387 pObject->Release(); 388 } 389 else 390 { 391 Details::ReleaseInContext(pObject, _contextCallback.Get()); 392 } 393 _contextCallback = nullptr; 394 _contextToken = 0; 395 #if _MSC_VER >= 1800 396 _agileState = AgileState::Unknown; 397 #endif 398 } 399 } 400 401 bool operator==(nullptr_t) const throw() 402 { 403 return _object == nullptr; 404 } 405 406 bool operator==(const Agile& other) const throw() 407 { 408 return _object == other._object && _contextToken == other._contextToken; 409 } 410 411 bool operator<(const Agile& other) const throw() 412 { 413 if (reinterpret_cast<void*>(_object) < reinterpret_cast<void*>(other._object)) 414 { 415 return true; 416 } 417 418 return _object == other._object && _contextToken < other._contextToken; 419 } 420 }; 421 422 template <typename T> 423 class Agile<T, false> 424 { 425 static_assert(__is_win_class(typename Details::AgileTypeHelper<T>::type) || __is_win_interface(typename Details::AgileTypeHelper<T>::type), "Agile can only be used with ref class or interface class types"); 426 typename typedef Details::AgileTypeHelper<T>::agileMemberType TypeT; 427 TypeT _object; 428 429 public: 430 Agile() throw() : _object(nullptr) 431 { 432 } 433 434 Agile(nullptr_t) throw() : _object(nullptr) 435 { 436 } 437 438 explicit Agile(TypeT object) throw() : _object(object) 439 { 440 } 441 442 Agile(const Agile& object) throw() : _object(object._object) 443 { 444 } 445 446 Agile(Agile&& object) throw() : _object(nullptr) 447 { 448 Swap(object); 449 } 450 451 ~Agile() throw() 452 { 453 Release(); 454 } 455 456 TypeT Get() const 457 { 458 return _object; 459 } 460 461 TypeT* GetAddressOf() throw() 462 { 463 Release(); 464 return &_object; 465 } 466 467 TypeT* GetAddressOfForInOut() throw() 468 { 469 return &_object; 470 } 471 472 TypeT operator->() const throw() 473 { 474 return Get(); 475 } 476 477 Agile& operator=(nullptr_t) throw() 478 { 479 Release(); 480 return *this; 481 } 482 483 Agile& operator=(TypeT object) throw() 484 { 485 if (_object != object) 486 { 487 _object = object; 488 } 489 return *this; 490 } 491 492 Agile& operator=(Agile object) throw() 493 { 494 object.Swap(*this); 495 return *this; 496 } 497 498 #if _MSC_VER < 1800 499 Agile& operator=(IUnknown* lp) throw() 500 { 501 Release(); 502 // bump ref count 503 ::Microsoft::WRL::ComPtr<IUnknown> spObject(lp); 504 505 // put it into Platform Object 506 Platform::Object object; 507 *(IUnknown**)(&object) = spObject.Detach(); 508 509 _object = object; 510 return *this; 511 } 512 #endif 513 514 // Release the interface and set to NULL 515 void Release() throw() 516 { 517 _object = nullptr; 518 } 519 520 void Swap(Agile& object) 521 { 522 std::swap(_object, object._object); 523 } 524 525 bool operator==(nullptr_t) const throw() 526 { 527 return _object == nullptr; 528 } 529 530 bool operator==(const Agile& other) const throw() 531 { 532 return _object == other._object; 533 } 534 535 bool operator<(const Agile& other) const throw() 536 { 537 return reinterpret_cast<void*>(_object) < reinterpret_cast<void*>(other._object); 538 } 539 }; 540 541 #pragma warning(pop) 542 543 template<class U> 544 bool operator==(nullptr_t, const Agile<U>& a) throw() 545 { 546 return a == nullptr; 547 } 548 549 template<class U> 550 bool operator!=(const Agile<U>& a, nullptr_t) throw() 551 { 552 return !(a == nullptr); 553 } 554 555 template<class U> 556 bool operator!=(nullptr_t, const Agile<U>& a) throw() 557 { 558 return !(a == nullptr); 559 } 560 561 template<class U> 562 bool operator!=(const Agile<U>& a, const Agile<U>& b) throw() 563 { 564 return !(a == b); 565 } 566 567 568 #endif // _PLATFORM_AGILE_H_ 569