1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Optimization.hpp" 8 9 #include <armnn/utility/IgnoreUnused.hpp> 10 #include <armnn/utility/PolymorphicDowncast.hpp> 11 12 namespace armnn 13 { 14 namespace optimizations 15 { 16 17 template <typename Comparable> 18 class SquashEqualSiblingsImpl 19 { 20 public: 21 /// Run for every connection between a base Layer (any) and a child ComparableLayer. 22 /// For all siblings of the child layer that compare equal to it, bypasses and removes 23 /// them. I.e., moves the connections in the outputs of the siblings to the outputs of 24 /// the child layer, so the siblings are left unconnected (and later removed). Run(Graph & graph,InputSlot & connection) const25 void Run(Graph& graph, InputSlot& connection) const 26 { 27 IgnoreUnused(graph); 28 auto& child = connection.GetOwningLayer(); 29 30 if (!child.IsOutputUnconnected()) 31 { 32 OutputSlot& baseOutput = *connection.GetConnectedOutputSlot(); 33 34 if (baseOutput.GetNumConnections() > 1) 35 { 36 auto& comparableChild = *PolymorphicDowncast<Comparable*>(&child); 37 38 Layer* lowestPriorityChild = &child; 39 for (auto&& it : baseOutput.GetConnections()) 40 { 41 Layer* sibling = &it->GetOwningLayer(); 42 if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling)) 43 { 44 if (sibling->GetPriority() < lowestPriorityChild->GetPriority()) 45 { 46 std::swap(sibling, lowestPriorityChild); 47 } 48 // Bypasses sibling. It will be removed as it's left unconnected. 49 auto siblingOut = sibling->BeginOutputSlots(); 50 for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots(); 51 lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut) 52 { 53 siblingOut->MoveAllConnections(*lowestPriorityChildOut); 54 ++siblingOut; 55 } 56 } 57 } 58 } 59 } 60 } 61 62 protected: 63 SquashEqualSiblingsImpl() = default; 64 ~SquashEqualSiblingsImpl() = default; 65 }; 66 67 using SquashEqualPermuteSiblings = OptimizeForConnection<Layer, PermuteLayer, SquashEqualSiblingsImpl<PermuteLayer>>; 68 using SquashEqualTransposeSiblings = OptimizeForConnection<Layer, TransposeLayer, 69 SquashEqualSiblingsImpl<TransposeLayer>>; 70 using SquashEqualReshapeSiblings = OptimizeForConnection<Layer, ReshapeLayer, SquashEqualSiblingsImpl<ReshapeLayer>>; 71 72 } // namespace optimizations 73 } // namespace armnn 74