1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/backends/CpuTensorHandleFwd.hpp> 9 #include <armnn/backends/ITensorHandle.hpp> 10 11 #include <armnn/TypesUtils.hpp> 12 13 #include <CompatibleTypes.hpp> 14 15 #include <algorithm> 16 17 #include <armnn/utility/Assert.hpp> 18 19 namespace armnn 20 { 21 22 // Get a TensorShape representing the strides (in bytes) for each dimension 23 // of a tensor, assuming fully packed data with no padding 24 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo); 25 26 // Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data. 27 class ConstCpuTensorHandle : public ITensorHandle 28 { 29 public: 30 template <typename T> GetConstTensor() const31 const T* GetConstTensor() const 32 { 33 ARMNN_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType())); 34 return reinterpret_cast<const T*>(m_Memory); 35 } 36 GetTensorInfo() const37 const TensorInfo& GetTensorInfo() const 38 { 39 return m_TensorInfo; 40 } 41 Manage()42 virtual void Manage() override {} 43 GetParent() const44 virtual ITensorHandle* GetParent() const override { return nullptr; } 45 Map(bool) const46 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; } Unmap() const47 virtual void Unmap() const override {} 48 GetStrides() const49 TensorShape GetStrides() const override 50 { 51 return GetUnpaddedTensorStrides(m_TensorInfo); 52 } GetShape() const53 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); } 54 55 protected: 56 ConstCpuTensorHandle(const TensorInfo& tensorInfo); 57 SetConstMemory(const void * mem)58 void SetConstMemory(const void* mem) { m_Memory = mem; } 59 60 private: 61 // Only used for testing CopyOutTo(void *) const62 void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); } CopyInFrom(const void *)63 void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); } 64 65 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete; 66 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete; 67 68 TensorInfo m_TensorInfo; 69 const void* m_Memory; 70 }; 71 72 template<> 73 const void* ConstCpuTensorHandle::GetConstTensor<void>() const; 74 75 // Abstract specialization of ConstCpuTensorHandle that allows write access to the same data. 76 class CpuTensorHandle : public ConstCpuTensorHandle 77 { 78 public: 79 template <typename T> GetTensor() const80 T* GetTensor() const 81 { 82 ARMNN_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType())); 83 return reinterpret_cast<T*>(m_MutableMemory); 84 } 85 86 protected: 87 CpuTensorHandle(const TensorInfo& tensorInfo); 88 SetMemory(void * mem)89 void SetMemory(void* mem) 90 { 91 m_MutableMemory = mem; 92 SetConstMemory(m_MutableMemory); 93 } 94 95 private: 96 97 CpuTensorHandle(const CpuTensorHandle& other) = delete; 98 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete; 99 void* m_MutableMemory; 100 }; 101 102 template <> 103 void* CpuTensorHandle::GetTensor<void>() const; 104 105 // A CpuTensorHandle that owns the wrapped memory region. 106 class ScopedCpuTensorHandle : public CpuTensorHandle 107 { 108 public: 109 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo); 110 111 // Copies contents from Tensor. 112 explicit ScopedCpuTensorHandle(const ConstTensor& tensor); 113 114 // Copies contents from ConstCpuTensorHandle 115 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle); 116 117 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other); 118 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other); 119 ~ScopedCpuTensorHandle(); 120 121 virtual void Allocate() override; 122 123 private: 124 // Only used for testing 125 void CopyOutTo(void* memory) const override; 126 void CopyInFrom(const void* memory) override; 127 128 void CopyFrom(const ScopedCpuTensorHandle& other); 129 void CopyFrom(const void* srcMemory, unsigned int numBytes); 130 }; 131 132 // A CpuTensorHandle that wraps an already allocated memory region. 133 // 134 // Clients must make sure the passed in memory region stays alive for the lifetime of 135 // the PassthroughCpuTensorHandle instance. 136 // 137 // Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle. 138 class PassthroughCpuTensorHandle : public CpuTensorHandle 139 { 140 public: PassthroughCpuTensorHandle(const TensorInfo & tensorInfo,void * mem)141 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem) 142 : CpuTensorHandle(tensorInfo) 143 { 144 SetMemory(mem); 145 } 146 147 virtual void Allocate() override; 148 }; 149 150 // A ConstCpuTensorHandle that wraps an already allocated memory region. 151 // 152 // This allows users to pass in const memory to a network. 153 // Clients must make sure the passed in memory region stays alive for the lifetime of 154 // the PassthroughCpuTensorHandle instance. 155 // 156 // Note there is no polymorphism to/from PassthroughCpuTensorHandle. 157 class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle 158 { 159 public: ConstPassthroughCpuTensorHandle(const TensorInfo & tensorInfo,const void * mem)160 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem) 161 : ConstCpuTensorHandle(tensorInfo) 162 { 163 SetConstMemory(mem); 164 } 165 166 virtual void Allocate() override; 167 }; 168 169 170 // Template specializations. 171 172 template <> 173 const void* ConstCpuTensorHandle::GetConstTensor() const; 174 175 template <> 176 void* CpuTensorHandle::GetTensor() const; 177 178 } // namespace armnn 179