• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "RefTensorHandle.hpp"
6 
7 namespace armnn
8 {
9 
RefTensorHandle(const TensorInfo & tensorInfo,std::shared_ptr<RefMemoryManager> & memoryManager)10 RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager):
11     m_TensorInfo(tensorInfo),
12     m_MemoryManager(memoryManager),
13     m_Pool(nullptr),
14     m_UnmanagedMemory(nullptr),
15     m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
16     m_Imported(false),
17     m_IsImportEnabled(false)
18 {
19 
20 }
21 
RefTensorHandle(const TensorInfo & tensorInfo,MemorySourceFlags importFlags)22 RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo,
23                                  MemorySourceFlags importFlags)
24                                  : m_TensorInfo(tensorInfo),
25                                    m_Pool(nullptr),
26                                    m_UnmanagedMemory(nullptr),
27                                    m_ImportFlags(importFlags),
28                                    m_Imported(false),
29                                    m_IsImportEnabled(true)
30 {
31 
32 }
33 
~RefTensorHandle()34 RefTensorHandle::~RefTensorHandle()
35 {
36     if (!m_Pool)
37     {
38         // unmanaged
39         if (!m_Imported)
40         {
41             ::operator delete(m_UnmanagedMemory);
42         }
43     }
44 }
45 
Manage()46 void RefTensorHandle::Manage()
47 {
48     if (!m_IsImportEnabled)
49     {
50         ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice");
51         ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()");
52 
53         m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
54     }
55 }
56 
Allocate()57 void RefTensorHandle::Allocate()
58 {
59     // If import is enabled, do not allocate the tensor
60     if (!m_IsImportEnabled)
61     {
62 
63         if (!m_UnmanagedMemory)
64         {
65             if (!m_Pool)
66             {
67                 // unmanaged
68                 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
69             }
70             else
71             {
72                 m_MemoryManager->Allocate(m_Pool);
73             }
74         }
75         else
76         {
77             throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle"
78                                            "that already has allocated memory.");
79         }
80     }
81 }
82 
Map(bool) const83 const void* RefTensorHandle::Map(bool /*unused*/) const
84 {
85     return GetPointer();
86 }
87 
GetPointer() const88 void* RefTensorHandle::GetPointer() const
89 {
90     if (m_UnmanagedMemory)
91     {
92         return m_UnmanagedMemory;
93     }
94     else if (m_Pool)
95     {
96         return m_MemoryManager->GetPointer(m_Pool);
97     }
98     else
99     {
100         throw NullPointerException("RefTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
101     }
102 }
103 
CopyOutTo(void * dest) const104 void RefTensorHandle::CopyOutTo(void* dest) const
105 {
106     const void *src = GetPointer();
107     ARMNN_ASSERT(src);
108     memcpy(dest, src, m_TensorInfo.GetNumBytes());
109 }
110 
CopyInFrom(const void * src)111 void RefTensorHandle::CopyInFrom(const void* src)
112 {
113     void *dest = GetPointer();
114     ARMNN_ASSERT(dest);
115     memcpy(dest, src, m_TensorInfo.GetNumBytes());
116 }
117 
Import(void * memory,MemorySource source)118 bool RefTensorHandle::Import(void* memory, MemorySource source)
119 {
120     if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
121     {
122         if (m_IsImportEnabled && source == MemorySource::Malloc)
123         {
124             // Check memory alignment
125             constexpr uintptr_t alignment = sizeof(size_t);
126             if (reinterpret_cast<uintptr_t>(memory) % alignment)
127             {
128                 if (m_Imported)
129                 {
130                     m_Imported = false;
131                     m_UnmanagedMemory = nullptr;
132                 }
133 
134                 return false;
135             }
136 
137             // m_UnmanagedMemory not yet allocated.
138             if (!m_Imported && !m_UnmanagedMemory)
139             {
140                 m_UnmanagedMemory = memory;
141                 m_Imported = true;
142                 return true;
143             }
144 
145             // m_UnmanagedMemory initially allocated with Allocate().
146             if (!m_Imported && m_UnmanagedMemory)
147             {
148                 return false;
149             }
150 
151             // m_UnmanagedMemory previously imported.
152             if (m_Imported)
153             {
154                 m_UnmanagedMemory = memory;
155                 return true;
156             }
157         }
158     }
159 
160     return false;
161 }
162 
163 }
164