• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "LayersFwd.hpp"
8 #include "IGraphObservable.hpp"
9 
10 #include <armnn/Types.hpp>
11 #include <armnn/TensorFwd.hpp>
12 #include <armnn/NetworkFwd.hpp>
13 #include <armnn/Exceptions.hpp>
14 #include <armnn/utility/Assert.hpp>
15 #include <armnn/utility/PolymorphicDowncast.hpp>
16 #include <armnn/utility/TransformIterator.hpp>
17 
18 #include <list>
19 #include <map>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <vector>
23 
24 namespace armnn
25 {
26 
27 class SubgraphView;
28 
29 class Graph
30 {
31 public:
32     template <typename LayerType>
PtrCast(Layer * const layer)33     static LayerType* PtrCast(Layer* const layer)
34     {
35         return PolymorphicDowncast<LayerType*>(layer);
36     }
37 
38     template <typename Func>
ForEachLayer(Func func) const39     void ForEachLayer(Func func) const
40     {
41         for (auto it = m_Layers.begin(); it != m_Layers.end(); )
42         {
43              auto next = std::next(it);
44              func(*it);
45              it = next;
46         }
47     }
48 
49     using LayerList = std::list<Layer*>;
50     using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
51     using IteratorDifference = Iterator::difference_type;
52 
53     using ConstIterator        = TransformIterator<decltype(&PtrCast<const Layer>),       Iterator>;
54     using ConstIteratorInputs  = TransformIterator<decltype(&PtrCast<const InputLayer>),  Iterator>;
55     using ConstIteratorOutputs = TransformIterator<decltype(&PtrCast<const OutputLayer>), Iterator>;
56 
57     /// Wrapper class returned by Graph::GetInputLayers()
58     struct InputLayersAccessor
59     {
InputLayersAccessorarmnn::Graph::InputLayersAccessor60         explicit InputLayersAccessor(const Graph& graph) : m_Graph(graph) {}
61 
beginarmnn::Graph::InputLayersAccessor62         ConstIteratorInputs begin() const
63         {
64             return { m_Graph.m_Layers.begin(), &(PtrCast<const InputLayer>) };
65         }
66 
endarmnn::Graph::InputLayersAccessor67         ConstIteratorInputs end() const
68         {
69             return { std::next(m_Graph.m_Layers.begin(), static_cast<IteratorDifference>(m_Graph.GetNumInputs())),
70                      &(PtrCast<const InputLayer>) };
71         }
72 
73         const Graph& m_Graph;
74     };
75 
76     /// Wrapper class returned by Graph::GetOutputLayers()
77     struct OutputLayersAccessor
78     {
OutputLayersAccessorarmnn::Graph::OutputLayersAccessor79         explicit OutputLayersAccessor(const Graph& graph) : m_Graph(graph) {}
80 
beginarmnn::Graph::OutputLayersAccessor81         ConstIteratorOutputs begin() const
82         {
83             return { std::prev(m_Graph.m_Layers.end(), static_cast<IteratorDifference>(m_Graph.GetNumOutputs())),
84                      &(PtrCast<const OutputLayer>) };
85         }
86 
endarmnn::Graph::OutputLayersAccessor87         ConstIteratorOutputs end() const
88         {
89             return { m_Graph.m_Layers.end(), &(PtrCast<const OutputLayer>) };
90         }
91 
92         const Graph& m_Graph;
93     };
94 
Graph(bool shapeInferenceMethod=false)95     Graph(bool shapeInferenceMethod = false)
96         : m_LayersInOrder(true)
97         , m_ShapeInferenceMethod(shapeInferenceMethod ? ShapeInferenceMethod::InferAndValidate :
98                                                         ShapeInferenceMethod::ValidateOnly)
99         {}
100 
101     Graph(const Graph& other);
102 
103     Graph& operator=(const Graph& other) = delete;
104 
Graph(Graph && other)105     Graph(Graph&& other)
106     {
107         *this = std::move(other);
108     }
109 
operator =(Graph && other)110     Graph& operator=(Graph&& other)
111     {
112         m_InputIds      = std::move(other.m_InputIds);
113         m_OutputIds     = std::move(other.m_OutputIds);
114         m_LayersInOrder = std::move(other.m_LayersInOrder);
115         m_Views         = std::move(other.m_Views);
116 
117         other.ForEachLayer([this](Layer* otherLayer)
118         {
119             otherLayer->Reparent(*this, m_Layers.end());
120         });
121 
122         ARMNN_ASSERT(other.m_PosInGraphMap.empty());
123         ARMNN_ASSERT(other.m_Layers.empty());
124 
125         return *this;
126     }
127 
~Graph()128     ~Graph()
129     {
130         ForEachLayer([](Layer* layer)
131         {
132             delete layer;
133         });
134     }
135 
136     Status Print() const;
137 
138     Status SerializeToDot(std::ostream& stream);
139 
140     /// Adds a new layer, of type LayerType, to the graph constructed with the arguments passed.
141     template <typename LayerT, typename... Args>
142     LayerT* AddLayer(Args&&... args);
143 
144     /// Inserts a new layer between the output slot currently connected to insertBefore
145     /// and insertBefore itself.
146     template <typename LayerT, typename... Args>
147     LayerT* InsertNewLayer(InputSlot& insertBefore, Args&&... args);
148 
149     /// Inserts a new layer between insertAfter and the input slot(s) currently connected to it
150     template <typename LayerT, typename... Args>
151     LayerT* InsertNewLayer(OutputSlot& insertAfter, Args&&... args);
152 
153     /// Deletes the layer at the specified position.
154     void EraseLayer(Iterator pos);
155 
156     /// Deletes the layer. Sets @a layer to nullptr on return.
157     /// Templated to support pointers to any layer type.
158     template <typename LayerT>
159     void EraseLayer(LayerT*& layer);
160 
161     /// Returns iterator pointing to the beginning of the list. Lowercase for range-based for loops.
begin()162     Iterator begin() { return m_Layers.begin(); }
163     /// Returns iterator pointing to the end of the list. Lowercase for range-based for loops.
end()164     Iterator end() { return m_Layers.end(); }
165 
166     /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
begin() const167     ConstIterator begin() const { return {m_Layers.begin(), &(PtrCast<const Layer>)}; }
168     /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
end() const169     ConstIterator end() const { return {m_Layers.end(), &(PtrCast<const Layer>)}; }
170 
171     /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
cbegin() const172     ConstIterator cbegin() const { return begin(); }
173     /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
cend() const174     ConstIterator cend() const { return end(); }
175 
176     /// Sorts layers in topological order and return this.
TopologicalSort()177     Graph& TopologicalSort() { const_cast<const Graph*>(this)->TopologicalSort(); return *this; }
178     const Graph& TopologicalSort() const;
179 
GetNumInputs() const180     size_t GetNumInputs() const { return m_InputIds.size(); }
GetNumOutputs() const181     size_t GetNumOutputs() const { return m_OutputIds.size(); }
182 
183     /// Returns a wrapper object with begin(), end() methods to iterate over the input layers
184     /// in a range-based for loop.
GetInputLayers() const185     InputLayersAccessor GetInputLayers() const { return InputLayersAccessor(*this); }
186 
187     /// Returns a wrapper object with begin(), end() methods to iterate over the output layers
188     /// in a range-based for loop.
GetOutputLayers() const189     OutputLayersAccessor GetOutputLayers() const { return OutputLayersAccessor(*this); }
190 
GetNumLayers() const191     size_t GetNumLayers() const { return m_Layers.size(); }
192 
193     /// Allocates memory for all tensors under output tensor handers of each layer.
194     Status AllocateDynamicBuffers();
195 
196     /// Modifies the graph in-place, removing edges connecting layers using different compute devices,
197     /// and relinking them via an intermediary copy layers.
198     void AddCompatibilityLayers(std::map<BackendId, std::unique_ptr<class IBackendInternal>>& backends,
199                                 TensorHandleFactoryRegistry& registry);
200 
201     /// Substitutes the given sub-graph with either a new layer or a new sub-graph.
202     /// In either case, the given layer or all the layers in the given sub-graph must belong to this graph.
203     void SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substituteLayer);
204     void SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph);
205 
206     void InferTensorInfos();
207 
AttachObservable(IGraphObservable * const observable,GraphEvent notifyOnEvent)208     void AttachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
209         m_Views[notifyOnEvent].emplace_back(observable);
210     }
211 
DetachObservable(IGraphObservable * const observable,GraphEvent notifyOnEvent)212     void DetachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
213         m_Views[notifyOnEvent].remove(observable);
214     }
215 
216     /// Gets the position of a layer in the graph.
217     Iterator GetPosInGraph(Layer& layer);
218 
219 private:
220     template <typename LayerT>
221     class LayerInGraphBase;
222 
223     template <typename LayerT>
224     class LayerInGraph;
225 
ForwardToEndOfInputs(Iterator it) const226     Iterator ForwardToEndOfInputs(Iterator it) const
227     {
228         while ((it != m_Layers.end()) && ((*it)->GetType() == LayerType::Input))
229         {
230             ++it;
231         }
232         return it;
233     }
234 
RewindToBeginOfOutputs(Iterator it) const235     Iterator RewindToBeginOfOutputs(Iterator it) const
236     {
237         while ((it != m_Layers.begin()) && ((*std::prev(it))->GetType() == LayerType::Output))
238         {
239             --it;
240         }
241         return it;
242     }
243 
NotifyObservables(GraphEvent event,Layer * graphState)244     void NotifyObservables(GraphEvent event, Layer* graphState)
245     {
246         // Iterate over all observables observing this event
247         for (auto& observable : m_Views[event])
248         {
249             observable->Update(graphState);
250         }
251     }
252 
253     std::unordered_set<LayerBindingId> m_InputIds;
254     std::unordered_set<LayerBindingId> m_OutputIds;
255     std::unordered_map<const Layer*, Iterator> m_PosInGraphMap;
256 
257     void ReplaceSubgraphConnections(const SubgraphView& subgraph, IConnectableLayer* substituteLayer);
258     void ReplaceSubgraphConnections(const SubgraphView& subgraph, const SubgraphView& substituteSubgraph);
259     void EraseSubgraphLayers(SubgraphView &subgraph);
260 
261     /// Mutable to allow sorting on const object.
262     mutable LayerList m_Layers;
263     mutable bool m_LayersInOrder;
264 
265     std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;
266     ShapeInferenceMethod m_ShapeInferenceMethod;
267 };
268 
269 /// Common base class for layers in the graph.
270 template <typename LayerT>
271 class Graph::LayerInGraphBase : public LayerT
272 {
273 protected:
274     template <typename... Args>
LayerInGraphBase(Graph & graph,Iterator insertBefore,Args &&...args)275     LayerInGraphBase(Graph& graph, Iterator insertBefore, Args&&... args)
276         : LayerT(std::forward<Args>(args)...), m_Graph(&graph)
277     {
278         Insert(*m_Graph, insertBefore);
279     }
~LayerInGraphBase()280     ~LayerInGraphBase()
281     {
282         Remove(*m_Graph);
283     }
284 
Reparent(Graph & destGraph,Iterator insertBefore)285     void Reparent(Graph& destGraph, Iterator insertBefore) override
286     {
287         Insert(destGraph, insertBefore);
288         Remove(*m_Graph);
289 
290         m_Graph = &destGraph;
291     }
292 
293 private:
Insert(Graph & graph,Iterator insertBefore)294     void Insert(Graph& graph, Iterator insertBefore)
295     {
296         graph.m_PosInGraphMap.emplace(this, graph.m_Layers.emplace(insertBefore, this));
297     }
298 
Remove(Graph & graph)299     void Remove(Graph& graph)
300     {
301         auto layerIt = graph.GetPosInGraph(*this);
302         graph.m_Layers.erase(layerIt);
303 
304         const size_t numErased = graph.m_PosInGraphMap.erase(this);
305         IgnoreUnused(numErased);
306         ARMNN_ASSERT(numErased == 1);
307     }
308 
309 protected:
310     Graph* m_Graph;
311 };
312 
313 /// Input/Output layers specialize this template.
314 template <typename LayerT>
315 class Graph::LayerInGraph final : public LayerInGraphBase<LayerT>
316 {
317 public:
318     template <typename... Args>
LayerInGraph(Graph & graph,Args &&...args)319     LayerInGraph(Graph& graph, Args&&... args)
320         : LayerInGraphBase<LayerT>(graph,
321                                    // Insert at the back of the intermediate layers (before outputs).
322                                    std::prev(graph.end(), IteratorDifference(graph.GetNumOutputs())),
323                                    std::forward<Args>(args)...)
324     {
325     }
326     template <typename... Args>
LayerInGraph(Graph & graph,Iterator insertBefore,Args &&...args)327     LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args)
328         : LayerInGraphBase<LayerT>(graph,
329                                    // Make sure it's inserted after all inputs and before all outputs.
330                                    graph.ForwardToEndOfInputs(graph.RewindToBeginOfOutputs(insertBefore)),
331                                    std::forward<Args>(args)...)
332     {
333     }
334 };
335 
336 /// Inputs add/remove their binding id to m_InputIds in the graph.
337 template <>
338 class Graph::LayerInGraph<InputLayer> final : public LayerInGraphBase<InputLayer>
339 {
340 public:
341     template <typename... Args>
LayerInGraph(Graph & graph,Args &&...args)342     LayerInGraph(Graph& graph, Args&&... args)
343         : LayerInGraphBase<InputLayer>(graph,
344                                        // Always add to the back of the inputs.
345                                        std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())),
346                                        std::forward<Args>(args)...)
347     {
348         const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second;
349         if (!isNewId)
350         {
351             throw InvalidArgumentException("A layer already exists with the specified id");
352         }
353     }
354     template <typename... Args>
LayerInGraph(Graph & graph,Iterator,Args &&...args)355     LayerInGraph(Graph& graph, Iterator, Args&&... args)
356         // Ignore Iterator argument. Always add to the back of the inputs.
357         : LayerInGraph(graph, std::forward<Args>(args)...)
358     {
359     }
~LayerInGraph()360     ~LayerInGraph() override
361     {
362         const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId());
363         IgnoreUnused(numErased);
364         ARMNN_ASSERT(numErased == 1);
365     }
366 };
367 
368 /// Outputs add/remove their binding id to m_OutputIds in the graph.
369 template <>
370 class Graph::LayerInGraph<OutputLayer> final : public LayerInGraphBase<OutputLayer>
371 {
372 public:
373     template <typename... Args>
LayerInGraph(Graph & graph,Args &&...args)374     LayerInGraph(Graph& graph, Args&&... args)
375         : LayerInGraphBase<OutputLayer>(graph,
376                                         // Always add to the back of the outputs.
377                                         graph.end(),
378                                         std::forward<Args>(args)...)
379     {
380         const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second;
381         if (!isNewId)
382         {
383             throw InvalidArgumentException("A layer already exists with the specified id");
384         }
385     }
~LayerInGraph()386     ~LayerInGraph() override
387     {
388         const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId());
389         IgnoreUnused(numErased);
390         ARMNN_ASSERT(numErased == 1);
391     }
392 };
393 
GetPosInGraph(Layer & layer)394 inline Graph::Iterator Graph::GetPosInGraph(Layer& layer)
395 {
396     auto it = m_PosInGraphMap.find(&layer);
397     ARMNN_ASSERT(it != m_PosInGraphMap.end());
398     return it->second;
399 }
400 
401 template <typename LayerT, typename... Args>
AddLayer(Args &&...args)402 inline LayerT* Graph::AddLayer(Args&&... args)
403 {
404     m_LayersInOrder = m_LayersInOrder &&
405         ((LayerEnumOf<LayerT>() == LayerType::Input) || (LayerEnumOf<LayerT>() == LayerType::Output));
406     LayerT* const layer = new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...);
407 
408     layer->SetShapeInferenceMethod(m_ShapeInferenceMethod);
409 
410     NotifyObservables(GraphEvent::LayerAdded, layer);
411 
412     return layer;
413 }
414 
415 template <typename LayerT, typename... Args>
InsertNewLayer(InputSlot & insertBefore,Args &&...args)416 inline LayerT* Graph::InsertNewLayer(InputSlot& insertBefore, Args&&... args)
417 {
418     // Insert after the parent if any, or before the child otherwise, so the topological order is kept.
419     OutputSlot* parentOut = insertBefore.GetConnectedOutputSlot();
420     const Iterator pos = (parentOut != nullptr)
421                          ? std::next(GetPosInGraph(parentOut->GetOwningLayer()))
422                          : GetPosInGraph(insertBefore.GetOwningLayer());
423     LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...);
424     insertBefore.Insert(*layer);
425 
426     NotifyObservables(GraphEvent::LayerAdded, layer);
427 
428     return layer;
429 }
430 
431 template <typename LayerT, typename... Args>
InsertNewLayer(OutputSlot & insertAfter,Args &&...args)432 inline LayerT* Graph::InsertNewLayer(OutputSlot& insertAfter, Args&&... args)
433 {
434     Layer& owningLayer = insertAfter.GetOwningLayer();
435 
436     const Iterator pos = std::next(GetPosInGraph(owningLayer));
437     LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...);
438 
439     ARMNN_ASSERT(layer->GetNumInputSlots() == 1);
440 
441     insertAfter.MoveAllConnections(layer->GetOutputSlot());
442     insertAfter.Connect(layer->GetInputSlot(0));
443 
444     NotifyObservables(GraphEvent::LayerAdded, layer);
445 
446     return layer;
447 }
448 
EraseLayer(Iterator pos)449 inline void Graph::EraseLayer(Iterator pos)
450 {
451     NotifyObservables(GraphEvent::LayerErased, *pos);
452 
453     delete *pos;
454 }
455 
456 template <typename LayerT>
EraseLayer(LayerT * & layer)457 inline void Graph::EraseLayer(LayerT*& layer)
458 {
459     ARMNN_ASSERT(layer != nullptr);
460     EraseLayer(GetPosInGraph(*layer));
461     layer = nullptr;
462 }
463 
464 } // namespace armnn
465