• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <BFloat16.hpp>
8 #include <Half.hpp>
9 
10 #include <armnn/utility/Assert.hpp>
11 
12 #include <aclCommon/ArmComputeTensorHandle.hpp>
13 #include <aclCommon/ArmComputeTensorUtils.hpp>
14 #include <armnn/utility/PolymorphicDowncast.hpp>
15 
16 #include <arm_compute/runtime/MemoryGroup.h>
17 #include <arm_compute/runtime/IMemoryGroup.h>
18 #include <arm_compute/runtime/Tensor.h>
19 #include <arm_compute/runtime/SubTensor.h>
20 #include <arm_compute/core/TensorShape.h>
21 #include <arm_compute/core/Coordinates.h>
22 
23 namespace armnn
24 {
25 
26 class NeonTensorHandle : public IAclTensorHandle
27 {
28 public:
NeonTensorHandle(const TensorInfo & tensorInfo)29     NeonTensorHandle(const TensorInfo& tensorInfo)
30                      : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31                        m_Imported(false),
32                        m_IsImportEnabled(false),
33                        m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
34     {
35         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
36     }
37 
NeonTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Malloc))38     NeonTensorHandle(const TensorInfo& tensorInfo,
39                      DataLayout dataLayout,
40                      MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
41                      : m_ImportFlags(importFlags),
42                        m_Imported(false),
43                        m_IsImportEnabled(false),
44                        m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
45 
46 
47     {
48         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
49     }
50 
GetTensor()51     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
GetTensor() const52     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
53 
Allocate()54     virtual void Allocate() override
55     {
56         // If we have enabled Importing, don't Allocate the tensor
57         if (!m_IsImportEnabled)
58         {
59             armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
60         }
61     };
62 
Manage()63     virtual void Manage() override
64     {
65         // If we have enabled Importing, don't manage the tensor
66         if (!m_IsImportEnabled)
67         {
68             ARMNN_ASSERT(m_MemoryGroup != nullptr);
69             m_MemoryGroup->manage(&m_Tensor);
70         }
71     }
72 
GetParent() const73     virtual ITensorHandle* GetParent() const override { return nullptr; }
74 
GetDataType() const75     virtual arm_compute::DataType GetDataType() const override
76     {
77         return m_Tensor.info()->data_type();
78     }
79 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)80     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
81     {
82         m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
83     }
84 
Map(bool) const85     virtual const void* Map(bool /* blocking = true */) const override
86     {
87         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
88     }
89 
Unmap() const90     virtual void Unmap() const override {}
91 
GetStrides() const92     TensorShape GetStrides() const override
93     {
94         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
95     }
96 
GetShape() const97     TensorShape GetShape() const override
98     {
99         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
100     }
101 
SetImportFlags(MemorySourceFlags importFlags)102     void SetImportFlags(MemorySourceFlags importFlags)
103     {
104         m_ImportFlags = importFlags;
105     }
106 
GetImportFlags() const107     MemorySourceFlags GetImportFlags() const override
108     {
109         return m_ImportFlags;
110     }
111 
SetImportEnabledFlag(bool importEnabledFlag)112     void SetImportEnabledFlag(bool importEnabledFlag)
113     {
114         m_IsImportEnabled = importEnabledFlag;
115     }
116 
CanBeImported(void * memory,MemorySource source)117     bool CanBeImported(void* memory, MemorySource source) override
118     {
119         if (source != MemorySource::Malloc || reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment)
120         {
121             return false;
122         }
123         return true;
124     }
125 
Import(void * memory,MemorySource source)126     virtual bool Import(void* memory, MemorySource source) override
127     {
128         if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
129         {
130             if (source == MemorySource::Malloc && m_IsImportEnabled)
131             {
132                 if (!CanBeImported(memory, source))
133                 {
134                     throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
135                 }
136 
137                 // m_Tensor not yet Allocated
138                 if (!m_Imported && !m_Tensor.buffer())
139                 {
140                     arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
141                     // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
142                     // with the Status error message
143                     m_Imported = bool(status);
144                     if (!m_Imported)
145                     {
146                         throw MemoryImportException(status.error_description());
147                     }
148                     return m_Imported;
149                 }
150 
151                 // m_Tensor.buffer() initially allocated with Allocate().
152                 if (!m_Imported && m_Tensor.buffer())
153                 {
154                     throw MemoryImportException(
155                         "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
156                 }
157 
158                 // m_Tensor.buffer() previously imported.
159                 if (m_Imported)
160                 {
161                     arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
162                     // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
163                     // with the Status error message
164                     m_Imported = bool(status);
165                     if (!m_Imported)
166                     {
167                         throw MemoryImportException(status.error_description());
168                     }
169                     return m_Imported;
170                 }
171             }
172             else
173             {
174                 throw MemoryImportException("NeonTensorHandle::Import is disabled");
175             }
176         }
177         else
178         {
179             throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
180         }
181         return false;
182     }
183 
184 private:
185     // Only used for testing
CopyOutTo(void * memory) const186     void CopyOutTo(void* memory) const override
187     {
188         switch (this->GetDataType())
189         {
190             case arm_compute::DataType::F32:
191                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
192                                                                  static_cast<float*>(memory));
193                 break;
194             case arm_compute::DataType::U8:
195             case arm_compute::DataType::QASYMM8:
196                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197                                                                  static_cast<uint8_t*>(memory));
198                 break;
199             case arm_compute::DataType::QSYMM8:
200             case arm_compute::DataType::QASYMM8_SIGNED:
201                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
202                                                                  static_cast<int8_t*>(memory));
203                 break;
204             case arm_compute::DataType::BFLOAT16:
205                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
206                                                                  static_cast<armnn::BFloat16*>(memory));
207                 break;
208             case arm_compute::DataType::F16:
209                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
210                                                                  static_cast<armnn::Half*>(memory));
211                 break;
212             case arm_compute::DataType::S16:
213             case arm_compute::DataType::QSYMM16:
214                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
215                                                                  static_cast<int16_t*>(memory));
216                 break;
217             case arm_compute::DataType::S32:
218                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
219                                                                  static_cast<int32_t*>(memory));
220                 break;
221             default:
222             {
223                 throw armnn::UnimplementedException();
224             }
225         }
226     }
227 
228     // Only used for testing
CopyInFrom(const void * memory)229     void CopyInFrom(const void* memory) override
230     {
231         switch (this->GetDataType())
232         {
233             case arm_compute::DataType::F32:
234                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
235                                                                  this->GetTensor());
236                 break;
237             case arm_compute::DataType::U8:
238             case arm_compute::DataType::QASYMM8:
239                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
240                                                                  this->GetTensor());
241                 break;
242             case arm_compute::DataType::QSYMM8:
243             case arm_compute::DataType::QASYMM8_SIGNED:
244             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
245                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
246                                                                  this->GetTensor());
247                 break;
248             case arm_compute::DataType::BFLOAT16:
249                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
250                                                                  this->GetTensor());
251                 break;
252             case arm_compute::DataType::F16:
253                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
254                                                                  this->GetTensor());
255                 break;
256             case arm_compute::DataType::S16:
257             case arm_compute::DataType::QSYMM16:
258                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
259                                                                  this->GetTensor());
260                 break;
261             case arm_compute::DataType::S32:
262                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
263                                                                  this->GetTensor());
264                 break;
265             default:
266             {
267                 throw armnn::UnimplementedException();
268             }
269         }
270     }
271 
272     arm_compute::Tensor m_Tensor;
273     std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
274     MemorySourceFlags m_ImportFlags;
275     bool m_Imported;
276     bool m_IsImportEnabled;
277     const uintptr_t m_TypeAlignment;
278 };
279 
280 class NeonSubTensorHandle : public IAclTensorHandle
281 {
282 public:
NeonSubTensorHandle(IAclTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)283     NeonSubTensorHandle(IAclTensorHandle* parent,
284                         const arm_compute::TensorShape& shape,
285                         const arm_compute::Coordinates& coords)
286      : m_Tensor(&parent->GetTensor(), shape, coords)
287     {
288         parentHandle = parent;
289     }
290 
GetTensor()291     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
GetTensor() const292     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
293 
Allocate()294     virtual void Allocate() override {}
Manage()295     virtual void Manage() override {}
296 
GetParent() const297     virtual ITensorHandle* GetParent() const override { return parentHandle; }
298 
GetDataType() const299     virtual arm_compute::DataType GetDataType() const override
300     {
301         return m_Tensor.info()->data_type();
302     }
303 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)304     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
305 
Map(bool) const306     virtual const void* Map(bool /* blocking = true */) const override
307     {
308         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
309     }
Unmap() const310     virtual void Unmap() const override {}
311 
GetStrides() const312     TensorShape GetStrides() const override
313     {
314         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
315     }
316 
GetShape() const317     TensorShape GetShape() const override
318     {
319         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
320     }
321 
322 private:
323     // Only used for testing
CopyOutTo(void * memory) const324     void CopyOutTo(void* memory) const override
325     {
326         switch (this->GetDataType())
327         {
328             case arm_compute::DataType::F32:
329                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
330                                                                  static_cast<float*>(memory));
331                 break;
332             case arm_compute::DataType::U8:
333             case arm_compute::DataType::QASYMM8:
334                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
335                                                                  static_cast<uint8_t*>(memory));
336                 break;
337             case arm_compute::DataType::QSYMM8:
338             case arm_compute::DataType::QASYMM8_SIGNED:
339                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
340                                                                  static_cast<int8_t*>(memory));
341                 break;
342             case arm_compute::DataType::S16:
343             case arm_compute::DataType::QSYMM16:
344                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
345                                                                  static_cast<int16_t*>(memory));
346                 break;
347             case arm_compute::DataType::S32:
348                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
349                                                                  static_cast<int32_t*>(memory));
350                 break;
351             default:
352             {
353                 throw armnn::UnimplementedException();
354             }
355         }
356     }
357 
358     // Only used for testing
CopyInFrom(const void * memory)359     void CopyInFrom(const void* memory) override
360     {
361         switch (this->GetDataType())
362         {
363             case arm_compute::DataType::F32:
364                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
365                                                                  this->GetTensor());
366                 break;
367             case arm_compute::DataType::U8:
368             case arm_compute::DataType::QASYMM8:
369                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
370                                                                  this->GetTensor());
371                 break;
372             case arm_compute::DataType::QSYMM8:
373             case arm_compute::DataType::QASYMM8_SIGNED:
374                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
375                                                                  this->GetTensor());
376                 break;
377             case arm_compute::DataType::S16:
378             case arm_compute::DataType::QSYMM16:
379                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
380                                                                  this->GetTensor());
381                 break;
382             case arm_compute::DataType::S32:
383                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
384                                                                  this->GetTensor());
385                 break;
386             default:
387             {
388                 throw armnn::UnimplementedException();
389             }
390         }
391     }
392 
393     arm_compute::SubTensor m_Tensor;
394     ITensorHandle* parentHandle = nullptr;
395 };
396 
397 } // namespace armnn
398