1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <BFloat16.hpp> 8 #include <Half.hpp> 9 10 #include <armnn/utility/Assert.hpp> 11 12 #include <aclCommon/ArmComputeTensorHandle.hpp> 13 #include <aclCommon/ArmComputeTensorUtils.hpp> 14 #include <armnn/utility/PolymorphicDowncast.hpp> 15 16 #include <arm_compute/runtime/MemoryGroup.h> 17 #include <arm_compute/runtime/IMemoryGroup.h> 18 #include <arm_compute/runtime/Tensor.h> 19 #include <arm_compute/runtime/SubTensor.h> 20 #include <arm_compute/core/TensorShape.h> 21 #include <arm_compute/core/Coordinates.h> 22 23 namespace armnn 24 { 25 26 class NeonTensorHandle : public IAclTensorHandle 27 { 28 public: NeonTensorHandle(const TensorInfo & tensorInfo)29 NeonTensorHandle(const TensorInfo& tensorInfo) 30 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)), 31 m_Imported(false), 32 m_IsImportEnabled(false), 33 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType())) 34 { 35 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); 36 } 37 NeonTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Malloc))38 NeonTensorHandle(const TensorInfo& tensorInfo, 39 DataLayout dataLayout, 40 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc)) 41 : m_ImportFlags(importFlags), 42 m_Imported(false), 43 m_IsImportEnabled(false), 44 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType())) 45 46 47 { 48 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); 49 } 50 GetTensor()51 arm_compute::ITensor& GetTensor() override { return m_Tensor; } GetTensor() const52 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; } 53 Allocate()54 virtual void Allocate() override 55 { 56 // If we have enabled Importing, don't Allocate the tensor 57 if (!m_IsImportEnabled) 58 { 59 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); 60 } 61 }; 62 Manage()63 virtual void Manage() override 64 { 65 // If we have enabled Importing, don't manage the tensor 66 if (!m_IsImportEnabled) 67 { 68 ARMNN_ASSERT(m_MemoryGroup != nullptr); 69 m_MemoryGroup->manage(&m_Tensor); 70 } 71 } 72 GetParent() const73 virtual ITensorHandle* GetParent() const override { return nullptr; } 74 GetDataType() const75 virtual arm_compute::DataType GetDataType() const override 76 { 77 return m_Tensor.info()->data_type(); 78 } 79 SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)80 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override 81 { 82 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup); 83 } 84 Map(bool) const85 virtual const void* Map(bool /* blocking = true */) const override 86 { 87 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 88 } 89 Unmap() const90 virtual void Unmap() const override {} 91 GetStrides() const92 TensorShape GetStrides() const override 93 { 94 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 95 } 96 GetShape() const97 TensorShape GetShape() const override 98 { 99 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 100 } 101 SetImportFlags(MemorySourceFlags importFlags)102 void SetImportFlags(MemorySourceFlags importFlags) 103 { 104 m_ImportFlags = importFlags; 105 } 106 GetImportFlags() const107 MemorySourceFlags GetImportFlags() const override 108 { 109 return m_ImportFlags; 110 } 111 SetImportEnabledFlag(bool importEnabledFlag)112 void SetImportEnabledFlag(bool importEnabledFlag) 113 { 114 m_IsImportEnabled = importEnabledFlag; 115 } 116 CanBeImported(void * memory,MemorySource source)117 bool CanBeImported(void* memory, MemorySource source) override 118 { 119 if (source != MemorySource::Malloc || reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment) 120 { 121 return false; 122 } 123 return true; 124 } 125 Import(void * memory,MemorySource source)126 virtual bool Import(void* memory, MemorySource source) override 127 { 128 if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) 129 { 130 if (source == MemorySource::Malloc && m_IsImportEnabled) 131 { 132 if (!CanBeImported(memory, source)) 133 { 134 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory"); 135 } 136 137 // m_Tensor not yet Allocated 138 if (!m_Imported && !m_Tensor.buffer()) 139 { 140 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory); 141 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception 142 // with the Status error message 143 m_Imported = bool(status); 144 if (!m_Imported) 145 { 146 throw MemoryImportException(status.error_description()); 147 } 148 return m_Imported; 149 } 150 151 // m_Tensor.buffer() initially allocated with Allocate(). 152 if (!m_Imported && m_Tensor.buffer()) 153 { 154 throw MemoryImportException( 155 "NeonTensorHandle::Import Attempting to import on an already allocated tensor"); 156 } 157 158 // m_Tensor.buffer() previously imported. 159 if (m_Imported) 160 { 161 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory); 162 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception 163 // with the Status error message 164 m_Imported = bool(status); 165 if (!m_Imported) 166 { 167 throw MemoryImportException(status.error_description()); 168 } 169 return m_Imported; 170 } 171 } 172 else 173 { 174 throw MemoryImportException("NeonTensorHandle::Import is disabled"); 175 } 176 } 177 else 178 { 179 throw MemoryImportException("NeonTensorHandle::Incorrect import flag"); 180 } 181 return false; 182 } 183 184 private: 185 // Only used for testing CopyOutTo(void * memory) const186 void CopyOutTo(void* memory) const override 187 { 188 switch (this->GetDataType()) 189 { 190 case arm_compute::DataType::F32: 191 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 192 static_cast<float*>(memory)); 193 break; 194 case arm_compute::DataType::U8: 195 case arm_compute::DataType::QASYMM8: 196 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 197 static_cast<uint8_t*>(memory)); 198 break; 199 case arm_compute::DataType::QSYMM8: 200 case arm_compute::DataType::QASYMM8_SIGNED: 201 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 202 static_cast<int8_t*>(memory)); 203 break; 204 case arm_compute::DataType::BFLOAT16: 205 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 206 static_cast<armnn::BFloat16*>(memory)); 207 break; 208 case arm_compute::DataType::F16: 209 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 210 static_cast<armnn::Half*>(memory)); 211 break; 212 case arm_compute::DataType::S16: 213 case arm_compute::DataType::QSYMM16: 214 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 215 static_cast<int16_t*>(memory)); 216 break; 217 case arm_compute::DataType::S32: 218 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 219 static_cast<int32_t*>(memory)); 220 break; 221 default: 222 { 223 throw armnn::UnimplementedException(); 224 } 225 } 226 } 227 228 // Only used for testing CopyInFrom(const void * memory)229 void CopyInFrom(const void* memory) override 230 { 231 switch (this->GetDataType()) 232 { 233 case arm_compute::DataType::F32: 234 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 235 this->GetTensor()); 236 break; 237 case arm_compute::DataType::U8: 238 case arm_compute::DataType::QASYMM8: 239 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 240 this->GetTensor()); 241 break; 242 case arm_compute::DataType::QSYMM8: 243 case arm_compute::DataType::QASYMM8_SIGNED: 244 case arm_compute::DataType::QSYMM8_PER_CHANNEL: 245 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 246 this->GetTensor()); 247 break; 248 case arm_compute::DataType::BFLOAT16: 249 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory), 250 this->GetTensor()); 251 break; 252 case arm_compute::DataType::F16: 253 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), 254 this->GetTensor()); 255 break; 256 case arm_compute::DataType::S16: 257 case arm_compute::DataType::QSYMM16: 258 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 259 this->GetTensor()); 260 break; 261 case arm_compute::DataType::S32: 262 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 263 this->GetTensor()); 264 break; 265 default: 266 { 267 throw armnn::UnimplementedException(); 268 } 269 } 270 } 271 272 arm_compute::Tensor m_Tensor; 273 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; 274 MemorySourceFlags m_ImportFlags; 275 bool m_Imported; 276 bool m_IsImportEnabled; 277 const uintptr_t m_TypeAlignment; 278 }; 279 280 class NeonSubTensorHandle : public IAclTensorHandle 281 { 282 public: NeonSubTensorHandle(IAclTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)283 NeonSubTensorHandle(IAclTensorHandle* parent, 284 const arm_compute::TensorShape& shape, 285 const arm_compute::Coordinates& coords) 286 : m_Tensor(&parent->GetTensor(), shape, coords) 287 { 288 parentHandle = parent; 289 } 290 GetTensor()291 arm_compute::ITensor& GetTensor() override { return m_Tensor; } GetTensor() const292 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; } 293 Allocate()294 virtual void Allocate() override {} Manage()295 virtual void Manage() override {} 296 GetParent() const297 virtual ITensorHandle* GetParent() const override { return parentHandle; } 298 GetDataType() const299 virtual arm_compute::DataType GetDataType() const override 300 { 301 return m_Tensor.info()->data_type(); 302 } 303 SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)304 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {} 305 Map(bool) const306 virtual const void* Map(bool /* blocking = true */) const override 307 { 308 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 309 } Unmap() const310 virtual void Unmap() const override {} 311 GetStrides() const312 TensorShape GetStrides() const override 313 { 314 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 315 } 316 GetShape() const317 TensorShape GetShape() const override 318 { 319 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 320 } 321 322 private: 323 // Only used for testing CopyOutTo(void * memory) const324 void CopyOutTo(void* memory) const override 325 { 326 switch (this->GetDataType()) 327 { 328 case arm_compute::DataType::F32: 329 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 330 static_cast<float*>(memory)); 331 break; 332 case arm_compute::DataType::U8: 333 case arm_compute::DataType::QASYMM8: 334 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 335 static_cast<uint8_t*>(memory)); 336 break; 337 case arm_compute::DataType::QSYMM8: 338 case arm_compute::DataType::QASYMM8_SIGNED: 339 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 340 static_cast<int8_t*>(memory)); 341 break; 342 case arm_compute::DataType::S16: 343 case arm_compute::DataType::QSYMM16: 344 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 345 static_cast<int16_t*>(memory)); 346 break; 347 case arm_compute::DataType::S32: 348 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 349 static_cast<int32_t*>(memory)); 350 break; 351 default: 352 { 353 throw armnn::UnimplementedException(); 354 } 355 } 356 } 357 358 // Only used for testing CopyInFrom(const void * memory)359 void CopyInFrom(const void* memory) override 360 { 361 switch (this->GetDataType()) 362 { 363 case arm_compute::DataType::F32: 364 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 365 this->GetTensor()); 366 break; 367 case arm_compute::DataType::U8: 368 case arm_compute::DataType::QASYMM8: 369 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 370 this->GetTensor()); 371 break; 372 case arm_compute::DataType::QSYMM8: 373 case arm_compute::DataType::QASYMM8_SIGNED: 374 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 375 this->GetTensor()); 376 break; 377 case arm_compute::DataType::S16: 378 case arm_compute::DataType::QSYMM16: 379 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 380 this->GetTensor()); 381 break; 382 case arm_compute::DataType::S32: 383 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 384 this->GetTensor()); 385 break; 386 default: 387 { 388 throw armnn::UnimplementedException(); 389 } 390 } 391 } 392 393 arm_compute::SubTensor m_Tensor; 394 ITensorHandle* parentHandle = nullptr; 395 }; 396 397 } // namespace armnn 398