1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ITensorHandle.hpp" 9 10 #include <armnn/TypesUtils.hpp> 11 #include <armnn/utility/Assert.hpp> 12 #include <armnnUtils/CompatibleTypes.hpp> 13 14 #include <algorithm> 15 16 namespace armnn 17 { 18 19 // Get a TensorShape representing the strides (in bytes) for each dimension 20 // of a tensor, assuming fully packed data with no padding 21 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo); 22 23 // Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data. 24 class ConstTensorHandle : public ITensorHandle 25 { 26 public: 27 template <typename T> GetConstTensor() const28 const T* GetConstTensor() const 29 { 30 if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType())) 31 { 32 return reinterpret_cast<const T*>(m_Memory); 33 } 34 else 35 { 36 throw armnn::Exception("Attempting to get not compatible type tensor!"); 37 } 38 } 39 GetTensorInfo() const40 const TensorInfo& GetTensorInfo() const 41 { 42 return m_TensorInfo; 43 } 44 Manage()45 virtual void Manage() override {} 46 GetParent() const47 virtual ITensorHandle* GetParent() const override { return nullptr; } 48 Map(bool) const49 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; } Unmap() const50 virtual void Unmap() const override {} 51 GetStrides() const52 TensorShape GetStrides() const override 53 { 54 return GetUnpaddedTensorStrides(m_TensorInfo); 55 } GetShape() const56 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); } 57 58 protected: 59 ConstTensorHandle(const TensorInfo& tensorInfo); 60 SetConstMemory(const void * mem)61 void SetConstMemory(const void* mem) { m_Memory = mem; } 62 63 private: 64 // Only used for testing CopyOutTo(void *) const65 void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); } CopyInFrom(const void *)66 void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); } 67 68 ConstTensorHandle(const ConstTensorHandle& other) = delete; 69 ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete; 70 71 TensorInfo m_TensorInfo; 72 const void* m_Memory; 73 }; 74 75 template<> 76 const void* ConstTensorHandle::GetConstTensor<void>() const; 77 78 // Abstract specialization of ConstTensorHandle that allows write access to the same data. 79 class TensorHandle : public ConstTensorHandle 80 { 81 public: 82 template <typename T> GetTensor() const83 T* GetTensor() const 84 { 85 if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType())) 86 { 87 return reinterpret_cast<T*>(m_MutableMemory); 88 } 89 else 90 { 91 throw armnn::Exception("Attempting to get not compatible type tensor!"); 92 } 93 } 94 95 protected: 96 TensorHandle(const TensorInfo& tensorInfo); 97 SetMemory(void * mem)98 void SetMemory(void* mem) 99 { 100 m_MutableMemory = mem; 101 SetConstMemory(m_MutableMemory); 102 } 103 104 private: 105 106 TensorHandle(const TensorHandle& other) = delete; 107 TensorHandle& operator=(const TensorHandle& other) = delete; 108 void* m_MutableMemory; 109 }; 110 111 template <> 112 void* TensorHandle::GetTensor<void>() const; 113 114 // A TensorHandle that owns the wrapped memory region. 115 class ScopedTensorHandle : public TensorHandle 116 { 117 public: 118 explicit ScopedTensorHandle(const TensorInfo& tensorInfo); 119 120 // Copies contents from Tensor. 121 explicit ScopedTensorHandle(const ConstTensor& tensor); 122 123 // Copies contents from ConstTensorHandle 124 explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle); 125 126 ScopedTensorHandle(const ScopedTensorHandle& other); 127 ScopedTensorHandle& operator=(const ScopedTensorHandle& other); 128 ~ScopedTensorHandle(); 129 130 virtual void Allocate() override; 131 132 private: 133 // Only used for testing 134 void CopyOutTo(void* memory) const override; 135 void CopyInFrom(const void* memory) override; 136 137 void CopyFrom(const ScopedTensorHandle& other); 138 void CopyFrom(const void* srcMemory, unsigned int numBytes); 139 }; 140 141 // A TensorHandle that wraps an already allocated memory region. 142 // 143 // Clients must make sure the passed in memory region stays alive for the lifetime of 144 // the PassthroughTensorHandle instance. 145 // 146 // Note there is no polymorphism to/from ConstPassthroughTensorHandle. 147 class PassthroughTensorHandle : public TensorHandle 148 { 149 public: PassthroughTensorHandle(const TensorInfo & tensorInfo,void * mem)150 PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem) 151 : TensorHandle(tensorInfo) 152 { 153 SetMemory(mem); 154 } 155 156 virtual void Allocate() override; 157 }; 158 159 // A ConstTensorHandle that wraps an already allocated memory region. 160 // 161 // This allows users to pass in const memory to a network. 162 // Clients must make sure the passed in memory region stays alive for the lifetime of 163 // the PassthroughTensorHandle instance. 164 // 165 // Note there is no polymorphism to/from PassthroughTensorHandle. 166 class ConstPassthroughTensorHandle : public ConstTensorHandle 167 { 168 public: ConstPassthroughTensorHandle(const TensorInfo & tensorInfo,const void * mem)169 ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem) 170 : ConstTensorHandle(tensorInfo) 171 { 172 SetConstMemory(mem); 173 } 174 175 virtual void Allocate() override; 176 }; 177 178 179 // Template specializations. 180 181 template <> 182 const void* ConstTensorHandle::GetConstTensor() const; 183 184 template <> 185 void* TensorHandle::GetTensor() const; 186 187 class ManagedConstTensorHandle 188 { 189 190 public: ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)191 explicit ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr) 192 : m_Mapped(false) 193 , m_TensorHandle(std::move(ptr)) {}; 194 195 /// RAII Managed resource Unmaps MemoryArea once out of scope Map(bool blocking=true)196 const void* Map(bool blocking = true) 197 { 198 if (m_TensorHandle) 199 { 200 auto pRet = m_TensorHandle->Map(blocking); 201 m_Mapped = true; 202 return pRet; 203 } 204 else 205 { 206 throw armnn::Exception("Attempting to Map null TensorHandle"); 207 } 208 209 } 210 211 // Delete copy constructor as it's unnecessary 212 ManagedConstTensorHandle(const ConstTensorHandle& other) = delete; 213 214 // Delete copy assignment as it's unnecessary 215 ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete; 216 217 // Delete move assignment as it's unnecessary 218 ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete; 219 ~ManagedConstTensorHandle()220 ~ManagedConstTensorHandle() 221 { 222 // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled 223 if (m_TensorHandle) 224 { 225 Unmap(); 226 } 227 } 228 Unmap()229 void Unmap() 230 { 231 // Only unmap if mapped and TensorHandle exists. 232 if (m_Mapped && m_TensorHandle) 233 { 234 m_TensorHandle->Unmap(); 235 m_Mapped = false; 236 } 237 } 238 GetTensorInfo() const239 const TensorInfo& GetTensorInfo() const 240 { 241 return m_TensorHandle->GetTensorInfo(); 242 } 243 IsMapped() const244 bool IsMapped() const 245 { 246 return m_Mapped; 247 } 248 249 private: 250 bool m_Mapped; 251 std::shared_ptr<ConstTensorHandle> m_TensorHandle; 252 }; 253 254 } // namespace armnn 255