1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <armnn/Exceptions.hpp>
6 #include <armnn/utility/IgnoreUnused.hpp>
7
8 #include <armnn/backends/TensorHandle.hpp>
9
10 #include <cstring>
11
12 namespace armnn
13 {
14
GetUnpaddedTensorStrides(const TensorInfo & tensorInfo)15 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo)
16 {
17 TensorShape shape(tensorInfo.GetShape());
18 auto size = GetDataTypeSize(tensorInfo.GetDataType());
19 auto runningSize = size;
20 std::vector<unsigned int> strides(shape.GetNumDimensions());
21 auto lastIdx = shape.GetNumDimensions()-1;
22 for (unsigned int i=0; i < lastIdx ; i++)
23 {
24 strides[lastIdx-i] = runningSize;
25 runningSize *= shape[lastIdx-i];
26 }
27 strides[0] = runningSize;
28 return TensorShape(shape.GetNumDimensions(), strides.data());
29 }
30
ConstTensorHandle(const TensorInfo & tensorInfo)31 ConstTensorHandle::ConstTensorHandle(const TensorInfo& tensorInfo)
32 : m_TensorInfo(tensorInfo)
33 , m_Memory(nullptr)
34 {
35 }
36
37 template <>
GetConstTensor() const38 const void* ConstTensorHandle::GetConstTensor<void>() const
39 {
40 return m_Memory;
41 }
42
TensorHandle(const TensorInfo & tensorInfo)43 TensorHandle::TensorHandle(const TensorInfo& tensorInfo)
44 : ConstTensorHandle(tensorInfo)
45 , m_MutableMemory(nullptr)
46 {
47 }
48
49 template <>
GetTensor() const50 void* TensorHandle::GetTensor<void>() const
51 {
52 return m_MutableMemory;
53 }
54
ScopedTensorHandle(const TensorInfo & tensorInfo)55 ScopedTensorHandle::ScopedTensorHandle(const TensorInfo& tensorInfo)
56 : TensorHandle(tensorInfo)
57 {
58 }
59
ScopedTensorHandle(const ConstTensor & tensor)60 ScopedTensorHandle::ScopedTensorHandle(const ConstTensor& tensor)
61 : ScopedTensorHandle(tensor.GetInfo())
62 {
63 CopyFrom(tensor.GetMemoryArea(), tensor.GetNumBytes());
64 }
65
ScopedTensorHandle(const ConstTensorHandle & tensorHandle)66 ScopedTensorHandle::ScopedTensorHandle(const ConstTensorHandle& tensorHandle)
67 : ScopedTensorHandle(tensorHandle.GetTensorInfo())
68 {
69 CopyFrom(tensorHandle.GetConstTensor<void>(), tensorHandle.GetTensorInfo().GetNumBytes());
70 }
71
ScopedTensorHandle(const ScopedTensorHandle & other)72 ScopedTensorHandle::ScopedTensorHandle(const ScopedTensorHandle& other)
73 : TensorHandle(other.GetTensorInfo())
74 {
75 CopyFrom(other);
76 }
77
operator =(const ScopedTensorHandle & other)78 ScopedTensorHandle& ScopedTensorHandle::operator=(const ScopedTensorHandle& other)
79 {
80 ::operator delete(GetTensor<void>());
81 SetMemory(nullptr);
82 CopyFrom(other);
83 return *this;
84 }
85
~ScopedTensorHandle()86 ScopedTensorHandle::~ScopedTensorHandle()
87 {
88 ::operator delete(GetTensor<void>());
89 }
90
Allocate()91 void ScopedTensorHandle::Allocate()
92 {
93 if (GetTensor<void>() == nullptr)
94 {
95 SetMemory(::operator new(GetTensorInfo().GetNumBytes()));
96 }
97 else
98 {
99 throw InvalidArgumentException("TensorHandle::Allocate Trying to allocate a TensorHandle"
100 "that already has allocated memory.");
101 }
102 }
103
CopyOutTo(void * memory) const104 void ScopedTensorHandle::CopyOutTo(void* memory) const
105 {
106 memcpy(memory, GetTensor<void>(), GetTensorInfo().GetNumBytes());
107 }
108
CopyInFrom(const void * memory)109 void ScopedTensorHandle::CopyInFrom(const void* memory)
110 {
111 memcpy(GetTensor<void>(), memory, GetTensorInfo().GetNumBytes());
112 }
113
CopyFrom(const ScopedTensorHandle & other)114 void ScopedTensorHandle::CopyFrom(const ScopedTensorHandle& other)
115 {
116 CopyFrom(other.GetTensor<void>(), other.GetTensorInfo().GetNumBytes());
117 }
118
CopyFrom(const void * srcMemory,unsigned int numBytes)119 void ScopedTensorHandle::CopyFrom(const void* srcMemory, unsigned int numBytes)
120 {
121 ARMNN_ASSERT(GetTensor<void>() == nullptr);
122 ARMNN_ASSERT(GetTensorInfo().GetNumBytes() == numBytes);
123
124 if (srcMemory)
125 {
126 Allocate();
127 memcpy(GetTensor<void>(), srcMemory, numBytes);
128 }
129 }
130
Allocate()131 void PassthroughTensorHandle::Allocate()
132 {
133 throw InvalidArgumentException("PassthroughTensorHandle::Allocate() should never be called");
134 }
135
Allocate()136 void ConstPassthroughTensorHandle::Allocate()
137 {
138 throw InvalidArgumentException("ConstPassthroughTensorHandle::Allocate() should never be called");
139 }
140
141 } // namespace armnn
142