1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "SubgraphView.hpp"
7 #include "Graph.hpp"
8
9 #include <armnn/utility/IgnoreUnused.hpp>
10 #include <armnn/utility/NumericCast.hpp>
11 #include <armnn/utility/PolymorphicDowncast.hpp>
12
13 #include <utility>
14
15 namespace armnn
16 {
17
18 namespace
19 {
20
21 template <class C>
AssertIfNullsOrDuplicates(const C & container,const std::string & errorMessage)22 void AssertIfNullsOrDuplicates(const C& container, const std::string& errorMessage)
23 {
24 using T = typename C::value_type;
25 std::unordered_set<T> duplicateSet;
26 std::for_each(container.begin(), container.end(), [&duplicateSet, &errorMessage](const T& i)
27 {
28 // Ignore unused for release builds
29 IgnoreUnused(errorMessage);
30
31 // Check if the item is valid
32 ARMNN_ASSERT_MSG(i, errorMessage.c_str());
33
34 // Check if a duplicate has been found
35 ARMNN_ASSERT_MSG(duplicateSet.find(i) == duplicateSet.end(), errorMessage.c_str());
36
37 duplicateSet.insert(i);
38 });
39 }
40
41 } // anonymous namespace
42
SubgraphView(Graph & graph)43 SubgraphView::SubgraphView(Graph& graph)
44 : m_InputSlots{}
45 , m_OutputSlots{}
46 , m_Layers(graph.begin(), graph.end())
47 {
48 CheckSubgraph();
49 }
50
SubgraphView(InputSlots && inputs,OutputSlots && outputs,Layers && layers)51 SubgraphView::SubgraphView(InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers)
52 : m_InputSlots{inputs}
53 , m_OutputSlots{outputs}
54 , m_Layers{layers}
55 {
56 CheckSubgraph();
57 }
58
SubgraphView(const SubgraphView & subgraph)59 SubgraphView::SubgraphView(const SubgraphView& subgraph)
60 : m_InputSlots(subgraph.m_InputSlots.begin(), subgraph.m_InputSlots.end())
61 , m_OutputSlots(subgraph.m_OutputSlots.begin(), subgraph.m_OutputSlots.end())
62 , m_Layers(subgraph.m_Layers.begin(), subgraph.m_Layers.end())
63 {
64 CheckSubgraph();
65 }
66
SubgraphView(SubgraphView && subgraph)67 SubgraphView::SubgraphView(SubgraphView&& subgraph)
68 : m_InputSlots(std::move(subgraph.m_InputSlots))
69 , m_OutputSlots(std::move(subgraph.m_OutputSlots))
70 , m_Layers(std::move(subgraph.m_Layers))
71 {
72 CheckSubgraph();
73 }
74
SubgraphView(IConnectableLayer * layer)75 SubgraphView::SubgraphView(IConnectableLayer* layer)
76 : m_InputSlots{}
77 , m_OutputSlots{}
78 , m_Layers{PolymorphicDowncast<Layer*>(layer)}
79 {
80 unsigned int numInputSlots = layer->GetNumInputSlots();
81 m_InputSlots.resize(numInputSlots);
82 for (unsigned int i = 0; i < numInputSlots; i++)
83 {
84 m_InputSlots.at(i) = PolymorphicDowncast<InputSlot*>(&(layer->GetInputSlot(i)));
85 }
86
87 unsigned int numOutputSlots = layer->GetNumOutputSlots();
88 m_OutputSlots.resize(numOutputSlots);
89 for (unsigned int i = 0; i < numOutputSlots; i++)
90 {
91 m_OutputSlots.at(i) = PolymorphicDowncast<OutputSlot*>(&(layer->GetOutputSlot(i)));
92 }
93
94 CheckSubgraph();
95 }
96
operator =(SubgraphView && other)97 SubgraphView& SubgraphView::operator=(SubgraphView&& other)
98 {
99 m_InputSlots = std::move(other.m_InputSlots);
100 m_OutputSlots = std::move(other.m_OutputSlots);
101 m_Layers = std::move(other.m_Layers);
102
103 CheckSubgraph();
104
105 return *this;
106 }
107
CheckSubgraph()108 void SubgraphView::CheckSubgraph()
109 {
110 // Check for invalid or duplicate input slots
111 AssertIfNullsOrDuplicates(m_InputSlots, "Sub-graphs cannot contain null or duplicate input slots");
112
113 // Check for invalid or duplicate output slots
114 AssertIfNullsOrDuplicates(m_OutputSlots, "Sub-graphs cannot contain null or duplicate output slots");
115
116 // Check for invalid or duplicate layers
117 AssertIfNullsOrDuplicates(m_Layers, "Sub-graphs cannot contain null or duplicate layers");
118 }
119
GetInputSlots() const120 const SubgraphView::InputSlots& SubgraphView::GetInputSlots() const
121 {
122 return m_InputSlots;
123 }
124
GetOutputSlots() const125 const SubgraphView::OutputSlots& SubgraphView::GetOutputSlots() const
126 {
127 return m_OutputSlots;
128 }
129
GetInputSlot(unsigned int index) const130 const InputSlot* SubgraphView::GetInputSlot(unsigned int index) const
131 {
132 return m_InputSlots.at(index);
133 }
134
GetInputSlot(unsigned int index)135 InputSlot* SubgraphView::GetInputSlot(unsigned int index)
136 {
137 return m_InputSlots.at(index);
138 }
139
GetOutputSlot(unsigned int index) const140 const OutputSlot* SubgraphView::GetOutputSlot(unsigned int index) const
141 {
142 return m_OutputSlots.at(index);
143 }
144
GetOutputSlot(unsigned int index)145 OutputSlot* SubgraphView::GetOutputSlot(unsigned int index)
146 {
147 return m_OutputSlots.at(index);
148 }
149
GetNumInputSlots() const150 unsigned int SubgraphView::GetNumInputSlots() const
151 {
152 return armnn::numeric_cast<unsigned int>(m_InputSlots.size());
153 }
154
GetNumOutputSlots() const155 unsigned int SubgraphView::GetNumOutputSlots() const
156 {
157 return armnn::numeric_cast<unsigned int>(m_OutputSlots.size());
158 }
159
GetLayers() const160 const SubgraphView::Layers& SubgraphView::GetLayers() const
161 {
162 return m_Layers;
163 }
164
begin()165 SubgraphView::Iterator SubgraphView::begin()
166 {
167 return m_Layers.begin();
168 }
169
end()170 SubgraphView::Iterator SubgraphView::end()
171 {
172 return m_Layers.end();
173 }
174
begin() const175 SubgraphView::ConstIterator SubgraphView::begin() const
176 {
177 return m_Layers.begin();
178 }
179
end() const180 SubgraphView::ConstIterator SubgraphView::end() const
181 {
182 return m_Layers.end();
183 }
184
cbegin() const185 SubgraphView::ConstIterator SubgraphView::cbegin() const
186 {
187 return begin();
188 }
189
cend() const190 SubgraphView::ConstIterator SubgraphView::cend() const
191 {
192 return end();
193 }
194
Clear()195 void SubgraphView::Clear()
196 {
197 m_InputSlots.clear();
198 m_OutputSlots.clear();
199 m_Layers.clear();
200 }
201
202 } // namespace armnn
203