1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <BFloat16.hpp> 8 #include <Half.hpp> 9 10 #include <armnn/utility/Assert.hpp> 11 12 #include <aclCommon/ArmComputeTensorHandle.hpp> 13 #include <aclCommon/ArmComputeTensorUtils.hpp> 14 #include <armnn/utility/PolymorphicDowncast.hpp> 15 16 #include <arm_compute/runtime/MemoryGroup.h> 17 #include <arm_compute/runtime/IMemoryGroup.h> 18 #include <arm_compute/runtime/Tensor.h> 19 #include <arm_compute/runtime/SubTensor.h> 20 #include <arm_compute/core/TensorShape.h> 21 #include <arm_compute/core/Coordinates.h> 22 23 namespace armnn 24 { 25 26 class NeonTensorHandle : public IAclTensorHandle 27 { 28 public: NeonTensorHandle(const TensorInfo & tensorInfo)29 NeonTensorHandle(const TensorInfo& tensorInfo) 30 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)), 31 m_Imported(false), 32 m_IsImportEnabled(false) 33 { 34 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); 35 } 36 NeonTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Malloc))37 NeonTensorHandle(const TensorInfo& tensorInfo, 38 DataLayout dataLayout, 39 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc)) 40 : m_ImportFlags(importFlags), 41 m_Imported(false), 42 m_IsImportEnabled(false) 43 44 { 45 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); 46 } 47 GetTensor()48 arm_compute::ITensor& GetTensor() override { return m_Tensor; } GetTensor() const49 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; } 50 Allocate()51 virtual void Allocate() override 52 { 53 // If we have enabled Importing, don't Allocate the tensor 54 if (!m_IsImportEnabled) 55 { 56 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); 57 } 58 }; 59 Manage()60 virtual void Manage() override 61 { 62 // If we have enabled Importing, don't manage the tensor 63 if (!m_IsImportEnabled) 64 { 65 ARMNN_ASSERT(m_MemoryGroup != nullptr); 66 m_MemoryGroup->manage(&m_Tensor); 67 } 68 } 69 GetParent() const70 virtual ITensorHandle* GetParent() const override { return nullptr; } 71 GetDataType() const72 virtual arm_compute::DataType GetDataType() const override 73 { 74 return m_Tensor.info()->data_type(); 75 } 76 SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)77 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override 78 { 79 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup); 80 } 81 Map(bool) const82 virtual const void* Map(bool /* blocking = true */) const override 83 { 84 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 85 } 86 Unmap() const87 virtual void Unmap() const override {} 88 GetStrides() const89 TensorShape GetStrides() const override 90 { 91 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 92 } 93 GetShape() const94 TensorShape GetShape() const override 95 { 96 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 97 } 98 SetImportFlags(MemorySourceFlags importFlags)99 void SetImportFlags(MemorySourceFlags importFlags) 100 { 101 m_ImportFlags = importFlags; 102 } 103 GetImportFlags() const104 MemorySourceFlags GetImportFlags() const override 105 { 106 return m_ImportFlags; 107 } 108 SetImportEnabledFlag(bool importEnabledFlag)109 void SetImportEnabledFlag(bool importEnabledFlag) 110 { 111 m_IsImportEnabled = importEnabledFlag; 112 } 113 Import(void * memory,MemorySource source)114 virtual bool Import(void* memory, MemorySource source) override 115 { 116 if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) 117 { 118 if (source == MemorySource::Malloc && m_IsImportEnabled) 119 { 120 // Checks the 16 byte memory alignment 121 constexpr uintptr_t alignment = sizeof(size_t); 122 if (reinterpret_cast<uintptr_t>(memory) % alignment) 123 { 124 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory"); 125 } 126 127 // m_Tensor not yet Allocated 128 if (!m_Imported && !m_Tensor.buffer()) 129 { 130 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory); 131 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception 132 // with the Status error message 133 m_Imported = bool(status); 134 if (!m_Imported) 135 { 136 throw MemoryImportException(status.error_description()); 137 } 138 return m_Imported; 139 } 140 141 // m_Tensor.buffer() initially allocated with Allocate(). 142 if (!m_Imported && m_Tensor.buffer()) 143 { 144 throw MemoryImportException( 145 "NeonTensorHandle::Import Attempting to import on an already allocated tensor"); 146 } 147 148 // m_Tensor.buffer() previously imported. 149 if (m_Imported) 150 { 151 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory); 152 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception 153 // with the Status error message 154 m_Imported = bool(status); 155 if (!m_Imported) 156 { 157 throw MemoryImportException(status.error_description()); 158 } 159 return m_Imported; 160 } 161 } 162 else 163 { 164 throw MemoryImportException("NeonTensorHandle::Import is disabled"); 165 } 166 } 167 else 168 { 169 throw MemoryImportException("NeonTensorHandle::Incorrect import flag"); 170 } 171 return false; 172 } 173 174 private: 175 // Only used for testing CopyOutTo(void * memory) const176 void CopyOutTo(void* memory) const override 177 { 178 switch (this->GetDataType()) 179 { 180 case arm_compute::DataType::F32: 181 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 182 static_cast<float*>(memory)); 183 break; 184 case arm_compute::DataType::U8: 185 case arm_compute::DataType::QASYMM8: 186 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 187 static_cast<uint8_t*>(memory)); 188 break; 189 case arm_compute::DataType::QASYMM8_SIGNED: 190 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 191 static_cast<int8_t*>(memory)); 192 break; 193 case arm_compute::DataType::BFLOAT16: 194 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 195 static_cast<armnn::BFloat16*>(memory)); 196 break; 197 case arm_compute::DataType::F16: 198 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 199 static_cast<armnn::Half*>(memory)); 200 break; 201 case arm_compute::DataType::S16: 202 case arm_compute::DataType::QSYMM16: 203 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 204 static_cast<int16_t*>(memory)); 205 break; 206 case arm_compute::DataType::S32: 207 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 208 static_cast<int32_t*>(memory)); 209 break; 210 default: 211 { 212 throw armnn::UnimplementedException(); 213 } 214 } 215 } 216 217 // Only used for testing CopyInFrom(const void * memory)218 void CopyInFrom(const void* memory) override 219 { 220 switch (this->GetDataType()) 221 { 222 case arm_compute::DataType::F32: 223 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 224 this->GetTensor()); 225 break; 226 case arm_compute::DataType::U8: 227 case arm_compute::DataType::QASYMM8: 228 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 229 this->GetTensor()); 230 break; 231 case arm_compute::DataType::QASYMM8_SIGNED: 232 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 233 this->GetTensor()); 234 break; 235 case arm_compute::DataType::BFLOAT16: 236 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory), 237 this->GetTensor()); 238 break; 239 case arm_compute::DataType::F16: 240 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), 241 this->GetTensor()); 242 break; 243 case arm_compute::DataType::S16: 244 case arm_compute::DataType::QSYMM16: 245 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 246 this->GetTensor()); 247 break; 248 case arm_compute::DataType::S32: 249 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 250 this->GetTensor()); 251 break; 252 default: 253 { 254 throw armnn::UnimplementedException(); 255 } 256 } 257 } 258 259 arm_compute::Tensor m_Tensor; 260 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; 261 MemorySourceFlags m_ImportFlags; 262 bool m_Imported; 263 bool m_IsImportEnabled; 264 }; 265 266 class NeonSubTensorHandle : public IAclTensorHandle 267 { 268 public: NeonSubTensorHandle(IAclTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)269 NeonSubTensorHandle(IAclTensorHandle* parent, 270 const arm_compute::TensorShape& shape, 271 const arm_compute::Coordinates& coords) 272 : m_Tensor(&parent->GetTensor(), shape, coords) 273 { 274 parentHandle = parent; 275 } 276 GetTensor()277 arm_compute::ITensor& GetTensor() override { return m_Tensor; } GetTensor() const278 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; } 279 Allocate()280 virtual void Allocate() override {} Manage()281 virtual void Manage() override {} 282 GetParent() const283 virtual ITensorHandle* GetParent() const override { return parentHandle; } 284 GetDataType() const285 virtual arm_compute::DataType GetDataType() const override 286 { 287 return m_Tensor.info()->data_type(); 288 } 289 SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)290 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {} 291 Map(bool) const292 virtual const void* Map(bool /* blocking = true */) const override 293 { 294 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 295 } Unmap() const296 virtual void Unmap() const override {} 297 GetStrides() const298 TensorShape GetStrides() const override 299 { 300 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 301 } 302 GetShape() const303 TensorShape GetShape() const override 304 { 305 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 306 } 307 308 private: 309 // Only used for testing CopyOutTo(void * memory) const310 void CopyOutTo(void* memory) const override 311 { 312 switch (this->GetDataType()) 313 { 314 case arm_compute::DataType::F32: 315 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 316 static_cast<float*>(memory)); 317 break; 318 case arm_compute::DataType::U8: 319 case arm_compute::DataType::QASYMM8: 320 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 321 static_cast<uint8_t*>(memory)); 322 break; 323 case arm_compute::DataType::QASYMM8_SIGNED: 324 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 325 static_cast<int8_t*>(memory)); 326 break; 327 case arm_compute::DataType::S16: 328 case arm_compute::DataType::QSYMM16: 329 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 330 static_cast<int16_t*>(memory)); 331 break; 332 case arm_compute::DataType::S32: 333 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 334 static_cast<int32_t*>(memory)); 335 break; 336 default: 337 { 338 throw armnn::UnimplementedException(); 339 } 340 } 341 } 342 343 // Only used for testing CopyInFrom(const void * memory)344 void CopyInFrom(const void* memory) override 345 { 346 switch (this->GetDataType()) 347 { 348 case arm_compute::DataType::F32: 349 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 350 this->GetTensor()); 351 break; 352 case arm_compute::DataType::U8: 353 case arm_compute::DataType::QASYMM8: 354 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 355 this->GetTensor()); 356 break; 357 case arm_compute::DataType::QASYMM8_SIGNED: 358 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 359 this->GetTensor()); 360 break; 361 case arm_compute::DataType::S16: 362 case arm_compute::DataType::QSYMM16: 363 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 364 this->GetTensor()); 365 break; 366 case arm_compute::DataType::S32: 367 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 368 this->GetTensor()); 369 break; 370 default: 371 { 372 throw armnn::UnimplementedException(); 373 } 374 } 375 } 376 377 arm_compute::SubTensor m_Tensor; 378 ITensorHandle* parentHandle = nullptr; 379 }; 380 381 } // namespace armnn 382