1 // 2 // Copyright © 2020,2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Optimization.hpp" 8 9 #include <armnn/utility/PolymorphicDowncast.hpp> 10 #include <armnnUtils/Transpose.hpp> 11 12 namespace armnn 13 { 14 namespace optimizations 15 { 16 class MoveTransposeUpImpl 17 { 18 public: 19 /// Run for every connection between a base Layer (any) and a child TransposeLayer. If the type 20 /// of the base layer allows it, it moves the permutation to the inputs of the base layer. 21 /// I.e., adds equivalent permutations before the inputs of the base layer and moves the 22 /// connections in the output of the child transpose layer to the output of the base layer. Run(Graph & graph,InputSlot & connection) const23 void Run(Graph& graph, InputSlot& connection) const 24 { 25 OutputSlot& baseOutput = *connection.GetConnectedOutputSlot(); 26 27 if (baseOutput.GetNumConnections() == 1U) 28 { 29 Layer& base = baseOutput.GetOwningLayer(); 30 31 if (CanMoveTransposeToInputs(base)) 32 { 33 auto transpose = PolymorphicDowncast<TransposeLayer*>(&connection.GetOwningLayer()); 34 const PermutationVector& perm = transpose->GetPermutation(); 35 36 // Inserts an equivalent transpose before every input of the base layer. 37 for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput) 38 { 39 // Inserts a new transpose layer. 40 const std::string name = std::string("moved_up-") + transpose->GetName(); 41 TransposeLayer& permLayer = *graph.InsertNewLayer<TransposeLayer>(*baseInput, perm, name.c_str()); 42 43 // Sets output tensor info for the new layer. 44 OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot(); 45 const TensorInfo permOutInfo = armnnUtils::TransposeTensorShape(parentOutput.GetTensorInfo(), perm); 46 permLayer.GetOutputHandler().SetTensorInfo(permOutInfo); 47 } 48 49 // Bypasses transpose. It will be removed as it's left unconnected. 50 transpose->GetOutputSlot().MoveAllConnections(base.GetOutputSlot()); 51 } 52 } 53 } 54 55 protected: 56 MoveTransposeUpImpl() = default; 57 ~MoveTransposeUpImpl() = default; 58 59 private: CanMoveTransposeToInputs(const Layer & base)60 static bool CanMoveTransposeToInputs(const Layer& base) 61 { 62 switch (base.GetType()) 63 { 64 case LayerType::Activation: 65 case LayerType::Addition: 66 case LayerType::FakeQuantization: 67 case LayerType::Floor: 68 case LayerType::MemCopy: 69 case LayerType::Multiplication: 70 return true; 71 case LayerType::ElementwiseBinary: 72 { 73 auto descriptor = PolymorphicDowncast<const ElementwiseBinaryDescriptor*>(&base.GetParameters()); 74 return (descriptor->m_Operation == BinaryOperation::Add || 75 descriptor->m_Operation == BinaryOperation::Mul); 76 } 77 default: 78 return false; 79 } 80 } 81 }; 82 83 using MoveTransposeUp = OptimizeForConnection<Layer, TransposeLayer, MoveTransposeUpImpl>; 84 85 } // namespace optimizations 86 } // namespace armnn 87