• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <aclCommon/ArmComputeTensorHandle.hpp>
8 #include <aclCommon/ArmComputeTensorUtils.hpp>
9 
10 #include <Half.hpp>
11 
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 
14 #include <arm_compute/runtime/CL/CLTensor.h>
15 #include <arm_compute/runtime/CL/CLSubTensor.h>
16 #include <arm_compute/runtime/IMemoryGroup.h>
17 #include <arm_compute/runtime/MemoryGroup.h>
18 #include <arm_compute/core/TensorShape.h>
19 #include <arm_compute/core/Coordinates.h>
20 
21 namespace armnn
22 {
23 
24 
25 class IClTensorHandle : public IAclTensorHandle
26 {
27 public:
28     virtual arm_compute::ICLTensor& GetTensor() = 0;
29     virtual arm_compute::ICLTensor const& GetTensor() const = 0;
30     virtual arm_compute::DataType GetDataType() const = 0;
31     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
32 };
33 
34 class ClTensorHandle : public IClTensorHandle
35 {
36 public:
ClTensorHandle(const TensorInfo & tensorInfo)37     ClTensorHandle(const TensorInfo& tensorInfo)
38     {
39         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
40     }
41 
ClTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout)42     ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
43     {
44         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45     }
46 
GetTensor()47     arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
GetTensor() const48     arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
Allocate()49     virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
50 
Manage()51     virtual void Manage() override
52     {
53         assert(m_MemoryGroup != nullptr);
54         m_MemoryGroup->manage(&m_Tensor);
55     }
56 
Map(bool blocking=true) const57     virtual const void* Map(bool blocking = true) const override
58     {
59         const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
60         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
61     }
62 
Unmap() const63     virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
64 
GetParent() const65     virtual ITensorHandle* GetParent() const override { return nullptr; }
66 
GetDataType() const67     virtual arm_compute::DataType GetDataType() const override
68     {
69         return m_Tensor.info()->data_type();
70     }
71 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)72     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
73     {
74         m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
75     }
76 
GetStrides() const77     TensorShape GetStrides() const override
78     {
79         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
80     }
81 
GetShape() const82     TensorShape GetShape() const override
83     {
84         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
85     }
86 
87 private:
88     // Only used for testing
CopyOutTo(void * memory) const89     void CopyOutTo(void* memory) const override
90     {
91         const_cast<armnn::ClTensorHandle*>(this)->Map(true);
92         switch(this->GetDataType())
93         {
94             case arm_compute::DataType::F32:
95                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
96                                                                  static_cast<float*>(memory));
97                 break;
98             case arm_compute::DataType::U8:
99             case arm_compute::DataType::QASYMM8:
100                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
101                                                                  static_cast<uint8_t*>(memory));
102                 break;
103             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
104             case arm_compute::DataType::QASYMM8_SIGNED:
105                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
106                                                                  static_cast<int8_t*>(memory));
107                 break;
108             case arm_compute::DataType::F16:
109                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
110                                                                  static_cast<armnn::Half*>(memory));
111                 break;
112             case arm_compute::DataType::S16:
113             case arm_compute::DataType::QSYMM16:
114                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
115                                                                  static_cast<int16_t*>(memory));
116                 break;
117             case arm_compute::DataType::S32:
118                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
119                                                                  static_cast<int32_t*>(memory));
120                 break;
121             default:
122             {
123                 throw armnn::UnimplementedException();
124             }
125         }
126         const_cast<armnn::ClTensorHandle*>(this)->Unmap();
127     }
128 
129     // Only used for testing
CopyInFrom(const void * memory)130     void CopyInFrom(const void* memory) override
131     {
132         this->Map(true);
133         switch(this->GetDataType())
134         {
135             case arm_compute::DataType::F32:
136                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
137                                                                  this->GetTensor());
138                 break;
139             case arm_compute::DataType::U8:
140             case arm_compute::DataType::QASYMM8:
141                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
142                                                                  this->GetTensor());
143                 break;
144             case arm_compute::DataType::F16:
145                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
146                                                                  this->GetTensor());
147                 break;
148             case arm_compute::DataType::S16:
149             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
150             case arm_compute::DataType::QASYMM8_SIGNED:
151                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
152                                                                  this->GetTensor());
153                 break;
154             case arm_compute::DataType::QSYMM16:
155                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
156                                                                  this->GetTensor());
157                 break;
158             case arm_compute::DataType::S32:
159                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
160                                                                  this->GetTensor());
161                 break;
162             default:
163             {
164                 throw armnn::UnimplementedException();
165             }
166         }
167         this->Unmap();
168     }
169 
170     arm_compute::CLTensor m_Tensor;
171     std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
172 };
173 
174 class ClSubTensorHandle : public IClTensorHandle
175 {
176 public:
ClSubTensorHandle(IClTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)177     ClSubTensorHandle(IClTensorHandle* parent,
178                       const arm_compute::TensorShape& shape,
179                       const arm_compute::Coordinates& coords)
180     : m_Tensor(&parent->GetTensor(), shape, coords)
181     {
182         parentHandle = parent;
183     }
184 
GetTensor()185     arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
GetTensor() const186     arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
187 
Allocate()188     virtual void Allocate() override {}
Manage()189     virtual void Manage() override {}
190 
Map(bool blocking=true) const191     virtual const void* Map(bool blocking = true) const override
192     {
193         const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
194         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
195     }
Unmap() const196     virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
197 
GetParent() const198     virtual ITensorHandle* GetParent() const override { return parentHandle; }
199 
GetDataType() const200     virtual arm_compute::DataType GetDataType() const override
201     {
202         return m_Tensor.info()->data_type();
203     }
204 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)205     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
206 
GetStrides() const207     TensorShape GetStrides() const override
208     {
209         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
210     }
211 
GetShape() const212     TensorShape GetShape() const override
213     {
214         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
215     }
216 
217 private:
218     // Only used for testing
CopyOutTo(void * memory) const219     void CopyOutTo(void* memory) const override
220     {
221         const_cast<ClSubTensorHandle*>(this)->Map(true);
222         switch(this->GetDataType())
223         {
224             case arm_compute::DataType::F32:
225                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
226                                                                  static_cast<float*>(memory));
227                 break;
228             case arm_compute::DataType::U8:
229             case arm_compute::DataType::QASYMM8:
230                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
231                                                                  static_cast<uint8_t*>(memory));
232                 break;
233             case arm_compute::DataType::F16:
234                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
235                                                                  static_cast<armnn::Half*>(memory));
236                 break;
237             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
238             case arm_compute::DataType::QASYMM8_SIGNED:
239             armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
240                                                              static_cast<int8_t*>(memory));
241                 break;
242             case arm_compute::DataType::S16:
243             case arm_compute::DataType::QSYMM16:
244                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
245                                                                  static_cast<int16_t*>(memory));
246                 break;
247             case arm_compute::DataType::S32:
248                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
249                                                                  static_cast<int32_t*>(memory));
250                 break;
251             default:
252             {
253                 throw armnn::UnimplementedException();
254             }
255         }
256         const_cast<ClSubTensorHandle*>(this)->Unmap();
257     }
258 
259     // Only used for testing
CopyInFrom(const void * memory)260     void CopyInFrom(const void* memory) override
261     {
262         this->Map(true);
263         switch(this->GetDataType())
264         {
265             case arm_compute::DataType::F32:
266                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
267                                                                  this->GetTensor());
268                 break;
269             case arm_compute::DataType::U8:
270             case arm_compute::DataType::QASYMM8:
271                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
272                                                                  this->GetTensor());
273                 break;
274             case arm_compute::DataType::F16:
275                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
276                                                                  this->GetTensor());
277                 break;
278             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
279             case arm_compute::DataType::QASYMM8_SIGNED:
280                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
281                                                                  this->GetTensor());
282                 break;
283             case arm_compute::DataType::S16:
284             case arm_compute::DataType::QSYMM16:
285                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
286                                                                  this->GetTensor());
287                 break;
288             case arm_compute::DataType::S32:
289                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
290                                                                  this->GetTensor());
291                 break;
292             default:
293             {
294                 throw armnn::UnimplementedException();
295             }
296         }
297         this->Unmap();
298     }
299 
300     mutable arm_compute::CLSubTensor m_Tensor;
301     ITensorHandle* parentHandle = nullptr;
302 };
303 
304 } // namespace armnn
305