• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/backends/CpuTensorHandleFwd.hpp>
9 #include <armnn/backends/ITensorHandle.hpp>
10 
11 #include <armnn/TypesUtils.hpp>
12 
13 #include <CompatibleTypes.hpp>
14 
15 #include <algorithm>
16 
17 #include <armnn/utility/Assert.hpp>
18 
19 namespace armnn
20 {
21 
22 // Get a TensorShape representing the strides (in bytes) for each dimension
23 // of a tensor, assuming fully packed data with no padding
24 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
25 
26 // Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
27 class ConstCpuTensorHandle : public ITensorHandle
28 {
29 public:
30     template <typename T>
GetConstTensor() const31     const T* GetConstTensor() const
32     {
33         ARMNN_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
34         return reinterpret_cast<const T*>(m_Memory);
35     }
36 
GetTensorInfo() const37     const TensorInfo& GetTensorInfo() const
38     {
39         return m_TensorInfo;
40     }
41 
Manage()42     virtual void Manage() override {}
43 
GetParent() const44     virtual ITensorHandle* GetParent() const override { return nullptr; }
45 
Map(bool) const46     virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
Unmap() const47     virtual void Unmap() const override {}
48 
GetStrides() const49     TensorShape GetStrides() const override
50     {
51         return GetUnpaddedTensorStrides(m_TensorInfo);
52     }
GetShape() const53     TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
54 
55 protected:
56     ConstCpuTensorHandle(const TensorInfo& tensorInfo);
57 
SetConstMemory(const void * mem)58     void SetConstMemory(const void* mem) { m_Memory = mem; }
59 
60 private:
61     // Only used for testing
CopyOutTo(void *) const62     void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
CopyInFrom(const void *)63     void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
64 
65     ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
66     ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
67 
68     TensorInfo m_TensorInfo;
69     const void* m_Memory;
70 };
71 
72 template<>
73 const void* ConstCpuTensorHandle::GetConstTensor<void>() const;
74 
75 // Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
76 class CpuTensorHandle : public ConstCpuTensorHandle
77 {
78 public:
79     template <typename T>
GetTensor() const80     T* GetTensor() const
81     {
82         ARMNN_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
83         return reinterpret_cast<T*>(m_MutableMemory);
84     }
85 
86 protected:
87     CpuTensorHandle(const TensorInfo& tensorInfo);
88 
SetMemory(void * mem)89     void SetMemory(void* mem)
90     {
91         m_MutableMemory = mem;
92         SetConstMemory(m_MutableMemory);
93     }
94 
95 private:
96 
97     CpuTensorHandle(const CpuTensorHandle& other) = delete;
98     CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
99     void* m_MutableMemory;
100 };
101 
102 template <>
103 void* CpuTensorHandle::GetTensor<void>() const;
104 
105 // A CpuTensorHandle that owns the wrapped memory region.
106 class ScopedCpuTensorHandle : public CpuTensorHandle
107 {
108 public:
109     explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
110 
111     // Copies contents from Tensor.
112     explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
113 
114     // Copies contents from ConstCpuTensorHandle
115     explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
116 
117     ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
118     ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
119     ~ScopedCpuTensorHandle();
120 
121     virtual void Allocate() override;
122 
123 private:
124     // Only used for testing
125     void CopyOutTo(void* memory) const override;
126     void CopyInFrom(const void* memory) override;
127 
128     void CopyFrom(const ScopedCpuTensorHandle& other);
129     void CopyFrom(const void* srcMemory, unsigned int numBytes);
130 };
131 
132 // A CpuTensorHandle that wraps an already allocated memory region.
133 //
134 // Clients must make sure the passed in memory region stays alive for the lifetime of
135 // the PassthroughCpuTensorHandle instance.
136 //
137 // Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
138 class PassthroughCpuTensorHandle : public CpuTensorHandle
139 {
140 public:
PassthroughCpuTensorHandle(const TensorInfo & tensorInfo,void * mem)141     PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
142     :   CpuTensorHandle(tensorInfo)
143     {
144         SetMemory(mem);
145     }
146 
147     virtual void Allocate() override;
148 };
149 
150 // A ConstCpuTensorHandle that wraps an already allocated memory region.
151 //
152 // This allows users to pass in const memory to a network.
153 // Clients must make sure the passed in memory region stays alive for the lifetime of
154 // the PassthroughCpuTensorHandle instance.
155 //
156 // Note there is no polymorphism to/from PassthroughCpuTensorHandle.
157 class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
158 {
159 public:
ConstPassthroughCpuTensorHandle(const TensorInfo & tensorInfo,const void * mem)160     ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
161     :   ConstCpuTensorHandle(tensorInfo)
162     {
163         SetConstMemory(mem);
164     }
165 
166     virtual void Allocate() override;
167 };
168 
169 
170 // Template specializations.
171 
172 template <>
173 const void* ConstCpuTensorHandle::GetConstTensor() const;
174 
175 template <>
176 void* CpuTensorHandle::GetTensor() const;
177 
178 } // namespace armnn
179