1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <aclCommon/ArmComputeTensorHandle.hpp> 8 #include <aclCommon/ArmComputeTensorUtils.hpp> 9 10 #include <Half.hpp> 11 12 #include <armnn/utility/PolymorphicDowncast.hpp> 13 14 #include <arm_compute/runtime/CL/CLTensor.h> 15 #include <arm_compute/runtime/CL/CLSubTensor.h> 16 #include <arm_compute/runtime/IMemoryGroup.h> 17 #include <arm_compute/runtime/MemoryGroup.h> 18 #include <arm_compute/core/TensorShape.h> 19 #include <arm_compute/core/Coordinates.h> 20 21 namespace armnn 22 { 23 24 25 class IClTensorHandle : public IAclTensorHandle 26 { 27 public: 28 virtual arm_compute::ICLTensor& GetTensor() = 0; 29 virtual arm_compute::ICLTensor const& GetTensor() const = 0; 30 virtual arm_compute::DataType GetDataType() const = 0; 31 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0; 32 }; 33 34 class ClTensorHandle : public IClTensorHandle 35 { 36 public: ClTensorHandle(const TensorInfo & tensorInfo)37 ClTensorHandle(const TensorInfo& tensorInfo) 38 { 39 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); 40 } 41 ClTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout)42 ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout) 43 { 44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); 45 } 46 GetTensor()47 arm_compute::CLTensor& GetTensor() override { return m_Tensor; } GetTensor() const48 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; } Allocate()49 virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);} 50 Manage()51 virtual void Manage() override 52 { 53 assert(m_MemoryGroup != nullptr); 54 m_MemoryGroup->manage(&m_Tensor); 55 } 56 Map(bool blocking=true) const57 virtual const void* Map(bool blocking = true) const override 58 { 59 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking); 60 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 61 } 62 Unmap() const63 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); } 64 GetParent() const65 virtual ITensorHandle* GetParent() const override { return nullptr; } 66 GetDataType() const67 virtual arm_compute::DataType GetDataType() const override 68 { 69 return m_Tensor.info()->data_type(); 70 } 71 SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)72 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override 73 { 74 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup); 75 } 76 GetStrides() const77 TensorShape GetStrides() const override 78 { 79 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 80 } 81 GetShape() const82 TensorShape GetShape() const override 83 { 84 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 85 } 86 87 private: 88 // Only used for testing CopyOutTo(void * memory) const89 void CopyOutTo(void* memory) const override 90 { 91 const_cast<armnn::ClTensorHandle*>(this)->Map(true); 92 switch(this->GetDataType()) 93 { 94 case arm_compute::DataType::F32: 95 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 96 static_cast<float*>(memory)); 97 break; 98 case arm_compute::DataType::U8: 99 case arm_compute::DataType::QASYMM8: 100 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 101 static_cast<uint8_t*>(memory)); 102 break; 103 case arm_compute::DataType::QSYMM8_PER_CHANNEL: 104 case arm_compute::DataType::QASYMM8_SIGNED: 105 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 106 static_cast<int8_t*>(memory)); 107 break; 108 case arm_compute::DataType::F16: 109 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 110 static_cast<armnn::Half*>(memory)); 111 break; 112 case arm_compute::DataType::S16: 113 case arm_compute::DataType::QSYMM16: 114 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 115 static_cast<int16_t*>(memory)); 116 break; 117 case arm_compute::DataType::S32: 118 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 119 static_cast<int32_t*>(memory)); 120 break; 121 default: 122 { 123 throw armnn::UnimplementedException(); 124 } 125 } 126 const_cast<armnn::ClTensorHandle*>(this)->Unmap(); 127 } 128 129 // Only used for testing CopyInFrom(const void * memory)130 void CopyInFrom(const void* memory) override 131 { 132 this->Map(true); 133 switch(this->GetDataType()) 134 { 135 case arm_compute::DataType::F32: 136 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 137 this->GetTensor()); 138 break; 139 case arm_compute::DataType::U8: 140 case arm_compute::DataType::QASYMM8: 141 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 142 this->GetTensor()); 143 break; 144 case arm_compute::DataType::F16: 145 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), 146 this->GetTensor()); 147 break; 148 case arm_compute::DataType::S16: 149 case arm_compute::DataType::QSYMM8_PER_CHANNEL: 150 case arm_compute::DataType::QASYMM8_SIGNED: 151 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 152 this->GetTensor()); 153 break; 154 case arm_compute::DataType::QSYMM16: 155 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 156 this->GetTensor()); 157 break; 158 case arm_compute::DataType::S32: 159 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 160 this->GetTensor()); 161 break; 162 default: 163 { 164 throw armnn::UnimplementedException(); 165 } 166 } 167 this->Unmap(); 168 } 169 170 arm_compute::CLTensor m_Tensor; 171 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; 172 }; 173 174 class ClSubTensorHandle : public IClTensorHandle 175 { 176 public: ClSubTensorHandle(IClTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)177 ClSubTensorHandle(IClTensorHandle* parent, 178 const arm_compute::TensorShape& shape, 179 const arm_compute::Coordinates& coords) 180 : m_Tensor(&parent->GetTensor(), shape, coords) 181 { 182 parentHandle = parent; 183 } 184 GetTensor()185 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; } GetTensor() const186 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; } 187 Allocate()188 virtual void Allocate() override {} Manage()189 virtual void Manage() override {} 190 Map(bool blocking=true) const191 virtual const void* Map(bool blocking = true) const override 192 { 193 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking); 194 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 195 } Unmap() const196 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); } 197 GetParent() const198 virtual ITensorHandle* GetParent() const override { return parentHandle; } 199 GetDataType() const200 virtual arm_compute::DataType GetDataType() const override 201 { 202 return m_Tensor.info()->data_type(); 203 } 204 SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)205 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {} 206 GetStrides() const207 TensorShape GetStrides() const override 208 { 209 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 210 } 211 GetShape() const212 TensorShape GetShape() const override 213 { 214 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 215 } 216 217 private: 218 // Only used for testing CopyOutTo(void * memory) const219 void CopyOutTo(void* memory) const override 220 { 221 const_cast<ClSubTensorHandle*>(this)->Map(true); 222 switch(this->GetDataType()) 223 { 224 case arm_compute::DataType::F32: 225 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 226 static_cast<float*>(memory)); 227 break; 228 case arm_compute::DataType::U8: 229 case arm_compute::DataType::QASYMM8: 230 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 231 static_cast<uint8_t*>(memory)); 232 break; 233 case arm_compute::DataType::F16: 234 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 235 static_cast<armnn::Half*>(memory)); 236 break; 237 case arm_compute::DataType::QSYMM8_PER_CHANNEL: 238 case arm_compute::DataType::QASYMM8_SIGNED: 239 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 240 static_cast<int8_t*>(memory)); 241 break; 242 case arm_compute::DataType::S16: 243 case arm_compute::DataType::QSYMM16: 244 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 245 static_cast<int16_t*>(memory)); 246 break; 247 case arm_compute::DataType::S32: 248 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 249 static_cast<int32_t*>(memory)); 250 break; 251 default: 252 { 253 throw armnn::UnimplementedException(); 254 } 255 } 256 const_cast<ClSubTensorHandle*>(this)->Unmap(); 257 } 258 259 // Only used for testing CopyInFrom(const void * memory)260 void CopyInFrom(const void* memory) override 261 { 262 this->Map(true); 263 switch(this->GetDataType()) 264 { 265 case arm_compute::DataType::F32: 266 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 267 this->GetTensor()); 268 break; 269 case arm_compute::DataType::U8: 270 case arm_compute::DataType::QASYMM8: 271 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 272 this->GetTensor()); 273 break; 274 case arm_compute::DataType::F16: 275 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), 276 this->GetTensor()); 277 break; 278 case arm_compute::DataType::QSYMM8_PER_CHANNEL: 279 case arm_compute::DataType::QASYMM8_SIGNED: 280 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 281 this->GetTensor()); 282 break; 283 case arm_compute::DataType::S16: 284 case arm_compute::DataType::QSYMM16: 285 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 286 this->GetTensor()); 287 break; 288 case arm_compute::DataType::S32: 289 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 290 this->GetTensor()); 291 break; 292 default: 293 { 294 throw armnn::UnimplementedException(); 295 } 296 } 297 this->Unmap(); 298 } 299 300 mutable arm_compute::CLSubTensor m_Tensor; 301 ITensorHandle* parentHandle = nullptr; 302 }; 303 304 } // namespace armnn 305