• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 #include <armnn/utility/IgnoreUnused.hpp>
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 template <typename Comparable>
18 class SquashEqualSiblingsImpl
19 {
20 public:
21     /// Run for every connection between a base Layer (any) and a child ComparableLayer.
22     /// For all siblings of the child layer that compare equal to it, bypasses and removes
23     /// them. I.e., moves the connections in the outputs of the siblings to the outputs of
24     /// the child layer, so the siblings are left unconnected (and later removed).
Run(Graph & graph,InputSlot & connection) const25     void Run(Graph& graph, InputSlot& connection) const
26     {
27         IgnoreUnused(graph);
28         auto& child = connection.GetOwningLayer();
29 
30         if (!child.IsOutputUnconnected())
31         {
32             OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
33 
34             if (baseOutput.GetNumConnections() > 1)
35             {
36                 auto& comparableChild = *PolymorphicDowncast<Comparable*>(&child);
37 
38                 Layer* lowestPriorityChild = &child;
39                 for (auto&& it : baseOutput.GetConnections())
40                 {
41                     Layer* sibling = &it->GetOwningLayer();
42                     if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling))
43                     {
44                         if (sibling->GetPriority() < lowestPriorityChild->GetPriority())
45                         {
46                             std::swap(sibling, lowestPriorityChild);
47                         }
48                         // Bypasses sibling. It will be removed as it's left unconnected.
49                         auto siblingOut = sibling->BeginOutputSlots();
50                         for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots();
51                              lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut)
52                         {
53                             siblingOut->MoveAllConnections(*lowestPriorityChildOut);
54                             ++siblingOut;
55                         }
56                     }
57                 }
58             }
59         }
60     }
61 
62 protected:
63     SquashEqualSiblingsImpl() = default;
64     ~SquashEqualSiblingsImpl() = default;
65 };
66 
67 using SquashEqualPermuteSiblings = OptimizeForConnection<Layer, PermuteLayer, SquashEqualSiblingsImpl<PermuteLayer>>;
68 using SquashEqualTransposeSiblings = OptimizeForConnection<Layer, TransposeLayer,
69     SquashEqualSiblingsImpl<TransposeLayer>>;
70 using SquashEqualReshapeSiblings = OptimizeForConnection<Layer, ReshapeLayer, SquashEqualSiblingsImpl<ReshapeLayer>>;
71 
72 } // namespace optimizations
73 } // namespace armnn
74