• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Assert.hpp"
9 
10 #include <armnn/Exceptions.hpp>
11 
12 #include <memory>
13 #include <type_traits>
14 
15 namespace armnn
16 {
17 
18 // If we are testing then throw an exception, otherwise regular assert
19 #if defined(ARMNN_POLYMORPHIC_CAST_TESTABLE)
20 #   define ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond) ConditionalThrow<std::bad_cast>(cond)
21 #else
22 #   define ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond) ARMNN_ASSERT(cond)
23 #endif
24 
25 //Only check the condition if debug build or during testing
26 #if !defined(NDEBUG) || defined(ARMNN_POLYMORPHIC_CAST_TESTABLE)
27 #   define ARMNN_POLYMORPHIC_CAST_CHECK(cond)  ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond)
28 #else
29 #   define ARMNN_POLYMORPHIC_CAST_CHECK(cond) // release builds dont check the cast
30 #endif
31 
32 
33 namespace utility
34 {
35 // static_pointer_cast overload for std::shared_ptr
36 template <class T1, class T2>
StaticPointerCast(const std::shared_ptr<T2> & sp)37 std::shared_ptr<T1> StaticPointerCast (const std::shared_ptr<T2>& sp)
38 {
39     return std::static_pointer_cast<T1>(sp);
40 }
41 
42 // dynamic_pointer_cast overload for std::shared_ptr
43 template <class T1, class T2>
DynamicPointerCast(const std::shared_ptr<T2> & sp)44 std::shared_ptr<T1> DynamicPointerCast (const std::shared_ptr<T2>& sp)
45 {
46     return std::dynamic_pointer_cast<T1>(sp);
47 }
48 
49 // static_pointer_cast overload for raw pointers
50 template<class T1, class T2>
StaticPointerCast(T2 * ptr)51 inline T1* StaticPointerCast(T2* ptr)
52 {
53     return static_cast<T1*>(ptr);
54 }
55 
56 // dynamic_pointer_cast overload for raw pointers
57 template<class T1, class T2>
DynamicPointerCast(T2 * ptr)58 inline T1* DynamicPointerCast(T2* ptr)
59 {
60     return dynamic_cast<T1*>(ptr);
61 }
62 
63 } // namespace utility
64 
65 /// Polymorphic downcast for build in pointers only
66 ///
67 /// Usage: Child* pChild = PolymorphicDowncast<Child*>(pBase);
68 ///
69 /// \tparam DestType    Pointer type to the target object (Child pointer type)
70 /// \tparam SourceType  Pointer type to the source object (Base pointer type)
71 /// \param value        Pointer to the source object
72 /// \return             Pointer of type DestType (Pointer of type child)
73 template<typename DestType, typename SourceType>
PolymorphicDowncast(SourceType * value)74 DestType PolymorphicDowncast(SourceType* value)
75 {
76     static_assert(std::is_pointer<DestType>::value,
77                   "PolymorphicDowncast only works with pointer types.");
78 
79     ARMNN_POLYMORPHIC_CAST_CHECK(dynamic_cast<DestType>(value) == value);
80     return static_cast<DestType>(value);
81 }
82 
83 
84 /// Polymorphic downcast for shared pointers and build in pointers
85 ///
86 /// Usage: auto pChild = PolymorphicPointerDowncast<Child>(pBase)
87 ///
88 /// \tparam DestType    Type of the target object (Child type)
89 /// \tparam SourceType  Pointer type to the source object (Base (shared) pointer type)
90 /// \param value        Pointer to the source object
91 /// \return             Pointer of type DestType ((Shared) pointer of type child)
92 template<typename DestType, typename SourceType>
PolymorphicPointerDowncast(const SourceType & value)93 auto PolymorphicPointerDowncast(const SourceType& value)
94 {
95     ARMNN_POLYMORPHIC_CAST_CHECK(utility::DynamicPointerCast<DestType>(value)
96                                  == value);
97     return utility::StaticPointerCast<DestType>(value);
98 }
99 
100 } //namespace armnn
101