• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ITensorHandle.hpp"
9 
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/utility/Assert.hpp>
12 #include <armnnUtils/CompatibleTypes.hpp>
13 
14 #include <algorithm>
15 
16 namespace armnn
17 {
18 
19 // Get a TensorShape representing the strides (in bytes) for each dimension
20 // of a tensor, assuming fully packed data with no padding
21 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
22 
23 // Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data.
24 class ConstTensorHandle : public ITensorHandle
25 {
26 public:
27     template <typename T>
GetConstTensor() const28     const T* GetConstTensor() const
29     {
30         if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
31         {
32             return reinterpret_cast<const T*>(m_Memory);
33         }
34         else
35         {
36             throw armnn::Exception("Attempting to get not compatible type tensor!");
37         }
38     }
39 
GetTensorInfo() const40     const TensorInfo& GetTensorInfo() const
41     {
42         return m_TensorInfo;
43     }
44 
Manage()45     virtual void Manage() override {}
46 
GetParent() const47     virtual ITensorHandle* GetParent() const override { return nullptr; }
48 
Map(bool) const49     virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
Unmap() const50     virtual void Unmap() const override {}
51 
GetStrides() const52     TensorShape GetStrides() const override
53     {
54         return GetUnpaddedTensorStrides(m_TensorInfo);
55     }
GetShape() const56     TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
57 
58 protected:
59     ConstTensorHandle(const TensorInfo& tensorInfo);
60 
SetConstMemory(const void * mem)61     void SetConstMemory(const void* mem) { m_Memory = mem; }
62 
63 private:
64     // Only used for testing
CopyOutTo(void *) const65     void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
CopyInFrom(const void *)66     void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
67 
68     ConstTensorHandle(const ConstTensorHandle& other) = delete;
69     ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete;
70 
71     TensorInfo m_TensorInfo;
72     const void* m_Memory;
73 };
74 
75 template<>
76 const void* ConstTensorHandle::GetConstTensor<void>() const;
77 
78 // Abstract specialization of ConstTensorHandle that allows write access to the same data.
79 class TensorHandle : public ConstTensorHandle
80 {
81 public:
82     template <typename T>
GetTensor() const83     T* GetTensor() const
84     {
85         if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
86         {
87             return reinterpret_cast<T*>(m_MutableMemory);
88         }
89         else
90         {
91             throw armnn::Exception("Attempting to get not compatible type tensor!");
92         }
93     }
94 
95 protected:
96     TensorHandle(const TensorInfo& tensorInfo);
97 
SetMemory(void * mem)98     void SetMemory(void* mem)
99     {
100         m_MutableMemory = mem;
101         SetConstMemory(m_MutableMemory);
102     }
103 
104 private:
105 
106     TensorHandle(const TensorHandle& other) = delete;
107     TensorHandle& operator=(const TensorHandle& other) = delete;
108     void* m_MutableMemory;
109 };
110 
111 template <>
112 void* TensorHandle::GetTensor<void>() const;
113 
114 // A TensorHandle that owns the wrapped memory region.
115 class ScopedTensorHandle : public TensorHandle
116 {
117 public:
118     explicit ScopedTensorHandle(const TensorInfo& tensorInfo);
119 
120     // Copies contents from Tensor.
121     explicit ScopedTensorHandle(const ConstTensor& tensor);
122 
123     // Copies contents from ConstTensorHandle
124     explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle);
125 
126     ScopedTensorHandle(const ScopedTensorHandle& other);
127     ScopedTensorHandle& operator=(const ScopedTensorHandle& other);
128     ~ScopedTensorHandle();
129 
130     virtual void Allocate() override;
131 
132 private:
133     // Only used for testing
134     void CopyOutTo(void* memory) const override;
135     void CopyInFrom(const void* memory) override;
136 
137     void CopyFrom(const ScopedTensorHandle& other);
138     void CopyFrom(const void* srcMemory, unsigned int numBytes);
139 };
140 
141 // A TensorHandle that wraps an already allocated memory region.
142 //
143 // Clients must make sure the passed in memory region stays alive for the lifetime of
144 // the PassthroughTensorHandle instance.
145 //
146 // Note there is no polymorphism to/from ConstPassthroughTensorHandle.
147 class PassthroughTensorHandle : public TensorHandle
148 {
149 public:
PassthroughTensorHandle(const TensorInfo & tensorInfo,void * mem)150     PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem)
151     :   TensorHandle(tensorInfo)
152     {
153         SetMemory(mem);
154     }
155 
156     virtual void Allocate() override;
157 };
158 
159 // A ConstTensorHandle that wraps an already allocated memory region.
160 //
161 // This allows users to pass in const memory to a network.
162 // Clients must make sure the passed in memory region stays alive for the lifetime of
163 // the PassthroughTensorHandle instance.
164 //
165 // Note there is no polymorphism to/from PassthroughTensorHandle.
166 class ConstPassthroughTensorHandle : public ConstTensorHandle
167 {
168 public:
ConstPassthroughTensorHandle(const TensorInfo & tensorInfo,const void * mem)169     ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem)
170     :   ConstTensorHandle(tensorInfo)
171     {
172         SetConstMemory(mem);
173     }
174 
175     virtual void Allocate() override;
176 };
177 
178 
179 // Template specializations.
180 
181 template <>
182 const void* ConstTensorHandle::GetConstTensor() const;
183 
184 template <>
185 void* TensorHandle::GetTensor() const;
186 
187 class ManagedConstTensorHandle
188 {
189 
190 public:
ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)191     explicit ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)
192         : m_Mapped(false)
193         , m_TensorHandle(std::move(ptr)) {};
194 
195     /// RAII Managed resource Unmaps MemoryArea once out of scope
Map(bool blocking=true)196     const void* Map(bool blocking = true)
197     {
198         if (m_TensorHandle)
199         {
200             auto pRet = m_TensorHandle->Map(blocking);
201             m_Mapped = true;
202             return pRet;
203         }
204         else
205         {
206             throw armnn::Exception("Attempting to Map null TensorHandle");
207         }
208 
209     }
210 
211     // Delete copy constructor as it's unnecessary
212     ManagedConstTensorHandle(const ConstTensorHandle& other) = delete;
213 
214     // Delete copy assignment as it's unnecessary
215     ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete;
216 
217     // Delete move assignment as it's unnecessary
218     ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete;
219 
~ManagedConstTensorHandle()220     ~ManagedConstTensorHandle()
221     {
222         // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
223         if (m_TensorHandle)
224         {
225             Unmap();
226         }
227     }
228 
Unmap()229     void Unmap()
230     {
231         // Only unmap if mapped and TensorHandle exists.
232         if (m_Mapped && m_TensorHandle)
233         {
234             m_TensorHandle->Unmap();
235             m_Mapped = false;
236         }
237     }
238 
GetTensorInfo() const239     const TensorInfo& GetTensorInfo() const
240     {
241         return m_TensorHandle->GetTensorInfo();
242     }
243 
IsMapped() const244     bool IsMapped() const
245     {
246         return m_Mapped;
247     }
248 
249 private:
250     bool m_Mapped;
251     std::shared_ptr<ConstTensorHandle> m_TensorHandle;
252 };
253 
254 } // namespace armnn
255