• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <armnn/Exceptions.hpp>
6 #include <armnn/utility/IgnoreUnused.hpp>
7 
8 #include <armnn/backends/TensorHandle.hpp>
9 
10 #include <cstring>
11 
12 namespace armnn
13 {
14 
GetUnpaddedTensorStrides(const TensorInfo & tensorInfo)15 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo)
16 {
17     TensorShape shape(tensorInfo.GetShape());
18     auto size = GetDataTypeSize(tensorInfo.GetDataType());
19     auto runningSize = size;
20     std::vector<unsigned int> strides(shape.GetNumDimensions());
21     auto lastIdx = shape.GetNumDimensions()-1;
22     for (unsigned int i=0; i < lastIdx ; i++)
23     {
24         strides[lastIdx-i] = runningSize;
25         runningSize *= shape[lastIdx-i];
26     }
27     strides[0] = runningSize;
28     return TensorShape(shape.GetNumDimensions(), strides.data());
29 }
30 
ConstTensorHandle(const TensorInfo & tensorInfo)31 ConstTensorHandle::ConstTensorHandle(const TensorInfo& tensorInfo)
32 : m_TensorInfo(tensorInfo)
33 , m_Memory(nullptr)
34 {
35 }
36 
37 template <>
GetConstTensor() const38 const void* ConstTensorHandle::GetConstTensor<void>() const
39 {
40     return m_Memory;
41 }
42 
TensorHandle(const TensorInfo & tensorInfo)43 TensorHandle::TensorHandle(const TensorInfo& tensorInfo)
44 : ConstTensorHandle(tensorInfo)
45 , m_MutableMemory(nullptr)
46 {
47 }
48 
49 template <>
GetTensor() const50 void* TensorHandle::GetTensor<void>() const
51 {
52     return m_MutableMemory;
53 }
54 
ScopedTensorHandle(const TensorInfo & tensorInfo)55 ScopedTensorHandle::ScopedTensorHandle(const TensorInfo& tensorInfo)
56 : TensorHandle(tensorInfo)
57 {
58 }
59 
ScopedTensorHandle(const ConstTensor & tensor)60 ScopedTensorHandle::ScopedTensorHandle(const ConstTensor& tensor)
61 : ScopedTensorHandle(tensor.GetInfo())
62 {
63     CopyFrom(tensor.GetMemoryArea(), tensor.GetNumBytes());
64 }
65 
ScopedTensorHandle(const ConstTensorHandle & tensorHandle)66 ScopedTensorHandle::ScopedTensorHandle(const ConstTensorHandle& tensorHandle)
67 : ScopedTensorHandle(tensorHandle.GetTensorInfo())
68 {
69     CopyFrom(tensorHandle.GetConstTensor<void>(), tensorHandle.GetTensorInfo().GetNumBytes());
70 }
71 
ScopedTensorHandle(const ScopedTensorHandle & other)72 ScopedTensorHandle::ScopedTensorHandle(const ScopedTensorHandle& other)
73 : TensorHandle(other.GetTensorInfo())
74 {
75     CopyFrom(other);
76 }
77 
operator =(const ScopedTensorHandle & other)78 ScopedTensorHandle& ScopedTensorHandle::operator=(const ScopedTensorHandle& other)
79 {
80     ::operator delete(GetTensor<void>());
81     SetMemory(nullptr);
82     CopyFrom(other);
83     return *this;
84 }
85 
~ScopedTensorHandle()86 ScopedTensorHandle::~ScopedTensorHandle()
87 {
88     ::operator delete(GetTensor<void>());
89 }
90 
Allocate()91 void ScopedTensorHandle::Allocate()
92 {
93     if (GetTensor<void>() == nullptr)
94     {
95         SetMemory(::operator new(GetTensorInfo().GetNumBytes()));
96     }
97     else
98     {
99         throw InvalidArgumentException("TensorHandle::Allocate Trying to allocate a TensorHandle"
100             "that already has allocated memory.");
101     }
102 }
103 
CopyOutTo(void * memory) const104 void ScopedTensorHandle::CopyOutTo(void* memory) const
105 {
106     memcpy(memory, GetTensor<void>(), GetTensorInfo().GetNumBytes());
107 }
108 
CopyInFrom(const void * memory)109 void ScopedTensorHandle::CopyInFrom(const void* memory)
110 {
111     memcpy(GetTensor<void>(), memory, GetTensorInfo().GetNumBytes());
112 }
113 
CopyFrom(const ScopedTensorHandle & other)114 void ScopedTensorHandle::CopyFrom(const ScopedTensorHandle& other)
115 {
116     CopyFrom(other.GetTensor<void>(), other.GetTensorInfo().GetNumBytes());
117 }
118 
CopyFrom(const void * srcMemory,unsigned int numBytes)119 void ScopedTensorHandle::CopyFrom(const void* srcMemory, unsigned int numBytes)
120 {
121     ARMNN_ASSERT(GetTensor<void>() == nullptr);
122     ARMNN_ASSERT(GetTensorInfo().GetNumBytes() == numBytes);
123 
124     if (srcMemory)
125     {
126         Allocate();
127         memcpy(GetTensor<void>(), srcMemory, numBytes);
128     }
129 }
130 
Allocate()131 void PassthroughTensorHandle::Allocate()
132 {
133     throw InvalidArgumentException("PassthroughTensorHandle::Allocate() should never be called");
134 }
135 
Allocate()136 void ConstPassthroughTensorHandle::Allocate()
137 {
138     throw InvalidArgumentException("ConstPassthroughTensorHandle::Allocate() should never be called");
139 }
140 
141 } // namespace armnn
142