• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020,2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 #include <armnn/utility/PolymorphicDowncast.hpp>
10 #include <armnnUtils/Transpose.hpp>
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 class MoveTransposeUpImpl
17 {
18 public:
19     /// Run for every connection between a base Layer (any) and a child TransposeLayer. If the type
20     /// of the base layer allows it, it moves the permutation to the inputs of the base layer.
21     /// I.e., adds equivalent permutations before the inputs of the base layer and moves the
22     /// connections in the output of the child transpose layer to the output of the base layer.
Run(Graph & graph,InputSlot & connection) const23     void Run(Graph& graph, InputSlot& connection) const
24     {
25         OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
26 
27         if (baseOutput.GetNumConnections() == 1U)
28         {
29             Layer& base = baseOutput.GetOwningLayer();
30 
31             if (CanMoveTransposeToInputs(base))
32             {
33                 auto transpose = PolymorphicDowncast<TransposeLayer*>(&connection.GetOwningLayer());
34                 const PermutationVector& perm = transpose->GetPermutation();
35 
36                 // Inserts an equivalent transpose before every input of the base layer.
37                 for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput)
38                 {
39                     // Inserts a new transpose layer.
40                     const std::string name = std::string("moved_up-") + transpose->GetName();
41                     TransposeLayer& permLayer = *graph.InsertNewLayer<TransposeLayer>(*baseInput, perm, name.c_str());
42 
43                     // Sets output tensor info for the new layer.
44                     OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot();
45                     const TensorInfo permOutInfo = armnnUtils::TransposeTensorShape(parentOutput.GetTensorInfo(), perm);
46                     permLayer.GetOutputHandler().SetTensorInfo(permOutInfo);
47                 }
48 
49                 // Bypasses transpose. It will be removed as it's left unconnected.
50                 transpose->GetOutputSlot().MoveAllConnections(base.GetOutputSlot());
51             }
52         }
53     }
54 
55 protected:
56     MoveTransposeUpImpl() = default;
57     ~MoveTransposeUpImpl() = default;
58 
59 private:
CanMoveTransposeToInputs(const Layer & base)60     static bool CanMoveTransposeToInputs(const Layer& base)
61     {
62         switch (base.GetType())
63         {
64             case LayerType::Activation:
65             case LayerType::Addition:
66             case LayerType::FakeQuantization:
67             case LayerType::Floor:
68             case LayerType::MemCopy:
69             case LayerType::Multiplication:
70                 return true;
71             case LayerType::ElementwiseBinary:
72             {
73                 auto descriptor = PolymorphicDowncast<const ElementwiseBinaryDescriptor*>(&base.GetParameters());
74                 return (descriptor->m_Operation == BinaryOperation::Add ||
75                         descriptor->m_Operation == BinaryOperation::Mul);
76             }
77             default:
78                 return false;
79         }
80     }
81 };
82 
83 using MoveTransposeUp = OptimizeForConnection<Layer, TransposeLayer, MoveTransposeUpImpl>;
84 
85 } // namespace optimizations
86 } // namespace armnn
87