1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "LayersFwd.hpp"
8 #include "IGraphObservable.hpp"
9
10 #include <armnn/Types.hpp>
11 #include <armnn/TensorFwd.hpp>
12 #include <armnn/NetworkFwd.hpp>
13 #include <armnn/Exceptions.hpp>
14 #include <armnn/utility/Assert.hpp>
15 #include <armnn/utility/PolymorphicDowncast.hpp>
16 #include <armnn/utility/TransformIterator.hpp>
17
18 #include <list>
19 #include <map>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <vector>
23
24 namespace armnn
25 {
26
27 class SubgraphView;
28
29 class Graph
30 {
31 public:
32 template <typename LayerType>
PtrCast(Layer * const layer)33 static LayerType* PtrCast(Layer* const layer)
34 {
35 return PolymorphicDowncast<LayerType*>(layer);
36 }
37
38 template <typename Func>
ForEachLayer(Func func) const39 void ForEachLayer(Func func) const
40 {
41 for (auto it = m_Layers.begin(); it != m_Layers.end(); )
42 {
43 auto next = std::next(it);
44 func(*it);
45 it = next;
46 }
47 }
48
49 using LayerList = std::list<Layer*>;
50 using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
51 using IteratorDifference = Iterator::difference_type;
52
53 using ConstIterator = TransformIterator<decltype(&PtrCast<const Layer>), Iterator>;
54 using ConstIteratorInputs = TransformIterator<decltype(&PtrCast<const InputLayer>), Iterator>;
55 using ConstIteratorOutputs = TransformIterator<decltype(&PtrCast<const OutputLayer>), Iterator>;
56
57 /// Wrapper class returned by Graph::GetInputLayers()
58 struct InputLayersAccessor
59 {
InputLayersAccessorarmnn::Graph::InputLayersAccessor60 explicit InputLayersAccessor(const Graph& graph) : m_Graph(graph) {}
61
beginarmnn::Graph::InputLayersAccessor62 ConstIteratorInputs begin() const
63 {
64 return { m_Graph.m_Layers.begin(), &(PtrCast<const InputLayer>) };
65 }
66
endarmnn::Graph::InputLayersAccessor67 ConstIteratorInputs end() const
68 {
69 return { std::next(m_Graph.m_Layers.begin(), static_cast<IteratorDifference>(m_Graph.GetNumInputs())),
70 &(PtrCast<const InputLayer>) };
71 }
72
73 const Graph& m_Graph;
74 };
75
76 /// Wrapper class returned by Graph::GetOutputLayers()
77 struct OutputLayersAccessor
78 {
OutputLayersAccessorarmnn::Graph::OutputLayersAccessor79 explicit OutputLayersAccessor(const Graph& graph) : m_Graph(graph) {}
80
beginarmnn::Graph::OutputLayersAccessor81 ConstIteratorOutputs begin() const
82 {
83 return { std::prev(m_Graph.m_Layers.end(), static_cast<IteratorDifference>(m_Graph.GetNumOutputs())),
84 &(PtrCast<const OutputLayer>) };
85 }
86
endarmnn::Graph::OutputLayersAccessor87 ConstIteratorOutputs end() const
88 {
89 return { m_Graph.m_Layers.end(), &(PtrCast<const OutputLayer>) };
90 }
91
92 const Graph& m_Graph;
93 };
94
Graph(bool shapeInferenceMethod=false)95 Graph(bool shapeInferenceMethod = false)
96 : m_LayersInOrder(true)
97 , m_ShapeInferenceMethod(shapeInferenceMethod ? ShapeInferenceMethod::InferAndValidate :
98 ShapeInferenceMethod::ValidateOnly)
99 {}
100
101 Graph(const Graph& other);
102
103 Graph& operator=(const Graph& other) = delete;
104
Graph(Graph && other)105 Graph(Graph&& other)
106 {
107 *this = std::move(other);
108 }
109
operator =(Graph && other)110 Graph& operator=(Graph&& other)
111 {
112 m_InputIds = std::move(other.m_InputIds);
113 m_OutputIds = std::move(other.m_OutputIds);
114 m_LayersInOrder = std::move(other.m_LayersInOrder);
115 m_Views = std::move(other.m_Views);
116
117 other.ForEachLayer([this](Layer* otherLayer)
118 {
119 otherLayer->Reparent(*this, m_Layers.end());
120 });
121
122 ARMNN_ASSERT(other.m_PosInGraphMap.empty());
123 ARMNN_ASSERT(other.m_Layers.empty());
124
125 return *this;
126 }
127
~Graph()128 ~Graph()
129 {
130 ForEachLayer([](Layer* layer)
131 {
132 delete layer;
133 });
134 }
135
136 Status Print() const;
137
138 Status SerializeToDot(std::ostream& stream);
139
140 /// Adds a new layer, of type LayerType, to the graph constructed with the arguments passed.
141 template <typename LayerT, typename... Args>
142 LayerT* AddLayer(Args&&... args);
143
144 /// Inserts a new layer between the output slot currently connected to insertBefore
145 /// and insertBefore itself.
146 template <typename LayerT, typename... Args>
147 LayerT* InsertNewLayer(InputSlot& insertBefore, Args&&... args);
148
149 /// Inserts a new layer between insertAfter and the input slot(s) currently connected to it
150 template <typename LayerT, typename... Args>
151 LayerT* InsertNewLayer(OutputSlot& insertAfter, Args&&... args);
152
153 /// Deletes the layer at the specified position.
154 void EraseLayer(Iterator pos);
155
156 /// Deletes the layer. Sets @a layer to nullptr on return.
157 /// Templated to support pointers to any layer type.
158 template <typename LayerT>
159 void EraseLayer(LayerT*& layer);
160
161 /// Returns iterator pointing to the beginning of the list. Lowercase for range-based for loops.
begin()162 Iterator begin() { return m_Layers.begin(); }
163 /// Returns iterator pointing to the end of the list. Lowercase for range-based for loops.
end()164 Iterator end() { return m_Layers.end(); }
165
166 /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
begin() const167 ConstIterator begin() const { return {m_Layers.begin(), &(PtrCast<const Layer>)}; }
168 /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
end() const169 ConstIterator end() const { return {m_Layers.end(), &(PtrCast<const Layer>)}; }
170
171 /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
cbegin() const172 ConstIterator cbegin() const { return begin(); }
173 /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
cend() const174 ConstIterator cend() const { return end(); }
175
176 /// Sorts layers in topological order and return this.
TopologicalSort()177 Graph& TopologicalSort() { const_cast<const Graph*>(this)->TopologicalSort(); return *this; }
178 const Graph& TopologicalSort() const;
179
GetNumInputs() const180 size_t GetNumInputs() const { return m_InputIds.size(); }
GetNumOutputs() const181 size_t GetNumOutputs() const { return m_OutputIds.size(); }
182
183 /// Returns a wrapper object with begin(), end() methods to iterate over the input layers
184 /// in a range-based for loop.
GetInputLayers() const185 InputLayersAccessor GetInputLayers() const { return InputLayersAccessor(*this); }
186
187 /// Returns a wrapper object with begin(), end() methods to iterate over the output layers
188 /// in a range-based for loop.
GetOutputLayers() const189 OutputLayersAccessor GetOutputLayers() const { return OutputLayersAccessor(*this); }
190
GetNumLayers() const191 size_t GetNumLayers() const { return m_Layers.size(); }
192
193 /// Allocates memory for all tensors under output tensor handers of each layer.
194 Status AllocateDynamicBuffers();
195
196 /// Modifies the graph in-place, removing edges connecting layers using different compute devices,
197 /// and relinking them via an intermediary copy layers.
198 void AddCompatibilityLayers(std::map<BackendId, std::unique_ptr<class IBackendInternal>>& backends,
199 TensorHandleFactoryRegistry& registry);
200
201 /// Substitutes the given sub-graph with either a new layer or a new sub-graph.
202 /// In either case, the given layer or all the layers in the given sub-graph must belong to this graph.
203 void SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substituteLayer);
204 void SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph);
205
206 void InferTensorInfos();
207
AttachObservable(IGraphObservable * const observable,GraphEvent notifyOnEvent)208 void AttachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
209 m_Views[notifyOnEvent].emplace_back(observable);
210 }
211
DetachObservable(IGraphObservable * const observable,GraphEvent notifyOnEvent)212 void DetachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
213 m_Views[notifyOnEvent].remove(observable);
214 }
215
216 /// Gets the position of a layer in the graph.
217 Iterator GetPosInGraph(Layer& layer);
218
219 private:
220 template <typename LayerT>
221 class LayerInGraphBase;
222
223 template <typename LayerT>
224 class LayerInGraph;
225
ForwardToEndOfInputs(Iterator it) const226 Iterator ForwardToEndOfInputs(Iterator it) const
227 {
228 while ((it != m_Layers.end()) && ((*it)->GetType() == LayerType::Input))
229 {
230 ++it;
231 }
232 return it;
233 }
234
RewindToBeginOfOutputs(Iterator it) const235 Iterator RewindToBeginOfOutputs(Iterator it) const
236 {
237 while ((it != m_Layers.begin()) && ((*std::prev(it))->GetType() == LayerType::Output))
238 {
239 --it;
240 }
241 return it;
242 }
243
NotifyObservables(GraphEvent event,Layer * graphState)244 void NotifyObservables(GraphEvent event, Layer* graphState)
245 {
246 // Iterate over all observables observing this event
247 for (auto& observable : m_Views[event])
248 {
249 observable->Update(graphState);
250 }
251 }
252
253 std::unordered_set<LayerBindingId> m_InputIds;
254 std::unordered_set<LayerBindingId> m_OutputIds;
255 std::unordered_map<const Layer*, Iterator> m_PosInGraphMap;
256
257 void ReplaceSubgraphConnections(const SubgraphView& subgraph, IConnectableLayer* substituteLayer);
258 void ReplaceSubgraphConnections(const SubgraphView& subgraph, const SubgraphView& substituteSubgraph);
259 void EraseSubgraphLayers(SubgraphView &subgraph);
260
261 /// Mutable to allow sorting on const object.
262 mutable LayerList m_Layers;
263 mutable bool m_LayersInOrder;
264
265 std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;
266 ShapeInferenceMethod m_ShapeInferenceMethod;
267 };
268
269 /// Common base class for layers in the graph.
270 template <typename LayerT>
271 class Graph::LayerInGraphBase : public LayerT
272 {
273 protected:
274 template <typename... Args>
LayerInGraphBase(Graph & graph,Iterator insertBefore,Args &&...args)275 LayerInGraphBase(Graph& graph, Iterator insertBefore, Args&&... args)
276 : LayerT(std::forward<Args>(args)...), m_Graph(&graph)
277 {
278 Insert(*m_Graph, insertBefore);
279 }
~LayerInGraphBase()280 ~LayerInGraphBase()
281 {
282 Remove(*m_Graph);
283 }
284
Reparent(Graph & destGraph,Iterator insertBefore)285 void Reparent(Graph& destGraph, Iterator insertBefore) override
286 {
287 Insert(destGraph, insertBefore);
288 Remove(*m_Graph);
289
290 m_Graph = &destGraph;
291 }
292
293 private:
Insert(Graph & graph,Iterator insertBefore)294 void Insert(Graph& graph, Iterator insertBefore)
295 {
296 graph.m_PosInGraphMap.emplace(this, graph.m_Layers.emplace(insertBefore, this));
297 }
298
Remove(Graph & graph)299 void Remove(Graph& graph)
300 {
301 auto layerIt = graph.GetPosInGraph(*this);
302 graph.m_Layers.erase(layerIt);
303
304 const size_t numErased = graph.m_PosInGraphMap.erase(this);
305 IgnoreUnused(numErased);
306 ARMNN_ASSERT(numErased == 1);
307 }
308
309 protected:
310 Graph* m_Graph;
311 };
312
313 /// Input/Output layers specialize this template.
314 template <typename LayerT>
315 class Graph::LayerInGraph final : public LayerInGraphBase<LayerT>
316 {
317 public:
318 template <typename... Args>
LayerInGraph(Graph & graph,Args &&...args)319 LayerInGraph(Graph& graph, Args&&... args)
320 : LayerInGraphBase<LayerT>(graph,
321 // Insert at the back of the intermediate layers (before outputs).
322 std::prev(graph.end(), IteratorDifference(graph.GetNumOutputs())),
323 std::forward<Args>(args)...)
324 {
325 }
326 template <typename... Args>
LayerInGraph(Graph & graph,Iterator insertBefore,Args &&...args)327 LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args)
328 : LayerInGraphBase<LayerT>(graph,
329 // Make sure it's inserted after all inputs and before all outputs.
330 graph.ForwardToEndOfInputs(graph.RewindToBeginOfOutputs(insertBefore)),
331 std::forward<Args>(args)...)
332 {
333 }
334 };
335
336 /// Inputs add/remove their binding id to m_InputIds in the graph.
337 template <>
338 class Graph::LayerInGraph<InputLayer> final : public LayerInGraphBase<InputLayer>
339 {
340 public:
341 template <typename... Args>
LayerInGraph(Graph & graph,Args &&...args)342 LayerInGraph(Graph& graph, Args&&... args)
343 : LayerInGraphBase<InputLayer>(graph,
344 // Always add to the back of the inputs.
345 std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())),
346 std::forward<Args>(args)...)
347 {
348 const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second;
349 if (!isNewId)
350 {
351 throw InvalidArgumentException("A layer already exists with the specified id");
352 }
353 }
354 template <typename... Args>
LayerInGraph(Graph & graph,Iterator,Args &&...args)355 LayerInGraph(Graph& graph, Iterator, Args&&... args)
356 // Ignore Iterator argument. Always add to the back of the inputs.
357 : LayerInGraph(graph, std::forward<Args>(args)...)
358 {
359 }
~LayerInGraph()360 ~LayerInGraph() override
361 {
362 const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId());
363 IgnoreUnused(numErased);
364 ARMNN_ASSERT(numErased == 1);
365 }
366 };
367
368 /// Outputs add/remove their binding id to m_OutputIds in the graph.
369 template <>
370 class Graph::LayerInGraph<OutputLayer> final : public LayerInGraphBase<OutputLayer>
371 {
372 public:
373 template <typename... Args>
LayerInGraph(Graph & graph,Args &&...args)374 LayerInGraph(Graph& graph, Args&&... args)
375 : LayerInGraphBase<OutputLayer>(graph,
376 // Always add to the back of the outputs.
377 graph.end(),
378 std::forward<Args>(args)...)
379 {
380 const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second;
381 if (!isNewId)
382 {
383 throw InvalidArgumentException("A layer already exists with the specified id");
384 }
385 }
~LayerInGraph()386 ~LayerInGraph() override
387 {
388 const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId());
389 IgnoreUnused(numErased);
390 ARMNN_ASSERT(numErased == 1);
391 }
392 };
393
GetPosInGraph(Layer & layer)394 inline Graph::Iterator Graph::GetPosInGraph(Layer& layer)
395 {
396 auto it = m_PosInGraphMap.find(&layer);
397 ARMNN_ASSERT(it != m_PosInGraphMap.end());
398 return it->second;
399 }
400
401 template <typename LayerT, typename... Args>
AddLayer(Args &&...args)402 inline LayerT* Graph::AddLayer(Args&&... args)
403 {
404 m_LayersInOrder = m_LayersInOrder &&
405 ((LayerEnumOf<LayerT>() == LayerType::Input) || (LayerEnumOf<LayerT>() == LayerType::Output));
406 LayerT* const layer = new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...);
407
408 layer->SetShapeInferenceMethod(m_ShapeInferenceMethod);
409
410 NotifyObservables(GraphEvent::LayerAdded, layer);
411
412 return layer;
413 }
414
415 template <typename LayerT, typename... Args>
InsertNewLayer(InputSlot & insertBefore,Args &&...args)416 inline LayerT* Graph::InsertNewLayer(InputSlot& insertBefore, Args&&... args)
417 {
418 // Insert after the parent if any, or before the child otherwise, so the topological order is kept.
419 OutputSlot* parentOut = insertBefore.GetConnectedOutputSlot();
420 const Iterator pos = (parentOut != nullptr)
421 ? std::next(GetPosInGraph(parentOut->GetOwningLayer()))
422 : GetPosInGraph(insertBefore.GetOwningLayer());
423 LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...);
424 insertBefore.Insert(*layer);
425
426 NotifyObservables(GraphEvent::LayerAdded, layer);
427
428 return layer;
429 }
430
431 template <typename LayerT, typename... Args>
InsertNewLayer(OutputSlot & insertAfter,Args &&...args)432 inline LayerT* Graph::InsertNewLayer(OutputSlot& insertAfter, Args&&... args)
433 {
434 Layer& owningLayer = insertAfter.GetOwningLayer();
435
436 const Iterator pos = std::next(GetPosInGraph(owningLayer));
437 LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...);
438
439 ARMNN_ASSERT(layer->GetNumInputSlots() == 1);
440
441 insertAfter.MoveAllConnections(layer->GetOutputSlot());
442 insertAfter.Connect(layer->GetInputSlot(0));
443
444 NotifyObservables(GraphEvent::LayerAdded, layer);
445
446 return layer;
447 }
448
EraseLayer(Iterator pos)449 inline void Graph::EraseLayer(Iterator pos)
450 {
451 NotifyObservables(GraphEvent::LayerErased, *pos);
452
453 delete *pos;
454 }
455
456 template <typename LayerT>
EraseLayer(LayerT * & layer)457 inline void Graph::EraseLayer(LayerT*& layer)
458 {
459 ARMNN_ASSERT(layer != nullptr);
460 EraseLayer(GetPosInGraph(*layer));
461 layer = nullptr;
462 }
463
464 } // namespace armnn
465