• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <BFloat16.hpp>
8 #include <Half.hpp>
9 
10 #include <armnn/utility/Assert.hpp>
11 
12 #include <aclCommon/ArmComputeTensorHandle.hpp>
13 #include <aclCommon/ArmComputeTensorUtils.hpp>
14 #include <armnn/utility/PolymorphicDowncast.hpp>
15 
16 #include <arm_compute/runtime/MemoryGroup.h>
17 #include <arm_compute/runtime/IMemoryGroup.h>
18 #include <arm_compute/runtime/Tensor.h>
19 #include <arm_compute/runtime/SubTensor.h>
20 #include <arm_compute/core/TensorShape.h>
21 #include <arm_compute/core/Coordinates.h>
22 
23 namespace armnn
24 {
25 
26 class NeonTensorHandle : public IAclTensorHandle
27 {
28 public:
NeonTensorHandle(const TensorInfo & tensorInfo)29     NeonTensorHandle(const TensorInfo& tensorInfo)
30                      : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31                        m_Imported(false),
32                        m_IsImportEnabled(false)
33     {
34         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35     }
36 
NeonTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Malloc))37     NeonTensorHandle(const TensorInfo& tensorInfo,
38                      DataLayout dataLayout,
39                      MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
40                      : m_ImportFlags(importFlags),
41                        m_Imported(false),
42                        m_IsImportEnabled(false)
43 
44     {
45         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
46     }
47 
GetTensor()48     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
GetTensor() const49     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
50 
Allocate()51     virtual void Allocate() override
52     {
53         // If we have enabled Importing, don't Allocate the tensor
54         if (!m_IsImportEnabled)
55         {
56             armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
57         }
58     };
59 
Manage()60     virtual void Manage() override
61     {
62         // If we have enabled Importing, don't manage the tensor
63         if (!m_IsImportEnabled)
64         {
65             ARMNN_ASSERT(m_MemoryGroup != nullptr);
66             m_MemoryGroup->manage(&m_Tensor);
67         }
68     }
69 
GetParent() const70     virtual ITensorHandle* GetParent() const override { return nullptr; }
71 
GetDataType() const72     virtual arm_compute::DataType GetDataType() const override
73     {
74         return m_Tensor.info()->data_type();
75     }
76 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)77     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
78     {
79         m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
80     }
81 
Map(bool) const82     virtual const void* Map(bool /* blocking = true */) const override
83     {
84         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
85     }
86 
Unmap() const87     virtual void Unmap() const override {}
88 
GetStrides() const89     TensorShape GetStrides() const override
90     {
91         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
92     }
93 
GetShape() const94     TensorShape GetShape() const override
95     {
96         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
97     }
98 
SetImportFlags(MemorySourceFlags importFlags)99     void SetImportFlags(MemorySourceFlags importFlags)
100     {
101         m_ImportFlags = importFlags;
102     }
103 
GetImportFlags() const104     MemorySourceFlags GetImportFlags() const override
105     {
106         return m_ImportFlags;
107     }
108 
SetImportEnabledFlag(bool importEnabledFlag)109     void SetImportEnabledFlag(bool importEnabledFlag)
110     {
111         m_IsImportEnabled = importEnabledFlag;
112     }
113 
Import(void * memory,MemorySource source)114     virtual bool Import(void* memory, MemorySource source) override
115     {
116         if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117         {
118             if (source == MemorySource::Malloc && m_IsImportEnabled)
119             {
120                 // Checks the 16 byte memory alignment
121                 constexpr uintptr_t alignment = sizeof(size_t);
122                 if (reinterpret_cast<uintptr_t>(memory) % alignment)
123                 {
124                     throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
125                 }
126 
127                 // m_Tensor not yet Allocated
128                 if (!m_Imported && !m_Tensor.buffer())
129                 {
130                     arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
131                     // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
132                     // with the Status error message
133                     m_Imported = bool(status);
134                     if (!m_Imported)
135                     {
136                         throw MemoryImportException(status.error_description());
137                     }
138                     return m_Imported;
139                 }
140 
141                 // m_Tensor.buffer() initially allocated with Allocate().
142                 if (!m_Imported && m_Tensor.buffer())
143                 {
144                     throw MemoryImportException(
145                         "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
146                 }
147 
148                 // m_Tensor.buffer() previously imported.
149                 if (m_Imported)
150                 {
151                     arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
152                     // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
153                     // with the Status error message
154                     m_Imported = bool(status);
155                     if (!m_Imported)
156                     {
157                         throw MemoryImportException(status.error_description());
158                     }
159                     return m_Imported;
160                 }
161             }
162             else
163             {
164                 throw MemoryImportException("NeonTensorHandle::Import is disabled");
165             }
166         }
167         else
168         {
169             throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
170         }
171         return false;
172     }
173 
174 private:
175     // Only used for testing
CopyOutTo(void * memory) const176     void CopyOutTo(void* memory) const override
177     {
178         switch (this->GetDataType())
179         {
180             case arm_compute::DataType::F32:
181                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
182                                                                  static_cast<float*>(memory));
183                 break;
184             case arm_compute::DataType::U8:
185             case arm_compute::DataType::QASYMM8:
186                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
187                                                                  static_cast<uint8_t*>(memory));
188                 break;
189             case arm_compute::DataType::QASYMM8_SIGNED:
190                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
191                                                                  static_cast<int8_t*>(memory));
192                 break;
193             case arm_compute::DataType::BFLOAT16:
194                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
195                                                                  static_cast<armnn::BFloat16*>(memory));
196                 break;
197             case arm_compute::DataType::F16:
198                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
199                                                                  static_cast<armnn::Half*>(memory));
200                 break;
201             case arm_compute::DataType::S16:
202             case arm_compute::DataType::QSYMM16:
203                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
204                                                                  static_cast<int16_t*>(memory));
205                 break;
206             case arm_compute::DataType::S32:
207                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
208                                                                  static_cast<int32_t*>(memory));
209                 break;
210             default:
211             {
212                 throw armnn::UnimplementedException();
213             }
214         }
215     }
216 
217     // Only used for testing
CopyInFrom(const void * memory)218     void CopyInFrom(const void* memory) override
219     {
220         switch (this->GetDataType())
221         {
222             case arm_compute::DataType::F32:
223                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
224                                                                  this->GetTensor());
225                 break;
226             case arm_compute::DataType::U8:
227             case arm_compute::DataType::QASYMM8:
228                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
229                                                                  this->GetTensor());
230                 break;
231             case arm_compute::DataType::QASYMM8_SIGNED:
232                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
233                                                                  this->GetTensor());
234                 break;
235             case arm_compute::DataType::BFLOAT16:
236                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
237                                                                  this->GetTensor());
238                 break;
239             case arm_compute::DataType::F16:
240                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
241                                                                  this->GetTensor());
242                 break;
243             case arm_compute::DataType::S16:
244             case arm_compute::DataType::QSYMM16:
245                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
246                                                                  this->GetTensor());
247                 break;
248             case arm_compute::DataType::S32:
249                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
250                                                                  this->GetTensor());
251                 break;
252             default:
253             {
254                 throw armnn::UnimplementedException();
255             }
256         }
257     }
258 
259     arm_compute::Tensor m_Tensor;
260     std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
261     MemorySourceFlags m_ImportFlags;
262     bool m_Imported;
263     bool m_IsImportEnabled;
264 };
265 
266 class NeonSubTensorHandle : public IAclTensorHandle
267 {
268 public:
NeonSubTensorHandle(IAclTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)269     NeonSubTensorHandle(IAclTensorHandle* parent,
270                         const arm_compute::TensorShape& shape,
271                         const arm_compute::Coordinates& coords)
272      : m_Tensor(&parent->GetTensor(), shape, coords)
273     {
274         parentHandle = parent;
275     }
276 
GetTensor()277     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
GetTensor() const278     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
279 
Allocate()280     virtual void Allocate() override {}
Manage()281     virtual void Manage() override {}
282 
GetParent() const283     virtual ITensorHandle* GetParent() const override { return parentHandle; }
284 
GetDataType() const285     virtual arm_compute::DataType GetDataType() const override
286     {
287         return m_Tensor.info()->data_type();
288     }
289 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)290     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
291 
Map(bool) const292     virtual const void* Map(bool /* blocking = true */) const override
293     {
294         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
295     }
Unmap() const296     virtual void Unmap() const override {}
297 
GetStrides() const298     TensorShape GetStrides() const override
299     {
300         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
301     }
302 
GetShape() const303     TensorShape GetShape() const override
304     {
305         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
306     }
307 
308 private:
309     // Only used for testing
CopyOutTo(void * memory) const310     void CopyOutTo(void* memory) const override
311     {
312         switch (this->GetDataType())
313         {
314             case arm_compute::DataType::F32:
315                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
316                                                                  static_cast<float*>(memory));
317                 break;
318             case arm_compute::DataType::U8:
319             case arm_compute::DataType::QASYMM8:
320                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
321                                                                  static_cast<uint8_t*>(memory));
322                 break;
323             case arm_compute::DataType::QASYMM8_SIGNED:
324                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
325                                                                  static_cast<int8_t*>(memory));
326                 break;
327             case arm_compute::DataType::S16:
328             case arm_compute::DataType::QSYMM16:
329                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
330                                                                  static_cast<int16_t*>(memory));
331                 break;
332             case arm_compute::DataType::S32:
333                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
334                                                                  static_cast<int32_t*>(memory));
335                 break;
336             default:
337             {
338                 throw armnn::UnimplementedException();
339             }
340         }
341     }
342 
343     // Only used for testing
CopyInFrom(const void * memory)344     void CopyInFrom(const void* memory) override
345     {
346         switch (this->GetDataType())
347         {
348             case arm_compute::DataType::F32:
349                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
350                                                                  this->GetTensor());
351                 break;
352             case arm_compute::DataType::U8:
353             case arm_compute::DataType::QASYMM8:
354                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
355                                                                  this->GetTensor());
356                 break;
357             case arm_compute::DataType::QASYMM8_SIGNED:
358                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
359                                                                  this->GetTensor());
360                 break;
361             case arm_compute::DataType::S16:
362             case arm_compute::DataType::QSYMM16:
363                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
364                                                                  this->GetTensor());
365                 break;
366             case arm_compute::DataType::S32:
367                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
368                                                                  this->GetTensor());
369                 break;
370             default:
371             {
372                 throw armnn::UnimplementedException();
373             }
374         }
375     }
376 
377     arm_compute::SubTensor m_Tensor;
378     ITensorHandle* parentHandle = nullptr;
379 };
380 
381 } // namespace armnn
382